[FSDP] Support Context parallelism for FSDP using ring-flash-attn#467
[FSDP] Support Context parallelism for FSDP using ring-flash-attn#467zhuzilin merged 18 commits intoTHUDM:mainfrom
Conversation
a8ad2ec to
4948643
Compare
88afbd4 to
f32b57f
Compare
653578b to
6d01709
Compare
35c2578 to
346abce
Compare
|
Really great process and hope you learned a lot from this process. Shall we post a blog on the journey of CP in awesome-ml-sys? Also, your job opportunity shall always be the first. Really glad to see your resolution and great improvement. Hope for the best. |
|
🐂🍺 |
a964ed1 to
a41f7c1
Compare
slime/backends/fsdp_utils/actor.py
Outdated
| world_size = dist.get_world_size() | ||
| rank = dist.get_rank() | ||
|
|
||
| if self.args.enable_cp: |
There was a problem hiding this comment.
we can use the self.args.context_parallel_size directly. And we don't need to separate the mesh init for cp size > 1.
slime/backends/fsdp_utils/actor.py
Outdated
| ) | ||
| logits = self.model(**model_args).logits.squeeze(0) | ||
| if self.args.enable_cp: | ||
| log_probs_result, entropy_result = get_chunked_logp_and_entropy( |
There was a problem hiding this comment.
please merge the with and without cp implemtation into one.
slime/backends/fsdp_utils/actor.py
Outdated
| rank = dist.get_rank() | ||
|
|
||
| rollout_data = process_rollout_data(self.args, rollout_data_ref, rank, world_size) | ||
| dp_rank = self.dp_rank if self.args.enable_cp else rank |
There was a problem hiding this comment.
we can always use dp_rank.
slime/backends/fsdp_utils/actor.py
Outdated
| ).logits.squeeze(0) | ||
|
|
||
| # Gather logits from all CP ranks if CP is enabled (with gradient support) | ||
| if self.args.enable_cp: |
There was a problem hiding this comment.
similar as the comment above, please merge the 2 branch.
| if tokens[idx].item() == 0: | ||
| pad_length += 1 | ||
| else: | ||
| break |
There was a problem hiding this comment.
I think that we can re-calculate the pad length instead of a for loop
|
Thank you so much for this! |
|
牛逼! |
…UDM#467) Co-authored-by: Hecate0821 <hec4te0821@gmail.com>
…UDM#467) Co-authored-by: Hecate0821 <hec4te0821@gmail.com>

Try to solve #294 using ring-flash-attn with datapacking
How to use
Detailed Development Tracking
Result
Compare to main branch


Almost match with main branch
script changed: