Skip to content

Commit

Permalink
Fix async_request attribute
Browse files Browse the repository at this point in the history
Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com>
  • Loading branch information
mikolajblaz committed Apr 24, 2024
1 parent 9e16390 commit 99d515a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions nemo/utils/callbacks/torch_dist_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TorchDistAsyncSaveShardedStrategy(TorchDistSaveShardedStrategy):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.save_and_finalize_callbacks = None
self.async_request = None

def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format.
Expand All @@ -69,8 +69,8 @@ def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
None,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica),
)
self.save_and_finalize_callbacks = self._get_save_and_finalize_callbacks(writer, save_state_dict_ret)
return self.save_and_finalize_callbacks
self.async_request = self._get_save_and_finalize_callbacks(writer, save_state_dict_ret)
return self.async_request

def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret):
save_fn_args = writer.get_save_function_and_args()
Expand Down

0 comments on commit 99d515a

Please sign in to comment.