Skip to content

[draft] support async param gather in layer-wise optimizer#2787

Draft
FDecaYed wants to merge 1 commit intoNVIDIA:devfrom
FDecaYed:deyuf/layerwise_async_gather
Draft

[draft] support async param gather in layer-wise optimizer#2787
FDecaYed wants to merge 1 commit intoNVIDIA:devfrom
FDecaYed:deyuf/layerwise_async_gather

Conversation

@FDecaYed
Copy link
Contributor

@FDecaYed FDecaYed commented Dec 31, 2025

What does this PR do ?

This is draft code now, so likely doesn't run. It's meant for demo how to support the feature before proper implementation.

Current Architecture

DistributedOptimizer implements async param gather with these components:

  • Bucket-based param organization - Parameters are grouped into buckets within _ParamAndGradBucketGroup
  • start_param_sync() - Launches async all-gather using torch.distributed.all_gather with async_op=True
  • finish_param_sync() - Waits for completion and dispatches the next bucket's all-gather
  • Forward pre-hooks - Registered on modules to call finish_param_sync() when a module needs its params
  • assume even sharding and allgather directly into tensor(param buffer)

Changes needed for layer-wise

  • let each bucket know which param is on which rank, and group them into lists(lw_params_list in addition to params_list that each bucket already holds, think it as per-bucket layerwise sharding)
  • change all_gather_into_tensor to all_gather(_v)
  • copy from all_gatherv output into param inside finish_param_sync()

@FDecaYed FDecaYed requested review from a team as code owners December 31, 2025 07:17
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 31, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@FDecaYed FDecaYed marked this pull request as draft December 31, 2025 07:17
@github-actions github-actions bot requested a review from Phlip79 December 31, 2025 07:17
async_op=async_op,
)
else:
assert async_op, "Layer-wise optimizer requires overlap_param_gather=True"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does async allgather with layerwise still require use-distributed-optimizer? I think yes, to let DDP make the buckets right?

Copy link
Contributor Author

@FDecaYed FDecaYed Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for simplicity to demo the idea, I didn't touch that part. currently if use-distributed-optimizer is off, then all async related functionality will be turned off and code errors out.

But technically this is not required, we just need to change those check from if use-distributed-optimizer to if (use-distributed-optimizer or use-layer-wise)

Copy link
Contributor

@mkhona-nvidia mkhona-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another feature is that the current layerwise allows reshardable EP sizes in the middle of training. AFAIK, using adam's distopt does not allow this because of the way the DDP buckets are constructed. Can we avoid this and continue to allow reshardable EP sizes in the middle of training with layerwise?

pg_collection: Optional[ProcessGroupCollection] = None,
init_state_fn_list: Optional[List[Callable]] = None,
model_chunks: Optional[List] = None,
async_allgather: Optional[bool] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reuse "overlap_param_gather" if the flag indicates the same thing, no need to introduce new names.

@FDecaYed
Copy link
Contributor Author

FDecaYed commented Jan 4, 2026

Another feature is that the current layerwise allows reshardable EP sizes in the middle of training. AFAIK, using adam's distopt does not allow this because of the way the DDP buckets are constructed. Can we avoid this and continue to allow reshardable EP sizes in the middle of training with layerwise?

my rough feeling is it should still be supported. the async feature is just moving param allgather from within optimizer.step into forward pre-hook calling bucket function. Checkpointing should not be affected

@mkhona-nvidia
Copy link
Contributor

Another feature is that the current layerwise allows reshardable EP sizes in the middle of training. AFAIK, using adam's distopt does not allow this because of the way the DDP buckets are constructed. Can we avoid this and continue to allow reshardable EP sizes in the middle of training with layerwise?

my rough feeling is it should still be supported. the async feature is just moving param allgather from within optimizer.step into forward pre-hook calling bucket function. Checkpointing should not be affected

I guess the issue is that we don't have a good understanding of why adamW distopt cannot allow resharding of EP in the middle of training. My feeling is that it is because of the way the DDP buckets are made (separate bucketing scheme for EP=true and separate bucketing scheme for rest of the network). If the layerwise distopt overlap_param_gather uses the same scheme again, I think we will run into the same issue

@FDecaYed
Copy link
Contributor Author

FDecaYed commented Jan 5, 2026

Another feature is that the current layerwise allows reshardable EP sizes in the middle of training. AFAIK, using adam's distopt does not allow this because of the way the DDP buckets are constructed. Can we avoid this and continue to allow reshardable EP sizes in the middle of training with layerwise?

my rough feeling is it should still be supported. the async feature is just moving param allgather from within optimizer.step into forward pre-hook calling bucket function. Checkpointing should not be affected

I guess the issue is that we don't have a good understanding of why adamW distopt cannot allow resharding of EP in the middle of training. My feeling is that it is because of the way the DDP buckets are made (separate bucketing scheme for EP=true and separate bucketing scheme for rest of the network). If the layerwise distopt overlap_param_gather uses the same scheme again, I think we will run into the same issue

I think the EP buckets are separate from regular DP weight no matter if dist-opt is on or not. and async-gather doesn't affect it either.
Here is my guess why dist-opt doesn't support change EP(haven't check code recently, from memory so guess):
there are 2 view of the param buffer, the buffer view(every thing flattened) and the param view(list of param, with their .data backed by the cont buffer). the way dist-opt support torch_dist ckpting is just save the 'buffer view' as it is, which is tightly coupled with EP and cannot be changed. but layerwise safe by each parameter(param view), thus is not affected.
and reason dist-opt have to save it that way is each master parameter/state can be sharded to multiple rank. this is not a problem for layerwise

mchrzanowski added a commit to mchrzanowski/Megatron-LM that referenced this pull request Feb 20, 2026
Integrate async param all-gather from upstream PR NVIDIA#2787 so that
dist_muon/dist_mop can overlap parameter all-gather with forward
compute via DDP's existing bucket and forward-pre-hook infrastructure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants