Skip to content

Commit

Permalink
Fixes to tests
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <snarenthiran@nvidia.com>
  • Loading branch information
SeanNaren committed Nov 29, 2022
1 parent e03c83d commit a4b3e53
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 108 deletions.
18 changes: 14 additions & 4 deletions nemo/collections/common/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
current_step=trainer.global_step,
)
for optim in trainer.optimizers
if not isinstance(optim, EMAOptimizer)
]

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -121,13 +122,14 @@ def on_load_checkpoint(
if os.path.exists(ema_path):
ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
checkpoint['optimizer_states'] = ema_state_dict['optimizer_states']
for optimizer_state in checkpoint['optimizer_states']:
optimizer_state['ema'] = list(ema_state_dict['state_dict'].values())
del ema_state_dict
logging.info("EMA weights have been loaded successfully. Continuing training with saved EMA weights.")
else:
warnings.warn(
"we were unable to find the associated EMA weights when re-loading, "
"training will start with new EMA weights.",
UserWarning,
raise MisconfigurationException(
"Unable to find the associated EMA weights when re-loading, "
f"training will start with new EMA weights. Expected them to be at: {ema_path}",
)


Expand Down Expand Up @@ -320,6 +322,10 @@ def state_dict(self):
state_dict = {
'opt': self.optimizer.state_dict(),
'ema': self.ema_params,
'current_step': self.current_step,
'decay': self.decay,
'every_n_steps': self.every_n_steps,
'device': self.device,
}
return state_dict

Expand All @@ -328,6 +334,10 @@ def load_state_dict(self, state_dict):

self.optimizer.load_state_dict(state_dict['opt'])
self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema']))
self.current_step = state_dict['current_step']
self.decay = state_dict['decay']
self.device = state_dict['device']
self.every_n_steps = state_dict['every_n_steps']
self.rebuild_ema_params = False

def add_param_group(self, param_group):
Expand Down
152 changes: 48 additions & 104 deletions tests/collections/common/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
from tests.collections.nlp.test_gpt_model import DEVICE_CAPABILITY


def extract_ema_weights(pl_module, trainer):
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
ema_callback.swap_model_weights(trainer)
weights = [w.detach().clone() for w in pl_module.state_dict().values()]
ema_callback.swap_model_weights(trainer)
return weights


class OnesDataset(torch.utils.data.Dataset):
def __init__(self, dataset_len):
super().__init__()
Expand Down Expand Up @@ -90,27 +98,15 @@ def test_ema_value(self):
with pytest.raises(MisconfigurationException, match="between 0 and 1"):
EMA(decay=2)

@mock.patch('nemo.collections.common.callbacks.ema.apex_available', False)
def test_ema_apex_unavailable(self):
with pytest.warns(UserWarning, match="EMA has better performance when Apex is installed"):
EMA(decay=0.999)

@pytest.mark.unit
@pytest.mark.run_only_on('GPU')
def test_ema_saved_state(self, tmpdir, caplog):
"""Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly."""
temp_path = os.path.join(tmpdir, 'saved_state')

def extract_ema_weights(ema_callback, pl_module, trainer):
ema_callback.swap_model_weights(trainer)
weights = [w.detach().clone() for w in pl_module.state_dict().values()]
ema_callback.swap_model_weights(trainer)
return weights

class TerminateCallback(Callback):
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
self.saved_ema_weights = extract_ema_weights(ema_callback, pl_module, trainer)
self.saved_ema_weights = extract_ema_weights(pl_module, trainer)
self.pl_module_weights = list(pl_module.state_dict().values())
raise SystemExit

Expand Down Expand Up @@ -144,11 +140,10 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu

class CheckStateCallback(Callback):
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
weights = list(pl_module.state_dict().values())
for x, y in zip(weights, terminate_callback.pl_module_weights):
assert torch.allclose(x.cpu(), y.cpu())
current_ema_weights = extract_ema_weights(ema_callback, pl_module, trainer)
current_ema_weights = extract_ema_weights(pl_module, trainer)
for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights):
assert torch.allclose(x.cpu(), y.cpu())

Expand Down Expand Up @@ -214,12 +209,14 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
exp_manager(
trainer,
{
"ema": {"enable": True, "evaluate_ema_weights_instead": True},
"ema": {"enable": True, "validate_original_weights": True},
"explicit_log_dir": str(temp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
with pytest.warns(UserWarning, match="we were unable to find the associated EMA weights when re-loading"):
with pytest.raises(
MisconfigurationException, match="Unable to find the associated EMA weights when re-loading"
):
trainer.fit(model, ckpt_path=resume_path)

@pytest.mark.unit
Expand All @@ -232,70 +229,23 @@ def test_exp_manager_ema_weights(self, tmpdir):
exp_manager(
trainer,
{
"ema": {"enable": True, "evaluate_ema_weights_instead": True},
"ema": {"enable": True, "validate_original_weights": True},
"explicit_log_dir": str(tmp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
assert any(isinstance(callback, EMA) for callback in trainer.callbacks)
trainer.fit(model)
ema_weights = extract_ema_weights(model, trainer)

assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt")
ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt"
assert os.path.exists(ema_path)

duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path))
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_callback._ema_model_weights):
for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_weights):
assert torch.allclose(saved_weight.cpu(), ema_weight.cpu())

@pytest.mark.unit
@pytest.mark.run_only_on('GPU')
def test_ema_save_in_callback(self, tmpdir):
"""Test to ensure when `save_ema_weights_in_callback_state` is enabled, we save to the callback state."""
temp_path = os.path.join(tmpdir, 'saved_state')

model = ExampleModel()

trainer = Trainer(
max_epochs=2,
limit_val_batches=1,
limit_train_batches=16,
logger=False,
val_check_interval=0.5,
enable_checkpointing=False,
accelerator='gpu',
devices=1,
callbacks=[EMA(decay=0.999, save_ema_weights_in_callback_state=True, evaluate_ema_weights_instead=True)],
)
exp_manager(
trainer,
{"explicit_log_dir": str(temp_path), "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},},
)
trainer.fit(model=model)

resume_path = os.path.join(temp_path, "checkpoints/epoch=0-step=8.ckpt")
callback = EMA(decay=0.999, save_ema_weights_in_callback_state=True)

class AssertCallback(Callback):
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert callback._ema_model_weights is not None

model = ExampleModel()

trainer = Trainer(
max_epochs=2,
limit_val_batches=1,
limit_train_batches=16,
logger=False,
val_check_interval=0.5,
enable_checkpointing=False,
accelerator='gpu',
devices=1,
callbacks=[callback, AssertCallback()],
)
trainer.fit(model, ckpt_path=resume_path)


class TestEMATrain:
@pytest.mark.unit
Expand All @@ -314,41 +264,31 @@ class TestEMATrain:
],
)
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2])
@pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False])
@pytest.mark.parametrize("apex_available_mock", [True, False])
@pytest.mark.run_only_on('GPU')
@pytest.mark.parametrize("validate_original_weights", [True, False])
def test_ema_run_cuda(
self,
test_data_dir,
precision,
accumulate_grad_batches,
evaluate_ema_weights_instead,
apex_available_mock,
tmpdir,
self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir,
):
with mock.patch('nemo.collections.common.callbacks.ema.apex_available', apex_available_mock):
self.run_training_test(
accumulate_grad_batches=accumulate_grad_batches,
evaluate_ema_weights_instead=evaluate_ema_weights_instead,
accelerator='gpu',
precision=precision,
tmpdir=tmpdir,
)
self.run_training_test(
accumulate_grad_batches=accumulate_grad_batches,
validate_original_weights=validate_original_weights,
accelerator='gpu',
precision=precision,
tmpdir=tmpdir,
)

@pytest.mark.unit
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2])
@pytest.mark.parametrize("evaluate_ema_weights_instead", [True, False])
@pytest.mark.run_only_on('GPU')
def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, evaluate_ema_weights_instead, tmpdir):
@pytest.mark.parametrize("validate_original_weights", [True, False])
def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, validate_original_weights, tmpdir):
self.run_training_test(
accumulate_grad_batches=accumulate_grad_batches,
evaluate_ema_weights_instead=evaluate_ema_weights_instead,
validate_original_weights=validate_original_weights,
accelerator='cpu',
precision=32,
tmpdir=tmpdir,
)

def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instead, accelerator, precision, tmpdir):
def run_training_test(self, accumulate_grad_batches, validate_original_weights, accelerator, precision, tmpdir):
pl.seed_everything(123)
model = ExampleModel()
trainer = Trainer(
Expand All @@ -367,13 +307,14 @@ def run_training_test(self, accumulate_grad_batches, evaluate_ema_weights_instea
exp_manager(
trainer,
{
"ema": {"enable": True, "evaluate_ema_weights_instead": evaluate_ema_weights_instead},
"ema": {"enable": True, "validate_original_weights": validate_original_weights},
"explicit_log_dir": str(tmpdir),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
# add the check callback after the exp manager has made modifications.
trainer.callbacks.append(EMAAssertCallback())
trainer.callbacks.insert(0, EMAValidationAssertCallback())
trainer.fit(model=model, val_dataloaders=model.train_dataloader())


Expand All @@ -383,16 +324,15 @@ def __init__(self):

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
model_weights = list(pl_module.state_dict().values())
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
for x, y in zip(model_weights, ema_callback._ema_model_weights):
ema_weights = extract_ema_weights(pl_module, trainer)
for x, y in zip(model_weights, ema_weights):
assert torch.allclose(x, y)

def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
# saved for manual calculation of ema to compare against implementation
self._before_calc_ema_weights = deepcopy(ema_callback._ema_model_weights)
self._before_calc_ema_weights = extract_ema_weights(pl_module, trainer)

def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
Expand All @@ -402,29 +342,33 @@ def on_train_batch_end(
return
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
decay = ema_callback.decay
ema_weights = extract_ema_weights(pl_module, trainer)
expected_ema_weights = []
for orig_weight, ema_weight in zip(list(pl_module.state_dict().values()), self._before_calc_ema_weights):
expected_ema_weight = orig_weight * (1 - decay) + ema_weight * decay
expected_ema_weights.append(expected_ema_weight)

for actual_ema_weight, expected_ema_weight in zip(ema_callback._ema_model_weights, expected_ema_weights):
for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights):
assert torch.allclose(actual_ema_weight, expected_ema_weight)


class EMAValidationAssertCallback(Callback):
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
if ema_callback.evaluate_ema_weights_instead:
# todo (sean): shouldn't use the weights buffer to check original weights
self._original_weights = list(x.detach().clone() for x in ema_callback._weights_buffer)
self._original_weights = list(pl_module.state_dict().values())
self._ema_weights = extract_ema_weights(pl_module, trainer)
# call original EMA function
super().on_validation_start(trainer, pl_module)
if not ema_callback.validate_original_weights:
if ema_callback._ema_initialized:
for ema_weights, module_weights in zip(
ema_callback._ema_model_weights, pl_module.state_dict().values()
):
# check model weights are now EMA weights
for ema_weights, module_weights in zip(self._ema_weights, pl_module.state_dict().values()):
torch.allclose(ema_weights, module_weights)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
if ema_callback.evaluate_ema_weights_instead:
if not ema_callback.validate_original_weights:
model_weights = list(pl_module.state_dict().values())
if ema_callback._ema_initialized:
for orig_weights, module_weights in zip(self._original_weights, model_weights):
torch.allclose(orig_weights, module_weights.cpu())
torch.allclose(orig_weights.cpu(), module_weights.cpu())

0 comments on commit a4b3e53

Please sign in to comment.