Skip to content

Store tile shapes in row-major order#142

Merged
maleadt merged 5 commits intomainfrom
tb/storage_order
Mar 25, 2026
Merged

Store tile shapes in row-major order#142
maleadt merged 5 commits intomainfrom
tb/storage_order

Conversation

@maleadt
Copy link
Member

@maleadt maleadt commented Mar 25, 2026

Tile IR is natively row-major (Python cuTile passes shapes through verbatim). We were passing Julia's column-major shapes as-is, then compensating per-operation: reshape emitted a permute-reshape-permute sandwich, and batched matmul emitted 4 permute ops to convert trailing batch dims to MmaFOp's leading-batch convention.

This PR proposes to reverse all shapes once at the Julia-Tile IR boundary: Julia (M, K, B) becomes Tile IR (B, K, M).

What changes

Reshape — the double-permute hack is gone. Was:

permute [1, 0]  →  reshape  →  permute [1, 0]

Now:

reshape

Batched matmul — 4 permute ops per MmaFOp eliminated. Operands are swapped instead (mmaf(b, a, acc)), which produces the correct result because Julia (M,K,B) → Tile IR (B,K,M) and (B,N,K) @ (B,K,M) = (B,N,M) → Julia (M,N,B). Was:

permute [2, 0, 1]   # a: trailing → leading batch
permute [2, 0, 1]   # b: trailing → leading batch
permute [2, 0, 1]   # acc: trailing → leading batch
mmaf a, b, acc
permute [1, 2, 0]   # result: leading → trailing batch

Now:

mmaf b, a, acc

Axes for reduce/scan/cat are flipped (tileir_axis = ndim - 1 - julia_axis). Permutation indices are transformed. Load/store indices and tensor view sizes/strides are reversed.

Tile IR comparison: 3D batched matmul (32,16,4) * (16,32,4)

Before (7 ops between load and store):

%cst      = constant <f32: 0> : tile<32x32x4xf32>
%4        = permute %tile [2, 0, 1]     : tile<32x16x4xf32> -> tile<4x32x16xf32>
%5        = permute %tile_25 [2, 0, 1]  : tile<16x32x4xf32> -> tile<4x16x32xf32>
%6        = permute %cst [2, 0, 1]      : tile<32x32x4xf32> -> tile<4x32x32xf32>
%7        = mmaf %4, %5, %6             : tile<4x32x16xf32>, tile<4x16x32xf32>, tile<4x32x32xf32>
%8        = permute %7 [1, 2, 0]        : tile<4x32x32xf32> -> tile<32x32x4xf32>

After (2 ops):

%cst      = constant <f32: 0> : tile<4x32x32xf32>
%4        = mmaf %tile_25, %tile, %cst   : tile<4x32x16xf32>, tile<4x16x32xf32>, tile<4x32x32xf32>

Benchmarks (RTX 5080, min of 10 runs)

Benchmark Before (ms) After (ms) Change
batchmatmul 0.696 0.602 -13.5%
matmul 4.177 3.744 -10.4%
layernorm fwd 0.986 0.967 -1.9%
layernorm bwd 2.000 2.015 +0.8%
FFT 0.537 0.535 -0.4%
transpose 0.669 0.669 0%
vadd 1.917 1.918 0%

maleadt and others added 3 commits March 25, 2026 10:01
Tile IR is natively row-major (Python cuTile passes shapes verbatim).
We were passing Julia's column-major shapes through as-is, then
compensating with per-operation fixups: reshape emitted a
permute-reshape-permute sandwich, and batched matmul emitted 4
permute ops to convert trailing batch dims to MmaFOp's leading
batch convention.

Instead, reverse all shapes at the Julia↔Tile IR boundary:
Julia (M, K, B) → Tile IR (B, K, M). CGVal.shape now stores
Tile IR (row-major) order. Conversion happens in three functions:
_tile_type_for_julia!, tile_type_and_shape_for_julia!, and
extract_tile_shape.

This eliminates the reshape double-permute (now a direct ReshapeOp)
and all matmul permutes (operands swapped: mmaf(b, a, acc) computes
(N,K)@(K,M)=(N,M) → Julia (M,N), which is correct). Axes for
reduce/scan/cat are flipped (tileir_axis = ndim-1-julia_axis),
indices for load/store are reversed, and tensor view sizes/strides
are reversed to match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@AntonOresten
Copy link
Contributor

Operands are swapped instead (mmaf(b, a, acc))

This is a great solution. I was previously confused by the operand ordering and shapes and why it was that they could be passed directly to mmaf, but this resolves the confusion.

@maleadt maleadt marked this pull request as ready for review March 25, 2026 13:18
@maleadt maleadt merged commit ac0ecb8 into main Mar 25, 2026
9 checks passed
@maleadt maleadt deleted the tb/storage_order branch March 25, 2026 14:29
@maleadt maleadt mentioned this pull request Mar 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants