Skip to content

Add Context parallelism to Wan 2.1#200

Merged
entrpn merged 24 commits intomainfrom
wan_context_parallelism_inference
Jul 15, 2025
Merged

Add Context parallelism to Wan 2.1#200
entrpn merged 24 commits intomainfrom
wan_context_parallelism_inference

Conversation

@entrpn
Copy link
Copy Markdown
Collaborator

@entrpn entrpn commented Jul 9, 2025

  • Adds context parallelism to flash attention fn.
  • Adds better sharding constraints when reshaping and padding to reduce AGs.

@entrpn entrpn requested a review from susanbao July 9, 2025 20:29
Comment thread src/maxdiffusion/max_utils.py Outdated
if multi_slice_env:
dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN")
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices)
mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also discovered that cloud tpus, don't have "slice_index" attribute which is used in create_hybrid_mesh and also couple lines up to determine dcn or not, maybe we should pass process_is_granule=True to create_hybrid_device_mesh
image

Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py
@coolkp
Copy link
Copy Markdown
Collaborator

coolkp commented Jul 14, 2025

Awesome PR ! excited for the perf gains!

@entrpn entrpn merged commit 4a6f807 into main Jul 15, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants