Skip to content

Commit 3bdcbbb

Browse files
shjwudpko3n1g
authored andcommitted
ADLR/megatron-lm!1923 - Fix DDP scaling factor with Context Parallel
Co-authored-by: Jianbin Chang <shjwudp@gmail.com>
1 parent ce2b519 commit 3bdcbbb

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

megatron/core/distributed/distributed_data_parallel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from contextlib import contextmanager
5-
from typing import Dict, Optional
5+
from typing import Dict
66

77
import torch
88

@@ -114,7 +114,9 @@ def allocate_buffers_for_parameters(
114114
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
115115

116116
if not config.calculate_per_token_loss:
117-
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size()
117+
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size(
118+
with_context_parallel=True
119+
)
118120
if self.ddp_config.average_in_collective:
119121
# Collective is averaging gradients in collective with data_parallel_group.
120122
assert (
@@ -155,7 +157,9 @@ def allocate_buffers_for_parameters(
155157
1.0 / parallel_state.get_expert_model_parallel_world_size()
156158
)
157159
else:
158-
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
160+
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
161+
with_context_parallel=True
162+
)
159163
gradient_scaling_factor = 1.0 / data_parallel_world_size
160164
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
161165

0 commit comments

Comments
 (0)