diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 8ca010e59f70..aa21f85c5b1e 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -943,8 +943,6 @@ def save_to(self, model, save_path: str): if dist_ckpt: # model weights is a directory dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt)) - - sharded_state_dict = model.sharded_state_dict() # dist checkpoint needs torch.distributed to save the checkpoint if not parallel_state.is_initialized(): @@ -954,6 +952,7 @@ def dummy(): if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() + sharded_state_dict = model.sharded_state_dict() checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr')) checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir)