From 4fc5385c89ec9bc14cd06ff0e794976eaf471800 Mon Sep 17 00:00:00 2001 From: Jocelyn Date: Thu, 13 Oct 2022 19:04:58 -0700 Subject: [PATCH] Fix loading checkpoints from dirpath if it's specified (#5155) * Fix loading checkpoints from dirpath if it's specified Signed-off-by: Jocelyn Huang * Add unit tests for loading with dirpath Signed-off-by: Jocelyn Huang Signed-off-by: Jocelyn Huang Co-authored-by: Eric Harper --- nemo/utils/exp_manager.py | 20 ++++++++++++++++---- tests/core/test_exp_manager.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 30ea789f538f..4e15943b5e2e 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -281,7 +281,17 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo ) if cfg.resume_if_exists: - check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) + # Check for existing checkpoints in `dirpath` if it's specified, use /checkpoints otherwise + if cfg.checkpoint_callback_params.dirpath: + check_resume( + trainer, + log_dir, + cfg.resume_past_end, + cfg.resume_ignore_no_checkpoint, + cfg.checkpoint_callback_params.dirpath, + ) + else: + check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) checkpoint_name = name # If name returned from get_log_dir is "", use cfg.name for checkpointing @@ -426,13 +436,14 @@ def check_resume( log_dir: str, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False, + dirpath: str = None, ): """Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets trainer._checkpoint_connector.resume_from_checkpoint_fit_path as necessary. Returns: - log_dir (Path): the log_dir - exp_dir (str): the base exp_dir without name nor version + log_dir (Path): The log_dir + exp_dir (str): The base exp_dir without name nor version name (str): The name of the experiment version (str): The version of the experiment @@ -444,7 +455,8 @@ def check_resume( if not log_dir: raise ValueError(f"Resuming requires the log_dir {log_dir} to be passed to exp_manager") - checkpoint_dir = Path(Path(log_dir) / "checkpoints") + # Use /checkpoints/ unless `dirpath` is set + checkpoint_dir = Path(dirpath) if dirpath else Path(Path(log_dir) / "checkpoints") checkpoint = None end_checkpoints = list(checkpoint_dir.rglob("*end.ckpt")) diff --git a/tests/core/test_exp_manager.py b/tests/core/test_exp_manager.py index 6d84337958b0..02053122c64c 100644 --- a/tests/core/test_exp_manager.py +++ b/tests/core/test_exp_manager.py @@ -345,6 +345,38 @@ def test_resume(self, tmp_path): prev_log = Path(tmp_path / "test_resume" / "default" / "version_0" / "run_0" / "lightning_logs.txt") assert prev_log.exists() + # Error becasue `dirpath` specified and has no checkpoint + test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False) + dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts") + dirpath_checkpoint_dir.mkdir(parents=True) + with pytest.raises(NotFoundError): + exp_manager( + test_trainer, + { + "resume_if_exists": True, + "checkpoint_callback_params": {"dirpath": str(dirpath_checkpoint_dir)}, + "explicit_log_dir": str(log_dir), + }, + ) + + # Check that model loads from `dirpath` and not /checkpoints + dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs") + dirpath_log_dir.mkdir(parents=True) + dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last.ckpt") + dirpath_checkpoint.touch() + exp_manager( + test_trainer, + { + "resume_if_exists": True, + "checkpoint_callback_params": {"dirpath": str(dirpath_checkpoint_dir)}, + "explicit_log_dir": str(dirpath_log_dir), + }, + ) + assert ( + Path(test_trainer._checkpoint_connector.resume_from_checkpoint_fit_path).resolve() + == dirpath_checkpoint.resolve() + ) + @pytest.mark.unit def test_nemo_checkpoint_save_best_model_1(self, tmp_path): test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=4)