Skip to content

Commit

Permalink
Allow swapping of weights be optional
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <snarenthiran@nvidia.com>
  • Loading branch information
SeanNaren committed Oct 19, 2022
1 parent 979c51a commit f8a361e
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions nemo/collections/common/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ class EMA(Callback):
Args:
decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
apply_ema_every_n_steps: Apply EMA every N steps.
validate_original_weights: Validate the original weights, as apposed to the EMA weights.
"""

def __init__(self, decay: float):
def __init__(self, decay: float, apply_ema_every_n_steps: int = 1, validate_original_weights: bool = False):
if not (0 <= decay <= 1):
raise MisconfigurationException("EMA decay value must be between 0 and 1")
self.decay = decay
self.apply_ema_every_n_steps = apply_ema_every_n_steps
self.validate_original_weights = validate_original_weights

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.optimizers = [
Expand All @@ -49,16 +53,20 @@ def swap_model_weights(self, trainer: "pl.Trainer"):
optimizer.switch_main_parameter_weights()

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)
if not self.validate_original_weights:
self.swap_model_weights(trainer)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)
if not self.validate_original_weights:
self.swap_model_weights(trainer)

def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)
if not self.validate_original_weights:
self.swap_model_weights(trainer)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.swap_model_weights(trainer)
if not self.validate_original_weights:
self.swap_model_weights(trainer)


@torch.no_grad()
Expand Down

0 comments on commit f8a361e

Please sign in to comment.