[None][feat] Add Attention2D sequence parallelism for visual-gen models#12943
[None][feat] Add Attention2D sequence parallelism for visual-gen models#12943NVShreyas merged 15 commits intoNVIDIA:mainfrom
Conversation
|
Hi @NVShreyas , can you please help review this PR if possible? |
📝 WalkthroughWalkthroughThis pull request introduces Attention2D-based sequence parallelism as an alternative to Ulysses sequence parallelism for visual generation models. It adds CLI arguments for row/column sizing, implements the Changes
Sequence Diagram(s)sequenceDiagram
participant Rank0
participant Rank1
participant Rank2
participant Rank3
Note over Rank0,Rank3: 2x2 Attention2D Mesh (row_size=2, col_size=2)
Rank0->>Rank0: Q, K, V input (local)
Rank1->>Rank1: Q, K, V input (local)
Rank2->>Rank2: Q, K, V input (local)
Rank3->>Rank3: Q, K, V input (local)
par Row Gather (Q)
Rank0->>Rank1: Broadcast Q
Rank1->>Rank0: Broadcast Q
Rank2->>Rank3: Broadcast Q
Rank3->>Rank2: Broadcast Q
end
par Col Gather (K,V)
Rank0->>Rank2: Broadcast K,V
Rank2->>Rank0: Broadcast K,V
Rank1->>Rank3: Broadcast K,V
Rank3->>Rank1: Broadcast K,V
end
par Inner Backend Forward (with LSE)
Rank0->>Rank0: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
Rank1->>Rank1: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
Rank2->>Rank2: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
Rank3->>Rank3: forward_with_lse(Q_full, K_full, V_full) → (out, lse)
end
par LSE-based Combine (reduce-scatter)
Rank0->>Rank0: flash_attn_combine(out, lse) → output_0
Rank1->>Rank1: flash_attn_combine(out, lse) → output_1
Rank2->>Rank2: flash_attn_combine(out, lse) → output_2
Rank3->>Rank3: flash_attn_combine(out, lse) → output_3
end
Note over Rank0,Rank3: Each rank has reconstructed local output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~55 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py (1)
1186-1193:⚠️ Potential issue | 🟡 MinorUpdate the CFG log to reflect
seq_parallel_size.The grouping and gather offset now use
seq_parallel_size, but the log still printsulysses_size. On Attention2D runs this will reportulysses_size=1, which is misleading when debugging distributed layout.Suggested fix
if do_cfg_parallel_mm and self.rank == 0: logger.info( f"CFG parallel (multi-modal guidance): cfg_size={cfg_size}, " - f"ulysses_size={ulysses_size}" + f"seq_parallel_size={seq_parallel_size}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py` around lines 1186 - 1193, The CFG log message still prints ulysses_size which is misleading because grouping/gather offset now use seq_parallel_size; update the logger.info call in pipeline_ltx2.py (the block that checks do_cfg_parallel_mm and self.rank == 0) to include seq_parallel_size (and/or both seq_parallel_size and ulysses_size) instead of only ulysses_size so the distributed layout is accurately reported; reference symbols: seq_parallel_size, cfg_group, do_cfg_parallel_mm, ulysses_size, and the logger.info call.tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py (1)
1140-1163:⚠️ Potential issue | 🔴 CriticalConfirmed bug:
set_ulysses_enabledreferences undefined attributes that will causeAttributeErrorat runtime.This method references
self.ulysses_size(line 1149) and assigns toself.use_ulysses(line 1152), but the__init__method definesself.use_seq_parallelandself.seq_parallel_sizeinstead. Since the method is called inpipeline_ltx2_two_stages.py(lines 697, 713), this will fail at runtime.Update the method to use the correct attribute names:
self.seq_parallel_sizeandself.use_seq_parallel. Also fix the block attribute assignments (block._use_ulyssesshould beblock._use_seq_parallel).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines 1140 - 1163, The method set_ulysses_enabled uses wrong attribute names and will raise AttributeError; update all references from ulysses to the seq-parallel names used in __init__: replace self.ulysses_size with self.seq_parallel_size, self.use_ulysses with self.use_seq_parallel, and per-block assignment block._use_ulysses with block._use_seq_parallel; keep the existing logic for disabling audio sharding (self._audio_is_sharded / block._audio_is_sharded) and leave attn1/audio_attn1 calls as-is unless their APIs also differ.
🧹 Nitpick comments (3)
examples/visual_gen/visual_gen_flux.py (1)
266-279: Consider adding mutual exclusivity validation for consistency.Unlike
visual_gen_wan_t2v.py(lines 197-201), this file doesn't validate that--ulysses_sizeand--attn2d_row_size/--attn2d_col_sizearen't both set before callingVisualGen. Whileparallelism.pywill catch this later, early CLI-level validation provides clearer user feedback. Consider adding:if attn2d_size > 1 and args.ulysses_size > 1: raise ValueError( "--ulysses_size and --attn2d_row_size/--attn2d_col_size are mutually exclusive." )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/visual_gen/visual_gen_flux.py` around lines 266 - 279, Add a CLI-level mutual exclusivity check before building diffusion_args/initializing VisualGen: compute attn2d_size (already present) and if both attn2d_size > 1 and args.ulysses_size > 1 raise a ValueError with a clear message that "--ulysses_size and --attn2d_row_size/--attn2d_col_size are mutually exclusive." Place this check near where attn2d_size is computed (before calling build_diffusion_args and before logger.info/VisualGen initialization) so users get immediate, clear feedback.tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py (1)
295-300: Consider renaming iteration variables for clarity.Using
kandvas dict comprehension iteration variables whenkandvare also tensor variables in this function is confusing to readers. While Python 3 scopes comprehension variables locally (no shadowing occurs), this pattern can mislead reviewers into thinking the tensor values are being overwritten.♻️ Suggested rename for clarity
- inner_kwargs = {k: v for k, v in kwargs.items() if k not in ("batch_size", "seq_len")} + inner_kwargs = {key: val for key, val in kwargs.items() if key not in ("batch_size", "seq_len")}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py` around lines 295 - 300, The dict comprehension that builds inner_kwargs uses iteration variables named k and v which collide with the tensor variables q, k, v in this scope and can confuse readers; change the comprehension to use distinct names (e.g., key, val or kk, vv) so it reads: inner_kwargs = {key: val for key, val in kwargs.items() if key not in ("batch_size", "seq_len")} and then call inner_backend.forward_with_lse(q=q, k=k, v=v, batch_size=B, seq_len=seq_len, **inner_kwargs) to avoid any perceived shadowing while keeping the same logic.tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py (1)
158-181: Consider explicit process group cleanup for completeness.The row and column process groups created with
dist.new_groupare not explicitly destroyed. While this is not a practical issue since the spawned processes terminate after each test, explicitly destroying these groups in_cleanupwould be more rigorous.♻️ Optional: Add explicit group cleanup
You could return the groups from the test logic and destroy them explicitly, or use a context manager pattern. However, since processes terminate after each test, this is purely a cleanliness improvement and not required.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py` around lines 158 - 181, The helper _make_process_groups creates row and col groups via dist.new_group but never exposes them for teardown; modify _make_process_groups to return the created group handles (row_pg and col_pg) to callers and ensure test teardown/_cleanup calls torch.distributed.destroy_process_group(row_pg) and destroy_process_group(col_pg) (guarding for None) to explicitly destroy the groups created by dist.new_group; alternatively implement a context-manager wrapper around _make_process_groups that destroys row_pg and col_pg in its __exit__ to guarantee cleanup.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/visual_gen/visual_gen_ltx2.py`:
- Around line 170-186: The help text for the CLI args --attn2d_row_size and
--attn2d_col_size is too strict; update the parser.add_argument help strings for
--attn2d_row_size and --attn2d_col_size to remove "Must be used together" and
instead state that they can be set independently (asymmetric meshes like 1x4 or
4x1 are valid) and that the total sequence parallelism degree equals
attn2d_row_size * attn2d_col_size; keep the note about mutual exclusivity with
--ulysses_size.
In `@examples/visual_gen/visual_gen_wan_i2v.py`:
- Around line 139-155: The help text for the CLI args --attn2d_row_size and
--attn2d_col_size incorrectly states they "must be used together," which
discourages valid asymmetric Attention2D meshes (1xN or Nx1); update the
parser.add_argument help strings for attn2d_row_size and attn2d_col_size to
indicate the two values can be set independently and that a value of 1 is
allowed to create asymmetric meshes (e.g., "Values may be set independently; a
value of 1 is allowed to form asymmetric 1xN or Nx1 meshes. Mutually exclusive
with --ulysses_size.").
In `@examples/visual_gen/visual_gen_wan_t2v.py`:
- Around line 136-138: Replace the Unicode multiplication sign in the string
literal "2 CFG groups × 2 Ulysses ranks = 4 GPUs total. " with the ASCII letter
x so it becomes "2 CFG groups x 2 Ulysses ranks = 4 GPUs total. "; update the
same occurrence that appears alongside the help/description for the command-line
flag (the multi-line string shown) to avoid encoding/ambiguity issues.
---
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py`:
- Around line 1186-1193: The CFG log message still prints ulysses_size which is
misleading because grouping/gather offset now use seq_parallel_size; update the
logger.info call in pipeline_ltx2.py (the block that checks do_cfg_parallel_mm
and self.rank == 0) to include seq_parallel_size (and/or both seq_parallel_size
and ulysses_size) instead of only ulysses_size so the distributed layout is
accurately reported; reference symbols: seq_parallel_size, cfg_group,
do_cfg_parallel_mm, ulysses_size, and the logger.info call.
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 1140-1163: The method set_ulysses_enabled uses wrong attribute
names and will raise AttributeError; update all references from ulysses to the
seq-parallel names used in __init__: replace self.ulysses_size with
self.seq_parallel_size, self.use_ulysses with self.use_seq_parallel, and
per-block assignment block._use_ulysses with block._use_seq_parallel; keep the
existing logic for disabling audio sharding (self._audio_is_sharded /
block._audio_is_sharded) and leave attn1/audio_attn1 calls as-is unless their
APIs also differ.
---
Nitpick comments:
In `@examples/visual_gen/visual_gen_flux.py`:
- Around line 266-279: Add a CLI-level mutual exclusivity check before building
diffusion_args/initializing VisualGen: compute attn2d_size (already present) and
if both attn2d_size > 1 and args.ulysses_size > 1 raise a ValueError with a
clear message that "--ulysses_size and --attn2d_row_size/--attn2d_col_size are
mutually exclusive." Place this check near where attn2d_size is computed (before
calling build_diffusion_args and before logger.info/VisualGen initialization) so
users get immediate, clear feedback.
In `@tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py`:
- Around line 295-300: The dict comprehension that builds inner_kwargs uses
iteration variables named k and v which collide with the tensor variables q, k,
v in this scope and can confuse readers; change the comprehension to use
distinct names (e.g., key, val or kk, vv) so it reads: inner_kwargs = {key: val
for key, val in kwargs.items() if key not in ("batch_size", "seq_len")} and then
call inner_backend.forward_with_lse(q=q, k=k, v=v, batch_size=B,
seq_len=seq_len, **inner_kwargs) to avoid any perceived shadowing while keeping
the same logic.
In `@tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py`:
- Around line 158-181: The helper _make_process_groups creates row and col
groups via dist.new_group but never exposes them for teardown; modify
_make_process_groups to return the created group handles (row_pg and col_pg) to
callers and ensure test teardown/_cleanup calls
torch.distributed.destroy_process_group(row_pg) and
destroy_process_group(col_pg) (guarding for None) to explicitly destroy the
groups created by dist.new_group; alternatively implement a context-manager
wrapper around _make_process_groups that destroys row_pg and col_pg in its
__exit__ to guarantee cleanup.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0897897f-3ccc-4acc-88d4-67be8dbc7222
📒 Files selected for processing (18)
examples/visual_gen/visual_gen_flux.pyexamples/visual_gen/visual_gen_ltx2.pyexamples/visual_gen/visual_gen_wan_i2v.pyexamples/visual_gen/visual_gen_wan_t2v.pytensorrt_llm/_torch/visual_gen/attention_backend/__init__.pytensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.pytensorrt_llm/_torch/visual_gen/attention_backend/parallel.pytensorrt_llm/_torch/visual_gen/config.pytensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.pytensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.pytensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.pytensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytensorrt_llm/_torch/visual_gen/parallelism.pytensorrt_llm/_torch/visual_gen/pipeline.pytests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.pytests/unittest/_torch/visual_gen/test_visual_gen_args.py
|
could you add a test similar to this for one of the models? |
Good point. Added a new test |
|
could you rebase on main so we can run CI? |
Implements Attention2DAttention, a 2D-mesh parallel attention backend
for FLUX, Wan, and LTX2 visual-gen models. The attention computation
is distributed across a row x col process group mesh:
- Q tokens are sharded and all-gathered across the row group
- K/V tokens are fused into a single all-gather across the col group
- Partial outputs (with LSE) are combined via flash_attn_combine
across the row group
New parallel config fields dit_attn2d_row_size and dit_attn2d_col_size
control the mesh dimensions (e.g. 2x4, 8x8). Attention2D is mutually
exclusive with Ulysses sequence parallelism.
- attention_backend/parallel.py: Attention2DAttention class
- attention_backend/flash_attn4.py: forward_with_lse for FA4 backend
- config.py: dit_attn2d_row_size / dit_attn2d_col_size config fields
- parallelism.py: setup_sequence_parallelism support for Attention2D
- transformer_{flux,flux2,wan,ltx2}.py: wire up Attention2D per-model
- examples: --attn2d_row_size / --attn2d_col_size CLI args in all scripts
- tests: multi-GPU unit tests for 2x2, 1x4, 4x1 meshes and FA4 backend
Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Fix head-count check in LTXModel to only apply to Ulysses (not Attention2D) - Fix BasicAVTransformerBlock to support Attention2D by generalizing Ulysses-specific state to strategy-agnostic _seq_parallel_* fields - Store attn2d_mesh_process_group in DiffusionModelConfig and _setup_attn2d so v2a blocks can gather full video sequence across all mesh ranks - Update module docstring in parallel.py to reflect both wrapper classes - Improve test_attn2d_attention.py: add numerical correctness checks for asymmetric meshes, replace FA4 shape-only test with correctness test, add init guard tests for invalid inner backends - Add seq_parallel_size and attn2d config tests to test_visual_gen_args.py Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Fuse K and V all-gathers into a single NCCL call to reduce launch overhead - Cache inner backend layout in __init__ to avoid repeated attribute lookups and validate layout at construction time rather than on every forward call Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
The upstream rebase changed the forward signature to **kwargs style, removing explicit batch_size/seq_len/seq_len_kv parameters, but the call to forward_with_lse still referenced them as undefined variables. Derive from tensor shapes instead. Add TestFlashAttn4Forward.test_fa4_forward_returns_correct_shape to exercise FlashAttn4Attention.forward directly, catching regressions that bypass forward_with_lse. Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
Merge the separate Ulysses/Attention2D log lines into a single 'Initializing VisualGen: seq_parallel=...' line for all four example scripts, making the active parallelism strategy immediately clear at a glance. Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- parallel.py: rename dict comprehension loop vars k/v -> key/val to avoid shadowing the tensor parameters k and v in the enclosing scope - pipeline_ltx2.py: fix CFG log message to use seq_parallel_size instead of ulysses_size (correct for both Ulysses and Attention2D); remove now-unused ulysses_size local variable - transformer_ltx2.py: fix set_ulysses_enabled to use correct attribute names (seq_parallel_size, use_seq_parallel, _use_seq_parallel) - visual_gen_flux.py: add mutual exclusivity check between --ulysses_size and --attn2d_row_size/--attn2d_col_size - visual_gen_*.py: clarify --attn2d_row_size/--attn2d_col_size help text: remove "Must be used together", note asymmetric meshes (1x4, 4x1) are valid - visual_gen_wan_t2v.py: replace Unicode multiplication sign with ASCII x Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- flash_attn4.py: remove unused batch_size/seq_len/seq_len_kv params from forward_with_lse(); forward() now calls it without threading those values through - parallel.py: remove now-unnecessary batch_size/seq_len stripping in Attention2DAttention.forward_with_lse call; pass **kwargs directly - config.py: fix total_parallel_size to not include dit_cp_size and dit_dp_size (spurious addition from conflict resolution); restore to upstream formula with seq_parallel_size replacing dit_ulysses_size Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
… cp_size
Attention2D shards the sequence/context axis (like ring attention), not heads
(like Ulysses). Correct the architecture and terminology throughout:
- VisualGenMapping: rename ring_size -> cp_size; rename "ring" device mesh
dim -> "cp" (aligns with DeviceMeshTopologyImpl convention); add
attn2d_row_size/attn2d_col_size params; move Attention2D group setup into
_build_attn2d_groups() called from build_mesh(); drop separate mesh group
creation since the "cp" mesh dim already provides it (attn2d_mesh_group now
reads from _group("cp")); rename ring_rank/ring_group -> cp_rank/cp_group;
fix mutual exclusivity: attn2d XOR ring (both CP), attn2d+ulysses raises
NotImplementedError (orthogonal but not yet implemented)
- parallelism.py: deleted — group setup now lives in VisualGenMapping, state
reading is inlined in each transformer; no production callers remained
- pipeline_loader.py: compute cp_size = attn2d_size or ring_size before
constructing VisualGenMapping; read attn2d groups from vgm properties
- transformer_ltx2.py: replace setup_sequence_parallelism() call with inline
vgm-based pattern matching flux/wan; fix BasicAVTransformerBlock which read
a nonexistent config attribute for the Ulysses process group
- config.py, examples, model docstrings: update terminology — Attention2D is
context parallelism; Ulysses is head-sharding; "mutually exclusive" ->
"cannot be combined (not yet implemented)" for attn2d+ulysses
- Tests: update test_visual_gen_mapping.py for cp_rank/cp_group/cp_size API
and "cp" dim names; replace TestAttn2DSetupParallelism (tested deleted
parallelism.py) with TestAttn2DValidation testing VisualGenMapping directly
Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
DiffusionModelConfig.parallel (ParallelConfig) duplicated information
already available through model_config.visual_gen_mapping. This field
was not present in origin/main and required keeping two sources of
truth in sync.
All callers now read parallelism dimensions directly from vgm:
- attn2d_row/col_size → vgm.attn2d_row/col_size
- ulysses_size → vgm.ulysses_size
- cfg_size → vgm.cfg_size
- seq_parallel_size → inlined (attn2d_size if >1 else ulysses_size)
Also removes the three attn2d_*_process_group fields from
DiffusionModelConfig and their pipeline_loader assignments; model
constructors now fetch these directly from vgm.attn2d_{row,col,mesh}_group.
Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Remove redundant `if vgm else None/0` guards inside `use_attn2d` blocks (vgm is always non-None when attn2d_size > 1) - Replace dist.get_rank(attn2d_mesh_group) with vgm.attn2d_mesh_rank - Remove unused `import torch.distributed as dist` from flux/flux2/wan transformers - Remove _vgm alias in attention.py (vgm already in scope) - Remove dead cfg_size entry from _setup_cfg_config return dict - Fix gather index bug: unconditional CFG result is always at index 1, not seq_parallel_size (which would be out-of-bounds with active parallelism) - Reorder dit_attn2d_row/col_size fields in ParallelConfig to sit alongside other CP-related sizes; drop redundant comment block - Update parallel.py module docstring: "Parallelism Wrappers" (plural, covers both Ulysses and Attention2D) - Restore primary-heads divisibility check for Ulysses in LTX2 transformer - Align BasicAVTransformerBlock parallelism setup with origin/main style (vgm extracted once, guarded via `vgm is not None`) Tests (test_visual_gen_mapping.py): - TestConstruction: add test_stores_attn2d_sizes - TestSingleGPURanksAndGroups: add attn2d_mesh_rank to test_ranks_are_zero, test_attn2d_mesh_rank_aliases_cp_rank, test_attn2d_mesh_group_aliases_cp_group, test_attn2d_row_col_groups_none_without_attn2d - TestValidation (new): test_ulysses_and_attn2d_raises, test_attn2d_cp_size_mismatch_raises, test_attn2d_and_ring_are_mutually_exclusive Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
- Add FA4 asymmetric mesh tests (1x4 and 4x1) to cover asymmetric Q/KV sequence lengths and flash_attn_combine with N=4 fan-in - Move VisualGenMapping validation tests from TestAttn2DValidation (test_attn2d_attention.py) into TestConstruction (test_visual_gen_mapping.py) where other constructor raises tests already live; remove duplicate class - Add test_attn2d_and_ring_are_mutually_exclusive to TestConstruction Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #45778 [ run ] triggered by Bot. Commit: |
|
PR_Github #45778 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45789 [ run ] triggered by Bot. Commit: |
|
PR_Github #45789 [ run ] completed with state |
|
/bot run --add-multi-gpu-test --only-multi-gpu-test |
|
PR_Github #45961 [ run ] triggered by Bot. Commit: |
|
PR_Github #45961 [ run ] completed with state
|
|
/bot run --add-multi-gpu-test --only-multi-gpu-test |
|
PR_Github #46150 [ run ] triggered by Bot. Commit: |
|
PR_Github #46150 [ run ] completed with state |
|
/bot run |
|
PR_Github #46248 [ run ] triggered by Bot. Commit: |
|
PR_Github #46248 [ run ] completed with state |
chang-l
left a comment
There was a problem hiding this comment.
Can we also update examples/visual_gen/README.md: add --attn2d_row_size/--attn2d_col_size to the Common Arguments table; add an "Attention2D (2D context parallelism)" subsection alongside Ulysses in Multi-GPU Parallelism section; rewrite the "ulysses_size must divide head count" troubleshooting line to point to Attention2D as the alternative?
- Expand Attention2DAttention docstring with purpose, motivation vs. Ulysses and Ring attention (O(N/√P) scaling), mesh layout ASCII diagram, 4-step architecture breakdown, supported backends, and constraints - Update examples/visual_gen/README.md: add Attention2D subsection to Multi-GPU Parallelism, add --attn2d_row_size/--attn2d_col_size to Common Arguments table, rewrite Ulysses head-count troubleshooting line to point to Attention2D, add Attention2D Errors section Signed-off-by: Venmugil Elango <498703+venmugil@users.noreply.github.com>
|
/bot run |
|
PR_Github #46448 [ run ] triggered by Bot. Commit: |
|
PR_Github #46448 [ run ] completed with state
|
|
/bot run |
|
PR_Github #46504 [ run ] triggered by Bot. Commit: |
chang-l
left a comment
There was a problem hiding this comment.
Thanks for the contribution!
|
PR_Github #46504 [ run ] completed with state
|
|
/bot run |
|
PR_Github #46602 [ run ] triggered by Bot. Commit: |
|
PR_Github #46602 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests
Description
Adds
Attention2DAttention, a 2D-mesh sequence parallelism strategy for visual-gen diffusion models (FLUX, Wan, LTX2), based on the approach described in arxiv.org/pdf/2503.15758.Unlike Ulysses, which requires the number of attention heads to be divisible by the parallelism degree, Attention2D distributes the sequence across a
row_size × col_sizeprocess group mesh with no head-divisibility constraint. It also has asymptotically lower communication cost: as the number of processors increases, each GPU handles a smaller sequence shard, so the per-GPU communication volume scales as O(1/P).Communication pattern:
Design:
Test Coverage
tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.pyFlashAttn4Attention.forwardtests/unittest/_torch/visual_gen/test_visual_gen_args.pytests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.pyPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.