Skip to content

Commit

Permalink
gather on rank 0 to avoid oom
Browse files Browse the repository at this point in the history
  • Loading branch information
KaiLv69 committed Aug 23, 2023
1 parent f434afa commit ac6eed4
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions collie/controller/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,14 @@ def save_model(self, path: str, process_exclusion: bool = False,
self._checkpoint_prologue()
for name, param in self.engine.module.named_parameters():
with deepspeed.zero.GatheredParameters(param):
state_dict[name] = param.detach().cpu()
if env.dp_rank == 0:
state_dict[name] = param.detach().cpu()
self._checkpoint_epilogue()
else:
state_dict = self.engine.module.state_dict()
if env.dp_rank == 0:
state_dict = self.engine.module.state_dict()
else:
state_dict = {}
self.engine.module.save_parallel_state_dict(
state_dict=state_dict,
path=path,
Expand Down

0 comments on commit ac6eed4

Please sign in to comment.