Skip to content

[LTX2] resolve flash attention block size mismatch and missing config overrides#382

Merged
copybara-service[bot] merged 1 commit intomainfrom
ltx2_bugfix
Apr 21, 2026
Merged

[LTX2] resolve flash attention block size mismatch and missing config overrides#382
copybara-service[bot] merged 1 commit intomainfrom
ltx2_bugfix

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

@mbohlool mbohlool commented Apr 17, 2026

What does this PR do?
This PR fixes a crash occurring in LTX-2 when using specific frame counts (e.g., num_frames=121) with Flash Attention, and fixes a pipeline bug that prevented users from manually overriding the cross-attention kernels.

The Root Causes & Fixes:

  1. Ignored Kernel Config Overrides (ltx2_pipeline.py)

Bug: Passing a2v_attention_kernel=dot_product via CLI or YAML had no effect. The pipeline was only mapping the main attention config, dropping the cross-attention kernel parameters before initializing the transformer.

Fix: Added mapping for a2v_attention_kernel and v2a_attention_kernel into ltx2_config inside create_sharded_logical_transformer so user overrides are respected.

  1. Flash Attention Block Size Mismatch (attention_flax.py)

Bug: Generating 121 frames results in 126 audio latent tokens. A previous PR correctly padded this sequence from 126 to 128 to satisfy shard_map context chunking requirements. However, _tpu_flash_attention was calling _select_flash_block_sizes before the padding occurred. Because of the min() bounds used for cross-attention optimization, the block size was calculated as 126. This resulted in passing a padded sequence of 128 to the Splash Attention kernel but telling it to use a block size of 126, crashing because 128 % 126 != 0.

Fix: Swapped the order of operations in _tpu_flash_attention. Sequences are now padded by _reshape_data_for_flash before block sizes are calculated. This ensures _select_flash_block_sizes sees the padded shape, correctly calculating a divisible block size without removing the memory optimizations needed for the cross-attention unit tests to pass.

Testing
Verified that running generation with num_frames=121 executes cleanly on TPU with Flash Attention enabled.

Verified pytest src/maxdiffusion/tests/attention_test.py passes.

@mbohlool mbohlool requested a review from entrpn as a code owner April 17, 2026 21:08
@github-actions
Copy link
Copy Markdown

Comment thread src/maxdiffusion/models/attention_flax.py Outdated
…nfig overrides

This commit addresses two issues in the LTX-2 pipeline:

1. Pipeline Config Overrides:
Fixed a bug in `ltx2_pipeline.py` where `a2v_attention_kernel` and `v2a_attention_kernel` configurations were ignored. The model previously hardcoded a fallback to "flash" because these values were not mapped from the user config to `ltx2_config`.

2. Flash Attention Padding Mismatch:
Fixed a `ValueError` (e.g., `kv_block_size=126 should divide kv_seq_len=128`) in `attention_flax.py` that occurred for specific video frame counts. A previous fix padded sequences to satisfy `shard_map` context dimension requirements, but `_select_flash_block_sizes` was calculating block sizes based on the unpadded length. Moved the block size calculation to occur *after* `_reshape_data_for_flash` so that the dynamic `min()` bounds correctly align with the newly padded sequence lengths, keeping cross-attention optimizations intact and unit tests passing.
@copybara-service copybara-service Bot merged commit 0b6410b into main Apr 21, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants