Skip to content

Add partial-width tile shapes to x.all_gather tests#435

Merged
mawad-amd merged 2 commits intomuhaawad/hintsfrom
copilot/sub-pr-432-again
Mar 7, 2026
Merged

Add partial-width tile shapes to x.all_gather tests#435
mawad-amd merged 2 commits intomuhaawad/hintsfrom
copilot/sub-pr-432-again

Conversation

Copy link
Contributor

Copilot AI commented Mar 7, 2026

PR #434 fixed x.all_gather to use hint=(1, tile.block_n) instead of hint=(tile.block_m, tile.block_n), preventing incorrect cross-row contiguity assertions when BLOCK_N < N (stride_m > BLOCK_N), which caused scalar buffer_store_short writes to wrong addresses for 16-bit dtypes.

Test coverage additions (tests/x/test_all_gather.py)

  • (128, 128, 64, 32) — BLOCK_N < N/world_size; exercises the multi-block partial-width path (2 tiles per rank in the N direction)
  • (256, 128, 64, 16) — Minimum BLOCK_N=16; directly stresses the 16-bit vectorization path that emitted incorrect scalar stores under the old hint

Both shapes added to test_all_gather (all dtypes × both gather dims) and to test_all_gather_ctx_api. The original failing shape (128, 64, 64, 32) was already present; the new shapes extend coverage to multi-block and narrow-block configurations.


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copilot AI changed the title [WIP] Add hint parameters to IRIS collective communication operations Add partial-width tile shapes to x.all_gather tests Mar 7, 2026
@mawad-amd
Copy link
Collaborator

@copilot can you investigate this bug a bit more. Is this a triton bug we should report? Or is our code incorrect!

Copy link
Contributor Author

Copilot AI commented Mar 7, 2026

@copilot can you investigate this bug a bit more. Is this a triton bug we should report? Or is our code incorrect!

This is our code's bug, not a Triton bug. Here's the full root-cause analysis:

How hint works internally:

iris.store(..., hint=(H0, H1)) calls tl.max_contiguous(tl.multiple_of(ptr, (H0, H1)), (H0, H1)), which sets the tt.contiguity attribute on the translated pointer to [H0, H1]. Triton's coalescing pass calls getOrderFromContiguity([H0, H1]), which sorts dimensions by descending contiguity to pick the vectorization axis. If H0 > H1, dim 0 (rows) is chosen for vectorization. These are developer assertions that Triton trusts unconditionally — the source even says // override with hint. Triton has no way to validate them against actual runtime strides.

Why hint=(block_m, block_n) was wrong for partial-width tiles:

The semantic of contiguity[d] = C is: "C consecutive elements along dimension d are contiguous in memory" (from AxisInfo.h). For contiguity[0] = block_m to be true on a row-major tensor, rows must be contiguous — i.e., stride_m == block_n. That only holds when the tile spans the full width (block_n == N).

For the failing shape M=128, N=64, block_m=64, block_n=32: stride_m = N = 64 ≠ block_n = 32. So asserting contiguity[0] = 64 was a lie. Triton's coalescing pass trusted it, chose dim 0 for vectorization, and emitted stores striding by stride_m * sizeof(fp16) = 128 bytes instead of striding by 2 bytes (consecutive fp16 values). This produced writes to completely wrong addresses → max_diff = 1.0.

Why hint=(1, block_n) is always correct:

Within any row of a row-major tensor, block_n elements are always contiguous regardless of tile width vs. tensor width. Setting contiguity = [1, block_n] makes getOrderFromContiguity choose dim 1 (columns) for vectorization, which generates the correct contiguous loads/stores within each row.

In summary: Triton's hint mechanism is a low-level developer contract — the compiler trusts you. Passing a hint that asserts cross-row contiguity when the tile doesn't span the full tensor width is a correctness bug in the caller. No Triton fix is needed; the correct call pattern is hint=(1, block_n) for any partial-width tile in a row-major tensor.

@mawad-amd
Copy link
Collaborator

@copilot so are we losing on some potential vectorization here or no?

Copy link
Contributor Author

Copilot AI commented Mar 7, 2026

@copilot so are we losing on some potential vectorization here or no?

No vectorization is lost. Here's why:

Vector width is hardware-capped at 128 bits (from CoalesceUtils.cpp: "each thread handles at most 128 bits, which is the widest available vectorized store op"). For fp16 that's 8 elements, for fp32 it's 4. This cap applies regardless of which axis the hint selects.

hint=(1, block_n) gives correct vectorization for both tile configurations:

  • Partial-width tiles (block_n < N): Columns within a row are contiguous; rows are NOT (stride_m = N > block_n). Vectorizing along dim 1 with width min(block_n, 8) fp16 elements is the correct and maximum possible vectorization. The old hint=(block_m, block_n) would have selected dim 0 (rows) — which are non-contiguous — producing wrong addresses, not wider stores.

  • Full-width tiles (block_n == N): The entire tile is one contiguous block. Both hints ultimately produce the same 128-bit per-thread store width since that's the hardware cap. Triton may lay out threads slightly differently but peak throughput is identical.

The old hint was never giving wider stores — it was just choosing the wrong axis for vectorization, which is a correctness issue on partial-width tiles. On full-width tiles it happened to work because stride_m == block_n made the contiguity assertion accidentally true.

@mawad-amd mawad-amd marked this pull request as ready for review March 7, 2026 16:05
@mawad-amd mawad-amd merged commit 4324247 into muhaawad/hints Mar 7, 2026
2 checks passed
@mawad-amd mawad-amd deleted the copilot/sub-pr-432-again branch March 7, 2026 16:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants