Skip to content

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Oct 15, 2025

Blocked on #453.

Post Megatron-Bridge migration, we can now do a bucketed gathering of parameters, with precomputed size metadata, instead of iterating through parameters one by one. We then flatten the bucket into a single tensor, send over the metadata, and recover the weights/shapes on the inference engine side.

For Qwen3-30B-A3B, tp=2, ep=8, etp=1, 8xh100:

Before (48s):
image

After bucketing gather (45s):
image

After flattening cuda ipc (and removing redundant torch.device() calls) - 6s:
image

@erictang000 erictang000 changed the title [megatron] improving weight syncing part 1 - bucketed parameter gathering [megatron] improving weight syncing - bucketed param gather + cuda ipc flattening Oct 16, 2025
@erictang000 erictang000 marked this pull request as ready for review November 21, 2025 18:28
Comment on lines 34 to 36
sizes: List[int]
extras: Optional[List[Dict[str, Any]]]
packed: bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current PR will break compatibility for FSDP and DeepSpeed since they don't send these arguments.

You should probably update the "sizes" to be a NotRequired argument. And you should update Fsdp and deepspeed worker files to send packed=False.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, updated and tested both

request["names"],
request["dtypes"],
request["shapes"],
request["sizes"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an optional entry. Should use safer .get

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

names: List[str]
dtypes: List[str]
shapes: List[List[int]]
sizes: Optional[List[int]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is NotRequired . Optional means that hte field exists but is None

from typing import NotRequired

shapes: List[List[int]]
sizes: Optional[List[int]]
extras: Optional[List[Dict[str, Any]]]
packed: Optional[bool]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this


offset = 0
for name, shape, size in zip(names, shapes, sizes):
weight_list.append((name, packed_tensor[offset : offset + size].view(*shape)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype assert for this?

Copy link
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have left some minor comments. Please take a look

@erictang000 erictang000 merged commit 8c7ba7e into NovaSky-AI:main Nov 22, 2025
3 checks passed
@erictang000 erictang000 deleted the weight_syncing branch November 22, 2025 00:06
li-boxuan pushed a commit to li-boxuan/SkyRL that referenced this pull request Nov 23, 2025
…c flattening (NovaSky-AI#487)

Blocked on NovaSky-AI#453.

Post Megatron-Bridge migration, we can now do a bucketed gathering of
parameters, with precomputed size metadata, instead of iterating through
parameters one by one. We then flatten the bucket into a single tensor,
send over the metadata, and recover the weights/shapes on the inference
engine side.

For Qwen3-30B-A3B, tp=2, ep=8, etp=1, 8xh100:

Before (48s):
<img width="739" height="14" alt="image"
src="https://github.com/user-attachments/assets/d90a2091-f36b-445e-898d-81c5019ea8a4"
/>

After bucketing gather (45s):
<img width="754" height="24" alt="image"
src="https://github.com/user-attachments/assets/5c0c725f-aafc-4989-9f77-c8a06f95e942"
/>

After flattening cuda ipc (and removing redundant torch.device() calls)
- 6s:
<img width="711" height="19" alt="image"
src="https://github.com/user-attachments/assets/11a0be18-08cb-460a-9c3c-06ab06745b45"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants