Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix new bucket when param require new bucket #762

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@ def _create_new_bucket(data_end_index: int) -> int:
# and skip parameters that don't require gradients.
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel

def _does_param_require_new_bucket(param):
"""
Expand All @@ -301,11 +299,13 @@ def _does_param_require_new_bucket(param):
if _does_param_require_new_bucket(param) and len(bucket_params) > 0:
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current data_start_index.
data_start_index = _create_new_bucket(data_start_index)
if use_distributed_optimizer:
# data_start_index should already be padded.
assert data_start_index % self.data_parallel_world_size == 0
_create_new_bucket(data_start_index)

this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
self.param_index_map[param] = (
data_start_index,
data_end_index,
Expand Down