Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 23, 2021
1 parent 807f223 commit 37244b5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
3 changes: 2 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
print(checkpoint["callbacks"])
assert checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss]"] == early_stop_callback_state

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss')
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def on_train_epoch_start(self):
raise KeyboardInterrupt

checker = set()
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction) if not m.startswith("_")]
hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks}
hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint")

Expand Down
10 changes: 5 additions & 5 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def configure_optimizers(self):
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == limit_train_batches * (epoch + 1)

mc_specific_data = chk['callbacks']["ModelCheckpoint"]
mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
Expand Down Expand Up @@ -269,7 +269,7 @@ def _make_assertions(epoch, ix, add=''):
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
assert chk['global_step'] == expected_global_step

mc_specific_data = chk['callbacks']["ModelCheckpoint"]
mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score
Expand Down Expand Up @@ -870,8 +870,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
ckpt_last = torch.load(path_last)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

ch_type = "ModelCheckpoint"
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
ckpt_id = "ModelCheckpoint[monitor=early_stop_on]"
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
Expand Down Expand Up @@ -1128,7 +1128,7 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
ckpts = [ckpt["callbacks"]["ModelCheckpoint[monitor=foo]"] for ckpt in ckpts]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]


Expand Down

0 comments on commit 37244b5

Please sign in to comment.