Skip to content

fix: enable ulysses sharding for custom kernels and improve scaling precision#396

Merged
copybara-service[bot] merged 1 commit intomainfrom
fix-ulysses-custom
May 4, 2026
Merged

fix: enable ulysses sharding for custom kernels and improve scaling precision#396
copybara-service[bot] merged 1 commit intomainfrom
fix-ulysses-custom

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 3, 2026

Description

This PR introduces two small but important fixes to the Ulysses attention implementation:

  1. Enable Ulysses sharding for custom kernels: In src/maxdiffusion/pyconfig.py, the check for Ulysses attention is changed from attention == "ulysses" to "ulysses" in attention. This ensures that custom attention implementations that include "ulysses" in their identifier (e.g., custom_ulysses) will correctly trigger Ulysses sequence sharding instead of falling back to default sharding strategies.
  2. Ensure correct padding: In src/maxdiffusion/models/attention_flax.py, padding was done for query variable but query_scaled was used in the attention calculation. This fixes it and ensures the padded variable is used in the attention calculation.
  3. Improve scaling precision: In src/maxdiffusion/models/attention_flax.py, the hardcoded constant 1.44269504 used to scale queries for base-2 exponentiation is replaced with math.log2(math.e). This provides better precision and makes the intent of the code clearer.

Generation Time

  • Main: ~420s
  • Branch: ~140s

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 3, 2026 20:27
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 3, 2026

@Perseus14 Perseus14 changed the title Fix: Ulysses custom attention now uses ulysses sharding fix: enable ulysses sharding for custom kernels and improve scaling precision May 4, 2026
@Perseus14 Perseus14 requested a review from eltsai May 4, 2026 06:56
@eltsai
Copy link
Copy Markdown
Collaborator

eltsai commented May 4, 2026

Thanks for fixing this @Perseus14 ! I tested on v6e (don't have v7x for now due to capacity) and the generation time is 204 sec:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):      97.3s
  Compile:               229.8s
  ────────────────────────────────────────
  Inference:             204.0s
  Conditioning:           11.8s
  Denoise Total:         183.6s
  VAE Decode:              8.6s
==================================================

I think the generation time is expected, because on v7x we are seeing 28% speed boost (140 vs 194.4). From go/wan-dashboard the e2e gen time is 322 sec, the speed gain is about 36% percent.

@copybara-service copybara-service Bot merged commit 71b4138 into main May 4, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants