Skip to content

Commit

Permalink
add _del_model_without_trainer
Browse files Browse the repository at this point in the history
Signed-off-by: ericharper <complex451@gmail.com>
  • Loading branch information
ericharper committed Oct 11, 2021
1 parent d88241a commit 4406475
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
4 changes: 2 additions & 2 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:
# TODO: add validation generation config, size of context, max sequence length
# needed to initialize megatron
micro_batch_size: 8
tensor_model_parallel_size: 1
tensor_model_parallel_size: 2
max_position_embeddings: 1024
encoder_seq_length: 1024
num_layers: 24
Expand Down Expand Up @@ -66,7 +66,7 @@ model:
optim:
name: adam
lr: 0.0001
weight_decay: 0.01
#weight_decay: 0.01
betas:
- 0.9
- 0.98
Expand Down
18 changes: 16 additions & 2 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def nemo_topk_check_previous_run(self):
for _ in range(models_to_delete):
model = best_k_models.pop(-1)
self.best_k_models.pop(model)
self._del_model(model)
self._del_model_without_trainer(model)
logging.debug(f"Removed checkpoint: {model}")

self.kth_best_model_path = best_k_models[-1]
Expand Down Expand Up @@ -776,6 +776,20 @@ def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
else:
return super()._del_model(trainer, filepath)

def _del_model_without_trainer(self, filepath: str) -> None:
app_state = AppState()
if app_state.model_parallel_size is not None:
# filepath needs to be updated to include mp_rank
dirname = os.path.dirname(filepath)
basename = os.path.basename(filepath)
filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'

# each model parallel rank needs to remove its model
if app_state.data_parallel_rank is None or app_state.data_parallel_rank == 0:
if self._fs.exists(filepath):
self._fs.rm(filepath)
logging.info(f"Removed checkpoint: {filepath}")

def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[str, _METRIC]) -> None:
""" Overrides PTL method to account for model parallel checkpoints.
Checks for data parallel rank 0 rather than global rank 0.
Expand Down Expand Up @@ -817,7 +831,7 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate
and self.best_model_path != filepath
and app_state.data_parallel_rank == 0
):
self._del_model(self.best_model_path)
self._del_model(trainer, self.best_model_path)

self.best_model_path = filepath
else:
Expand Down

0 comments on commit 4406475

Please sign in to comment.