Skip to content

Commit

Permalink
Set tp world size to 1 in ckpt load, if MPU is not provided (microsof…
Browse files Browse the repository at this point in the history
…t#5243)

If MPU is not provided, set the tp world size to 1 when loading the
(universal) ckpt.
  • Loading branch information
samadejacobs committed Mar 8, 2024
1 parent 74910a9 commit 535a908
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,7 +2303,11 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
self._load_global_state(optim_sd)

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
if self.mpu is None:
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
Expand Down

0 comments on commit 535a908

Please sign in to comment.