diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 58d6a1668ab2..56db703aab28 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,25 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path -import warnings -from typing import Any, Dict, List, Optional +import contextlib +import copy +import logging +import os +import threading +from typing import Any, Dict, Iterable import pytorch_lightning as pl import torch from pytorch_lightning import Callback -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT - -from nemo.utils import logging - -try: - import amp_C - - apex_available = True -except Exception: - apex_available = False class EMA(Callback): @@ -42,88 +34,83 @@ class EMA(Callback): Args: decay: The exponential decay used when calculating the moving average. Has to be between 0-1. - apply_ema_every_n_steps: Apply EMA every n global steps. - start_step: Start applying EMA from ``start_step`` global step onwards. - evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights. - Note this means that when saving the model, the validation metrics are calculated with the EMA weights. - save_ema_weights_in_callback_state: Enable saving ema weights in callback state. - This is not required when using NeMo as the experiment manager handles saving weights. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. """ def __init__( - self, - decay: float, - apply_ema_every_n_steps: int = 1, - start_step: int = 0, - save_ema_weights_in_callback_state: bool = False, - evaluate_ema_weights_instead: bool = False, + self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False, ): - if not apex_available: - rank_zero_warn( - "EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation." - ) if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") - self._ema_model_weights: Optional[List[torch.Tensor]] = None - self._overflow_buf: Optional[torch.Tensor] = None - self._cur_step: Optional[int] = None - self._weights_buffer: Optional[List[torch.Tensor]] = None - self.apply_ema_every_n_steps = apply_ema_every_n_steps - self.start_step = start_step - self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state - self.evaluate_ema_weights_instead = evaluate_ema_weights_instead self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - logging.info('Creating EMA weights copy.') - if self._ema_model_weights is None: - self._ema_model_weights = [p.detach().clone() for p in pl_module.state_dict().values()] - # ensure that all the weights are on the correct device - self._ema_model_weights = [p.to(pl_module.device) for p in self._ema_model_weights] - self._overflow_buf = torch.IntTensor([0]).to(pl_module.device) - - def ema(self, pl_module: "pl.LightningModule") -> None: - if apex_available and pl_module.device.type == "cuda": - return self.apply_multi_tensor_ema(pl_module) - return self.apply_ema(pl_module) - - def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None: - model_weights = list(pl_module.state_dict().values()) - amp_C.multi_tensor_axpby( - 65536, # todo (sean): chunk size, should we expose? - self._overflow_buf, - [self._ema_model_weights, model_weights, self._ema_model_weights], - self.decay, - 1 - self.decay, - -1, - ) - - def apply_ema(self, pl_module: "pl.LightningModule") -> None: - for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._ema_model_weights): - diff = ema_weight.data - orig_weight.data - diff.mul_(1.0 - self.decay) - ema_weight.sub_(diff) - - def should_apply_ema(self, step: int) -> bool: - return step != self._cur_step and step >= self.start_step and step % self.apply_ema_every_n_steps == 0 - - def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: - if self.should_apply_ema(trainer.global_step): - self._cur_step = trainer.global_step - self.ema(pl_module) + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + device = pl_module.device if not self.cpu_offload else torch.device('cpu') + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] - def state_dict(self) -> Dict[str, Any]: - if self.save_ema_weights_in_callback_state: - return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights) - return dict(cur_step=self._cur_step) + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self._cur_step = state_dict['cur_step'] - # when loading using NeMo, ema weights will be loaded by the experiment manager separately. - if self._ema_model_weights is None: - self._ema_model_weights = state_dict.get('ema_weights') + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] @@ -142,43 +129,219 @@ def on_load_checkpoint( ema_path = trainer.ckpt_path.replace(ext, f'-EMA{ext}') if os.path.exists(ema_path): ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) - self._ema_model_weights = ema_state_dict['state_dict'].values() + + # this is wrong, basically when we save the EMA weights, optimizer_states actually contains the model parameters + # as we swapped the model parameters with the state dict parameters. + # we could enforce that if you trained with EMA and want to continue training + checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] del ema_state_dict logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") else: - warnings.warn( - "we were unable to find the associated EMA weights when re-loading, " - "training will start with new EMA weights.", - UserWarning, + raise MisconfigurationException( + "Unable to find the associated EMA weights when re-loading, " + f"training will start with new EMA weights. Expected them to be at: {ema_path}", ) - def replace_model_weights(self, pl_module: "pl.LightningModule") -> None: - self._weights_buffer = [p.detach().clone().to('cpu') for p in pl_module.state_dict().values()] - new_state_dict = {k: v for k, v in zip(pl_module.state_dict().keys(), self._ema_model_weights)} - pl_module.load_state_dict(new_state_dict) - def restore_original_weights(self, pl_module: "pl.LightningModule") -> None: - state_dict = pl_module.state_dict() - new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)} - pl_module.load_state_dict(new_state_dict) - del self._weights_buffer +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) - @property - def ema_initialized(self) -> bool: - return self._ema_model_weights is not None - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.replace_model_weights(pl_module) +def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.restore_original_weights(pl_module) + ema_update(ema_model_tuple, current_model_tuple, decay) - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.replace_model_weights(pl_module) - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.restore_original_weights(pl_module) +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + return (param for group in self.param_groups for param in group['params']) + + def step(self, closure=None, **kwargs): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) for param in self.all_parameters() + ) + + if self.device.type == 'cuda': + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == 'cpu': + self.thread = threading.Thread( + target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + def state_dict(self): + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters()) + state_dict = { + 'opt': self.optimizer.state_dict(), + 'ema': ema_params, + 'current_step': self.current_step, + 'decay': self.decay, + 'every_n_steps': self.every_n_steps, + 'device': self.device, + } + return state_dict + + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict['opt']) + self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema'])) + self.current_step = state_dict['current_step'] + self.decay = state_dict['decay'] + self.device = state_dict['device'] + self.every_n_steps = state_dict['every_n_steps'] + self.rebuild_ema_params = False + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 2e467ef594e8..2406d343e0f5 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -114,10 +114,10 @@ class StepTimingParams: @dataclass class EMAParams: enable: Optional[bool] = False - evaluate_ema_weights_instead: Optional[bool] = False decay: Optional[float] = 0.999 - apply_ema_every_n_steps: Optional[int] = 1 - start_step: Optional[int] = 0 + cpu_offload: Optional[bool] = False + validate_original_weights: Optional[bool] = False + every_n_steps: int = 1 @dataclass @@ -387,9 +387,9 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo if cfg.ema.enable: ema_callback = EMA( decay=cfg.ema.decay, - apply_ema_every_n_steps=cfg.ema.apply_ema_every_n_steps, - start_step=cfg.ema.start_step, - evaluate_ema_weights_instead=cfg.ema.evaluate_ema_weights_instead, + validate_original_weights=cfg.ema.validate_original_weights, + cpu_offload=cfg.ema.cpu_offload, + every_n_steps=cfg.ema.every_n_steps, ) trainer.callbacks.append(ema_callback) @@ -914,24 +914,27 @@ def _del_model_without_trainer(self, filepath: str) -> None: except: logging.info(f"Tried to remove checkpoint: {filepath} but failed.") - def _get_ema_callback(self, trainer) -> Optional[EMA]: + def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: ema_callback = None for callback in trainer.callbacks: if isinstance(callback, EMA): ema_callback = callback return ema_callback - def _save_checkpoint(self, trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) - ema_callback = self._get_ema_callback(trainer) + def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + ema_callback = self._ema_callback(trainer) if ema_callback is not None: + with ema_callback.save_original_optimizer_state(trainer): + super()._save_checkpoint(trainer, filepath) + # save EMA copy of the model as well. - ema_callback.replace_model_weights(trainer.lightning_module) - filepath = self._ema_format_filepath(filepath) - if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + with ema_callback.save_ema_model(trainer): + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) + else: super()._save_checkpoint(trainer, filepath) - ema_callback.restore_original_weights(trainer.lightning_module) def _ema_format_filepath(self, filepath: str) -> str: return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index ae4f40d52d51..16de01e0fbe4 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -13,9 +13,7 @@ # limitations under the License. import os.path -from copy import deepcopy from typing import Any, Dict, Union -from unittest import mock import pytest import pytorch_lightning as pl @@ -26,40 +24,53 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT from nemo.collections.common.callbacks import EMA +from nemo.collections.common.callbacks.ema import EMAOptimizer from nemo.core import ModelPT from nemo.utils.exp_manager import exp_manager from tests.collections.nlp.test_gpt_model import DEVICE_CAPABILITY -class OnesDataset(torch.utils.data.Dataset): - def __init__(self, dataset_len): - super().__init__() - self.__dataset_len = dataset_len +def extract_ema_weights(pl_module, trainer): + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + ema_callback.swap_model_weights(trainer) + weights = extract_weights(pl_module) + ema_callback.swap_model_weights(trainer) + return weights - def __getitem__(self, *args): - return torch.ones(2) + +def extract_weights(pl_module): + return [w.detach().clone() for w in pl_module.parameters()] + + +class RandomDataset(torch.utils.data.Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] def __len__(self): - return self.__dataset_len + return self.len class ExampleModel(ModelPT): def __init__(self, *args, **kwargs): cfg = OmegaConf.structured({}) super().__init__(cfg) - self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) + self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32) + self.bn = torch.nn.BatchNorm1d(32) def train_dataloader(self): - dataset = OnesDataset(16) + dataset = RandomDataset(32, 16) return torch.utils.data.DataLoader(dataset, batch_size=2) def val_dataloader(self): - dataset = OnesDataset(10) + dataset = RandomDataset(32, 16) return torch.utils.data.DataLoader(dataset, batch_size=2) def forward(self, batch): - output = self.l1(batch) - return torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device)) + return self.l1(self.bn(batch)).sum() def validation_step(self, batch, batch_idx): return self(batch) @@ -68,7 +79,7 @@ def training_step(self, batch, batch_idx): return self(batch) def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.1) + return torch.optim.SGD(self.parameters(), lr=1e-3) def list_available_models(self): pass @@ -89,11 +100,6 @@ def test_ema_value(self): with pytest.raises(MisconfigurationException, match="between 0 and 1"): EMA(decay=2) - @mock.patch('nemo.collections.common.callbacks.ema.apex_available', False) - def test_ema_apex_unavailable(self): - with pytest.warns(UserWarning, match="EMA has better performance when Apex is installed"): - EMA(decay=0.999) - @pytest.mark.unit @pytest.mark.run_only_on('GPU') def test_ema_saved_state(self, tmpdir, caplog): @@ -102,9 +108,8 @@ def test_ema_saved_state(self, tmpdir, caplog): class TerminateCallback(Callback): def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - self.saved_ema_weights = ema_callback._ema_model_weights - self.pl_module_weights = list(pl_module.state_dict().values()) + self.saved_ema_weights = extract_ema_weights(pl_module, trainer) + self.pl_module_weights = extract_weights(pl_module) raise SystemExit model = ExampleModel() @@ -124,7 +129,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -137,13 +142,16 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu class CheckStateCallback(Callback): def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - weights = list(pl_module.state_dict().values()) + weights = extract_weights(pl_module) for x, y in zip(weights, terminate_callback.pl_module_weights): assert torch.allclose(x.cpu(), y.cpu()) - for x, y in zip(ema_callback._ema_model_weights, terminate_callback.saved_ema_weights): + current_ema_weights = extract_ema_weights(pl_module, trainer) + for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights): assert torch.allclose(x.cpu(), y.cpu()) - assert ema_callback._cur_step == 8 + + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + assert optimizer.current_step == 8 trainer = Trainer( max_epochs=2, @@ -157,7 +165,7 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -181,7 +189,7 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -203,12 +211,14 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True, "validate_original_weights": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, ) - with pytest.warns(UserWarning, match="we were unable to find the associated EMA weights when re-loading"): + with pytest.raises( + MisconfigurationException, match="Unable to find the associated EMA weights when re-loading" + ): trainer.fit(model, ckpt_path=resume_path) @pytest.mark.unit @@ -221,70 +231,23 @@ def test_exp_manager_ema_weights(self, tmpdir): exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True, "validate_original_weights": True}, "explicit_log_dir": str(tmp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, ) assert any(isinstance(callback, EMA) for callback in trainer.callbacks) trainer.fit(model) + ema_weights = extract_ema_weights(model, trainer) assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt") ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt" assert os.path.exists(ema_path) duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path)) - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_callback._ema_model_weights): + for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_weights): assert torch.allclose(saved_weight.cpu(), ema_weight.cpu()) - @pytest.mark.unit - @pytest.mark.run_only_on('GPU') - def test_ema_save_in_callback(self, tmpdir): - """Test to ensure when `save_ema_weights_in_callback_state` is enabled, we save to the callback state.""" - temp_path = os.path.join(tmpdir, 'saved_state') - - model = ExampleModel() - - trainer = Trainer( - max_epochs=2, - limit_val_batches=1, - limit_train_batches=16, - logger=False, - val_check_interval=0.5, - enable_checkpointing=False, - accelerator='gpu', - devices=1, - callbacks=[EMA(decay=0.999, save_ema_weights_in_callback_state=True, evaluate_ema_weights_instead=True)], - ) - exp_manager( - trainer, - {"explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},}, - ) - trainer.fit(model=model) - - resume_path = os.path.join(temp_path, "checkpoints/epoch=0-step=8.ckpt") - callback = EMA(decay=0.999, save_ema_weights_in_callback_state=True) - - class AssertCallback(Callback): - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - assert callback._ema_model_weights is not None - - model = ExampleModel() - - trainer = Trainer( - max_epochs=2, - limit_val_batches=1, - limit_train_batches=16, - logger=False, - val_check_interval=0.5, - enable_checkpointing=False, - accelerator='gpu', - devices=1, - callbacks=[callback, AssertCallback()], - ) - trainer.fit(model, ckpt_path=resume_path) - class TestEMATrain: @pytest.mark.unit @@ -303,41 +266,32 @@ class TestEMATrain: ], ) @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) - @pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False]) - @pytest.mark.parametrize("apex_available_mock", [True, False]) + @pytest.mark.parametrize("validate_original_weights", [True, False]) @pytest.mark.run_only_on('GPU') def test_ema_run_cuda( - self, - test_data_dir, - precision, - accumulate_grad_batches, - evaluate_ema_weights_instead, - apex_available_mock, - tmpdir, + self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir, ): - with mock.patch('nemo.collections.common.callbacks.ema.apex_available', apex_available_mock): - self.run_training_test( - accumulate_grad_batches=accumulate_grad_batches, - evaluate_ema_weights_instead=evaluate_ema_weights_instead, - accelerator='gpu', - precision=precision, - tmpdir=tmpdir, - ) + self.run_training_test( + accumulate_grad_batches=accumulate_grad_batches, + validate_original_weights=validate_original_weights, + accelerator='gpu', + precision=precision, + tmpdir=tmpdir, + ) @pytest.mark.unit @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) - @pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False]) - @pytest.mark.run_only_on('GPU') - def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, evaluate_ema_weights_instead, tmpdir): + @pytest.mark.parametrize("validate_original_weights", [True, False]) + def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, validate_original_weights, tmpdir): self.run_training_test( accumulate_grad_batches=accumulate_grad_batches, - evaluate_ema_weights_instead=evaluate_ema_weights_instead, + validate_original_weights=validate_original_weights, accelerator='cpu', precision=32, tmpdir=tmpdir, ) - def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instead, accelerator, precision, tmpdir): + def run_training_test(self, accumulate_grad_batches, validate_original_weights, accelerator, precision, tmpdir): pl.seed_everything(123) model = ExampleModel() trainer = Trainer( @@ -356,33 +310,24 @@ def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instea exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": evaluate_ema_weights_instead}, + "ema": {"enable": True, "validate_original_weights": validate_original_weights, "decay": 0.999}, "explicit_log_dir": str(tmpdir), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, ) # add the check callback after the exp manager has made modifications. trainer.callbacks.append(EMAAssertCallback()) + trainer.callbacks.insert(0, EMAValidationAssertCallback()) trainer.fit(model=model, val_dataloaders=model.train_dataloader()) class EMAAssertCallback(Callback): - def __init__(self): - self._before_calc_ema_weights = None - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - model_weights = list(pl_module.state_dict().values()) - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - for x, y in zip(model_weights, ema_callback._ema_model_weights): + model_weights = extract_weights(pl_module) + self.ema_weights = extract_ema_weights(pl_module, trainer) + for x, y in zip(model_weights, self.ema_weights): assert torch.allclose(x, y) - def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int - ) -> None: - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - # saved for manual calculation of ema to compare against implementation - self._before_calc_ema_weights = deepcopy(ema_callback._ema_model_weights) - def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: @@ -392,28 +337,36 @@ def on_train_batch_end( ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] decay = ema_callback.decay expected_ema_weights = [] - for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._before_calc_ema_weights): - expected_ema_weight = orig_weight * (1 - decay) + ema_weight * decay - expected_ema_weights.append(expected_ema_weight) - for actual_ema_weight, expected_ema_weight in zip(ema_callback._ema_model_weights, expected_ema_weights): + new_weights = extract_weights(pl_module) + + for ema_weight, new_weight in zip(self.ema_weights, new_weights): + expected_ema_weight = ema_weight * decay + expected_ema_weight += new_weight * (1 - decay) + expected_ema_weights.append(expected_ema_weight) + ema_weights = extract_ema_weights(pl_module, trainer) + for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights): assert torch.allclose(actual_ema_weight, expected_ema_weight) + self.ema_weights = expected_ema_weights + +class EMAValidationAssertCallback(Callback): def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - if ema_callback.evaluate_ema_weights_instead: - # todo (sean): shouldn't use the weights buffer to check original weights - self._original_weights = list(x.detach().clone() for x in ema_callback._weights_buffer) - if ema_callback.ema_initialized: - for ema_weights, module_weights in zip( - ema_callback._ema_model_weights, pl_module.state_dict().values() - ): + self._original_weights = extract_weights(pl_module) + self._ema_weights = extract_ema_weights(pl_module, trainer) + # call original EMA function + super().on_validation_start(trainer, pl_module) + if not ema_callback.validate_original_weights: + if ema_callback._ema_initialized: + # check model weights are now EMA weights + for ema_weights, module_weights in zip(self._ema_weights, extract_weights(pl_module)): torch.allclose(ema_weights, module_weights) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - if ema_callback.evaluate_ema_weights_instead: - model_weights = list(pl_module.state_dict().values()) - if ema_callback.ema_initialized: + if not ema_callback.validate_original_weights: + model_weights = extract_weights(pl_module) + if ema_callback._ema_initialized: for orig_weights, module_weights in zip(self._original_weights, model_weights): - torch.allclose(orig_weights, module_weights.cpu()) + torch.allclose(orig_weights.cpu(), module_weights.cpu())