Skip to content

Commit

Permalink
Fix loading checkpoints from dirpath if it's specified (#5155)
Browse files Browse the repository at this point in the history
* Fix loading checkpoints from dirpath if it's specified

Signed-off-by: Jocelyn Huang <jocelynh@nvidia.com>

* Add unit tests for loading with dirpath

Signed-off-by: Jocelyn Huang <jocelynh@nvidia.com>

Signed-off-by: Jocelyn Huang <jocelynh@nvidia.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
redoctopus and ericharper committed Oct 14, 2022
1 parent 1ec21dd commit 4fc5385
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
20 changes: 16 additions & 4 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,17 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
)

if cfg.resume_if_exists:
check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint)
# Check for existing checkpoints in `dirpath` if it's specified, use <log_dir>/checkpoints otherwise
if cfg.checkpoint_callback_params.dirpath:
check_resume(
trainer,
log_dir,
cfg.resume_past_end,
cfg.resume_ignore_no_checkpoint,
cfg.checkpoint_callback_params.dirpath,
)
else:
check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint)

checkpoint_name = name
# If name returned from get_log_dir is "", use cfg.name for checkpointing
Expand Down Expand Up @@ -426,13 +436,14 @@ def check_resume(
log_dir: str,
resume_past_end: bool = False,
resume_ignore_no_checkpoint: bool = False,
dirpath: str = None,
):
"""Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets
trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary.
Returns:
log_dir (Path): the log_dir
exp_dir (str): the base exp_dir without name nor version
log_dir (Path): The log_dir
exp_dir (str): The base exp_dir without name nor version
name (str): The name of the experiment
version (str): The version of the experiment
Expand All @@ -444,7 +455,8 @@ def check_resume(
if not log_dir:
raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager")

checkpoint_dir = Path(Path(log_dir) / "checkpoints")
# Use <log_dir>/checkpoints/ unless `dirpath` is set
checkpoint_dir = Path(dirpath) if dirpath else Path(Path(log_dir) / "checkpoints")

checkpoint = None
end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt"))
Expand Down
32 changes: 32 additions & 0 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,38 @@ def test_resume(self, tmp_path):
prev_log = Path(tmp_path / "test_resume" / "default" / "version_0" / "run_0" / "lightning_logs.txt")
assert prev_log.exists()

# Error becasue `dirpath` specified and has no checkpoint
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)
dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts")
dirpath_checkpoint_dir.mkdir(parents=True)
with pytest.raises(NotFoundError):
exp_manager(
test_trainer,
{
"resume_if_exists": True,
"checkpoint_callback_params": {"dirpath": str(dirpath_checkpoint_dir)},
"explicit_log_dir": str(log_dir),
},
)

# Check that model loads from `dirpath` and not <log_dir>/checkpoints
dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs")
dirpath_log_dir.mkdir(parents=True)
dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last.ckpt")
dirpath_checkpoint.touch()
exp_manager(
test_trainer,
{
"resume_if_exists": True,
"checkpoint_callback_params": {"dirpath": str(dirpath_checkpoint_dir)},
"explicit_log_dir": str(dirpath_log_dir),
},
)
assert (
Path(test_trainer._checkpoint_connector.resume_from_checkpoint_fit_path).resolve()
== dirpath_checkpoint.resolve()
)

@pytest.mark.unit
def test_nemo_checkpoint_save_best_model_1(self, tmp_path):
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=4)
Expand Down

0 comments on commit 4fc5385

Please sign in to comment.