Skip to content

Commit

Permalink
Fix replacing weights
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <snarenthiran@nvidia.com>
  • Loading branch information
SeanNaren committed Oct 19, 2022
1 parent dc96ff9 commit 979c51a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
43 changes: 31 additions & 12 deletions nemo/collections/common/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
EMAOptimizer(optim, device=pl_module.device, decay=self.decay) for optim in trainer.optimizers
]

def swap_model_weights(self, trainer: "pl.Trainer"):
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
optimizer.switch_main_parameter_weights()

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)

def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)


@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
Expand Down Expand Up @@ -158,6 +175,18 @@ def update(self):
)
self.thread.start()

def swap_tensors(self, tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)

def switch_main_parameter_weights(self):
self.join()

for param, ema_param in zip(self.all_parameters(), self.ema_params):
self.swap_tensors(param.data, ema_param)

@contextlib.contextmanager
def swap_ema_weights(self, enabled: bool = True):
r"""
Expand All @@ -170,23 +199,13 @@ def swap_ema_weights(self, enabled: bool = True):
enabled (bool): whether the swap should be performed
"""

def swap_tensors(tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)

if enabled:
self.join()

for param, ema_param in zip(self.all_parameters(), self.ema_params):
swap_tensors(param.data, ema_param)
self.switch_main_parameter_weights()
try:
yield
finally:
if enabled:
for param, ema_param in zip(self.all_parameters(), self.ema_params):
swap_tensors(param.data, ema_param)
self.switch_main_parameter_weights()

def __getattr__(self, name):
return getattr(self.optimizer, name)
Expand Down
15 changes: 4 additions & 11 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,19 +897,12 @@ def _get_ema_callback(self, trainer) -> Optional[EMA]:
return ema_callback

def _save_checkpoint(self, trainer, filepath: str) -> None:
super()._save_checkpoint(trainer, filepath)
ema_callback = self._get_ema_callback(trainer)
if ema_callback is not None:
# save EMA copy of the model as well.
ema_callback.replace_model_weights(trainer.lightning_module)
filepath = self._ema_format_filepath(filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
ema_callback.restore_original_weights(trainer.lightning_module)

def _ema_format_filepath(self, filepath: str) -> str:
return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}')
ema_callback.swap_model_weights(trainer)
super()._save_checkpoint(trainer, filepath)
if ema_callback is not None:
ema_callback.swap_model_weights(trainer)


def configure_checkpointing(
Expand Down

0 comments on commit 979c51a

Please sign in to comment.