Skip to content

Commit

Permalink
delay sharded_state_dict in save_to
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Li committed Jun 7, 2024
1 parent 791a037 commit 77f5191
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,6 @@ def save_to(self, model, save_path: str):
if dist_ckpt:
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))

sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if not parallel_state.is_initialized():

Expand All @@ -954,6 +952,7 @@ def dummy():
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
sharded_state_dict = model.sharded_state_dict()
checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr'))
checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)

Expand Down

0 comments on commit 77f5191

Please sign in to comment.