Skip to content

Commit

Permalink
Merge branch 'fix_fp8_pp_hang_bug' into 'main'
Browse files Browse the repository at this point in the history
Fix pipeline parallel hang under FP8

See merge request ADLR/megatron-lm!769
  • Loading branch information
jaredcasper committed Sep 8, 2023
2 parents 5385e53 + 9a14c4c commit dc21350
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,11 @@ def forward(self, hidden_states, attention_mask, inference_params=None, rotary_p
amax_history_len=self.config.fp8_amax_history_len,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group()
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
Expand Down

0 comments on commit dc21350

Please sign in to comment.