From 4fb1fc6d5824fa3472c051e35a7fe183a84e2eec Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 14 Oct 2022 17:43:13 +0100 Subject: [PATCH 01/19] Move to optimizer Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 330 +++++++++++++---------- nemo/utils/exp_manager.py | 10 +- 2 files changed, 183 insertions(+), 157 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 58d6a1668ab2..ca4669ccd1fb 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,25 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path -import warnings -from typing import Any, Dict, List, Optional +import contextlib +import copy +import threading import pytorch_lightning as pl import torch from pytorch_lightning import Callback -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT - -from nemo.utils import logging - -try: - import amp_C - - apex_available = True -except Exception: - apex_available = False class EMA(Callback): @@ -42,143 +31,188 @@ class EMA(Callback): Args: decay: The exponential decay used when calculating the moving average. Has to be between 0-1. - apply_ema_every_n_steps: Apply EMA every n global steps. - start_step: Start applying EMA from ``start_step`` global step onwards. - evaluate_ema_weights_instead: Validate the EMA weights instead of the original weights. - Note this means that when saving the model, the validation metrics are calculated with the EMA weights. - save_ema_weights_in_callback_state: Enable saving ema weights in callback state. - This is not required when using NeMo as the experiment manager handles saving weights. """ - def __init__( - self, - decay: float, - apply_ema_every_n_steps: int = 1, - start_step: int = 0, - save_ema_weights_in_callback_state: bool = False, - evaluate_ema_weights_instead: bool = False, - ): - if not apex_available: - rank_zero_warn( - "EMA has better performance when Apex is installed: https://github.com/NVIDIA/apex#installation." - ) + def __init__(self, decay: float): if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") - self._ema_model_weights: Optional[List[torch.Tensor]] = None - self._overflow_buf: Optional[torch.Tensor] = None - self._cur_step: Optional[int] = None - self._weights_buffer: Optional[List[torch.Tensor]] = None - self.apply_ema_every_n_steps = apply_ema_every_n_steps - self.start_step = start_step - self.save_ema_weights_in_callback_state = save_ema_weights_in_callback_state - self.evaluate_ema_weights_instead = evaluate_ema_weights_instead self.decay = decay def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - logging.info('Creating EMA weights copy.') - if self._ema_model_weights is None: - self._ema_model_weights = [p.detach().clone() for p in pl_module.state_dict().values()] - # ensure that all the weights are on the correct device - self._ema_model_weights = [p.to(pl_module.device) for p in self._ema_model_weights] - self._overflow_buf = torch.IntTensor([0]).to(pl_module.device) - - def ema(self, pl_module: "pl.LightningModule") -> None: - if apex_available and pl_module.device.type == "cuda": - return self.apply_multi_tensor_ema(pl_module) - return self.apply_ema(pl_module) - - def apply_multi_tensor_ema(self, pl_module: "pl.LightningModule") -> None: - model_weights = list(pl_module.state_dict().values()) - amp_C.multi_tensor_axpby( - 65536, # todo (sean): chunk size, should we expose? - self._overflow_buf, - [self._ema_model_weights, model_weights, self._ema_model_weights], - self.decay, - 1 - self.decay, - -1, - ) - - def apply_ema(self, pl_module: "pl.LightningModule") -> None: - for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._ema_model_weights): - diff = ema_weight.data - orig_weight.data - diff.mul_(1.0 - self.decay) - ema_weight.sub_(diff) - - def should_apply_ema(self, step: int) -> bool: - return step != self._cur_step and step >= self.start_step and step % self.apply_ema_every_n_steps == 0 - - def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: - if self.should_apply_ema(trainer.global_step): - self._cur_step = trainer.global_step - self.ema(pl_module) - - def state_dict(self) -> Dict[str, Any]: - if self.save_ema_weights_in_callback_state: - return dict(cur_step=self._cur_step, ema_weights=self._ema_model_weights) - return dict(cur_step=self._cur_step) - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self._cur_step = state_dict['cur_step'] - # when loading using NeMo, ema weights will be loaded by the experiment manager separately. - if self._ema_model_weights is None: - self._ema_model_weights = state_dict.get('ema_weights') - - def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] - ) -> None: - checkpoint_callback = trainer.checkpoint_callback - - if trainer.ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__: - ext = checkpoint_callback.FILE_EXTENSION - if trainer.ckpt_path.endswith(f'-EMA{ext}'): - logging.info( - "loading EMA based weights. " - "The callback will treat the loaded EMA weights as the main weights" - " and create a new EMA copy when training." - ) - return - ema_path = trainer.ckpt_path.replace(ext, f'-EMA{ext}') - if os.path.exists(ema_path): - ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) - self._ema_model_weights = ema_state_dict['state_dict'].values() - del ema_state_dict - logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") - else: - warnings.warn( - "we were unable to find the associated EMA weights when re-loading, " - "training will start with new EMA weights.", - UserWarning, - ) - - def replace_model_weights(self, pl_module: "pl.LightningModule") -> None: - self._weights_buffer = [p.detach().clone().to('cpu') for p in pl_module.state_dict().values()] - new_state_dict = {k: v for k, v in zip(pl_module.state_dict().keys(), self._ema_model_weights)} - pl_module.load_state_dict(new_state_dict) - - def restore_original_weights(self, pl_module: "pl.LightningModule") -> None: - state_dict = pl_module.state_dict() - new_state_dict = {k: v for k, v in zip(state_dict.keys(), self._weights_buffer)} - pl_module.load_state_dict(new_state_dict) - del self._weights_buffer - - @property - def ema_initialized(self) -> bool: - return self._ema_model_weights is not None - - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.replace_model_weights(pl_module) - - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.restore_original_weights(pl_module) - - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.replace_model_weights(pl_module) - - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.ema_initialized and self.evaluate_ema_weights_instead: - self.restore_original_weights(pl_module) + trainer.optimizers = [ + EMAOptimizer(optim, device=pl_module.device, decay=self.decay) for optim in trainer.optimizers + ] + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + def __init__( + self, optimizer: torch.optim.Optimizer, device: torch.device, decay: float = 0.9999, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + + def all_parameters(self): + return (param for group in self.param_groups for param in group['params']) + + def step(self, closure=None): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + loss = self.optimizer.step(closure) + + self.update() + return loss + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) for param in self.all_parameters() + ) + + if self.device.type == 'cuda': + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == 'cpu': + self.thread = threading.Thread( + target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,), + ) + self.thread.start() + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + def swap_tensors(tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + if enabled: + self.join() + + for param, ema_param in zip(self.all_parameters(), self.ema_params): + swap_tensors(param.data, ema_param) + try: + yield + finally: + if enabled: + for param, ema_param in zip(self.all_parameters(), self.ema_params): + swap_tensors(param.data, ema_param) + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + def state_dict(self): + self.join() + + state_dict = { + 'opt': self.optimizer.state_dict(), + 'ema': self.ema_params, + } + return state_dict + + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict['opt']) + self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema'])) + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 4e15943b5e2e..42872c24dcc9 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -100,10 +100,7 @@ class StepTimingParams: @dataclass class EMAParams: enable: Optional[bool] = False - evaluate_ema_weights_instead: Optional[bool] = False decay: Optional[float] = 0.999 - apply_ema_every_n_steps: Optional[int] = 1 - start_step: Optional[int] = 0 @dataclass @@ -357,12 +354,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo trainer.callbacks.insert(0, timing_callback) if cfg.ema.enable: - ema_callback = EMA( - decay=cfg.ema.decay, - apply_ema_every_n_steps=cfg.ema.apply_ema_every_n_steps, - start_step=cfg.ema.start_step, - evaluate_ema_weights_instead=cfg.ema.evaluate_ema_weights_instead, - ) + ema_callback = EMA(decay=cfg.ema.decay,) trainer.callbacks.append(ema_callback) if cfg.create_checkpoint_callback: From 979c51a066fb1eb96f57d3dfcd23e6efb294e9a8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 19 Oct 2022 12:10:53 +0100 Subject: [PATCH 02/19] Fix replacing weights Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 43 +++++++++++++++++------- nemo/utils/exp_manager.py | 15 +++------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index ca4669ccd1fb..fd65785053b1 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -43,6 +43,23 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") EMAOptimizer(optim, device=pl_module.device, decay=self.decay) for optim in trainer.optimizers ] + def swap_model_weights(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights() + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.swap_model_weights(trainer) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.swap_model_weights(trainer) + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.swap_model_weights(trainer) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.swap_model_weights(trainer) + @torch.no_grad() def ema_update(ema_model_tuple, current_model_tuple, decay): @@ -158,6 +175,18 @@ def update(self): ) self.thread.start() + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self): + self.join() + + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + @contextlib.contextmanager def swap_ema_weights(self, enabled: bool = True): r""" @@ -170,23 +199,13 @@ def swap_ema_weights(self, enabled: bool = True): enabled (bool): whether the swap should be performed """ - def swap_tensors(tensor1, tensor2): - tmp = torch.empty_like(tensor1) - tmp.copy_(tensor1) - tensor1.copy_(tensor2) - tensor2.copy_(tmp) - if enabled: - self.join() - - for param, ema_param in zip(self.all_parameters(), self.ema_params): - swap_tensors(param.data, ema_param) + self.switch_main_parameter_weights() try: yield finally: if enabled: - for param, ema_param in zip(self.all_parameters(), self.ema_params): - swap_tensors(param.data, ema_param) + self.switch_main_parameter_weights() def __getattr__(self, name): return getattr(self.optimizer, name) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 42872c24dcc9..46622d21bcec 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -897,19 +897,12 @@ def _get_ema_callback(self, trainer) -> Optional[EMA]: return ema_callback def _save_checkpoint(self, trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) ema_callback = self._get_ema_callback(trainer) if ema_callback is not None: - # save EMA copy of the model as well. - ema_callback.replace_model_weights(trainer.lightning_module) - filepath = self._ema_format_filepath(filepath) - if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - super()._save_checkpoint(trainer, filepath) - ema_callback.restore_original_weights(trainer.lightning_module) - - def _ema_format_filepath(self, filepath: str) -> str: - return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') + ema_callback.swap_model_weights(trainer) + super()._save_checkpoint(trainer, filepath) + if ema_callback is not None: + ema_callback.swap_model_weights(trainer) def configure_checkpointing( From f8a361e8927fcd58ce24611c62ac5b2b85b3f569 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 19 Oct 2022 16:23:50 +0100 Subject: [PATCH 03/19] Allow swapping of weights be optional Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index fd65785053b1..a80ce0d483c7 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -31,12 +31,16 @@ class EMA(Callback): Args: decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + apply_ema_every_n_steps: Apply EMA every N steps. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. """ - def __init__(self, decay: float): + def __init__(self, decay: float, apply_ema_every_n_steps: int = 1, validate_original_weights: bool = False): if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") self.decay = decay + self.apply_ema_every_n_steps = apply_ema_every_n_steps + self.validate_original_weights = validate_original_weights def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: trainer.optimizers = [ @@ -49,16 +53,20 @@ def swap_model_weights(self, trainer: "pl.Trainer"): optimizer.switch_main_parameter_weights() def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.swap_model_weights(trainer) + if not self.validate_original_weights: + self.swap_model_weights(trainer) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.swap_model_weights(trainer) + if not self.validate_original_weights: + self.swap_model_weights(trainer) def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.swap_model_weights(trainer) + if not self.validate_original_weights: + self.swap_model_weights(trainer) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.swap_model_weights(trainer) + if not self.validate_original_weights: + self.swap_model_weights(trainer) @torch.no_grad() From 0623ddc19c90cc30b9a83b602c4b51fac108f50b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 19 Oct 2022 17:10:12 +0100 Subject: [PATCH 04/19] Save 2 models Signed-off-by: SeanNaren --- nemo/utils/exp_manager.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 46622d21bcec..0b28ecfb5ce4 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -897,13 +897,20 @@ def _get_ema_callback(self, trainer) -> Optional[EMA]: return ema_callback def _save_checkpoint(self, trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) ema_callback = self._get_ema_callback(trainer) if ema_callback is not None: + # save EMA copy of the model as well. ema_callback.swap_model_weights(trainer) - super()._save_checkpoint(trainer, filepath) - if ema_callback is not None: + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) ema_callback.swap_model_weights(trainer) + def _ema_format_filepath(self, filepath: str) -> str: + return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') + def configure_checkpointing( trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', From dcca548b53cc16e5b2641723e287d5569ee16b59 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 19 Oct 2022 19:13:48 +0100 Subject: [PATCH 05/19] Use different hook Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index a80ce0d483c7..23de2fabb220 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -42,7 +42,7 @@ def __init__(self, decay: float, apply_ema_every_n_steps: int = 1, validate_orig self.apply_ema_every_n_steps = apply_ema_every_n_steps self.validate_original_weights = validate_original_weights - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: trainer.optimizers = [ EMAOptimizer(optim, device=pl_module.device, decay=self.decay) for optim in trainer.optimizers ] From 74636758d4d6d124c7d4860fced13e7f49e7b952 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 20 Oct 2022 10:41:29 +0100 Subject: [PATCH 06/19] Expose cpu device Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 11 +++++------ nemo/utils/exp_manager.py | 8 +++++++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 23de2fabb220..ccaeaefc08fa 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -31,21 +31,20 @@ class EMA(Callback): Args: decay: The exponential decay used when calculating the moving average. Has to be between 0-1. - apply_ema_every_n_steps: Apply EMA every N steps. validate_original_weights: Validate the original weights, as apposed to the EMA weights. + cpu_offload: Offload weights to CPU. """ - def __init__(self, decay: float, apply_ema_every_n_steps: int = 1, validate_original_weights: bool = False): + def __init__(self, decay: float, validate_original_weights: bool = False, cpu_offload: bool = False): if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") self.decay = decay - self.apply_ema_every_n_steps = apply_ema_every_n_steps self.validate_original_weights = validate_original_weights + self.cpu_offload = cpu_offload def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - trainer.optimizers = [ - EMAOptimizer(optim, device=pl_module.device, decay=self.decay) for optim in trainer.optimizers - ] + device = pl_module.device if not self.cpu_offload else torch.device('cpu') + trainer.optimizers = [EMAOptimizer(optim, device=device, decay=self.decay) for optim in trainer.optimizers] def swap_model_weights(self, trainer: "pl.Trainer"): for optimizer in trainer.optimizers: diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 0b28ecfb5ce4..96b142d61aad 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -101,6 +101,8 @@ class StepTimingParams: class EMAParams: enable: Optional[bool] = False decay: Optional[float] = 0.999 + cpu_offload: Optional[bool] = False + validate_original_weights: Optional[bool] = False @dataclass @@ -354,7 +356,11 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo trainer.callbacks.insert(0, timing_callback) if cfg.ema.enable: - ema_callback = EMA(decay=cfg.ema.decay,) + ema_callback = EMA( + decay=cfg.ema.decay, + validate_original_weights=cfg.ema.validate_original_weights, + cpu_offload=cfg.ema.cpu_offload, + ) trainer.callbacks.append(ema_callback) if cfg.create_checkpoint_callback: From def549e5f613a71428b75fa4bb2b19f410b85c4d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 20 Oct 2022 10:49:11 +0100 Subject: [PATCH 07/19] Add clause to see if this fixes issue with O2 optimizer Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index ccaeaefc08fa..1cdd92198dca 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -20,6 +20,8 @@ from pytorch_lightning import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException +from nemo.core.optim.optimizer_with_main_params import MainParamsOptimizerWrapper + class EMA(Callback): """ @@ -139,6 +141,8 @@ def __init__( self.ema_params = () def all_parameters(self): + if isinstance(self.optimizer, MainParamsOptimizerWrapper): + return (param for group in self.optimizer.float16_groups for param in group['params']) return (param for group in self.param_groups for param in group['params']) def step(self, closure=None): From cf280ed8afae157222cdf5e05ff66b152efb40ea Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 24 Oct 2022 15:30:43 +0100 Subject: [PATCH 08/19] Try to get O2 working Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 15 ++++++++++++--- nemo/collections/nlp/parts/nlp_overrides.py | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 1cdd92198dca..cc9277b7cc48 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -129,6 +129,12 @@ class EMAOptimizer(torch.optim.Optimizer): def __init__( self, optimizer: torch.optim.Optimizer, device: torch.device, decay: float = 0.9999, ): + + # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has + # implemented custom logic which we would not want to call on destruction of the `EMAOptimizer`. + # this allows us to use the underlying optimizer without having to go through the EMAOptimizer wrapper. + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + self.optimizer = optimizer self.decay = decay self.device = device @@ -142,10 +148,10 @@ def __init__( def all_parameters(self): if isinstance(self.optimizer, MainParamsOptimizerWrapper): - return (param for group in self.optimizer.float16_groups for param in group['params']) + return (param for group in self.optimizer.float16_groups for param in group) return (param for group in self.param_groups for param in group['params']) - def step(self, closure=None): + def step(self, closure=None, **kwargs): self.join() if self.first_iteration: @@ -162,7 +168,10 @@ def step(self, closure=None): ) self.rebuild_ema_params = False - loss = self.optimizer.step(closure) + if isinstance(self.optimizer, MainParamsOptimizerWrapper): + loss = self.optimizer.step(**kwargs) + else: + loss = self.optimizer.step(closure) self.update() return loss diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2e52be81ce34..401da6b06044 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -35,6 +35,7 @@ from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn.parallel import DistributedDataParallel +from nemo.collections.common.callbacks.ema import EMAOptimizer from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.core.optim import MainParamsOptimizerWrapper @@ -597,7 +598,7 @@ def optimizer_step( **kwargs: Any, ) -> None: assert isinstance( - optimizer, MainParamsOptimizerWrapper + optimizer, (MainParamsOptimizerWrapper, EMAOptimizer,) ), "MegatronHalfPrecisionPlugin supports only the optimizer with master parameters" if self.scaler is None: From a9b88a374103cc0d2e2dbdd6aacadbb20abe7af1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 7 Nov 2022 14:36:57 +0000 Subject: [PATCH 09/19] WIP Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 62 ++++++++++++++++++------ nemo/utils/exp_manager.py | 2 + tests/collections/common/test_ema.py | 27 ++++++++--- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index cc9277b7cc48..03d12af9ebde 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -14,6 +14,7 @@ import contextlib import copy import threading +from typing import Iterable import pytorch_lightning as pl import torch @@ -34,41 +35,60 @@ class EMA(Callback): Args: decay: The exponential decay used when calculating the moving average. Has to be between 0-1. validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. cpu_offload: Offload weights to CPU. """ - def __init__(self, decay: float, validate_original_weights: bool = False, cpu_offload: bool = False): + def __init__( + self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False, + ): if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") self.decay = decay self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps self.cpu_offload = cpu_offload - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: device = pl_module.device if not self.cpu_offload else torch.device('cpu') - trainer.optimizers = [EMAOptimizer(optim, device=device, decay=self.decay) for optim in trainer.optimizers] - - def swap_model_weights(self, trainer: "pl.Trainer"): - for optimizer in trainer.optimizers: - assert isinstance(optimizer, EMAOptimizer) - optimizer.switch_main_parameter_weights() + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + ] def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not self.validate_original_weights: + if self._should_validate_ema_weights(trainer): self.swap_model_weights(trainer) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not self.validate_original_weights: + if self._should_validate_ema_weights(trainer): self.swap_model_weights(trainer) def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not self.validate_original_weights: + if self._should_validate_ema_weights(trainer): self.swap_model_weights(trainer) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not self.validate_original_weights: + if self._should_validate_ema_weights(trainer): self.swap_model_weights(trainer) + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers) + + def swap_model_weights(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights() + @torch.no_grad() def ema_update(ema_model_tuple, current_model_tuple, decay): @@ -127,7 +147,12 @@ class EMAOptimizer(torch.optim.Optimizer): """ def __init__( - self, optimizer: torch.optim.Optimizer, device: torch.device, decay: float = 0.9999, + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, ): # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has @@ -138,6 +163,8 @@ def __init__( self.optimizer = optimizer self.decay = decay self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps self.first_iteration = True self.rebuild_ema_params = True @@ -146,7 +173,7 @@ def __init__( self.ema_params = () - def all_parameters(self): + def all_parameters(self) -> Iterable[torch.Tensor]: if isinstance(self.optimizer, MainParamsOptimizerWrapper): return (param for group in self.optimizer.float16_groups for param in group) return (param for group in self.param_groups for param in group['params']) @@ -173,9 +200,14 @@ def step(self, closure=None, **kwargs): else: loss = self.optimizer.step(closure) - self.update() + if self._should_update_at_step(): + self.update() + self.current_step += 1 return loss + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + @torch.no_grad() def update(self): if self.stream is not None: diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 96b142d61aad..b403be2d70bb 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -103,6 +103,7 @@ class EMAParams: decay: Optional[float] = 0.999 cpu_offload: Optional[bool] = False validate_original_weights: Optional[bool] = False + every_n_steps: int = 1 @dataclass @@ -360,6 +361,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo decay=cfg.ema.decay, validate_original_weights=cfg.ema.validate_original_weights, cpu_offload=cfg.ema.cpu_offload, + every_n_steps=cfg.ema.every_n_steps, ) trainer.callbacks.append(ema_callback) diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index ae4f40d52d51..e9cc727e97e7 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT from nemo.collections.common.callbacks import EMA +from nemo.collections.common.callbacks.ema import EMAOptimizer from nemo.core import ModelPT from nemo.utils.exp_manager import exp_manager from tests.collections.nlp.test_gpt_model import DEVICE_CAPABILITY @@ -100,10 +101,16 @@ def test_ema_saved_state(self, tmpdir, caplog): """Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly.""" temp_path = os.path.join(tmpdir, 'saved_state') + def extract_ema_weights(ema_callback, pl_module, trainer): + ema_callback.swap_model_weights(trainer) + weights = [w.detach().clone() for w in pl_module.state_dict().values()] + ema_callback.swap_model_weights(trainer) + return weights + class TerminateCallback(Callback): def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - self.saved_ema_weights = ema_callback._ema_model_weights + self.saved_ema_weights = extract_ema_weights(ema_callback, pl_module, trainer) self.pl_module_weights = list(pl_module.state_dict().values()) raise SystemExit @@ -124,7 +131,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -141,9 +148,13 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") weights = list(pl_module.state_dict().values()) for x, y in zip(weights, terminate_callback.pl_module_weights): assert torch.allclose(x.cpu(), y.cpu()) - for x, y in zip(ema_callback._ema_model_weights, terminate_callback.saved_ema_weights): + current_ema_weights = extract_ema_weights(ema_callback, pl_module, trainer) + for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights): assert torch.allclose(x.cpu(), y.cpu()) - assert ema_callback._cur_step == 8 + + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + assert optimizer.current_step == 8 trainer = Trainer( max_epochs=2, @@ -157,7 +168,7 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -181,7 +192,7 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -404,7 +415,7 @@ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod if ema_callback.evaluate_ema_weights_instead: # todo (sean): shouldn't use the weights buffer to check original weights self._original_weights = list(x.detach().clone() for x in ema_callback._weights_buffer) - if ema_callback.ema_initialized: + if ema_callback._ema_initialized: for ema_weights, module_weights in zip( ema_callback._ema_model_weights, pl_module.state_dict().values() ): @@ -414,6 +425,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] if ema_callback.evaluate_ema_weights_instead: model_weights = list(pl_module.state_dict().values()) - if ema_callback.ema_initialized: + if ema_callback._ema_initialized: for orig_weights, module_weights in zip(self._original_weights, model_weights): torch.allclose(orig_weights, module_weights.cpu()) From e47d31481d370a6ee5bd823b5a5fa6840574c65c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Nov 2022 15:24:40 +0000 Subject: [PATCH 10/19] Fixes Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 50 ++++++++++++++++++++- nemo/collections/nlp/parts/nlp_overrides.py | 4 +- nemo/utils/exp_manager.py | 26 ++--------- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 03d12af9ebde..f5bcd69782e4 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -13,8 +13,11 @@ # limitations under the License. import contextlib import copy +import logging +import os import threading -from typing import Iterable +import warnings +from typing import Any, Dict, Iterable import pytorch_lightning as pl import torch @@ -49,7 +52,7 @@ def __init__( self.every_n_steps = every_n_steps self.cpu_offload = cpu_offload - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: device = pl_module.device if not self.cpu_offload else torch.device('cpu') trainer.optimizers = [ EMAOptimizer( @@ -89,6 +92,44 @@ def swap_model_weights(self, trainer: "pl.Trainer"): assert isinstance(optimizer, EMAOptimizer) optimizer.switch_main_parameter_weights() + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + if trainer.ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__: + ext = checkpoint_callback.FILE_EXTENSION + if trainer.ckpt_path.endswith(f'-EMA{ext}'): + logging.info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = trainer.ckpt_path.replace(ext, f'-EMA{ext}') + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) + checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] + del ema_state_dict + logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") + else: + warnings.warn( + "we were unable to find the associated EMA weights when re-loading, " + "training will start with new EMA weights.", + UserWarning, + ) + @torch.no_grad() def ema_update(ema_model_tuple, current_model_tuple, decay): @@ -165,6 +206,7 @@ def __init__( self.device = device self.current_step = current_step self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False self.first_iteration = True self.rebuild_ema_params = True @@ -272,6 +314,9 @@ def join(self): def state_dict(self): self.join() + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + state_dict = { 'opt': self.optimizer.state_dict(), 'ema': self.ema_params, @@ -283,6 +328,7 @@ def load_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict['opt']) self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema'])) + self.rebuild_ema_params = False def add_param_group(self, param_group): self.optimizer.add_param_group(param_group) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 401da6b06044..9bf2acf5809f 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -22,16 +22,16 @@ import pytorch_lightning as pl import torch +from lightning_lite.plugins import ClusterEnvironment +from lightning_lite.utilities.types import _PATH from omegaconf import OmegaConf from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher -from pytorch_lightning.utilities.types import _PATH from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn.parallel import DistributedDataParallel diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index b403be2d70bb..ed38e58fafdc 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -32,7 +32,6 @@ from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.callbacks.timer import Interval, Timer -from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.loops import TrainingEpochLoop from pytorch_lightning.strategies.ddp import DDPStrategy @@ -658,24 +657,6 @@ def get_git_diff(): return "{}\n".format(err.output.decode("utf-8")) -class LoggerList(_LoggerCollection): - """ A thin wrapper on Lightning's LoggerCollection such that name and version are better aligned with exp_manager - """ - - def __init__(self, _logger_iterable, nemo_name=None, nemo_version=""): - super().__init__(_logger_iterable) - self._nemo_name = nemo_name - self._nemo_version = nemo_version - - @property - def name(self) -> str: - return self._nemo_name - - @property - def version(self) -> str: - return self._nemo_version - - def configure_loggers( trainer: 'pytorch_lightning.Trainer', exp_dir: [Path, str], @@ -718,9 +699,6 @@ def configure_loggers( logger_list.append(wandb_logger) logging.info("WandBLogger has been set up") - logger_list = ( - LoggerList(logger_list, nemo_name=name, nemo_version=version) if len(logger_list) > 1 else logger_list[0] - ) trainer._logger_connector.configure_logger(logger_list) @@ -905,8 +883,10 @@ def _get_ema_callback(self, trainer) -> Optional[EMA]: return ema_callback def _save_checkpoint(self, trainer, filepath: str) -> None: - super()._save_checkpoint(trainer, filepath) ema_callback = self._get_ema_callback(trainer) + with ema_callback.save_original_optimizer_state(trainer): + super()._save_checkpoint(trainer, filepath) + if ema_callback is not None: # save EMA copy of the model as well. ema_callback.swap_model_weights(trainer) From a4b3e53e93031e17c258ad0fad8ef930541c659d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 29 Nov 2022 14:21:12 +0000 Subject: [PATCH 11/19] Fixes to tests Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 18 ++- tests/collections/common/test_ema.py | 152 +++++++---------------- 2 files changed, 62 insertions(+), 108 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index f5bcd69782e4..066a23e8dd22 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -63,6 +63,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - current_step=trainer.global_step, ) for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) ] def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -121,13 +122,14 @@ def on_load_checkpoint( if os.path.exists(ema_path): ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] + for optimizer_state in checkpoint['optimizer_states']: + optimizer_state['ema'] = list(ema_state_dict['state_dict'].values()) del ema_state_dict logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") else: - warnings.warn( - "we were unable to find the associated EMA weights when re-loading, " - "training will start with new EMA weights.", - UserWarning, + raise MisconfigurationException( + "Unable to find the associated EMA weights when re-loading, " + f"training will start with new EMA weights. Expected them to be at: {ema_path}", ) @@ -320,6 +322,10 @@ def state_dict(self): state_dict = { 'opt': self.optimizer.state_dict(), 'ema': self.ema_params, + 'current_step': self.current_step, + 'decay': self.decay, + 'every_n_steps': self.every_n_steps, + 'device': self.device, } return state_dict @@ -328,6 +334,10 @@ def load_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict['opt']) self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema'])) + self.current_step = state_dict['current_step'] + self.decay = state_dict['decay'] + self.device = state_dict['device'] + self.every_n_steps = state_dict['every_n_steps'] self.rebuild_ema_params = False def add_param_group(self, param_group): diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index e9cc727e97e7..3929679fc4fb 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -32,6 +32,14 @@ from tests.collections.nlp.test_gpt_model import DEVICE_CAPABILITY +def extract_ema_weights(pl_module, trainer): + ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] + ema_callback.swap_model_weights(trainer) + weights = [w.detach().clone() for w in pl_module.state_dict().values()] + ema_callback.swap_model_weights(trainer) + return weights + + class OnesDataset(torch.utils.data.Dataset): def __init__(self, dataset_len): super().__init__() @@ -90,27 +98,15 @@ def test_ema_value(self): with pytest.raises(MisconfigurationException, match="between 0 and 1"): EMA(decay=2) - @mock.patch('nemo.collections.common.callbacks.ema.apex_available', False) - def test_ema_apex_unavailable(self): - with pytest.warns(UserWarning, match="EMA has better performance when Apex is installed"): - EMA(decay=0.999) - @pytest.mark.unit @pytest.mark.run_only_on('GPU') def test_ema_saved_state(self, tmpdir, caplog): """Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly.""" temp_path = os.path.join(tmpdir, 'saved_state') - def extract_ema_weights(ema_callback, pl_module, trainer): - ema_callback.swap_model_weights(trainer) - weights = [w.detach().clone() for w in pl_module.state_dict().values()] - ema_callback.swap_model_weights(trainer) - return weights - class TerminateCallback(Callback): def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - self.saved_ema_weights = extract_ema_weights(ema_callback, pl_module, trainer) + self.saved_ema_weights = extract_ema_weights(pl_module, trainer) self.pl_module_weights = list(pl_module.state_dict().values()) raise SystemExit @@ -144,11 +140,10 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu class CheckStateCallback(Callback): def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] weights = list(pl_module.state_dict().values()) for x, y in zip(weights, terminate_callback.pl_module_weights): assert torch.allclose(x.cpu(), y.cpu()) - current_ema_weights = extract_ema_weights(ema_callback, pl_module, trainer) + current_ema_weights = extract_ema_weights(pl_module, trainer) for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights): assert torch.allclose(x.cpu(), y.cpu()) @@ -214,12 +209,14 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True, "validate_original_weights": True}, "explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, ) - with pytest.warns(UserWarning, match="we were unable to find the associated EMA weights when re-loading"): + with pytest.raises( + MisconfigurationException, match="Unable to find the associated EMA weights when re-loading" + ): trainer.fit(model, ckpt_path=resume_path) @pytest.mark.unit @@ -232,70 +229,23 @@ def test_exp_manager_ema_weights(self, tmpdir): exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": True}, + "ema": {"enable": True, "validate_original_weights": True}, "explicit_log_dir": str(tmp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, ) assert any(isinstance(callback, EMA) for callback in trainer.callbacks) trainer.fit(model) + ema_weights = extract_ema_weights(model, trainer) assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt") ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt" assert os.path.exists(ema_path) duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path)) - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_callback._ema_model_weights): + for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_weights): assert torch.allclose(saved_weight.cpu(), ema_weight.cpu()) - @pytest.mark.unit - @pytest.mark.run_only_on('GPU') - def test_ema_save_in_callback(self, tmpdir): - """Test to ensure when `save_ema_weights_in_callback_state` is enabled, we save to the callback state.""" - temp_path = os.path.join(tmpdir, 'saved_state') - - model = ExampleModel() - - trainer = Trainer( - max_epochs=2, - limit_val_batches=1, - limit_train_batches=16, - logger=False, - val_check_interval=0.5, - enable_checkpointing=False, - accelerator='gpu', - devices=1, - callbacks=[EMA(decay=0.999, save_ema_weights_in_callback_state=True, evaluate_ema_weights_instead=True)], - ) - exp_manager( - trainer, - {"explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},}, - ) - trainer.fit(model=model) - - resume_path = os.path.join(temp_path, "checkpoints/epoch=0-step=8.ckpt") - callback = EMA(decay=0.999, save_ema_weights_in_callback_state=True) - - class AssertCallback(Callback): - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - assert callback._ema_model_weights is not None - - model = ExampleModel() - - trainer = Trainer( - max_epochs=2, - limit_val_batches=1, - limit_train_batches=16, - logger=False, - val_check_interval=0.5, - enable_checkpointing=False, - accelerator='gpu', - devices=1, - callbacks=[callback, AssertCallback()], - ) - trainer.fit(model, ckpt_path=resume_path) - class TestEMATrain: @pytest.mark.unit @@ -314,41 +264,31 @@ class TestEMATrain: ], ) @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) - @pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False]) - @pytest.mark.parametrize("apex_available_mock", [True, False]) - @pytest.mark.run_only_on('GPU') + @pytest.mark.parametrize("validate_original_weights", [True, False]) def test_ema_run_cuda( - self, - test_data_dir, - precision, - accumulate_grad_batches, - evaluate_ema_weights_instead, - apex_available_mock, - tmpdir, + self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir, ): - with mock.patch('nemo.collections.common.callbacks.ema.apex_available', apex_available_mock): - self.run_training_test( - accumulate_grad_batches=accumulate_grad_batches, - evaluate_ema_weights_instead=evaluate_ema_weights_instead, - accelerator='gpu', - precision=precision, - tmpdir=tmpdir, - ) + self.run_training_test( + accumulate_grad_batches=accumulate_grad_batches, + validate_original_weights=validate_original_weights, + accelerator='gpu', + precision=precision, + tmpdir=tmpdir, + ) @pytest.mark.unit @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) - @pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False]) - @pytest.mark.run_only_on('GPU') - def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, evaluate_ema_weights_instead, tmpdir): + @pytest.mark.parametrize("validate_original_weights", [True, False]) + def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, validate_original_weights, tmpdir): self.run_training_test( accumulate_grad_batches=accumulate_grad_batches, - evaluate_ema_weights_instead=evaluate_ema_weights_instead, + validate_original_weights=validate_original_weights, accelerator='cpu', precision=32, tmpdir=tmpdir, ) - def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instead, accelerator, precision, tmpdir): + def run_training_test(self, accumulate_grad_batches, validate_original_weights, accelerator, precision, tmpdir): pl.seed_everything(123) model = ExampleModel() trainer = Trainer( @@ -367,13 +307,14 @@ def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instea exp_manager( trainer, { - "ema": {"enable": True, "evaluate_ema_weights_instead": evaluate_ema_weights_instead}, + "ema": {"enable": True, "validate_original_weights": validate_original_weights}, "explicit_log_dir": str(tmpdir), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, ) # add the check callback after the exp manager has made modifications. trainer.callbacks.append(EMAAssertCallback()) + trainer.callbacks.insert(0, EMAValidationAssertCallback()) trainer.fit(model=model, val_dataloaders=model.train_dataloader()) @@ -383,16 +324,15 @@ def __init__(self): def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: model_weights = list(pl_module.state_dict().values()) - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - for x, y in zip(model_weights, ema_callback._ema_model_weights): + ema_weights = extract_ema_weights(pl_module, trainer) + for x, y in zip(model_weights, ema_weights): assert torch.allclose(x, y) def on_train_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int ) -> None: - ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] # saved for manual calculation of ema to compare against implementation - self._before_calc_ema_weights = deepcopy(ema_callback._ema_model_weights) + self._before_calc_ema_weights = extract_ema_weights(pl_module, trainer) def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int @@ -402,29 +342,33 @@ def on_train_batch_end( return ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] decay = ema_callback.decay + ema_weights = extract_ema_weights(pl_module, trainer) expected_ema_weights = [] for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._before_calc_ema_weights): expected_ema_weight = orig_weight * (1 - decay) + ema_weight * decay expected_ema_weights.append(expected_ema_weight) - for actual_ema_weight, expected_ema_weight in zip(ema_callback._ema_model_weights, expected_ema_weights): + for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights): assert torch.allclose(actual_ema_weight, expected_ema_weight) + +class EMAValidationAssertCallback(Callback): def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - if ema_callback.evaluate_ema_weights_instead: - # todo (sean): shouldn't use the weights buffer to check original weights - self._original_weights = list(x.detach().clone() for x in ema_callback._weights_buffer) + self._original_weights = list(pl_module.state_dict().values()) + self._ema_weights = extract_ema_weights(pl_module, trainer) + # call original EMA function + super().on_validation_start(trainer, pl_module) + if not ema_callback.validate_original_weights: if ema_callback._ema_initialized: - for ema_weights, module_weights in zip( - ema_callback._ema_model_weights, pl_module.state_dict().values() - ): + # check model weights are now EMA weights + for ema_weights, module_weights in zip(self._ema_weights, pl_module.state_dict().values()): torch.allclose(ema_weights, module_weights) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - if ema_callback.evaluate_ema_weights_instead: + if not ema_callback.validate_original_weights: model_weights = list(pl_module.state_dict().values()) if ema_callback._ema_initialized: for orig_weights, module_weights in zip(self._original_weights, model_weights): - torch.allclose(orig_weights, module_weights.cpu()) + torch.allclose(orig_weights.cpu(), module_weights.cpu()) From 087d21f2b94466cdae22219b42380e75bf7bcf75 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 29 Nov 2022 14:48:25 +0000 Subject: [PATCH 12/19] Add guard Signed-off-by: SeanNaren --- nemo/utils/exp_manager.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index f81d96bbf791..8670b22c5567 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -916,19 +916,19 @@ def _del_model_without_trainer(self, filepath: str) -> None: except: logging.info(f"Tried to remove checkpoint: {filepath} but failed.") - def _get_ema_callback(self, trainer) -> Optional[EMA]: + def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: ema_callback = None for callback in trainer.callbacks: if isinstance(callback, EMA): ema_callback = callback return ema_callback - def _save_checkpoint(self, trainer, filepath: str) -> None: - ema_callback = self._get_ema_callback(trainer) - with ema_callback.save_original_optimizer_state(trainer): - super()._save_checkpoint(trainer, filepath) - + def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + ema_callback = self._ema_callback(trainer) if ema_callback is not None: + with ema_callback.save_original_optimizer_state(trainer): + super()._save_checkpoint(trainer, filepath) + # save EMA copy of the model as well. ema_callback.swap_model_weights(trainer) filepath = self._ema_format_filepath(filepath) @@ -936,6 +936,8 @@ def _save_checkpoint(self, trainer, filepath: str) -> None: rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") super()._save_checkpoint(trainer, filepath) ema_callback.swap_model_weights(trainer) + else: + super()._save_checkpoint(trainer, filepath) def _ema_format_filepath(self, filepath: str) -> str: return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') From 3e31eff80bf67b85509705897a3f804ec5893904 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 29 Nov 2022 14:52:41 +0000 Subject: [PATCH 13/19] Remove import Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 066a23e8dd22..9573c3e8ccd5 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -16,7 +16,6 @@ import logging import os import threading -import warnings from typing import Any, Dict, Iterable import pytorch_lightning as pl From ecb4415e319f71dd27898847803151ecd4cc9837 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 29 Nov 2022 15:57:39 +0000 Subject: [PATCH 14/19] Add guard Signed-off-by: SeanNaren --- tests/collections/common/test_ema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index 3929679fc4fb..79d3473b37d2 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -265,6 +265,7 @@ class TestEMATrain: ) @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) @pytest.mark.parametrize("validate_original_weights", [True, False]) + @pytest.mark.run_only_on('GPU') def test_ema_run_cuda( self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir, ): From 84582e5a87fb8b24300b7e2b8ad4e0c92fc9c564 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 29 Nov 2022 16:33:22 +0000 Subject: [PATCH 15/19] Add comment Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 9573c3e8ccd5..53eafc56e3a7 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -122,6 +122,7 @@ def on_load_checkpoint( ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] for optimizer_state in checkpoint['optimizer_states']: + # update the ema state with the EMA saved model weights optimizer_state['ema'] = list(ema_state_dict['state_dict'].values()) del ema_state_dict logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") From 1a508a3219f7f7ee6d8f262be44ef8de257b02fc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 30 Nov 2022 14:41:20 +0000 Subject: [PATCH 16/19] Remove overwrite Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 53eafc56e3a7..4bf901317a05 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -197,12 +197,6 @@ def __init__( every_n_steps: int = 1, current_step: int = 0, ): - - # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has - # implemented custom logic which we would not want to call on destruction of the `EMAOptimizer`. - # this allows us to use the underlying optimizer without having to go through the EMAOptimizer wrapper. - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} - self.optimizer = optimizer self.decay = decay self.device = device From 24968c896f661488c6ffb1eba853b6514c742450 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 30 Nov 2022 16:27:03 +0000 Subject: [PATCH 17/19] Add BatchNorm, currently tests fail Signed-off-by: SeanNaren --- tests/collections/common/test_ema.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index 79d3473b37d2..e63bfaa09941 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs): cfg = OmegaConf.structured({}) super().__init__(cfg) self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) + self.bn = torch.nn.BatchNorm1d(1) def train_dataloader(self): dataset = OnesDataset(16) @@ -67,7 +68,7 @@ def val_dataloader(self): return torch.utils.data.DataLoader(dataset, batch_size=2) def forward(self, batch): - output = self.l1(batch) + output = self.bn(self.l1(batch)) return torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device)) def validation_step(self, batch, batch_idx): From 69bd5025f828a0d9901cc34df201454561b0bf98 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 9 Dec 2022 15:59:18 +0000 Subject: [PATCH 18/19] Fix tests/functionality for batch norm Signed-off-by: SeanNaren --- nemo/collections/common/callbacks/ema.py | 40 ++++++++----- nemo/utils/exp_manager.py | 11 ++-- tests/collections/common/test_ema.py | 74 +++++++++++------------- 3 files changed, 64 insertions(+), 61 deletions(-) diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 4bf901317a05..56db703aab28 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -23,8 +23,6 @@ from pytorch_lightning import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException -from nemo.core.optim.optimizer_with_main_params import MainParamsOptimizerWrapper - class EMA(Callback): """ @@ -87,10 +85,21 @@ def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: def _ema_initialized(self, trainer: "pl.Trainer") -> bool: return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers) - def swap_model_weights(self, trainer: "pl.Trainer"): + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): for optimizer in trainer.optimizers: assert isinstance(optimizer, EMAOptimizer) - optimizer.switch_main_parameter_weights() + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) @contextlib.contextmanager def save_original_optimizer_state(self, trainer: "pl.Trainer"): @@ -120,10 +129,11 @@ def on_load_checkpoint( ema_path = trainer.ckpt_path.replace(ext, f'-EMA{ext}') if os.path.exists(ema_path): ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu')) + + # this is wrong, basically when we save the EMA weights, optimizer_states actually contains the model parameters + # as we swapped the model parameters with the state dict parameters. + # we could enforce that if you trained with EMA and want to continue training checkpoint['optimizer_states'] = ema_state_dict['optimizer_states'] - for optimizer_state in checkpoint['optimizer_states']: - # update the ema state with the EMA saved model weights - optimizer_state['ema'] = list(ema_state_dict['state_dict'].values()) del ema_state_dict logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.") else: @@ -210,10 +220,9 @@ def __init__( self.thread = None self.ema_params = () + self.in_saving_ema_model_context = False def all_parameters(self) -> Iterable[torch.Tensor]: - if isinstance(self.optimizer, MainParamsOptimizerWrapper): - return (param for group in self.optimizer.float16_groups for param in group) return (param for group in self.param_groups for param in group['params']) def step(self, closure=None, **kwargs): @@ -233,10 +242,7 @@ def step(self, closure=None, **kwargs): ) self.rebuild_ema_params = False - if isinstance(self.optimizer, MainParamsOptimizerWrapper): - loss = self.optimizer.step(**kwargs) - else: - loss = self.optimizer.step(closure) + loss = self.optimizer.step(closure) if self._should_update_at_step(): self.update() @@ -271,9 +277,9 @@ def swap_tensors(self, tensor1, tensor2): tensor1.copy_(tensor2) tensor2.copy_(tmp) - def switch_main_parameter_weights(self): + def switch_main_parameter_weights(self, saving_ema_model: bool = False): self.join() - + self.in_saving_ema_model_context = saving_ema_model for param, ema_param in zip(self.all_parameters(), self.ema_params): self.swap_tensors(param.data, ema_param) @@ -313,9 +319,11 @@ def state_dict(self): if self.save_original_optimizer_state: return self.optimizer.state_dict() + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters()) state_dict = { 'opt': self.optimizer.state_dict(), - 'ema': self.ema_params, + 'ema': ema_params, 'current_step': self.current_step, 'decay': self.decay, 'every_n_steps': self.every_n_steps, diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 8670b22c5567..00046fec08de 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -930,12 +930,11 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) super()._save_checkpoint(trainer, filepath) # save EMA copy of the model as well. - ema_callback.swap_model_weights(trainer) - filepath = self._ema_format_filepath(filepath) - if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - super()._save_checkpoint(trainer, filepath) - ema_callback.swap_model_weights(trainer) + with ema_callback.save_ema_model(trainer): + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) else: super()._save_checkpoint(trainer, filepath) diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index e63bfaa09941..16de01e0fbe4 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -13,9 +13,7 @@ # limitations under the License. import os.path -from copy import deepcopy from typing import Any, Dict, Union -from unittest import mock import pytest import pytorch_lightning as pl @@ -35,41 +33,44 @@ def extract_ema_weights(pl_module, trainer): ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] ema_callback.swap_model_weights(trainer) - weights = [w.detach().clone() for w in pl_module.state_dict().values()] + weights = extract_weights(pl_module) ema_callback.swap_model_weights(trainer) return weights -class OnesDataset(torch.utils.data.Dataset): - def __init__(self, dataset_len): - super().__init__() - self.__dataset_len = dataset_len +def extract_weights(pl_module): + return [w.detach().clone() for w in pl_module.parameters()] - def __getitem__(self, *args): - return torch.ones(2) + +class RandomDataset(torch.utils.data.Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] def __len__(self): - return self.__dataset_len + return self.len class ExampleModel(ModelPT): def __init__(self, *args, **kwargs): cfg = OmegaConf.structured({}) super().__init__(cfg) - self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) - self.bn = torch.nn.BatchNorm1d(1) + self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32) + self.bn = torch.nn.BatchNorm1d(32) def train_dataloader(self): - dataset = OnesDataset(16) + dataset = RandomDataset(32, 16) return torch.utils.data.DataLoader(dataset, batch_size=2) def val_dataloader(self): - dataset = OnesDataset(10) + dataset = RandomDataset(32, 16) return torch.utils.data.DataLoader(dataset, batch_size=2) def forward(self, batch): - output = self.bn(self.l1(batch)) - return torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device)) + return self.l1(self.bn(batch)).sum() def validation_step(self, batch, batch_idx): return self(batch) @@ -78,7 +79,7 @@ def training_step(self, batch, batch_idx): return self(batch) def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.1) + return torch.optim.SGD(self.parameters(), lr=1e-3) def list_available_models(self): pass @@ -108,7 +109,7 @@ def test_ema_saved_state(self, tmpdir, caplog): class TerminateCallback(Callback): def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.saved_ema_weights = extract_ema_weights(pl_module, trainer) - self.pl_module_weights = list(pl_module.state_dict().values()) + self.pl_module_weights = extract_weights(pl_module) raise SystemExit model = ExampleModel() @@ -141,7 +142,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu class CheckStateCallback(Callback): def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - weights = list(pl_module.state_dict().values()) + weights = extract_weights(pl_module) for x, y in zip(weights, terminate_callback.pl_module_weights): assert torch.allclose(x.cpu(), y.cpu()) current_ema_weights = extract_ema_weights(pl_module, trainer) @@ -309,7 +310,7 @@ def run_training_test(self, accumulate_grad_batches, validate_original_weights, exp_manager( trainer, { - "ema": {"enable": True, "validate_original_weights": validate_original_weights}, + "ema": {"enable": True, "validate_original_weights": validate_original_weights, "decay": 0.999}, "explicit_log_dir": str(tmpdir), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, }, @@ -321,21 +322,12 @@ def run_training_test(self, accumulate_grad_batches, validate_original_weights, class EMAAssertCallback(Callback): - def __init__(self): - self._before_calc_ema_weights = None - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - model_weights = list(pl_module.state_dict().values()) - ema_weights = extract_ema_weights(pl_module, trainer) - for x, y in zip(model_weights, ema_weights): + model_weights = extract_weights(pl_module) + self.ema_weights = extract_ema_weights(pl_module, trainer) + for x, y in zip(model_weights, self.ema_weights): assert torch.allclose(x, y) - def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int - ) -> None: - # saved for manual calculation of ema to compare against implementation - self._before_calc_ema_weights = extract_ema_weights(pl_module, trainer) - def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: @@ -344,33 +336,37 @@ def on_train_batch_end( return ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] decay = ema_callback.decay - ema_weights = extract_ema_weights(pl_module, trainer) expected_ema_weights = [] - for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._before_calc_ema_weights): - expected_ema_weight = orig_weight * (1 - decay) + ema_weight * decay - expected_ema_weights.append(expected_ema_weight) + new_weights = extract_weights(pl_module) + + for ema_weight, new_weight in zip(self.ema_weights, new_weights): + expected_ema_weight = ema_weight * decay + expected_ema_weight += new_weight * (1 - decay) + expected_ema_weights.append(expected_ema_weight) + ema_weights = extract_ema_weights(pl_module, trainer) for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights): assert torch.allclose(actual_ema_weight, expected_ema_weight) + self.ema_weights = expected_ema_weights class EMAValidationAssertCallback(Callback): def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] - self._original_weights = list(pl_module.state_dict().values()) + self._original_weights = extract_weights(pl_module) self._ema_weights = extract_ema_weights(pl_module, trainer) # call original EMA function super().on_validation_start(trainer, pl_module) if not ema_callback.validate_original_weights: if ema_callback._ema_initialized: # check model weights are now EMA weights - for ema_weights, module_weights in zip(self._ema_weights, pl_module.state_dict().values()): + for ema_weights, module_weights in zip(self._ema_weights, extract_weights(pl_module)): torch.allclose(ema_weights, module_weights) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] if not ema_callback.validate_original_weights: - model_weights = list(pl_module.state_dict().values()) + model_weights = extract_weights(pl_module) if ema_callback._ema_initialized: for orig_weights, module_weights in zip(self._original_weights, model_weights): torch.allclose(orig_weights.cpu(), module_weights.cpu()) From f5df83cc68e9b16547d6c4ccd321ee1fc9d2e715 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 9 Dec 2022 16:00:50 +0000 Subject: [PATCH 19/19] Get rid of NLP changes Signed-off-by: SeanNaren --- nemo/collections/nlp/parts/nlp_overrides.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 6e165ac11735..0149af26db21 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -35,7 +35,6 @@ from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn.parallel import DistributedDataParallel -from nemo.collections.common.callbacks.ema import EMAOptimizer from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.core.optim import MainParamsOptimizerWrapper @@ -598,7 +597,7 @@ def optimizer_step( **kwargs: Any, ) -> None: assert isinstance( - optimizer, (MainParamsOptimizerWrapper, EMAOptimizer,) + optimizer, MainParamsOptimizerWrapper ), "MegatronHalfPrecisionPlugin supports only the optimizer with master parameters" if self.scaler is None: