Skip to content

Add rewrite rule to drop contiguous-axis stride in scatter/gather offsets + unified AssumeOp injection#213

Merged
maleadt merged 6 commits into
mainfrom
tb/contiguous_gather
Apr 30, 2026
Merged

Add rewrite rule to drop contiguous-axis stride in scatter/gather offsets + unified AssumeOp injection#213
maleadt merged 6 commits into
mainfrom
tb/contiguous_gather

Conversation

@maleadt
Copy link
Copy Markdown
Member

@maleadt maleadt commented Apr 30, 2026

For a TileArray with ArraySpec{contiguous=true}, Julia's column-major convention makes stride[1] == 1 statically known, and the constant analysi) propagates that 1 through the broadcast/reshape/from_scalar chain that feeds the gather/scatter offset compute. This PR adds the matching muli(x, 1) → x rewrite so the contiguous-axis stride multiply collapses out of the offset.

maleadt added 4 commits April 30, 2026 12:37
…gather.

For a `TileArray` with `ArraySpec` `contiguous=true`, the constant
analysis already recognises `getfield(getfield(arg, :strides), 1)` as
the literal `1`, propagating through `broadcast`/`reshape`/`from_scalar`.
Adding the matching algebra rewrite drops the `muli(idx, broadcast(1))`
in scatter/gather offset chains.

The fold mirrors Python cuTile's `_gather_scatter_pointer_and_mask`
which uses a structural skip (`if static_stride == 1: offset_delta =
ind`). Without it, the contiguous-axis stride is a runtime broadcast
and consecutive lanes' addresses differ by an unknown scalar, forcing
tileiras to fall back to scalar stores (`STG.E.U16`) instead of wide
vector (`STG.E.128`) stores in 2-D scatter kernels (MoE down-projection).

Standalone, this rewrite triggers a tileiras crash at -O1+ on the MoE
kernel — tileiras's auto-vectorizer enters a code path that needs
alignment proofs which cuTile.jl doesn't currently emit for
scatter/gather pointer/size/stride args. The follow-up commit extends
the assume pass to inject those.
Drops the per-`make_tensor_view` `AssumeInfo` / `MTVPredicates` sidecar
in favor of on-demand chain derivation at consumer sites. The
divisibility / bounds dataflow results live on `CGCtx`; consumer
codegen calls `op_predicates(divby, bounds, op, kind, spec_div)` to
derive each operand's `AssumePredicate` chain, and `wrap_for` consults
a per-`Value` cache (`ctx.assume_wrapped`) so a `Value` reused across
consumers — e.g. a kernel-arg pointer threaded through both an MTV and
a gather — is wrapped exactly once. Mirrors the role of cuTile Python's
`var_map` in `_passes/propagate_divby.py`.

Extends the consumer set from `{make_tensor_view}` to
`{make_tensor_view, load_ptr_tko, store_ptr_tko}` (Python's
`_OPS_NEED_ASSUME`), and adds an entry-time spec-derived wrap on each
`TileArray` kernel-arg flat slot (`apply_arg_assume_predicates!` +
`arg_chain`). The entry wrap is what carries `spec.alignment` to the
base pointer of gather/scatter chains: the post-offset operand at
`load_ptr_tko` only sees the lane-stride alignment, but the assumed
base `Value` flows through `reshape` → `broadcast` → `offset` so
tileiras's vectorizer has both alignments — the proof its STG.E.128 /
LDG.E.128 lowering needs on the MoE down-projection scatter.

`current_block` is tracked on `CGCtx` so consumer-op codegen can run
parent-walking queries (`tuple_element_source` for tuple-typed
sizes/strides operands) starting from the right scope.
@maleadt maleadt changed the title Add rewrite rule to drop contiguous-axis stride in scatter/gather offsets Add rewrite rule to drop contiguous-axis stride in scatter/gather offsets + unified AssumeOp injection Apr 30, 2026
@maleadt
Copy link
Copy Markdown
Member Author

maleadt commented Apr 30, 2026

Also adding a refactor I was planning in a subsequent PR, to add assumptions for non-TileView operands (since gather/scatter take raw pointers), because withouto it I trigger NVIDIA/cuda-tile#19:

Unified AssumeOp injection

The fold above triggers a tileiras crash at -O1+ standalone — its
auto-vectorizer enters a code path that needs alignment proofs we
weren't emitting for gather/scatter pointer args. This commit
generalizes the existing per-MTV AssumeOp injection to cover all
of _OPS_NEED_ASSUME = {make_tensor_view, load_ptr_tko, store_ptr_tko}
(matching cuTile Python), and adds an entry-time spec-derived wrap on
each TileArray kernel-arg flat slot.

Architecturally it's also a simplification: the per-make_tensor_view
AssumeInfo / MTVPredicates sidecar is gone; the divby/bounds
dataflow results live on CGCtx and consumer codegen calls
op_predicates(divby, bounds, op, kind, spec_div) to derive each
operand's chain on demand. A per-Value cache (ctx.assume_wrapped)
plays the role of Python's var_map, ensuring a Value reused across
consumers — e.g. a kernel-arg pointer threaded through both an MTV and
a gather — is wrapped exactly once.

The entry wrap is what carries spec.alignment to the base pointer:
the post-offset operand at load_ptr_tko only sees the lane-stride
alignment, but the assumed base Value flows through
reshape → broadcast → offset so tileiras's vectorizer has both
alignments.

maleadt added 2 commits April 30, 2026 15:21
The two helpers were parallel: both pulled spec.alignment /
shape_div_by / stride_div_by and applied the same structural priors
(`Bounded(0,?)` for sizes/strides, `DivBy(d)` when `d > 1`). Replace
arg_chain's body with a path-keyed dispatch over op_predicates with
nothing dataflow inputs — the kernel-arg slot is the dataflow anchor,
so there's nothing upstream to refine against. The contiguous-axis
stride skip stays in the dispatcher since it has the path context.

Drops ~15 lines and removes a pair of trivially-equivalent code paths.
The per-Value cache keys only on Value, not on chain contents — sound
because the pipeline arranges that the first-seen chain on a given
Value is an upper bound on what any later consumer could derive
(kernel-arg-entry wrap seeds the spec-tightest chain; structural prior
is tile-type-determined so per-Value consistent).

Promote the comment to a docstring spelling out both reasons plus the
failure mode if a future consumer ever derives a tighter chain on a
pre-wrapped Value.
@maleadt maleadt merged commit e58e475 into main Apr 30, 2026
13 checks passed
@maleadt maleadt deleted the tb/contiguous_gather branch April 30, 2026 13:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant