File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
megatron/core/distributed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change 22
33import logging
44from contextlib import contextmanager
5- from typing import Dict , Optional
5+ from typing import Dict
66
77import torch
88
@@ -114,7 +114,9 @@ def allocate_buffers_for_parameters(
114114 param_and_grad_dtype_to_params [(param_dtype , grad_dtype )] = params
115115
116116 if not config .calculate_per_token_loss :
117- target_gradient_scaling_factor = 1.0 / parallel_state .get_data_parallel_world_size ()
117+ target_gradient_scaling_factor = 1.0 / parallel_state .get_data_parallel_world_size (
118+ with_context_parallel = True
119+ )
118120 if self .ddp_config .average_in_collective :
119121 # Collective is averaging gradients in collective with data_parallel_group.
120122 assert (
@@ -155,7 +157,9 @@ def allocate_buffers_for_parameters(
155157 1.0 / parallel_state .get_expert_model_parallel_world_size ()
156158 )
157159 else :
158- data_parallel_world_size = parallel_state .get_data_parallel_world_size ()
160+ data_parallel_world_size = parallel_state .get_data_parallel_world_size (
161+ with_context_parallel = True
162+ )
159163 gradient_scaling_factor = 1.0 / data_parallel_world_size
160164 expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
161165
You can’t perform that action at this time.
0 commit comments