Skip to content

[None][feat] Add Attention2D sequence parallelism for visual-gen models#12943

Merged
NVShreyas merged 15 commits intoNVIDIA:mainfrom
venmugil:attn2d
May 4, 2026
Merged

[None][feat] Add Attention2D sequence parallelism for visual-gen models#12943
NVShreyas merged 15 commits intoNVIDIA:mainfrom
venmugil:attn2d

Conversation

@venmugil
Copy link
Copy Markdown
Collaborator

@venmugil venmugil commented Apr 11, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Attention2DAttention, a 2D-mesh context parallelism strategy for visual-gen diffusion models (FLUX, Wan, LTX2), based on https://arxiv.org/pdf/2503.15758
    • New --dit_attn2d_row_size / --dit_attn2d_col_size CLI options; default 1 (disabled)
    • Enhanced log-sum-exp output in the Flash Attention 4 backend to support partial-output combining across the row process group
  • Improvements

    • Unified Attention2D under cp_size alongside ring attention — both are context parallelism strategies, mutually exclusive, sharing the same cp_size slot in VisualGenMapping
    • Validated mutual exclusivity: Ulysses + Attention2D raises NotImplementedError; Attention2D + ring raises ValueError
  • Tests

    • Added comprehensive distributed multi-GPU tests for Attention2D functionality and mesh variants

Description

Adds Attention2DAttention, a 2D-mesh sequence parallelism strategy for visual-gen diffusion models (FLUX, Wan, LTX2), based on the approach described in arxiv.org/pdf/2503.15758.

Unlike Ulysses, which requires the number of attention heads to be divisible by the parallelism degree, Attention2D distributes the sequence across a row_size × col_size process group mesh with no head-divisibility constraint. It also has asymptotically lower communication cost: as the number of processors increases, each GPU handles a smaller sequence shard, so the per-GPU communication volume scales as O(1/P).

Communication pattern:

  • Q tokens: all-gathered across the row group
  • K/V tokens: fused into a single all-gather across the col group
  • Partial outputs: combined via flash_attn_combine (LSE-weighted reduction) across the row group

Design:

  • Attention2D is unified under cp_size in VisualGenMapping, the same slot used by ring attention — the two are mutually exclusive context parallelism strategies
  • Controlled by dit_attn2d_row_size / dit_attn2d_col_size (default 1, i.e., disabled)
  • Composable with CFG parallelism

Test Coverage

  • New file: tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py
    • Multi-GPU tests covering 2×2, 1×4, and 4×1 mesh configurations with both vanilla and FlashAttention4 inner backends
    • Single-GPU smoke test for FlashAttn4Attention.forward
  • Extended: tests/unittest/_torch/visual_gen/test_visual_gen_args.py
    • ParallelConfig attn2d fields, seq_parallel_size computation for attn2d and Ulysses modes
  • Extended: tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py
    • attn2d size/rank/group construction tests; constructor validation tests (ulysses+attn2d, cp_size mismatch, ring+attn2d exclusivity); multi-GPU CP collective test verifying attn2d_mesh_rank == cp_rank across all 4 ranks

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@chang-l
Copy link
Copy Markdown
Collaborator

chang-l commented Apr 11, 2026

Hi @NVShreyas , can you please help review this PR if possible?

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 11, 2026
@venmugil venmugil marked this pull request as ready for review April 11, 2026 04:47
@venmugil venmugil requested review from a team as code owners April 11, 2026 04:47
@venmugil venmugil requested review from arysef and kaiyux April 11, 2026 04:47
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 11, 2026

📝 Walkthrough

Walkthrough

This pull request introduces Attention2D-based sequence parallelism as an alternative to Ulysses sequence parallelism for visual generation models. It adds CLI arguments for row/column sizing, implements the Attention2DAttention backend, updates configuration and distributed setup logic, refactors existing models to use generic sequence-parallel naming, and includes comprehensive distributed tests.

Changes

Cohort / File(s) Summary
CLI Argument Additions
examples/visual_gen/visual_gen_flux.py, examples/visual_gen/visual_gen_ltx2.py, examples/visual_gen/visual_gen_wan_*.py
Added --attn2d_row_size and --attn2d_col_size CLI arguments with mutual-exclusivity messaging vs --ulysses_size. Updated startup logging to report derived seq_parallel string (Ulysses/Attention2D/None) instead of raw ulysses_size; runtime validation added only in wan variants.
Attention Backend Exports
tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
Exported Attention2DAttention from .parallel submodule; added to __all__ alongside existing backends.
FlashAttention4 Support for LSE
tensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.py
Refactored to compute log-sum-exp (LSE) values; _fwd now returns (output, lse) tuple. Added _prepare_inputs helper for validation and dtype casting. Introduced forward_with_lse() method and support_lse classmethod (returns True).
Attention2D Implementation
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
Introduced new Attention2DAttention backend supporting 2D process-group mesh (row/column). Gathers Q across row, K/V across column, uses LSE-based combination via FlashAttention kernel. Requires inner backend with support_lse() and head metadata.
Parallelism Configuration
tensorrt_llm/_torch/visual_gen/config.py, tensorrt_llm/_torch/visual_gen/parallelism.py
Added dit_attn2d_row_size/dit_attn2d_col_size fields to ParallelConfig; new seq_parallel_size property selects active parallelism degree. Added attn2d_*_process_group fields to DiffusionModelConfig. Refactored setup_sequence_parallelism with helper functions _setup_ulysses/_setup_attn2d for mutually exclusive strategy setup.
Model Sequence-Parallel Refactoring
tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux*.py, tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py, tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py, tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
Replaced Ulysses-specific state (use_ulysses, ulysses_size, ulysses_pg, ulysses_rank) with generic use_seq_parallel/seq_parallel_* equivalents. Updated forward passes, gather/all-gather operations, and divisibility checks; updated comments to reflect "sequence parallelism" terminology.
Attention Module Integration
tensorrt_llm/_torch/visual_gen/modules/attention.py, tensorrt_llm/_torch/visual_gen/pipeline.py
Integrated dual strategy detection: wraps backend with Attention2DAttention when dit_attn2d_row_size * dit_attn2d_col_size > 1, otherwise uses UlyssesAttention if dit_ulysses_size > 1. Updated CFG parallel denoising to use seq_parallel_size for gather indexing.
Distributed Test Suite
tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py, tests/unittest/_torch/visual_gen/test_visual_gen_args.py
Added multi-GPU distributed tests for Attention2DAttention correctness, mesh variants (1x4, 4x1, 2x2), invalid mask rejection, and integration with FlashAttn4Attention. Added unit tests for ParallelConfig mesh sizing, validation, and seq_parallel_size property across various configurations.

Sequence Diagram(s)

sequenceDiagram
    participant Rank0
    participant Rank1
    participant Rank2
    participant Rank3
    
    Note over Rank0,Rank3: 2x2 Attention2D Mesh (row_size=2, col_size=2)
    
    Rank0->>Rank0: Q, K, V input (local)
    Rank1->>Rank1: Q, K, V input (local)
    Rank2->>Rank2: Q, K, V input (local)
    Rank3->>Rank3: Q, K, V input (local)
    
    par Row Gather (Q)
        Rank0->>Rank1: Broadcast Q
        Rank1->>Rank0: Broadcast Q
        Rank2->>Rank3: Broadcast Q
        Rank3->>Rank2: Broadcast Q
    end
    
    par Col Gather (K,V)
        Rank0->>Rank2: Broadcast K,V
        Rank2->>Rank0: Broadcast K,V
        Rank1->>Rank3: Broadcast K,V
        Rank3->>Rank1: Broadcast K,V
    end
    
    par Inner Backend Forward (with LSE)
        Rank0->>Rank0: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
        Rank1->>Rank1: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
        Rank2->>Rank2: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
        Rank3->>Rank3: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
    end
    
    par LSE-based Combine (reduce-scatter)
        Rank0->>Rank0: flash_attn_combine(out, lse) → output_0
        Rank1->>Rank1: flash_attn_combine(out, lse) → output_1
        Rank2->>Rank2: flash_attn_combine(out, lse) → output_2
        Rank3->>Rank3: flash_attn_combine(out, lse) → output_3
    end
    
    Note over Rank0,Rank3: Each rank has reconstructed local output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~55 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.58% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates the main change: adding Attention2D sequence parallelism feature for visual-generation models, which aligns with the substantial changes across config, backends, models, and CLI.
Description check ✅ Passed PR description follows the template structure with all required sections completed: title format (though missing [type]), Description explaining the feature with paper reference, Test Coverage listing test files, and PR Checklist items reviewed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py (1)

1186-1193: ⚠️ Potential issue | 🟡 Minor

Update the CFG log to reflect seq_parallel_size.

The grouping and gather offset now use seq_parallel_size, but the log still prints ulysses_size. On Attention2D runs this will report ulysses_size=1, which is misleading when debugging distributed layout.

Suggested fix
         if do_cfg_parallel_mm and self.rank == 0:
             logger.info(
                 f"CFG parallel (multi-modal guidance): cfg_size={cfg_size}, "
-                f"ulysses_size={ulysses_size}"
+                f"seq_parallel_size={seq_parallel_size}"
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py` around lines
1186 - 1193, The CFG log message still prints ulysses_size which is misleading
because grouping/gather offset now use seq_parallel_size; update the logger.info
call in pipeline_ltx2.py (the block that checks do_cfg_parallel_mm and self.rank
== 0) to include seq_parallel_size (and/or both seq_parallel_size and
ulysses_size) instead of only ulysses_size so the distributed layout is
accurately reported; reference symbols: seq_parallel_size, cfg_group,
do_cfg_parallel_mm, ulysses_size, and the logger.info call.
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py (1)

1140-1163: ⚠️ Potential issue | 🔴 Critical

Confirmed bug: set_ulysses_enabled references undefined attributes that will cause AttributeError at runtime.

This method references self.ulysses_size (line 1149) and assigns to self.use_ulysses (line 1152), but the __init__ method defines self.use_seq_parallel and self.seq_parallel_size instead. Since the method is called in pipeline_ltx2_two_stages.py (lines 697, 713), this will fail at runtime.

Update the method to use the correct attribute names: self.seq_parallel_size and self.use_seq_parallel. Also fix the block attribute assignments (block._use_ulysses should be block._use_seq_parallel).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines
1140 - 1163, The method set_ulysses_enabled uses wrong attribute names and will
raise AttributeError; update all references from ulysses to the seq-parallel
names used in __init__: replace self.ulysses_size with self.seq_parallel_size,
self.use_ulysses with self.use_seq_parallel, and per-block assignment
block._use_ulysses with block._use_seq_parallel; keep the existing logic for
disabling audio sharding (self._audio_is_sharded / block._audio_is_sharded) and
leave attn1/audio_attn1 calls as-is unless their APIs also differ.
🧹 Nitpick comments (3)
examples/visual_gen/visual_gen_flux.py (1)

266-279: Consider adding mutual exclusivity validation for consistency.

Unlike visual_gen_wan_t2v.py (lines 197-201), this file doesn't validate that --ulysses_size and --attn2d_row_size/--attn2d_col_size aren't both set before calling VisualGen. While parallelism.py will catch this later, early CLI-level validation provides clearer user feedback. Consider adding:

if attn2d_size > 1 and args.ulysses_size > 1:
    raise ValueError(
        "--ulysses_size and --attn2d_row_size/--attn2d_col_size are mutually exclusive."
    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/visual_gen/visual_gen_flux.py` around lines 266 - 279, Add a
CLI-level mutual exclusivity check before building diffusion_args/initializing
VisualGen: compute attn2d_size (already present) and if both attn2d_size > 1 and
args.ulysses_size > 1 raise a ValueError with a clear message that
"--ulysses_size and --attn2d_row_size/--attn2d_col_size are mutually exclusive."
Place this check near where attn2d_size is computed (before calling
build_diffusion_args and before logger.info/VisualGen initialization) so users
get immediate, clear feedback.
tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py (1)

295-300: Consider renaming iteration variables for clarity.

Using k and v as dict comprehension iteration variables when k and v are also tensor variables in this function is confusing to readers. While Python 3 scopes comprehension variables locally (no shadowing occurs), this pattern can mislead reviewers into thinking the tensor values are being overwritten.

♻️ Suggested rename for clarity
-        inner_kwargs = {k: v for k, v in kwargs.items() if k not in ("batch_size", "seq_len")}
+        inner_kwargs = {key: val for key, val in kwargs.items() if key not in ("batch_size", "seq_len")}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py` around lines
295 - 300, The dict comprehension that builds inner_kwargs uses iteration
variables named k and v which collide with the tensor variables q, k, v in this
scope and can confuse readers; change the comprehension to use distinct names
(e.g., key, val or kk, vv) so it reads: inner_kwargs = {key: val for key, val in
kwargs.items() if key not in ("batch_size", "seq_len")} and then call
inner_backend.forward_with_lse(q=q, k=k, v=v, batch_size=B, seq_len=seq_len,
**inner_kwargs) to avoid any perceived shadowing while keeping the same logic.
tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py (1)

158-181: Consider explicit process group cleanup for completeness.

The row and column process groups created with dist.new_group are not explicitly destroyed. While this is not a practical issue since the spawned processes terminate after each test, explicitly destroying these groups in _cleanup would be more rigorous.

♻️ Optional: Add explicit group cleanup

You could return the groups from the test logic and destroy them explicitly, or use a context manager pattern. However, since processes terminate after each test, this is purely a cleanliness improvement and not required.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py` around
lines 158 - 181, The helper _make_process_groups creates row and col groups via
dist.new_group but never exposes them for teardown; modify _make_process_groups
to return the created group handles (row_pg and col_pg) to callers and ensure
test teardown/_cleanup calls torch.distributed.destroy_process_group(row_pg) and
destroy_process_group(col_pg) (guarding for None) to explicitly destroy the
groups created by dist.new_group; alternatively implement a context-manager
wrapper around _make_process_groups that destroys row_pg and col_pg in its
__exit__ to guarantee cleanup.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/visual_gen/visual_gen_ltx2.py`:
- Around line 170-186: The help text for the CLI args --attn2d_row_size and
--attn2d_col_size is too strict; update the parser.add_argument help strings for
--attn2d_row_size and --attn2d_col_size to remove "Must be used together" and
instead state that they can be set independently (asymmetric meshes like 1x4 or
4x1 are valid) and that the total sequence parallelism degree equals
attn2d_row_size * attn2d_col_size; keep the note about mutual exclusivity with
--ulysses_size.

In `@examples/visual_gen/visual_gen_wan_i2v.py`:
- Around line 139-155: The help text for the CLI args --attn2d_row_size and
--attn2d_col_size incorrectly states they "must be used together," which
discourages valid asymmetric Attention2D meshes (1xN or Nx1); update the
parser.add_argument help strings for attn2d_row_size and attn2d_col_size to
indicate the two values can be set independently and that a value of 1 is
allowed to create asymmetric meshes (e.g., "Values may be set independently; a
value of 1 is allowed to form asymmetric 1xN or Nx1 meshes. Mutually exclusive
with --ulysses_size.").

In `@examples/visual_gen/visual_gen_wan_t2v.py`:
- Around line 136-138: Replace the Unicode multiplication sign in the string
literal "2 CFG groups × 2 Ulysses ranks = 4 GPUs total. " with the ASCII letter
x so it becomes "2 CFG groups x 2 Ulysses ranks = 4 GPUs total. "; update the
same occurrence that appears alongside the help/description for the command-line
flag (the multi-line string shown) to avoid encoding/ambiguity issues.

---

Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py`:
- Around line 1186-1193: The CFG log message still prints ulysses_size which is
misleading because grouping/gather offset now use seq_parallel_size; update the
logger.info call in pipeline_ltx2.py (the block that checks do_cfg_parallel_mm
and self.rank == 0) to include seq_parallel_size (and/or both seq_parallel_size
and ulysses_size) instead of only ulysses_size so the distributed layout is
accurately reported; reference symbols: seq_parallel_size, cfg_group,
do_cfg_parallel_mm, ulysses_size, and the logger.info call.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 1140-1163: The method set_ulysses_enabled uses wrong attribute
names and will raise AttributeError; update all references from ulysses to the
seq-parallel names used in __init__: replace self.ulysses_size with
self.seq_parallel_size, self.use_ulysses with self.use_seq_parallel, and
per-block assignment block._use_ulysses with block._use_seq_parallel; keep the
existing logic for disabling audio sharding (self._audio_is_sharded /
block._audio_is_sharded) and leave attn1/audio_attn1 calls as-is unless their
APIs also differ.

---

Nitpick comments:
In `@examples/visual_gen/visual_gen_flux.py`:
- Around line 266-279: Add a CLI-level mutual exclusivity check before building
diffusion_args/initializing VisualGen: compute attn2d_size (already present) and
if both attn2d_size > 1 and args.ulysses_size > 1 raise a ValueError with a
clear message that "--ulysses_size and --attn2d_row_size/--attn2d_col_size are
mutually exclusive." Place this check near where attn2d_size is computed (before
calling build_diffusion_args and before logger.info/VisualGen initialization) so
users get immediate, clear feedback.

In `@tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py`:
- Around line 295-300: The dict comprehension that builds inner_kwargs uses
iteration variables named k and v which collide with the tensor variables q, k,
v in this scope and can confuse readers; change the comprehension to use
distinct names (e.g., key, val or kk, vv) so it reads: inner_kwargs = {key: val
for key, val in kwargs.items() if key not in ("batch_size", "seq_len")} and then
call inner_backend.forward_with_lse(q=q, k=k, v=v, batch_size=B,
seq_len=seq_len, **inner_kwargs) to avoid any perceived shadowing while keeping
the same logic.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py`:
- Around line 158-181: The helper _make_process_groups creates row and col
groups via dist.new_group but never exposes them for teardown; modify
_make_process_groups to return the created group handles (row_pg and col_pg) to
callers and ensure test teardown/_cleanup calls
torch.distributed.destroy_process_group(row_pg) and
destroy_process_group(col_pg) (guarding for None) to explicitly destroy the
groups created by dist.new_group; alternatively implement a context-manager
wrapper around _make_process_groups that destroys row_pg and col_pg in its
__exit__ to guarantee cleanup.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0897897f-3ccc-4acc-88d4-67be8dbc7222

📥 Commits

Reviewing files that changed from the base of the PR and between 5e1a98e and fe7ffb1.

📒 Files selected for processing (18)
  • examples/visual_gen/visual_gen_flux.py
  • examples/visual_gen/visual_gen_ltx2.py
  • examples/visual_gen/visual_gen_wan_i2v.py
  • examples/visual_gen/visual_gen_wan_t2v.py
  • tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py
  • tensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.py
  • tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
  • tensorrt_llm/_torch/visual_gen/config.py
  • tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py
  • tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tensorrt_llm/_torch/visual_gen/parallelism.py
  • tensorrt_llm/_torch/visual_gen/pipeline.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py
  • tests/unittest/_torch/visual_gen/test_visual_gen_args.py

Comment thread examples/visual_gen/visual_gen_ltx2.py
Comment thread examples/visual_gen/visual_gen_wan_i2v.py
Comment thread examples/visual_gen/visual_gen_wan_t2v.py Outdated
@chang-l chang-l added VisualGen and removed Community want to contribute PRs initiated from Community labels Apr 11, 2026
Comment thread tensorrt_llm/_torch/visual_gen/modules/attention.py Outdated
Comment thread tests/unittest/_torch/visual_gen/test_visual_gen_args.py
Comment thread tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
Comment thread tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py
Comment thread tensorrt_llm/_torch/visual_gen/config.py Outdated
Comment thread tensorrt_llm/_torch/visual_gen/config.py
Comment thread tensorrt_llm/_torch/visual_gen/config.py
Comment thread tensorrt_llm/_torch/visual_gen/mapping.py
Comment thread tensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.py
Comment thread tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py Outdated
Comment thread tensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.py
Comment thread tensorrt_llm/_torch/visual_gen/mapping.py
@NVShreyas
Copy link
Copy Markdown
Collaborator

could you add a test similar to this for one of the models?
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py

@venmugil
Copy link
Copy Markdown
Collaborator Author

could you add a test similar to this for one of the models? https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py

Good point. Added a new test tests/unittest/_torch/visual_gen/multi_gpu/test_wan_attn2d.py

@NVShreyas
Copy link
Copy Markdown
Collaborator

could you rebase on main so we can run CI?

venmugil added 11 commits April 27, 2026 09:52
Implements Attention2DAttention, a 2D-mesh parallel attention backend
for FLUX, Wan, and LTX2 visual-gen models. The attention computation
is distributed across a row x col process group mesh:

- Q tokens are sharded and all-gathered across the row group
- K/V tokens are fused into a single all-gather across the col group
- Partial outputs (with LSE) are combined via flash_attn_combine
  across the row group

New parallel config fields dit_attn2d_row_size and dit_attn2d_col_size
control the mesh dimensions (e.g. 2x4, 8x8). Attention2D is mutually
exclusive with Ulysses sequence parallelism.

- attention_backend/parallel.py: Attention2DAttention class
- attention_backend/flash_attn4.py: forward_with_lse for FA4 backend
- config.py: dit_attn2d_row_size / dit_attn2d_col_size config fields
- parallelism.py: setup_sequence_parallelism support for Attention2D
- transformer_{flux,flux2,wan,ltx2}.py: wire up Attention2D per-model
- examples: --attn2d_row_size / --attn2d_col_size CLI args in all scripts
- tests: multi-GPU unit tests for 2x2, 1x4, 4x1 meshes and FA4 backend

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Fix head-count check in LTXModel to only apply to Ulysses (not Attention2D)
- Fix BasicAVTransformerBlock to support Attention2D by generalizing
  Ulysses-specific state to strategy-agnostic _seq_parallel_* fields
- Store attn2d_mesh_process_group in DiffusionModelConfig and _setup_attn2d
  so v2a blocks can gather full video sequence across all mesh ranks
- Update module docstring in parallel.py to reflect both wrapper classes
- Improve test_attn2d_attention.py: add numerical correctness checks for
  asymmetric meshes, replace FA4 shape-only test with correctness test,
  add init guard tests for invalid inner backends
- Add seq_parallel_size and attn2d config tests to test_visual_gen_args.py

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Fuse K and V all-gathers into a single NCCL call to reduce launch overhead
- Cache inner backend layout in __init__ to avoid repeated attribute lookups
  and validate layout at construction time rather than on every forward call

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
The upstream rebase changed the forward signature to **kwargs style,
removing explicit batch_size/seq_len/seq_len_kv parameters, but the
call to forward_with_lse still referenced them as undefined variables.
Derive from tensor shapes instead.

Add TestFlashAttn4Forward.test_fa4_forward_returns_correct_shape to
exercise FlashAttn4Attention.forward directly, catching regressions
that bypass forward_with_lse.

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
Merge the separate Ulysses/Attention2D log lines into a single
'Initializing VisualGen: seq_parallel=...' line for all four
example scripts, making the active parallelism strategy immediately
clear at a glance.

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- parallel.py: rename dict comprehension loop vars k/v -> key/val to
  avoid shadowing the tensor parameters k and v in the enclosing scope
- pipeline_ltx2.py: fix CFG log message to use seq_parallel_size
  instead of ulysses_size (correct for both Ulysses and Attention2D);
  remove now-unused ulysses_size local variable
- transformer_ltx2.py: fix set_ulysses_enabled to use correct attribute
  names (seq_parallel_size, use_seq_parallel, _use_seq_parallel)
- visual_gen_flux.py: add mutual exclusivity check between --ulysses_size
  and --attn2d_row_size/--attn2d_col_size
- visual_gen_*.py: clarify --attn2d_row_size/--attn2d_col_size help text:
  remove "Must be used together", note asymmetric meshes (1x4, 4x1) are valid
- visual_gen_wan_t2v.py: replace Unicode multiplication sign with ASCII x

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- flash_attn4.py: remove unused batch_size/seq_len/seq_len_kv params
  from forward_with_lse(); forward() now calls it without threading
  those values through
- parallel.py: remove now-unnecessary batch_size/seq_len stripping in
  Attention2DAttention.forward_with_lse call; pass **kwargs directly
- config.py: fix total_parallel_size to not include dit_cp_size and
  dit_dp_size (spurious addition from conflict resolution); restore to
  upstream formula with seq_parallel_size replacing dit_ulysses_size

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
… cp_size

Attention2D shards the sequence/context axis (like ring attention), not heads
(like Ulysses). Correct the architecture and terminology throughout:

- VisualGenMapping: rename ring_size -> cp_size; rename "ring" device mesh
  dim -> "cp" (aligns with DeviceMeshTopologyImpl convention); add
  attn2d_row_size/attn2d_col_size params; move Attention2D group setup into
  _build_attn2d_groups() called from build_mesh(); drop separate mesh group
  creation since the "cp" mesh dim already provides it (attn2d_mesh_group now
  reads from _group("cp")); rename ring_rank/ring_group -> cp_rank/cp_group;
  fix mutual exclusivity: attn2d XOR ring (both CP), attn2d+ulysses raises
  NotImplementedError (orthogonal but not yet implemented)

- parallelism.py: deleted — group setup now lives in VisualGenMapping, state
  reading is inlined in each transformer; no production callers remained

- pipeline_loader.py: compute cp_size = attn2d_size or ring_size before
  constructing VisualGenMapping; read attn2d groups from vgm properties

- transformer_ltx2.py: replace setup_sequence_parallelism() call with inline
  vgm-based pattern matching flux/wan; fix BasicAVTransformerBlock which read
  a nonexistent config attribute for the Ulysses process group

- config.py, examples, model docstrings: update terminology — Attention2D is
  context parallelism; Ulysses is head-sharding; "mutually exclusive" ->
  "cannot be combined (not yet implemented)" for attn2d+ulysses

- Tests: update test_visual_gen_mapping.py for cp_rank/cp_group/cp_size API
  and "cp" dim names; replace TestAttn2DSetupParallelism (tested deleted
  parallelism.py) with TestAttn2DValidation testing VisualGenMapping directly

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
DiffusionModelConfig.parallel (ParallelConfig) duplicated information
already available through model_config.visual_gen_mapping. This field
was not present in origin/main and required keeping two sources of
truth in sync.

All callers now read parallelism dimensions directly from vgm:
  - attn2d_row/col_size  → vgm.attn2d_row/col_size
  - ulysses_size         → vgm.ulysses_size
  - cfg_size             → vgm.cfg_size
  - seq_parallel_size    → inlined (attn2d_size if >1 else ulysses_size)

Also removes the three attn2d_*_process_group fields from
DiffusionModelConfig and their pipeline_loader assignments; model
constructors now fetch these directly from vgm.attn2d_{row,col,mesh}_group.

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Remove redundant `if vgm else None/0` guards inside `use_attn2d` blocks
  (vgm is always non-None when attn2d_size > 1)
- Replace dist.get_rank(attn2d_mesh_group) with vgm.attn2d_mesh_rank
- Remove unused `import torch.distributed as dist` from flux/flux2/wan transformers
- Remove _vgm alias in attention.py (vgm already in scope)
- Remove dead cfg_size entry from _setup_cfg_config return dict
- Fix gather index bug: unconditional CFG result is always at index 1,
  not seq_parallel_size (which would be out-of-bounds with active parallelism)
- Reorder dit_attn2d_row/col_size fields in ParallelConfig to sit
  alongside other CP-related sizes; drop redundant comment block
- Update parallel.py module docstring: "Parallelism Wrappers" (plural,
  covers both Ulysses and Attention2D)
- Restore primary-heads divisibility check for Ulysses in LTX2 transformer
- Align BasicAVTransformerBlock parallelism setup with origin/main style
  (vgm extracted once, guarded via `vgm is not None`)

Tests (test_visual_gen_mapping.py):
- TestConstruction: add test_stores_attn2d_sizes
- TestSingleGPURanksAndGroups: add attn2d_mesh_rank to test_ranks_are_zero,
  test_attn2d_mesh_rank_aliases_cp_rank, test_attn2d_mesh_group_aliases_cp_group,
  test_attn2d_row_col_groups_none_without_attn2d
- TestValidation (new): test_ulysses_and_attn2d_raises,
  test_attn2d_cp_size_mismatch_raises, test_attn2d_and_ring_are_mutually_exclusive

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Add FA4 asymmetric mesh tests (1x4 and 4x1) to cover asymmetric Q/KV
  sequence lengths and flash_attn_combine with N=4 fan-in
- Move VisualGenMapping validation tests from TestAttn2DValidation
  (test_attn2d_attention.py) into TestConstruction (test_visual_gen_mapping.py)
  where other constructor raises tests already live; remove duplicate class
- Add test_attn2d_and_ring_are_mutually_exclusive to TestConstruction

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
@NVShreyas
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45778 [ run ] triggered by Bot. Commit: 640ab84 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45778 [ run ] completed with state FAILURE. Commit: 640ab84
/LLM/main/L0_MergeRequest_PR pipeline #35968 completed with status: 'ABORTED'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45789 [ run ] triggered by Bot. Commit: 640ab84 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45789 [ run ] completed with state SUCCESS. Commit: 640ab84
/LLM/main/L0_MergeRequest_PR pipeline #35979 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator

/bot run --add-multi-gpu-test --only-multi-gpu-test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45961 [ run ] triggered by Bot. Commit: 640ab84 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45961 [ run ] completed with state SUCCESS. Commit: 640ab84
/LLM/main/L0_MergeRequest_PR pipeline #36115 (Partly Tested) completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator

/bot run --add-multi-gpu-test --only-multi-gpu-test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46150 [ run ] triggered by Bot. Commit: 640ab84 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46150 [ run ] completed with state SUCCESS. Commit: 640ab84
/LLM/main/L0_MergeRequest_PR pipeline #36276 (Partly Tested) completed with status: 'SUCCESS'

CI Report

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46248 [ run ] triggered by Bot. Commit: 640ab84 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46248 [ run ] completed with state SUCCESS. Commit: 640ab84
/LLM/main/L0_MergeRequest_PR pipeline #36356 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@chang-l chang-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also update examples/visual_gen/README.md: add --attn2d_row_size/--attn2d_col_size to the Common Arguments table; add an "Attention2D (2D context parallelism)" subsection alongside Ulysses in Multi-GPU Parallelism section; rewrite the "ulysses_size must divide head count" troubleshooting line to point to Attention2D as the alternative?

Comment thread tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py
- Expand Attention2DAttention docstring with purpose, motivation vs.
  Ulysses and Ring attention (O(N/√P) scaling), mesh layout ASCII
  diagram, 4-step architecture breakdown, supported backends, and
  constraints
- Update examples/visual_gen/README.md: add Attention2D subsection to
  Multi-GPU Parallelism, add --attn2d_row_size/--attn2d_col_size to
  Common Arguments table, rewrite Ulysses head-count troubleshooting
  line to point to Attention2D, add Attention2D Errors section

Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
@venmugil
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46448 [ run ] triggered by Bot. Commit: 5efda90 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46448 [ run ] completed with state SUCCESS. Commit: 5efda90
/LLM/main/L0_MergeRequest_PR pipeline #36518 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@venmugil
Copy link
Copy Markdown
Collaborator Author

venmugil commented May 1, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46504 [ run ] triggered by Bot. Commit: 5efda90 Link to invocation

Copy link
Copy Markdown
Collaborator

@chang-l chang-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46504 [ run ] completed with state SUCCESS. Commit: 5efda90
/LLM/main/L0_MergeRequest_PR pipeline #36565 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@venmugil
Copy link
Copy Markdown
Collaborator Author

venmugil commented May 3, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46602 [ run ] triggered by Bot. Commit: 5efda90 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46602 [ run ] completed with state SUCCESS. Commit: 5efda90
/LLM/main/L0_MergeRequest_PR pipeline #36649 completed with status: 'SUCCESS'

CI Report

Link to invocation

@NVShreyas NVShreyas merged commit a33b885 into NVIDIA:main May 4, 2026
6 checks passed
@venmugil venmugil deleted the attn2d branch May 4, 2026 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants