diff --git a/.claude/skills/improve-cutile-kernel-perf/SKILL.md b/.claude/skills/improve-cutile-kernel-perf/SKILL.md new file mode 100644 index 0000000..636e48c --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/SKILL.md @@ -0,0 +1,147 @@ +--- +name: improve-cutile-kernel-perf +description: Iteratively optimize cuTile kernel performance through systematic profiling, bottleneck analysis, IR comparison, and targeted tuning. Covers tile sizes, occupancy, autotune configs, TMA, latency hints, persistent scheduling, num_ctas, flush_to_zero, and IR-level debugging. Use when asked to "optimize cutile kernel", "improve kernel perf", "tune cutile performance", "make kernel faster", or iteratively benchmark and refine a cuTile GPU kernel in the TileGym project. +version: 2026.04.11-alpha +environment: + IDE: + - Claude Code + - Cursor (Agent mode) + model: + - Opus 4.6 +requires: +- GPU node Blackwell, Hopper and Ampere for benchmarking +license: MIT. Complete terms in LICENSE. +--- + +# Iterative cuTile Kernel Performance Optimization +Systematically profile, diagnose bottlenecks, and iteratively tune a cuTile kernel's performance in the TileGym repository. + +## Setup +Work with user to prepare optimization environment: +1. Create a fresh git branch: Propose a branch name, e.g., `cutile-perf--` from current branch. Checkout `git checkout -b ` +2. Locate the target kernel: + - cuTile kernels live under `src/tilegym/suites//cutile/` or `src/tilegym/ops/cutile/` + - Read the kernel file and identify: the `@ct.kernel` decorated function(s), the launch wrapper (`ct.launch()` or `ct_experimental.autotune_launch()`), the `@register_impl` registration, and current autotune configs (if any) +3. Classify the kernel: + - Arithmetic Intensity < 10 -> Memory-bound (primary metric: GB/s) + - Arithmetic Intensity 10-50 -> Balanced (track both GB/s and TFLOPS) + - Arithmetic Intensity > 50 -> Compute-bound (primary metric: TFLOPS) +4. Check GPU environment: + - Ensure a GPU node (B200/H100/H200) is available + - All subsequent benchmark commands should run on the GPU node + - Check `ncu` CLI available for deep profiling +5. Study related references: + - `references/optimization-playbook.md`: Step-by-step recipes for each optimization (A through J) with before/after code examples + - `references/perf-knobs-catalog.md`: Complete catalog of all tunable parameters (TMA, persistent scheduling, occupancy, tile sizes, latency hints, etc.) + - `references/cutile-api-reference.md`: cuTile API reference and 18 critical rules + - `references/performance-model.md`: Roofline/performance model, bottleneck diagnosis, autotuning + - `references/ir-dump-guide.md`: IR dump, analysis, and error diagnosis + - `references/cutile-patterns-reference.md`: Common cuTile patterns and conversion quick-reference +6. Create @sandbox/perf_results.md to track progress. The first run will write a baseline +7. Confirm and go: Once you get confirmation, kick off the experimentation + +## Experimentation +Every experiment iteration applies ONE optimization to the target kernel, verifies correctness, re-benchmarks, and records results. Each iteration should be enforced to finish within 10 minutes. + +### The goal +- Improve the **core metric**: reduce `SM Active Cycles` +- Subject to the **core constraint**: Correctness shall not regress — every optimization MUST preserve numerical correctness. `SM Active Cycles` shall not regress > 2% compared to baseline. + +### What you can change +- The target kernel file under `src/tilegym/suites//cutile/` or `src/tilegym/ops/cutile/`: kernel body, tile sizes, occupancy, num_ctas, TMA usage, latency hints, flush_to_zero, autotune configs, persistent scheduling, and other cuTile-specific parameters +- The kernel's launch wrapper: grid computation, autotune config space +- @sandbox/: Feel free to add new files or modify files created by you, but don't check to git + +### What you can NOT change +- Kernel functional semantics (inputs, outputs, and numerical behavior within tolerance) +- Test infrastructure and benchmark harness +- Anything not listed above + +### What to expect from experiment outputs + +#### Correctness test: +```bash +python -m pytest tests/suites/.../test_.py -k "test_op and cutile" -v +``` + +#### Performance benchmark: +For each iteration: +1. Run pytest benchmark: `python -m pytest ... --print-record` → extract latency (ms) +2. Run ncu profiling: `ncu [command]` → extract GB/s (memory-bound), TFLOPS (compute-bound) and `SM Active Cycles`. +3. Record both metrics in perf_results.md + +Benchmark cmdlines: +```bash +python -m pytest tests/suites/.../test_.py -k "test_perf and cutile" --print-record -v +``` + +latency sample: +``` +Cutile: {'forward': {'mean': 3.7903138461538455, 'std': 0.0016941310873207053, 'rel_std': 0.044696327430505396, 'median': 3.789880999999999, 'min': 3.7883389999999992, 'max': 3.7941230000000004, 'nrep': 13, 'peak_mem_mb': 913}} ms +``` + +### Track experiment progress +Use @sandbox/perf_results.md to record each iteration's results. It should only contain a Markdown table with 7 columns: +- `iteration`: iteration number, starting from 0 (baseline) +- `optimization`: what was applied (e.g., "baseline", "TMA replace gather", "persistent scheduling") +- `metric`: primary metric value (GB/s or TFLOPS) +- `latency_ms`: kernel latency in milliseconds, six decimal points +- `SM Active Cycles`: cuTile backend `SM Active Cycles` +- `correctness`: PASS or FAIL +- `status`: Whether this iteration was `keep`, `revert`, or `crash` + +Example content: + +```markdown +| iteration | optimization | metric | latency_ms | SM Active Cycles | correctness | status | +|----------:|:-------------|-------:|-----------:|------------------:|:------------|-------:| +| 0 | baseline | 245.30 | 0.82 | 1,342,117 | PASS | keep | +| 1 | TMA replace gather | 512.60 | 0.39 | 1,161,237 | PASS | keep | +``` + +Create the tabular header if the file was empty. Append one line for each iteration. + +### The baseline +The first iteration (iteration 0) will not change any code and simply run the correctness test and performance benchmark. Results will be listed at the first row as baseline. + +## The experiment loop +Core methodology is to apply ONE optimization per iteration from the playbook, verify correctness, benchmark, and decide whether to keep or revert. Try one optimization at a time, and have clean experiment records. + +LOOP: +1. Check git status: Current git branch/commit we're on +2. Profile and classify bottleneck using quick code inspection: + + | Pattern in Code | Likely Bottleneck | Optimization | + |----------------|-------------------|--------------| + | `ct.gather`/`ct.scatter` where TMA possible | TMA fallback | A (TMA) | + | No `for ... in range(bid, n, num_programs)` | Missing persistent | B (Persistent) | + | `@ct.kernel` with no `occupancy=` AND no autotune | Untuned occupancy | C (Autotune) | + | `ct.mma(a, b, acc)` without tf32 guard | Missing TF32 | D (TF32) | + | No `latency=` hints on `ct.load`/`ct.store` | Missing latency hints | E (Latency) | + | `ct.store()` without `allow_tma=False` | Suboptimal store path | F (Store TMA) | + | Small fixed tile sizes | Tile size mismatch | G (Tile Size) | + | All A–J exhausted or inapplicable | Unknown / kernel-specific | K (Customized Creative Optimization Plan) | + +3. Select and apply ONE optimization from `references/optimization-playbook.md`: + - **Memory-bound priority**: A (TMA) -> B (Persistent) -> C (Autotune) -> F (Store TMA) -> G (Tile Size) -> E (Latency) -> K (Creative Optimization Plan) + - **Compute-bound priority**: D (TF32) -> G (Tile Size) -> C (Autotune + num_ctas) -> I (Swizzle) -> B (Persistent) -> K (Creative Optimization Plan) +4. Verify correctness — if fails, **revert immediately**. Common causes: `flush_to_zero`/`rounding_mode=APPROX` changed results, tile size OOB, `allow_tma=False` semantics, persistent loop bound error +5. Re-benchmark and compare against current baseline +6. Git commit +7. Record results to @sandbox/perf_results.md +8. Decision rules: + + | Outcome | Action | + |---------|--------| + | Improvement(`SM Active Cycles`) >= 5% | Accept as new baseline, continue | + | Improvement 2-5% | Accept, lower priority for next iteration | + | Improvement < 2% | Accept but stop unless user wants more | + | Regression on any config | Revert immediately, try next optimization | + | No improvement after 2 consecutive iterations | Stop | + | Root cause is `scheduling` or `unknown` | Escalate to user | + +9. If keeping, advance the baseline numbers and continue loop +10. If reverting, git reset back to where you started and try the next optimization in priority order +UNTIL: all attempts are finished, or more than 20 iterations have occurred, or the user interrupts + +*Be autonomous*: Ask user clarifications at setup phase. Once stepped into the experiment loop, do not pause to ask user feedback: Use your best judgement for decision making, consult the optimization playbook and perf knobs catalog promptly, and think harder if stuck. diff --git a/.claude/skills/improve-cutile-kernel-perf/references/cutile-api-reference.md b/.claude/skills/improve-cutile-kernel-perf/references/cutile-api-reference.md new file mode 100644 index 0000000..5e09bdb --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/references/cutile-api-reference.md @@ -0,0 +1,856 @@ + + + + +# cuTile API Reference + +## Contents + [Quick Lookup: Most Common Mistakes](#quick-lookup-most-common-mistakes) + [Import & Decorator](#import--decorator) + [Indexing](#indexing) + [Memory Operations](#memory-operations) + [Tensor Creation](#tensor-creation) + [Reductions](#reductions) + [Scan Operations](#scan-operations) + [Matrix Operations](#matrix-operations) + [Type & Shape Operations](#type--shape-operations) + [Slicing & Extraction](#slicing--extraction) + [Math Functions](#math-functions) + [Comparison Operations](#comparison-operations) + [Bitwise Operations](#bitwise-operations) + [Atomic Operations](#atomic-operations) + [Debug & Utility Functions](#debug--utility-functions) + [Host Functions](#host-functions) + [Data Types](#data-types) + [Enums: PaddingMode, RoundingMode, MemoryOrder, MemoryScope](#enums) + [Launch Pattern](#launch-pattern) + [Kernel Compilation Hints](#kernel-compilation-hints) + [Critical Rules (The 18 Rules)](#critical-rules-the-18-rules) + +> **For patterns, debug tables, and conversion reference:** See [cutile-patterns-reference.md](cutile-patterns-reference.md) + +--- + +## Quick Lookup: Most Common Mistakes + +| What You Wrote | What's Wrong | Correct Form | +|----------------|--------------|--------------| +| `import cutile as ct` | Wrong module name | `import cuda.tile as ct` | +| `ct.add(bid, offset)` | Promotes to float | `bid + offset` (Python op) | +| `x.to(ct.float32)` | No `.to()` method | `ct.astype(x, ct.float32)` | +| `grid = lambda: (n,)` | No lambda grid | `grid = (n, 1, 1)` | +| `ct.launch(..., None)` | No None allowed | Use dummy tensor + flag | + +## Import & Decorator + +```python +import cuda.tile as ct # NOT import cutile as ct! + +@ct.kernel +def kernel(X, Y, BLOCK: ct.Constant[int]): + ... + +ConstInt = ct.Constant[int] # Type alias for cleaner signatures +``` + +## Indexing + +| Function | Description | Example | +|----------|-------------|---------| +| `ct.bid(axis)` | Get block ID (axis: 0, 1, 2) | `bid = ct.bid(0)` | +| `ct.num_blocks(axis)` | Get grid size along axis | `n = ct.num_blocks(0)` | +| `ct.arange(size, dtype=)` | Create range [0, size) — starts at 0! | `offs = ct.arange(256, dtype=ct.int32)` | +| `ct.num_tiles(arr, axis, shape)` | Number of tiles in tile space along axis | `n = ct.num_tiles(A, 0, shape=(64, 64))` | + +**Persistent scheduling pattern** (kernel processes multiple blocks): +```python +@ct.kernel +def persistent_kernel(X, Y, BLOCK: ConstInt): + num_blks = ct.num_blocks(0) # total blocks in grid + for bid in range(ct.bid(0), total_tiles, num_blks): + x = ct.load(X, index=(bid,), shape=(BLOCK,)) + ct.store(Y, index=(bid,), tile=x) +``` + +## Memory Operations + +### ⚠️ TMA-FIRST STRATEGY + +**ALWAYS try TMA (`ct.load`/`ct.store`) FIRST!** TMA is 2-4x faster than gather/scatter due to hardware acceleration. + +### TMA Load/Store (PREFERRED - Block-aligned) + +| Function | Signature | +|----------|-----------| +| `ct.load(arr, index, shape, *, order='C', padding_mode=PaddingMode.UNDETERMINED, latency=None, allow_tma=None, memory_order=MemoryOrder.WEAK, memory_scope=MemoryScope.NONE)` | TMA load | +| `ct.store(arr, index, tile, *, order='C', latency=None, allow_tma=None, memory_order=MemoryOrder.WEAK, memory_scope=MemoryScope.NONE)` | TMA store | + +**Parameters:** +- `order` — `'C'` (default, no permutation), `'F'` (reversed axes), or tuple of ints for custom axis permutation +- `padding_mode` — What value to use for out-of-bounds reads (see [PaddingMode](#enums)) +- `latency` — Hint for DRAM traffic intensity, int 1 (low) to 10 (high), or None (auto) +- `allow_tma` — If `False`, disables TMA for this load/store. Default `None` (TMA allowed) +- `memory_order` — Memory ordering for non-TMA load/store. Default `MemoryOrder.WEAK` (see [MemoryOrder](#enums)) +- `memory_scope` — Memory scope for non-TMA load/store. Default `MemoryScope.NONE` (see [MemoryScope](#enums)) + +**⚠️ CRITICAL: `index` and `shape` must have the SAME number of dimensions as the source tensor!** + +**⚠️ CRITICAL: `index` is BLOCK INDEX (which block), NOT element offset!** + +```python +# CORRECT: index=(bid,) means "load block number `bid`" +bid = ct.bid(0) +x = ct.load(X, index=(bid,), shape=(BLOCK,)) # Loads elements [bid*BLOCK : (bid+1)*BLOCK] + +# WRONG: Do NOT multiply bid by BLOCK_SIZE +# x = ct.load(X, index=(bid * BLOCK,), shape=(BLOCK,)) # WRONG! + +# Example: Loading 2D tile from 4D tensor [batch, head, seq, dim] +# CORRECT: index and shape both have 4 elements, then reshape +q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) +).reshape((TILE_M, TILE_D)) + +# WRONG: mismatched dimensions +# q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0), shape=(TILE_M, TILE_D)) # ERROR! + +# Load with transpose +tile = ct.load(array2d, (0, 0), shape=(4, 2), order='F') + +# Load a single element as 0d tile +tile = ct.load(array3d, (0, 0, 0), shape=()) +``` + +### Gather/Scatter (FALLBACK - Arbitrary offset) + +**Use ONLY when TMA truly fails** (truly sparse random access). Most "paged" or "ragged" patterns CAN use TMA - see patterns below! + +| Function | Signature | +|----------|-----------| +| `ct.gather(arr, indices, *, mask=None, padding_value=0, check_bounds=True, latency=None)` | Gather load | +| `ct.scatter(arr, indices, value, *, mask=None, check_bounds=True, latency=None)` | Scatter store | + +**gather parameters:** +- `indices` — Tuple of integer tiles (length = array rank), or single tile for 1D arrays +- `mask` — Boolean tile; where `False`, returns `padding_value` instead of loading +- `padding_value` — Value for masked/OOB elements (default: 0) +- `check_bounds` — If `True` (default), OOB indices return `padding_value`. If `False`, OOB is undefined behavior +- `latency` — DRAM traffic hint (1-10), or None (auto) + +**scatter parameters:** +- `indices` — Same as gather +- `value` — Tile or scalar to store +- `mask` — Boolean tile; where `False`, no store occurs +- `check_bounds` — If `True` (default), OOB indices are skipped. If `False`, OOB is undefined behavior +- `latency` — DRAM traffic hint (1-10), or None (auto) + +**Note:** When both `mask` and `check_bounds=True` are provided, the effective mask is the logical AND of both. + +### TMA with Runtime Index (ct.gather().item() Pattern) - CRITICAL! + +**⚠️ TMA works with RUNTIME indices!** For paged attention or indirect access: + +```python +# ⚠️ WRONG (78x slower!): Using gather for all loads +page_id_tile = ct.gather(block_tables, (idx,)) +k_indices = compute_flat_indices(page_id_tile, ...) +k_tile = ct.gather(k_cache.view(-1), k_indices) # NO TMA! + +# ✅ CORRECT: Extract scalar with .item(), then use ct.load(allow_tma=True) +page_id = ct.gather(block_tables, (idx,), padding_value=0).item() +k_tile = ct.load(k_cache, index=(page_id, ...), shape=(...), allow_tma=True) +``` + +| Pattern | Use | Performance | +|---------|-----|-------------| +| `ct.gather` for all loads | NO TMA | 78x slower | +| `ct.gather().item()` + `ct.load(allow_tma=True)` | TMA enabled | Baseline | + +## Tensor Creation + +| Function | Description | +|----------|-------------| +| `ct.zeros(shape, dtype)` | Create zero-filled tile | +| `ct.ones(shape, dtype)` | Create one-filled tile | +| `ct.full(shape, fill_value, dtype)` | Create tile filled with given value | + +**⚠️ `shape` must be compile-time constants (literals or `ct.Constant` params), NOT `X.shape`.** + +## Reductions + +| Function | Description | Optional Params | +|----------|-------------|-----------------| +| `ct.sum(x, axis=None, *, keepdims=False)` | Sum reduction | `rounding_mode=`, `flush_to_zero=` | +| `ct.max(x, axis=None, *, keepdims=False)` | Max reduction | `flush_to_zero=` | +| `ct.min(x, axis=None, *, keepdims=False)` | Min reduction | `flush_to_zero=` | +| `ct.prod(x, axis=None, *, keepdims=False)` | Product reduction | `rounding_mode=`, `flush_to_zero=` | +| `ct.argmax(x, axis=None, *, keepdims=False)` | Index of max value | — | +| `ct.argmin(x, axis=None, *, keepdims=False)` | Index of min value | — | +| `ct.reduce(x, axis, func, identity, *, keepdims=False)` | Custom reduction | — | + +**`axis`**: `None` (reduce all), `int`, or `tuple[int, ...]`. + +**`ct.reduce` example:** +```python +# Custom sum via reduce +result = ct.reduce(x, axis=0, func=lambda a, b: a + b, identity=0) + +# Multi-tile reduce (x is a tuple of tiles) +# func takes 2N args and returns N combined tiles +``` + +## Scan Operations + +| Function | Description | Optional Params | +|----------|-------------|-----------------| +| `ct.cumsum(x, axis=0, *, reverse=False)` | Cumulative sum | `rounding_mode=`, `flush_to_zero=` | +| `ct.cumprod(x, axis=0, *, reverse=False)` | Cumulative product | `rounding_mode=`, `flush_to_zero=` | +| `ct.scan(x, axis, func, identity, *, reverse=False)` | Custom scan (inclusive prefix) | — | + +**`ct.scan` example:** +```python +# Custom cumsum via scan +result = ct.scan(x, axis=0, func=lambda a, b: a + b, identity=0) +``` + +## Matrix Operations + +| Function | Description | +|----------|-------------| +| `ct.matmul(a, b)` or `a @ b` | Matrix multiply (1D/2D/3D). Auto-promotes dtypes. | +| `ct.mma(a, b, acc)` | MMA with accumulator — preserves acc dtype. | + +**`ct.mma` signature:** `def mma(x, y, /, acc) -> Tile` + +`acc` is a **positional** parameter (not keyword-only). Both forms work: +```python +acc = ct.mma(a, b, acc) # positional — OK +acc = ct.mma(a, b, acc=acc) # keyword — also OK +``` + +**Supported mma dtypes:** + +| Input | Acc/Output | +|-------|------------| +| f16 | f16 or f32 | +| bf16 | f32 | +| f32 | f32 | +| f64 | f64 | +| tf32 | f32 | +| f8e4m3fn | f16 or f32 | +| f8e5m2 | f16 or f32 | +| [u\|i]8 | i32 | + +**⚠️ `ct.mma` does NOT auto-cast f32→tf32.** You must manually cast: +```python +a_tf32 = ct.astype(a, ct.tfloat32) +b_tf32 = ct.astype(b, ct.tfloat32) +acc = ct.mma(a_tf32, b_tf32, acc) +``` + +### Block-Scaled MMA + +> **⚠️ Note:** `mma_scaled` is defined in `_stub.py` but is **not yet exported** from `cuda.tile.__init__`. The datatypes `float8_e8m0fnu` and `float4_e2m1fn` required by this API are also not yet exported. Confirm with the cuTile team before using. + +`ct.mma_scaled(x, x_scale, y, y_scale, /, acc)` — block-scaled matrix multiply-accumulate for microscaling (MX) formats. + +Computes: `result[i,j] = sum_k (x[i,k] * x_scale[i,k/V]) * (y[k,j] * y_scale[k/V,j]) + acc[i,j]` + +| Input (x/y) | Scale | Acc/Out | Block Factor V | +|-------------|-------|---------|---------------| +| f8e4m3fn, f8e5m2 | f8e8m0fnu | f32 | 32 | +| f4e2m1fn | f8e8m0fnu | f32 | 16, 32 | +| f4e2m1fn | f8e4m3fn | f32 | 16 | + +```python +tx = ct.full((16, 32), 1, dtype=ct.float8_e4m3fn) +sx = ct.full((16, 1), 1, dtype=ct.float8_e8m0fnu) # scale shape: [M, K_s] +ty = ct.full((32, 16), 1, dtype=ct.float8_e4m3fn) +sy = ct.full((1, 16), 1, dtype=ct.float8_e8m0fnu) # scale shape: [K_s, N] +acc = ct.full((16, 16), 0, dtype=ct.float32) +result = ct.mma_scaled(tx, sx, ty, sy, acc) +``` + +## Type & Shape Operations + +| Function | Description | +|----------|-------------| +| `ct.astype(x, dtype)` | Type cast — **NO .to() method!** | +| `ct.bitcast(x, dtype)` | Reinterpret bits as different dtype (no conversion) | +| `ct.transpose(x, axis0=None, axis1=None)` | Transpose two axes (2D: auto, >2D: must specify) | +| `ct.permute(x, axes)` | Permute dimensions | +| `ct.reshape(x, shape)` | Reshape tile (supports -1 for auto-infer) | +| `ct.expand_dims(x, axis)` | Insert size-1 axis. Also: `x[:, None]`, `x[None, :]` | +| `ct.cat(tiles, axis)` | Concatenate two same-shape tiles along axis | +| `ct.broadcast_to(x, shape)` | Broadcast tile to target shape (NumPy rules) | +| `ct.pack_to_bytes(x)` | Flatten tile and reinterpret raw bytes as 1D uint8 tile. ⚠️ **Not yet exported** from `cuda.tile.__init__` | +| `ct.unpack_from_bytes(x, dtype)` | Reinterpret 1D uint8 tile as 1D tile of target dtype (inverse of `pack_to_bytes`). ⚠️ **Not yet exported** from `cuda.tile.__init__` | + +**Tile properties:** `tile.dtype`, `tile.shape`, `tile.ndim` +**Tile methods:** `tile.item()` (reshape to 0D scalar), `tile.reshape(shape)`, `tile.permute(axes)`, `tile.transpose(axis0, axis1)`, `tile.astype(dtype)`, `tile.extract(index, shape)` + +**Array properties:** `array.dtype`, `array.shape`, `array.strides`, `array.ndim` +**Array methods:** `array.slice(axis, start, stop)` — creates a view with restricted range along one axis + +## Slicing & Extraction + +| Function | Description | +|----------|-------------| +| `ct.extract(tile, index, shape)` | Extract sub-tile (like ct.load but on a tile) | +| `array.slice(axis, start, stop)` | Slice array along axis (view, no copy) | + +```python +# ct.extract: Extract a sub-tile from a larger tile +a_reshaped = ct.reshape(a_interleaved, (TILE_M, TILE_N, 2)) + +# Extract first slice along dim 2 +gelu_part = ct.reshape( + ct.extract(a_reshaped, index=(0, 0, 0), shape=(TILE_M, TILE_N, 1)), + (TILE_M, TILE_N) +) +# Extract second slice along dim 2 +linear_part = ct.reshape( + ct.extract(a_reshaped, index=(0, 0, 1), shape=(TILE_M, TILE_N, 1)), + (TILE_M, TILE_N) +) + +# array.slice: Create a view of an array with restricted range +segment = A.slice(axis=1, start=offset, stop=offset + length) +tile = ct.load(segment, (0, 0), shape=(TILE_M, TILE_N)) +``` + +## Math Functions + +### Unary Math + +| Function | Description | Optional Params | +|----------|-------------|-----------------| +| `ct.exp(x)` | Exponential | — | +| `ct.exp2(x)` | Base-2 exponential | `flush_to_zero=` | +| `ct.log(x)` | Natural log | — | +| `ct.log2(x)` | Base-2 log | — | +| `ct.sqrt(x)` | Square root | `rounding_mode=`, `flush_to_zero=` | +| `ct.rsqrt(x)` | Reciprocal sqrt (1/√x) | `flush_to_zero=` | +| `ct.sin(x)` | Sine | — | +| `ct.cos(x)` | Cosine | — | +| `ct.tan(x)` | Tangent | — | +| `ct.sinh(x)` | Hyperbolic sine | — | +| `ct.cosh(x)` | Hyperbolic cosine | — | +| `ct.tanh(x)` | Hyperbolic tangent | `rounding_mode=` (supports `FULL`, `APPROX`) | +| `ct.floor(x)` | Floor | — | +| `ct.ceil(x)` | Ceiling | — | +| `ct.abs(x)` | Absolute value | — | +| `ct.negative(x)` or `-x` | Negation | — | +| `ct.isnan(x)` | Check for NaN (returns bool tile) | — | + +**`flush_to_zero`** (bool): If `True`, flushes subnormal inputs/results to sign-preserving zero. Default `False`. + +**`rounding_mode`** (RoundingMode): Controls rounding behavior for float ops. See [RoundingMode enum](#enums). + +### Binary Math + +| Function | Python Operator | Optional Params | +|----------|-----------------|-----------------| +| `ct.add(x, y)` | `x + y` | `rounding_mode=`, `flush_to_zero=` | +| `ct.sub(x, y)` | `x - y` | `rounding_mode=`, `flush_to_zero=` | +| `ct.mul(x, y)` | `x * y` | `rounding_mode=`, `flush_to_zero=` | +| `ct.truediv(x, y)` | `x / y` | `rounding_mode=`, `flush_to_zero=` | +| `ct.floordiv(x, y)` | `x // y` | — | +| `ct.mod(x, y)` | `x % y` | — | +| `ct.pow(x, y)` | `x ** y` | — | +| `ct.maximum(x, y)` | `max(x, y)` | `flush_to_zero=` | +| `ct.minimum(x, y)` | `min(x, y)` | `flush_to_zero=` | +| `ct.atan2(x1, x2)` | — | — | +| `ct.cdiv(x, y)` | — | — (ceil division, works on host too) | + +**Recommended**: Use Python `+, -, *, /, //, %, **` operators for all arithmetic on both tiles and scalars. +Use `ct.add`/`ct.mul`/`ct.sub`/`ct.truediv` only when you need `flush_to_zero=` or `rounding_mode=` parameters (e.g., `ct.truediv(x, y, rounding_mode=RoundingMode.APPROX)`). The `ct.*` forms may also promote int32 to float — another reason to prefer Python operators for general use. + +### Conditional + +| Function | Description | +|----------|-------------| +| `ct.where(cond, x, y)` | Select elements: `x` where `cond` is True, `y` otherwise | + +### Missing Functions (Must Implement Manually) + +| Function | Implementation | +|----------|----------------| +| `softmax(x)` | `exp_x = ct.exp(x - ct.max(x, axis=...)); exp_x / ct.sum(exp_x, axis=...)` | +| `sigmoid(x)` | `1.0 / (1.0 + ct.exp(-x))` | +| `sign(x)` | `ct.where(x > 0, 1, 0) + ct.where(x < 0, -1, 0)` | +| `flip(x, dim)` | Use manual indexing with reversed indices | +| `norm(x)` | `ct.sqrt(ct.sum(x * x))` | +| `fma(a, b, c)` | `a * b + c` (no `ct.fma` API — compiler auto-fuses to FMA instruction) | +| `clamp(x, min, max)` | `ct.minimum(ct.maximum(x, min_val), max_val)` | +| `square(x)` | `x * x` | + +## Comparison Operations + +All comparisons return boolean tiles and support broadcasting + dtype promotion. + +| Function | Python Operator | +|----------|-----------------| +| `ct.greater(x, y)` | `x > y` | +| `ct.greater_equal(x, y)` | `x >= y` | +| `ct.less(x, y)` | `x < y` | +| `ct.less_equal(x, y)` | `x <= y` | +| `ct.equal(x, y)` | `x == y` | +| `ct.not_equal(x, y)` | `x != y` | + +## Bitwise Operations + +| Function | Python Operator | +|----------|-----------------| +| `ct.bitwise_and(x, y)` | `x & y` | +| `ct.bitwise_or(x, y)` | `x \| y` | +| `ct.bitwise_xor(x, y)` | `x ^ y` | +| `ct.bitwise_lshift(x, y)` | `x << y` | +| `ct.bitwise_rshift(x, y)` | `x >> y` | +| `ct.bitwise_not(x)` | `~x` | + +## Atomic Operations + +All atomic operations follow the same index convention as `ct.gather`/`ct.scatter`. + +| Function | Description | +|----------|-------------| +| `ct.atomic_add(arr, indices, update, *, check_bounds=True, memory_order=ACQ_REL, memory_scope=DEVICE)` | Atomic add, returns old value | +| `ct.atomic_max(arr, indices, update, *, ...)` | Atomic max, returns old value | +| `ct.atomic_min(arr, indices, update, *, ...)` | Atomic min, returns old value | +| `ct.atomic_and(arr, indices, update, *, ...)` | Atomic bitwise AND, returns old value | +| `ct.atomic_or(arr, indices, update, *, ...)` | Atomic bitwise OR, returns old value | +| `ct.atomic_xor(arr, indices, update, *, ...)` | Atomic bitwise XOR, returns old value | +| `ct.atomic_xchg(arr, indices, update, *, ...)` | Atomic exchange, returns old value | +| `ct.atomic_cas(arr, indices, expected, desired, *, check_bounds=True, memory_order=ACQ_REL, memory_scope=DEVICE)` | Compare-and-swap, returns old value | + +**Common parameters:** +- `memory_order` — `MemoryOrder.RELAXED`, `.ACQUIRE`, `.RELEASE`, `.ACQ_REL` (default) +- `memory_scope` — `MemoryScope.BLOCK`, `.DEVICE` (default), `.SYS` +- `check_bounds` — If `True` (default), OOB indices are skipped + +## Debug & Utility Functions + +| Function | Description | +|----------|-------------| +| `ct.printf(format, *args)` | C-printf style device print (tiles only). **Debug only — significant overhead.** | +| `ct.print(*args, sep=' ', end='\n')` | Python-style device print. Supports f-strings and positional args. **Debug only — significant overhead.** | +| `ct.assert_(cond, message=None)` | Assert all elements are True. **Debug only — significant overhead.** | +| `ct.static_eval(expr)` | Evaluate Python expression at compile time | +| `ct.static_assert(condition, message=None)` | Compile-time assertion | +| `ct.static_iter(iterable)` | Compile-time iteration (use in `for ... in ct.static_iter(...)`) | + +```python +# printf example (C-style format strings) +ct.printf("value: %d", tile) +ct.printf("two tiles: %d, %f", tile_a, tile_b) + +# print example (Python-style, supports f-strings) +ct.print(f"tile={tile}") +ct.print(f"x={tile:.5f}", end='') +ct.print("tile:", tile, sep='=') + +# static_eval example — select tile based on compile-time condition +x_or_y = ct.static_eval(x if N % 2 == 0 else y) + +# static_assert example +ct.static_assert(x.dtype == y.dtype, f"Expected {x} and {y} to have same dtype.") + +# static_iter example — compile-time unrolled loop +for i in ct.static_iter(range(4)): + ... +``` + +## Host Functions + +| Function | Description | +|----------|-------------| +| `ct.cdiv(a, b)` | Ceiling division — works on **both host and kernel** | +| `ct.num_tiles(arr, axis, shape)` | Get number of tiles in tile space along axis | + +```python +# Prefer Python arithmetic on host (simpler, no ct import needed) +grid = ((N + BLOCK - 1) // BLOCK, 1, 1) + +# ct.cdiv also valid on host, but Python arithmetic is preferred +# grid = (ct.cdiv(N, BLOCK), 1, 1) + +# ct.cdiv in kernel code (operates on tiles) +num_iters = ct.cdiv(K, BLOCK_K) +``` + +### Power-of-2 Utility +```python +def next_power_of_2(x: int) -> int: + """Round up to nearest power of 2 (required for tile shapes)""" + return 1 << (x - 1).bit_length() +``` + +## Data Types + +``` +ct.float16, ct.float32, ct.float64, ct.bfloat16 +ct.tfloat32 +ct.float8_e4m3fn, ct.float8_e5m2 +ct.float8_e8m0fnu # 8-bit exponent-only (scale factor for mma_scaled) ⚠️ Not yet exported from cuda.tile.__init__ +ct.float4_e2m1fn # 4-bit MX format (for mma_scaled) ⚠️ Not yet exported from cuda.tile.__init__ +ct.int8, ct.int16, ct.int32, ct.int64 +ct.uint8, ct.uint16, ct.uint32, ct.uint64 +ct.bool_ +``` + +## Enums + +### PaddingMode (for `ct.load`) + +| Value | Description | +|-------|-------------| +| `PaddingMode.UNDETERMINED` | Padding value is not determined (default) | +| `PaddingMode.ZERO` | Pad with zero | +| `PaddingMode.NEG_ZERO` | Pad with negative zero | +| `PaddingMode.NAN` | Pad with NaN | +| `PaddingMode.POS_INF` | Pad with positive infinity | +| `PaddingMode.NEG_INF` | Pad with negative infinity | + +### RoundingMode (for math ops) + +| Value | Description | +|-------|-------------| +| `RoundingMode.RN` | Round to nearest, ties to even (default) | +| `RoundingMode.RZ` | Round towards zero (truncate) | +| `RoundingMode.RM` | Round towards negative infinity | +| `RoundingMode.RP` | Round towards positive infinity | +| `RoundingMode.FULL` | Full precision | +| `RoundingMode.APPROX` | Approximate (e.g., for `ct.tanh`) | +| `RoundingMode.RZI` | Round towards zero to nearest integer | + +### MemoryOrder (for load/store and atomics) + +| Value | Description | +|-------|-------------| +| `MemoryOrder.WEAK` | Weak (non-atomic) ordering (default for `ct.load`/`ct.store`) | +| `MemoryOrder.RELAXED` | No ordering guarantees | +| `MemoryOrder.ACQUIRE` | Acquire semantics | +| `MemoryOrder.RELEASE` | Release semantics | +| `MemoryOrder.ACQ_REL` | Combined acquire + release (default for atomics) | + +### MemoryScope (for load/store and atomics) + +| Value | Description | +|-------|-------------| +| `MemoryScope.NONE` | No memory scope; used with `MemoryOrder.WEAK` (default for `ct.load`/`ct.store`) | +| `MemoryScope.BLOCK` | Ordering within same block | +| `MemoryScope.DEVICE` | Ordering across all threads on GPU (default for atomics) | +| `MemoryScope.SYS` | Ordering across entire system (multi-GPU + host) | + +### ByTarget (for kernel hints) + +```python +from cuda.tile import ByTarget + +# Different values per GPU architecture +@ct.kernel(num_ctas=ByTarget(sm_100=8, sm_120=4, default=2)) +def kernel_fn(x): + ... +``` + +## Launch Pattern + +```python +# Grid can be 1-tuple, 2-tuple, or 3-tuple +grid = ((N + BLOCK - 1) // BLOCK,) # 1D grid — OK +grid = (grid_m, grid_n) # 2D grid — OK +grid = (grid_m, grid_n, 1) # 3D grid — OK + +ct.launch(torch.cuda.current_stream(), grid, kernel, (x, y, BLOCK, n)) +``` + +**`ct.launch` signature:** `launch(stream, grid, kernel, kernel_args)` +- `stream` — CUDA stream (e.g., `torch.cuda.current_stream()`) +- `grid` — Tuple of 1, 2, or 3 ints +- `kernel` — Function decorated with `@ct.kernel` +- `kernel_args` — Tuple of arguments to pass to the kernel + + +## Kernel Compilation Hints + +`ct.kernel` accepts optional hints that affect compilation and scheduling: + +```python +@ct.kernel(num_ctas=2, occupancy=4) +def kernel(X, Y, BLOCK: ct.Constant[int]): + ... + +# Or with ByTarget for architecture-specific values: +@ct.kernel(num_ctas=ct.ByTarget(sm_100=2), occupancy=ct.ByTarget(sm_100=4)) +def kernel(X, Y, BLOCK: ct.Constant[int]): + ... +``` + +| Hint | Description | Default | Range | +|------|-------------|---------|-------| +| `num_ctas` | Number of CTAs in a CGA | None (auto) | Power of 2, 1–16 | +| `occupancy` | Expected active CTAs per SM | None (auto) | 1–32 | +| `opt_level` | Optimization level | 3 | 0–3 | + +**Note:** `occupancy` CAN be passed directly to `@ct.kernel`, but for production code with autotuning, passing it via `hints_fn` in `autotune_launch` is the recommended approach: +```python +# Direct (simple cases): +@ct.kernel(occupancy=4) +def kernel(...): ... + +# Via autotune (production): +ct_experimental.autotune_launch( + stream, grid_fn=..., kernel=kernel, args_fn=..., + hints_fn=lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy}, + search_space=configs, +) +``` + +--- + +## Critical Rules (The 18 Rules) + +### Rule 1: Import Statement +```python +import cuda.tile as ct # NOT import cutile as ct! +``` + +### Rule 2: Index = Block Index, NOT Element Offset +```python +# cuTile uses block index for TMA, or computed indices for gather +x = ct.load(X, index=(bid,), shape=(BLOCK,)) +# OR +indices = bid * BLOCK + ct.arange(BLOCK, dtype=ct.int32) +x = ct.gather(X, indices, check_bounds=True) +``` + +### Rule 3: Python Operators for Index Math +```python +# WRONG — ct.add/ct.mul promote int32 to float +indices = ct.add(ct.mul(bid, BLOCK), ct.arange(BLOCK, dtype=ct.int32)) + +# CORRECT — use Python +, *, / +indices = bid * BLOCK + ct.arange(BLOCK, dtype=ct.int32) +``` + +### Rule 4: ct.mma — acc is Positional +```python +# Both forms are correct: +acc = ct.mma(a, b, acc) # positional — OK +acc = ct.mma(a, b, acc=acc) # keyword — also OK +``` + +### Rule 5: No None in ct.launch() +```python +# WRONG +ct.launch(stream, grid, kernel, (x, None, n)) + +# CORRECT +dummy = torch.zeros(1, device=x.device) +ct.launch(stream, grid, kernel, (x, dummy, n)) +``` + +### Rule 6: Prefer Python Arithmetic on Host; Use ct.cdiv() in Kernel +```python +# Host — prefer Python arithmetic: +grid = ((N + BLOCK - 1) // BLOCK, 1, 1) # preferred +# grid = (ct.cdiv(N, BLOCK), 1, 1) # also valid, but Python is simpler + +# Kernel — ct.cdiv() operates on tiles: +num_iters = ct.cdiv(K, BLOCK_K) +``` + +### Rule 7: ct.astype(), Not .to() or .cast() +```python +# WRONG +y = x.to(ct.float32) + +# CORRECT — function form +y = ct.astype(x, ct.float32) +# CORRECT — method form (preferred for chaining) +y = x.astype(ct.float32) +# CORRECT — chained on load +tile = ct.load(X, index=(bid,), shape=(BLOCK,)).astype(ct.float32) +``` + +### Rule 8: Helper Functions - No @ct.kernel +```python +# WRONG +@ct.kernel +def helper(x): return ct.exp(x) + +# CORRECT - plain function +def helper(x): return ct.exp(x) + +@ct.kernel +def main_kernel(X, Y, N: ConstInt): + y = helper(x) +``` + +### Rule 9: Pre-define Variables Before Branches +```python +# WRONG — Variable only defined in one branch +if condition: + result = ct.zeros((M,), dtype=ct.float32) + result = ct.load(X, ...) +else: + # result is undefined here! + pass +output = result # ERROR: result may not exist + +# CORRECT — Pre-define ALL variables used across branches +result = ct.zeros((M,), dtype=ct.float32) # Pre-define before branch +if condition: + result = ct.load(X, ...) +else: + result = ct.zeros((M,), dtype=ct.float32) +output = result # OK: always defined +``` + +### Rule 10: No break/continue in Loops +```python +# WRONG +for i in range(N): + if condition: break + +# CORRECT - use conditionals +for i in range(N): + if not condition: + # loop body +``` + +### Rule 11: Grid Must Be Tuple (1, 2, or 3 elements) +```python +# WRONG +grid = N // BLOCK # bare int +grid = [N // BLOCK, 1, 1] # list + +# CORRECT — tuple of 1, 2, or 3 ints +grid = ((N + BLOCK - 1) // BLOCK,) # 1-tuple +grid = (grid_m, grid_n) # 2-tuple +grid = (grid_m, grid_n, 1) # 3-tuple +``` + +### Rule 12: ct.arange Starts at 0 +```python +# ct.arange(N) produces [0, 1, ..., N-1] — always starts at 0, no start param +offs = ct.arange(BLOCK, dtype=ct.int32) +``` + +### Rule 13: NHWC Tensors - Use tensor.stride() +```python +# WRONG: Assumes NCHW layout +offset = n * C * H * W + c * H * W + h * W + w # WRONG for NHWC! + +# CORRECT: Use actual strides from tensor +stride_n, stride_c, stride_h, stride_w = tensor.stride() +offset = n * stride_n + c * stride_c + h * stride_h + w * stride_w + +# CRITICAL: tensor.view(-1) MAY REORDER DATA for non-contiguous! +# WRONG +flat = nhwc_tensor.view(-1) # May silently reorder! + +# CORRECT - Use torch.as_strided() +flat = torch.as_strided(tensor, (tensor.numel(),), (1,), storage_offset=tensor.storage_offset()) +``` + +### Rule 14: Block > Dim Masking - Apply ct.where AFTER gather +```python +# When BLOCK_SIZE > actual dimension size +# WRONG - No mask applied +offsets = ct.arange(BLOCK_C, dtype=ct.int32) +data = ct.gather(input, base + offsets) +sum_val = ct.sum(data, axis=0) # WRONG: includes padding! + +# CORRECT - Use gather's mask parameter +offsets = ct.arange(BLOCK_C, dtype=ct.int32) +mask = offsets < actual_C +data = ct.gather(input, base + offsets, mask=mask, padding_value=0) +sum_val = ct.sum(data, axis=0) # Correct! + +# Alternative - Mask AFTER gather with ct.where +data = ct.gather(input, base + offsets) +data = ct.where(mask, data, ct.zeros((BLOCK_C,), dtype=data.dtype)) +sum_val = ct.sum(data, axis=0) # Correct! + +# CRITICAL: Divide by actual_size, NOT BLOCK_SIZE +mean = sum_val / actual_C # Correct +mean = sum_val / BLOCK_C # WRONG! +``` + +### Rule 15: Masked Scatter — Use mask= or Out-of-Bounds Offsets +```python +# ct.scatter now supports mask= parameter! + +# PREFERRED: Use scatter's mask parameter directly +offsets = ct.arange(BLOCK, dtype=ct.int32) +mask = offsets < actual_size +ct.scatter(Y, offsets, data, mask=mask) # Masked elements are skipped + +# ALTERNATIVE: Out-of-bounds offsets (ct.scatter skips OOB indices when check_bounds=True) +ARRAY_SIZE = Y.numel() # Pass as kernel arg +oob_offset = ct.full((BLOCK,), ARRAY_SIZE, dtype=ct.int32) +offsets_masked = ct.where(mask, offsets, oob_offset) +ct.scatter(Y, offsets_masked, data) # OOB positions skipped! +``` + +### Rule 16: Constant Types — No Strings +```python +# ct.Constant works with int, float, bool — but NOT str +# WRONG +@ct.kernel +def kernel(X, MODE: ct.Constant[str]): # ERROR: str not supported! + if MODE == "relu": + ... + +# CORRECT — Use integer enum +RELU = 0 +GELU = 1 +@ct.kernel +def kernel(X, MODE: ct.Constant[int]): + if MODE == RELU: + ... + +# float and bool constants are also fine: +@ct.kernel +def kernel(X, SCALE: ct.Constant[float], USE_BIAS: ct.Constant[bool]): + ... +``` + +### Rule 17: Shape Args to ct.full/ct.zeros/ct.ones Must Be Static +```python +# ct.full / ct.zeros / ct.ones shape arguments must be compile-time constants. +# WRONG — X.shape is dynamic, cannot be used as shape arg to ct.full +@ct.kernel +def kernel(X, N: ct.Constant[int]): + result = ct.full(X.shape, 0.0, dtype=ct.float32) # ERROR! + +# CORRECT — Use compile-time constant +@ct.kernel +def kernel(X, N: ct.Constant[int], BLOCK: ct.Constant[int]): + result = ct.full((BLOCK,), 0.0, dtype=ct.float32) # OK: BLOCK is constexpr + +# NOTE: X.shape IS fine for arithmetic, loop bounds, and comparisons: +@ct.kernel +def kernel(X, BLOCK: ct.Constant[int]): + mask = ct.arange(BLOCK, dtype=ct.int32) < X.shape[0] # OK! + num_iters = ct.cdiv(X.shape[0], BLOCK) # OK! +``` + +### Rule 18: No Dead Code +```python +# cuTile compiles ALL parameters. Unused params waste registers and may cause errors. +# WRONG +@ct.kernel +def kernel(X, Y, Z, UNUSED: ct.Constant[int]): # UNUSED wastes a register + x = ct.load(X, ...) + ct.store(Y, ...) + # Z and UNUSED are never used! + +# CORRECT — Remove unused parameters +@ct.kernel +def kernel(X, Y): + x = ct.load(X, ...) + ct.store(Y, ...) +``` diff --git a/.claude/skills/improve-cutile-kernel-perf/references/cutile-patterns-reference.md b/.claude/skills/improve-cutile-kernel-perf/references/cutile-patterns-reference.md new file mode 100644 index 0000000..8cba370 --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/references/cutile-patterns-reference.md @@ -0,0 +1,173 @@ + + + + +# cuTile Patterns Quick-Reference Card + +**Quick-lookup tables, unique patterns, and debug reference for cuTile kernels.** + +> For core API (functions, types, 18 rules): See [cutile-api-reference.md](cutile-api-reference.md) +> For advanced conversion patterns (NHWC, masking, TMA decisions, ragged tensors): See [advanced-patterns.md](../translations/advanced-patterns.md) + +## Contents +- [Unique Patterns](#unique-patterns) +- [Quick Debug Reference Table](#quick-debug-reference-table) +- [Appendix: Block vs Tile Terminology](#appendix-block-vs-tile-terminology) + +--- + +## Unique Patterns + +### Scalar Extraction from Tensor + +Load a single element as a scalar tile for use in multi-dim indexing: + +```python +# Load single element, reshape to scalar +idx_tile = ct.load(input_ids, index=(row,), shape=(1,)) +scalar_idx = ct.reshape(idx_tile, ()) # (1,) → () + +# Use scalar in multi-dim gather +embedding = ct.gather(weight_2d, (scalar_idx, col_offsets)) +``` + +### Scalar Load (0D Tile) + +```python +# Load single element as 0D tile (scalar) +scalar_val = ct.load(X, index=(0,), shape=()) # 1D array +scalar_val = ct.load(X, index=(0, 0, 0), shape=()) # 3D array +# Note: index tuple length must match source array rank +``` + +### Batched MMA (3D Tiles) + +`ct.mma` supports 2D and 3D tiles natively. For batched matmul, load 3D tiles +and call `ct.mma` directly — no reshape needed: + +```python +@ct.kernel +def matmul_batched(A, B, C, B_DIM: ConstInt, M: ConstInt, N: ConstInt, K: ConstInt, + BLOCK_B: ConstInt, BLOCK_M: ConstInt, BLOCK_N: ConstInt): + bid_b, bid_m, bid_n = ct.bid(0), ct.bid(1), ct.bid(2) + + # Load 3D tiles: batch × rows/cols × contraction dim + a_tile = ct.load(A, index=(bid_b, bid_m, 0), shape=(BLOCK_B, BLOCK_M, K)) + b_tile = ct.load(B, index=(bid_b, 0, bid_n), shape=(BLOCK_B, K, BLOCK_N)) + + # mma supports 3D directly — batch dims are broadcast + acc = ct.zeros((BLOCK_B, BLOCK_M, BLOCK_N), dtype=ct.float32) + acc = ct.mma(a_tile, b_tile, acc=acc) + + ct.store(C, index=(bid_b, bid_m, bid_n), tile=acc) +``` + +For true 4D tensors (e.g. shape `(B, H, M, K)`), reshape to 3D before `ct.mma`: + +```python +@ct.kernel +def matmul_4d(A, B, C, BATCH: ConstInt, HEADS: ConstInt, M: ConstInt, N: ConstInt, K: ConstInt, + BLOCK_M: ConstInt, BLOCK_N: ConstInt): + bid_bh, bid_m, bid_n = ct.bid(0), ct.bid(1), ct.bid(2) + + # Load 4D tiles (batch and head merged into one grid dim) + # bid_bh indexes the flattened (BATCH * HEADS) dimension + b_idx = bid_bh // HEADS + h_idx = bid_bh % HEADS + + a_tile = ct.load(A, index=(b_idx, h_idx, bid_m, 0), + shape=(1, 1, BLOCK_M, K)) # 4D: (1, 1, BLOCK_M, K) + b_tile = ct.load(B, index=(b_idx, h_idx, 0, bid_n), + shape=(1, 1, K, BLOCK_N)) # 4D: (1, 1, K, BLOCK_N) + + # Reshape 4D → 2D for mma + a_2d = ct.reshape(a_tile, (BLOCK_M, K)) # (BLOCK_M, K) + b_2d = ct.reshape(b_tile, (K, BLOCK_N)) # (K, BLOCK_N) + + acc = ct.zeros((BLOCK_M, BLOCK_N), dtype=ct.float32) + acc = ct.mma(a_2d, b_2d, acc=acc) + + # Reshape back to 4D for store + result = ct.reshape(acc, (1, 1, BLOCK_M, BLOCK_N)) + ct.store(C, index=(b_idx, h_idx, bid_m, bid_n), tile=result) +``` + +### Multi-dimensional Index with Reshape (4D → 2D) + +```python +@ct.kernel +def attention_pattern(Q, K, V, Out, + batch_idx: ConstInt, head_idx: ConstInt, + TILE_M: ConstInt, TILE_N: ConstInt, TILE_D: ConstInt): + bid_m = ct.bid(0) + + # Load 4D slice, reshape to 2D for computation + q = ct.load(Q, index=(batch_idx, head_idx, bid_m, 0), + shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D)) + + # ... compute attention ... + + # Store back: reshape to 4D + ct.store(Out, index=(batch_idx, head_idx, bid_m, 0), + tile=result.reshape((1, 1, TILE_M, TILE_D))) +``` + +### Cross-Reference: Advanced Patterns + +For detailed coverage of these patterns, see the corresponding documents linked from the SKILL.md [Reference Documents table](../SKILL.md#reference-documents): + +| Pattern | Primary Source | +|---------|---------------| +| Multi-dim gather/scatter, Array.slice, paged attention TMA | `translations/advanced-patterns.md` | +| NHWC layout, block masking, masked scatter | `translations/advanced-patterns.md` + rules 13-15 in `references/cutile-api-reference.md` | +| Element-wise kernel example | `examples/01_vector_add/` | +| GEMM with TMA example | `examples/04_matmul/` | + +--- + +## Quick Debug Reference Table + +| Error Pattern | Likely Cause | Quick Fix | +|---------------|--------------|-----------| +| Only False-False passes | Missing `ct.permute()` | Add explicit permute after ct.load | +| TileSyntaxError: break | break in for loop | Use `if i < n:` wrapper | +| TileTypeError: shapes mismatch | Wrong `shape` param | `shape` = OUTPUT, not input | +| Numerical error (27%+ mismatch) | Wrong transpose logic | Use `ct.permute()`, not `order` | +| Compile error at ct.load | Element offset as index | Use `bid_m` not `bid_m*TILE_M` | +| TileTypeError: float16 padding | `padding_value=0.0` | Omit padding_value (defaults to 0) | +| AttributeError: 'cast' | Using `.cast()` | Use `ct.astype(x, dtype)` or `x.astype(dtype)` | +| TypeError: NoneType | None in ct.launch | Replace with dummy tensor | +| ModuleNotFoundError: cutile | Wrong import | Use `import cuda.tile as ct` | +| Numerical error on NHWC tensor | Wrong stride assumption | Use `tensor.stride()`, not hardcoded | +| Mean/sum off by small factor | BLOCK > actual size, no mask | Apply `ct.where(mask,...)` after gather | +| TileTypeError: mask param | ct.scatter mask syntax error | Use `ct.scatter(arr, idx, val, mask=mask)` or out-of-bounds offsets | +| Silent wrong results NHWC | `tensor.view(-1)` reorders data | Use `torch.as_strided()` instead | +| ~30% wrong values, pattern in groups | BLOCK > dim, invalid offsets overwrite adjacent | Use `ct.where(mask, offsets, oob_offset)` | +| Only first channels correct per group | Partial block scatter overwrites next block | Set invalid offsets to ARRAY_SIZE (out-of-bounds) | +| NaN in output | Division by zero or log(0) | Add numerical guards: `ct.where(x > 0, ct.log(x), 0)` | +| Large numerical errors (~1e-2) | Accumulation order differs | Use float32 accumulator: `acc = ct.zeros(..., dtype=ct.float32)` | +| Numerical mismatch with fp32 mma | CuTile `ct.mma` does not auto-cast fp32→tf32 | Guard: `a = ct.astype(a, ct.tfloat32) if a.dtype == ct.float32 else a` | +| CuTile unexpectedly slow, same algorithm | Unnecessary token dependency chains in CuTile IR | Try `CUDA_TILE_TESTING_DISABLE_TOKEN_ORDER=1`, verify correctness | +| Extremely slow (paged attn) | Using ct.gather for all loads | Use `ct.gather().item()` + `ct.load(allow_tma=True)` | +| load_pointer_tko in IR | ct.gather generating per-element loads | Extract scalar with `.item()`, use `ct.load` with runtime index | + +--- + +## Appendix: Block vs Tile Terminology + +TileGym uses mixed terminology: + +| Term | Context | Meaning | +|------|---------|---------| +| `BLOCK_SIZE` / `BLOCK_M` | Legacy convention | Tile dimension size | +| `TILE_SIZE` / `TILE_M` | cuTile convention | Same as BLOCK_M | +| `ct.bid(axis)` | cuTile API | Block ID = which tile in the grid | +| `ct.num_blocks(axis)` | cuTile API | Grid size = total number of tiles | +| `ct.num_tiles(arr, axis, shape)` | cuTile API | Dynamic tile count for sliced arrays | +| `CTA` | Hardware | Cooperative Thread Array ≈ thread block | +| `num_ctas` | ct.kernel kwarg | CTAs per SM (multi-CTA kernels) | + +**Convention in TileGym cuTile code:** +- Prefer `TILE_M`, `TILE_N`, `TILE_K` over `BLOCK_M`, `BLOCK_N`, `BLOCK_K` +- Both are accepted in `kernel_configs` dicts +- `ct.bid(0)` returns the tile index, despite "block" in the name diff --git a/.claude/skills/improve-cutile-kernel-perf/references/ir-dump-guide.md b/.claude/skills/improve-cutile-kernel-perf/references/ir-dump-guide.md new file mode 100644 index 0000000..4e1bc25 --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/references/ir-dump-guide.md @@ -0,0 +1,256 @@ + + + + +# IR Analysis Guide + +## Overview + +This guide covers how to dump and analyze MLIR IR for cuTile kernels. +cuTile compiles through the tileir backend: TileIR → Bytecode → cubin (PTX → SASS). +By examining IR and SASS you can pinpoint performance bottlenecks. + +--- + +## Compilation Path + +### cuTile + +``` +Python (@ct.kernel) + │ + ├──▶ Bytecode (.tileirbc) ← CUDA_TILE_DUMP_BYTECODE + │ │ + │ ▼ tileiras --arch sm_120 + │ cubin → SASS ← ACTUAL runtime path + │ + └──▶ TileIR MLIR (.tileir) ← CUDA_TILE_DUMP_TILEIR +``` + +- **`tileiras`** is the real compiler. It reads bytecode directly. + +### Which Level to Analyze? + +| Question | Analyze at | +|----------|-----------| +| Are the frontends generating the same high-level ops? | **TileIR** | +| How many HW instructions? Which MUFU ops? | **SASS** | +| What is the scheduling / loop throughput? | **tileiras --remarks** | + +--- + +## Prerequisites + +```bash +source /workspace/entrypoint.sh + +# Install cuda-tile +pip install cuda-tile[tileiras] + +# Verify tools +which tileiras +``` + +--- + +## Environment Variables + +| Variable | Purpose | Example | +|----------|---------|---------| +| `CUDA_TILE_DUMP_TILEIR` | cuTile TileIR MLIR dump | `/tmp/cutile_tileir` | +| `CUDA_TILE_DUMP_BYTECODE` | cuTile bytecode dump | `/tmp/cutile_bytecode` | +| `CUDA_TILE_LOGS` | cuTile compilation logs | `CUTILEIR` | +| `DISABLE_CUTILE_TUNE` | Force first autotune config (TileGym convention, not a cuTile env var) | `1` | +| `CUDA_TILE_ENABLE_CRASH_DUMP` | Crash dump on failure | `1` | +| `CUDA_TILE_TESTING_DISABLE_TOKEN_ORDER` | Disable token ordering in CuTile | `1` | + +--- + +## How to Dump IR + +### cuTile + +```bash +# Clean +rm -rf /tmp/cutile_tileir /tmp/cutile_bytecode +mkdir -p /tmp/cutile_tileir /tmp/cutile_bytecode + +# Dump TileIR MLIR + bytecode (requires cuda-tile) +# WARNING: autotune overwrites per config. Use DISABLE_CUTILE_TUNE=1. +CUDA_TILE_DUMP_TILEIR=/tmp/cutile_tileir \ +CUDA_TILE_DUMP_BYTECODE=/tmp/cutile_bytecode \ +DISABLE_CUTILE_TUNE=1 \ + pytest {test_path} -k "test_op and cutile and {config}" --timeout=120 + +# Compile bytecode → cubin +tileiras --arch sm_120 -o /tmp/cutile.cubin /tmp/cutile_bytecode/*.tileirbc + +# Dump SASS +/usr/local/cuda/bin/cuobjdump --dump-sass /tmp/cutile.cubin +``` + +--- + +## How to Analyze + +### SASS Level: Instruction Counts + +```bash +# MUFU instruction breakdown +/usr/local/cuda/bin/cuobjdump --dump-sass /tmp/cutile.cubin | \ + grep "MUFU" | sort | uniq -c | sort -rn + +# Total instruction count +/usr/local/cuda/bin/cuobjdump --dump-sass /tmp/cutile.cubin | grep -c ";" + +# Cubin size +ls -la /tmp/cutile.cubin +``` + +MUFU instruction mapping: + +| MUFU | HW operation | cuTile API | +|------|-------------|------------| +| `MUFU.TANH` | Hardware tanh (1 cycle) | `ct.tanh(x, rounding_mode=RoundingMode.APPROX)` (since CTK 13.2) | +| `MUFU.EX2` | Hardware exp2 (1 cycle) | `ct.exp()` lowers to mul + EX2 | +| `MUFU.RCP` | Hardware reciprocal (1 cycle) | `ct.truediv(x, y, rounding_mode=RoundingMode.APPROX)` | +| `MUFU.RSQ` | Hardware rsqrt (1 cycle) | `ct.rsqrt()` | + +### tileiras Scheduling Remarks + +```bash +tileiras --arch sm_120 \ + --remarks=all --remark-format=command-line \ + -o /dev/null /tmp/cutile_bytecode/*.tileirbc +``` + +Outputs: +- **II (Initiation Interval)**: loop throughput — lower is better +- **NumOps**: operations per loop body +- **Gantt chart**: visual timeline — check if loads overlap with compute +- **TMA Load shapes**: should match your tile sizes +- **Tensor-core shapes**: confirms MMA instruction selection + +What to look for: +- **High II** (>1000) → register pressure or long dependency chains +- **Gantt overlaps** (loads start while compute still running) → good pipelining +- **Sequential Gantt** (load → wait → compute → load) → no pipelining + +--- + +## Performance Debugging Techniques + +### Technique 1: Isolation Experiment + +When cuTile performance is unexpectedly poor, the gap may come from multiple sources +(activation function, memory access, compiler scheduling). To decompose: + +1. Replace the suspect operation with a trivial one (e.g., `activation_fn(x)` → `x * constant`) +2. Re-benchmark +3. If performance improves significantly, the suspect operation is the bottleneck + +### Technique 2: Register Pressure Diagnosis + +```bash +tileiras --arch sm_120 --remarks=schedule --remark-format=command-line \ + -o /dev/null /tmp/cutile_bytecode/*.tileirbc +``` + +If II is very high, try simplifying the inner loop body (e.g., remove activation, reduce tile size) +and check if II drops. If it does → original code has register pressure. + +### Technique 3: cuTile API Introspection + +Check what parameters a cuTile math function actually supports: + +```python +import cuda.tile as ct +import inspect + +for name in ['tanh', 'exp', 'exp2', 'rsqrt', 'truediv']: + fn = getattr(ct, name, None) + if fn: + sig = inspect.signature(fn) + print(f'ct.{name}: {sig}') +``` + +Check bytecode encoding to see if a parameter is even representable: + +```python +import cuda.tile._bytecode as bc +import inspect +print(inspect.getsource(bc.encode_TanHOp)) +``` + +--- + +## Known cuTile Limitations + +| Limitation | Impact | Workaround | +|-----------|--------|------------| +| `ct.tanh()` APPROX mode (since CTK 13.2) | Use `ct.tanh(x, rounding_mode=RoundingMode.APPROX)` to emit single MUFU.TANH | Prior to CTK 13.2, precise tanh emits many EX2+RCP; upgrade to 13.2+ and use APPROX | +| `ct.exp()` rounding_mode hardcoded to FULL | Cannot force fast exp — rounding_mode is not exposed in the API (TODO in source) | Compiler does its own lowering; no user workaround | +| `ct.mma` no auto float32→tf32 | cuTile does not auto-cast fp32→tf32 | Guard: `a = ct.astype(a, ct.tfloat32) if a.dtype == ct.float32 else a` before `ct.mma` | +| Unnecessary token dependencies | cuTile compiler may insert unnecessary token ordering dependencies, causing pipeline stalls | Set `CUDA_TILE_TESTING_DISABLE_TOKEN_ORDER=1` (see § Token Dependency Analysis below) | +| `tileiras` scheduling quality | May produce suboptimal II for some kernels | No user-facing workaround | + +--- + +## Token Dependency Analysis + +CuTile may insert **token dependencies** (ordering constraints) that serialize operations which should run in parallel. + +### Detect + +Dump IR and check for token operations: + +```bash +grep -i "token" /tmp/cutile_tileir/*.tileir +``` + +If cuTile has excessive token chains → likely unnecessary. + +### Mitigate + +```bash +CUDA_TILE_TESTING_DISABLE_TOKEN_ORDER=1 \ + pytest {test_path} -k "test_op and cutile" --timeout=120 +``` + +**IMPORTANT**: Always verify correctness after disabling tokens — re-run the pytest correctness test (e.g., `pytest {test_path} -k "test_op and cutile and {config}" --timeout=120`) and confirm all assertions pass. If correctness fails, the tokens are required for that kernel and this flag must not be used. + +--- + +## Full Compiler Pass Dump (Alternative to Per-Level Extraction) + +For a comprehensive view of all compiler passes in a single dump: + +```bash +# Dump ALL passes for cuTile +tileiras --arch {SM_ARCH} --mlir-print-ir-after-all -o /dev/null \ + /tmp/cutile_bytecode/*.tileirbc 2>&1 > /tmp/cutile_full_dump.txt + +# List available passes +grep "IR Dump After" /tmp/cutile_full_dump.txt | head -30 + +# Extract a specific pass by name +awk '/IR Dump After /{found=1; next} /IR Dump After/{if(found) exit} found' \ + /tmp/cutile_full_dump.txt | grep -v "^into " > /tmp/cutile_pass_output.mlir +``` + +**When to use full dump:** +- When you need to investigate pass ordering, or find where a transformation happens + +--- + +## When to Use IR Analysis + +**Use when:** +- cuTile performance is unexpectedly poor and you need to understand why +- Numerical results are correct but performance is poor +- Filing a feature request for the cuTile team (need concrete evidence) + +**Don't use when:** +- Kernel doesn't compile (fix syntax/type errors first) +- Numerical results are wrong (fix correctness first) +- Performance difference <5% (likely noise or autotune variance) diff --git a/.claude/skills/improve-cutile-kernel-perf/references/optimization-playbook.md b/.claude/skills/improve-cutile-kernel-perf/references/optimization-playbook.md new file mode 100644 index 0000000..3ce6462 --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/references/optimization-playbook.md @@ -0,0 +1,351 @@ +# Optimization Playbook + +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: MIT + +Step-by-step recipes for each performance optimization. Apply ONE per iteration. + +--- + +## Optimization A: Replace Gather/Scatter with TMA + +**Impact**: 2-78x +**When**: Kernel uses `ct.gather`/`ct.scatter` for contiguous or block-aligned access patterns. + +TMA (`ct.load`/`ct.store`) uses the Tensor Memory Accelerator hardware unit and is dramatically faster than software-computed gather/scatter for regular access. + +### Before (gather — slow) +```python +@ct.kernel +def kernel(X, Y, BLOCK: ct.Constant[int]): + bid = ct.bid(0) + indices = bid * BLOCK + ct.arange(BLOCK, dtype=ct.int32) + x = ct.gather(X, indices, check_bounds=True) + result = compute(x) + ct.scatter(Y, indices, result, check_bounds=True) +``` + +### After option 1: Direct TMA (block-aligned access) +```python +@ct.kernel +def kernel(X, Y, BLOCK: ct.Constant[int]): + bid = ct.bid(0) + x = ct.load(X, index=(bid,), shape=(BLOCK,), padding_mode=ct.PaddingMode.ZERO) # index = BLOCK index, NOT element offset + result = compute(x) + ct.store(Y, index=(bid,), tile=result) +``` + +### After option 2: Array.slice for ragged/variable-length +```python +@ct.kernel +def kernel(X, Y, start: int, length: int, BLOCK: ct.Constant[int]): + bid = ct.bid(0) + seg = X.slice(axis=0, start=start, stop=start + length) + x = ct.load(seg, index=(bid,), shape=(BLOCK,), padding_mode=ct.PaddingMode.ZERO) + result = compute(x) + seg_out = Y.slice(axis=0, start=start, stop=start + length) + ct.store(seg_out, index=(bid,), tile=result) +``` + +### After option 3: ct.gather().item() + TMA for paged/indirect access +```python +@ct.kernel +def kernel(X, block_table, Y, BLOCK: ct.Constant[int]): + bid = ct.bid(0) + # Extract scalar page ID, then use TMA + page_id = ct.gather(block_table, (bid,), padding_value=0).item() + x = ct.load(X, index=(page_id, 0), shape=(1, BLOCK), allow_tma=True) + # ... compute and store +``` + +**Decision**: Use TMA whenever data is contiguous or block-aligned. Use gather only for truly sparse random access. + +--- + +## Optimization B: Add Persistent Scheduling + +**Impact**: +50-300% +**When**: Kernel processes many independent work items (rows, tiles) with `grid = (n_items,)`. + +### Before (one block per work item) +```python +@ct.kernel +def kernel(input, output, N: ct.Constant[int]): + row = ct.bid(0) + data = ct.load(input, index=(row, 0), shape=(1, N)) + result = compute(data) + ct.store(output, index=(row, 0), tile=result) + +# Launch +grid = (n_rows, 1, 1) +ct.launch(stream, grid, kernel, (input, output, N)) +``` + +### After (persistent — fewer blocks, each processes multiple rows) +```python +@ct.kernel +def kernel(input, output, n_rows: ct.Constant[int], N: ct.Constant[int]): + pid = ct.bid(0) + num_programs = ct.num_blocks(0) + for row_idx in range(pid, n_rows, num_programs): + data = ct.load(input, index=(row_idx, 0), shape=(1, N)) + result = compute(data) + ct.store(output, index=(row_idx, 0), tile=result) + +# Launch +NUM_SM = torch.cuda.get_device_properties(device).multi_processor_count +occupancy = 4 # or from autotune cfg.occupancy +num_programs = min(NUM_SM * occupancy, n_rows) +grid = (num_programs, 1, 1) +ct.launch(stream, grid, kernel, (input, output, n_rows, N)) +``` + +**Heuristic**: Use persistent scheduling when `n_work_items > NUM_SM * 2`. + +--- + +## Optimization C: Add Autotune with Wide Config Space + +**Impact**: +10-50% +**When**: Kernel uses fixed occupancy/num_ctas/tile sizes, or has no autotune at all. + +### Template (Recommended: `ct.tune.exhaustive_search`) +```python +from types import SimpleNamespace +import cuda.tile as ct + +def _my_kernel_autotune_configs(): + """Generate autotune search space — be generous with range.""" + gpu_cap = torch.cuda.get_device_capability() + + if gpu_cap >= (10, 0): # Blackwell datacenter (sm100+) and consumer (sm120) + tile_sizes = [128, 256, 512, 1024] + occupancies = [1, 2, 4, 8, 16] + num_ctas_list = [1, 2, 4] + elif gpu_cap >= (9, 0): # Hopper (H100 / H200) + tile_sizes = [64, 128, 256, 512] + occupancies = [1, 2, 4, 8] + num_ctas_list = [1] + else: # Ampere (A100) and earlier + tile_sizes = [64, 128, 256] + occupancies = [1, 2, 4] + num_ctas_list = [1] + + configs = [] + for tile in tile_sizes: + for occ in occupancies: + for ncta in num_ctas_list: + configs.append(SimpleNamespace( + TILE_SIZE=tile, occupancy=occ, num_ctas=ncta + )) + return configs + +def launch_my_kernel(stream, input, output, N): + NUM_SM = torch.cuda.get_device_properties(input.device).multi_processor_count + + result = ct.tune.exhaustive_search( + search_space=_my_kernel_autotune_configs(), # must be a Sequence (list), not a generator + stream=stream, + grid_fn=lambda cfg: (min(NUM_SM * cfg.occupancy, N), 1, 1), + kernel=my_kernel, + args_fn=lambda cfg: (input, output, cfg.TILE_SIZE, N), + hints_fn=lambda cfg: { + "num_ctas": cfg.num_ctas, + "occupancy": cfg.occupancy, + }, + ) + # result.best_config, result.best_time_us, result.timings available +``` + +> **Note**: The legacy API `ct_experimental.autotune_launch()` still works but emits a `DeprecationWarning`. +> Key differences: `ct.tune.exhaustive_search` takes `search_space` as a `Sequence` (first positional arg), +> not an `Iterable | Callable` keyword arg. Convert generators to lists. + +**Key**: Do NOT hardcode `occupancy=N` in `@ct.kernel()` when using autotune — pass it via `hints_fn`. + +--- + +## Optimization D: Add TF32 Dtype Guard for MMA + +**Impact**: ~2x for FP32 MMA operations +**When**: Kernel uses `ct.mma()` with FP32 inputs without casting to TF32 first. + +cuTile's `ct.mma` does NOT auto-cast FP32 to TF32. You must explicitly cast. + +### Before +```python +a = ct.load(A, index=(bid_m, k), shape=(TILE_M, TILE_K)) +b = ct.load(B, index=(k, bid_n), shape=(TILE_K, TILE_N)) +acc = ct.mma(a, b, acc=acc) +``` + +### After +```python +a = ct.load(A, index=(bid_m, k), shape=(TILE_M, TILE_K)) +b = ct.load(B, index=(k, bid_n), shape=(TILE_K, TILE_N)) + +# Cast FP32 → TF32 for tensor core utilization +dtype = ct.tfloat32 if a.dtype == ct.float32 else a.dtype +a = ct.astype(a, dtype) +b = ct.astype(b, dtype) + +acc = ct.mma(a, b, acc=acc) # Now uses tensor cores +``` + +--- + +## Optimization E: Add Latency Hints + +**Impact**: +2-5% +**When**: Kernel has `ct.load`/`ct.store` calls without `latency=` parameter. + +Latency hints tell the compiler about expected DRAM traffic intensity, enabling better prefetching. + +### Recipe +```python +# On ct.load — higher values = more aggressive prefetch +ct.load(X, index=(bid, 0), shape=(M, N), latency=10) # +2% in rms_norm + +# On ct.store — moderate values +ct.store(Y, index=(bid, 0), tile=y, latency=3) # +3% in rms_norm + +# On ct.gather/ct.scatter +ct.gather(x, (row, offs), latency=1) +ct.scatter(out, (row, offs), yj, latency=1) +``` + +**Sweep strategy**: Try latency values {1, 2, 3, 6, 10} on the hottest loads. Benchmark each. + +--- + +## Optimization F: Disable TMA on Store + +**Impact**: +10-30% +**When**: Kernel uses `ct.store()` without `allow_tma=False`. + +For some kernels, disabling TMA on the store path gives a significant boost. This was discovered in rms_norm (+30%). + +### Recipe +```python +# Before +ct.store(Y, index=(bid, 0), tile=result) + +# After — try both and benchmark +ct.store(Y, index=(bid, 0), tile=result, allow_tma=False) # +30% in rms_norm! +``` + +**Caution**: Does NOT always help. Must benchmark to verify. + +--- + +## Optimization G: Tile Size Tuning + +**Impact**: +5-50% depending on mismatch +**When**: Current tile sizes are suboptimal for the workload or GPU architecture. + +--- + +## Optimization H: Numerical Shortcuts + +**Impact**: +1-5% +**When**: Kernel has many `ct.exp2`, `ct.truediv`, or similar math ops, and slight precision loss is acceptable. + +> **Note**: `ct.exp()` does NOT accept `flush_to_zero`. Only `ct.exp2`, `ct.rsqrt`, and `ct.truediv` support it. + +### flush_to_zero +```python +# Skip denormal number handling +# ct.exp() does NOT support flush_to_zero — use ct.exp2() instead +ct.exp2(qk, flush_to_zero=True) +ct.rsqrt(variance, flush_to_zero=True) +``` + +### Approximate division +```python +ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=ct.RoundingMode.APPROX) +``` + +**Caution**: May cause correctness failures with tight tolerances. Loosen atol/rtol if needed, but only after confirming the precision loss is acceptable for the use case. + +--- + +## Optimization I: GROUP_SIZE_M (2D Block Swizzling) + +**Impact**: +5-15% for large 2D tiled kernels +**When**: Kernel uses 2D tile grid (matmul, attention, bmm) without block swizzling. + +### Recipe +```python +def swizzle_2d(M, N, TILE_SIZE_M, TILE_SIZE_N, GROUP_SIZE_M): + bid = ct.bid(0) + num_bid_m = ct.cdiv(M, TILE_SIZE_M) + num_bid_n = ct.cdiv(N, TILE_SIZE_N) + num_bid_in_group = GROUP_SIZE_M * num_bid_n + group_id = bid // num_bid_in_group + first_bid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_bid_m - first_bid_m, GROUP_SIZE_M) + bid_m = first_bid_m + (bid % group_size_m) + bid_n = (bid % num_bid_in_group) // group_size_m + return bid_m, bid_n +``` + +Try GROUP_SIZE_M in {4, 8, 16}. The optimal value depends on matrix shape and L2 cache size. + +--- + +## Optimization J: Token Dependency Mitigation + +**Impact**: Variable (sometimes significant) +**When**: IR analysis shows cuTile has unnecessary token chains. + +### Detect +dump cuTile bytecode (`CUDA_TILE_DUMP_BYTECODE=/tmp/cutile_bytecode`) and TileIR (`CUDA_TILE_DUMP_TILEIR=/tmp/cutile_tileir`) +```bash +# Check token operations in cuTile IR +grep -i "token" /tmp/cutile_tileir/*.cuda_tile.mlir +``` + +### Mitigate +```bash +CUDA_TILE_TESTING_DISABLE_TOKEN_ORDER=1 \ + python -m pytest tests/suites//test_.py -k "test_op and cutile" --timeout=120 +``` + +**CRITICAL**: Always verify correctness after disabling tokens. If correctness fails, the tokens are required. + +--- + +## Optimization K: Customized Creative Optimization Plan (Last Resort) + +**Impact**: Variable — depends on kernel characteristics +**When**: All standard optimizations (A–J) have been exhausted or are inapplicable, and further performance gains are still desired. This is a last-resort creative pass. + +### Recipe + +Carefully inspect the kernel code, its access patterns, computation graph, and profiling data (`ncu` / `nsys`). Then **generate a custom optimization plan** with ~5 items tailored to the specific kernel. Each item should be a concrete, actionable change. + +**Step 1: Deep analysis** +- Re-read the kernel source and all profiling results collected so far +- Identify any remaining inefficiencies: redundant loads, suboptimal memory access patterns, unnecessary synchronization, under-utilized hardware features, suboptimal data types, etc. + +**Step 2: Generate the plan** + +Produce a numbered list of ~5 optimization items. Examples of what items might look like (these are illustrative — your plan should be kernel-specific): + +1. Fuse adjacent elementwise ops into the main loop body to reduce memory round-trips +2. Reorder loop dimensions to improve L2 cache hit rate for the dominant access pattern +3. Replace scalar reductions with warp-shuffle-based tree reductions +4. Pre-compute invariant expressions outside the inner loop +5. Split the kernel into two specialized variants for small-N vs large-N cases + +**Step 3: Execute iteratively** + +Apply each item ONE at a time, following the same experiment loop protocol: +- Apply change → verify correctness → benchmark → commit → record → decide keep/revert + +### Guidelines + +- Each item must be self-contained and independently testable +- Prioritize items by expected impact (highest first) +- If an item fails correctness or regresses performance, revert and move to the next +- Document the rationale for each item in the commit message and perf_results.md diff --git a/.claude/skills/improve-cutile-kernel-perf/references/perf-knobs-catalog.md b/.claude/skills/improve-cutile-kernel-perf/references/perf-knobs-catalog.md new file mode 100644 index 0000000..90addb3 --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/references/perf-knobs-catalog.md @@ -0,0 +1,193 @@ +# cuTile Performance Knobs Catalog + +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: MIT + +Comprehensive reference for all performance tuning parameters available in cuTile kernels. +For API details, see [`references/cutile-api-reference.md`](cutile-api-reference.md). + +--- + +## 1. TMA vs Gather/Scatter + +**The single most impactful choice.** TMA uses hardware-accelerated memory copies (2-78x faster). + +| Feature | TMA (`ct.load/ct.store`) | Gather/Scatter | +|---------|-------------------------|----------------| +| Access pattern | Block-aligned, contiguous tiles | Arbitrary element indices | +| Performance | Hardware-accelerated | Software-computed | +| Padding | `padding_mode=ct.PaddingMode.*` | `padding_value=`, `check_bounds=True`, `mask=` | +| HW limit | ~16K elements per load | No limit | +| Index semantics | Block index (which tile) | Element offset | + +**Rule**: Always TMA-first. Fall back to gather only for truly sparse/random access. + +**Special pattern**: `ct.gather().item()` + `ct.load(allow_tma=True)` for indirect/paged access. + +--- + +## 2. Persistent Scheduling + +**What**: Launch fewer blocks than work items; each block processes multiple items via grid-stride loop. + +| Aspect | Simple Grid | Persistent | +|--------|-------------|------------| +| Grid size | `(n_items,)` | `(NUM_SM * occupancy,)` | +| Kernel pattern | `bid = ct.bid(0)` | `for i in range(bid, n_items, ct.num_blocks(0))` | +| SM utilization | Poor if n_items >> NUM_SM | Optimal | +| Best for | n_items < NUM_SM | n_items > NUM_SM * 2 | + +**Expected gain**: +50-300% for memory-bound ops with many work items. + +--- + +## 3. Occupancy + +**What**: Number of concurrent thread blocks per SM. + +The occupancy hint accepts an integer N from 1 to 32, indicating that the programmer expects N active thread blocks to run simultaneously per SM. This hint is 1 by default and is worth tuning for many SIMT compute-intensive kernels. + +--- + +## 4. num_ctas (Cooperative Thread Arrays) + +**What**: Setting num_ctas=2 is critical for dense dot-related workloads on specific hardware, for example, it enables 2CTA mode MMA on Blackwell architecture. + +--- + +## 5. Tile Sizes +**What**: The tile size parameters (e.g., `TILE_M`, `TILE_N`, `TILE_K`, or similar) determine the size of each program's work assignment—how much of the input/output tensor each thread block processes. Adjusting tile sizes is the primary way to tune data granularity, register/SR memory utilization, and memory transaction efficiency. + +- Larger tile sizes usually increase per-block work, raising register pressure but reducing launch overhead and sometimes improving memory coalescing. +- Smaller tile sizes allow for more blocks in parallel, reducing per-block resource usage but potentially increasing overall launch overhead. + +**Tuning rule**: Always benchmark several plausible tile/block sizes. Optimal values are hardware- and kernel-specific. On Blackwell, try tile shapes covering a range from 16x16 up to 128x128 for 2D problems. + +**Where**: As kernel template parameters, function arguments, or autotune config values: +```python +@ct.kernel +def my_kernel(..., TILE_M: ct.constexpr, TILE_N: ct.constexpr): + ... +``` +or via `ct.tune.exhaustive_search()` to autotune tile sizes: +```python +search_space = { + "TILE_M": [32, 64, 128], + "TILE_N": [32, 64, 128], +} +result = ct.tune.exhaustive_search(search_space, kernel_fn, ...) +``` +**Impact**: This is often the most powerful lever for both performance and resource tuning in cuTile kernels. + +**The most versatile tuning knob.** Determines data per block, register usage, and memory transaction granularity. + +--- + +## 6. Latency Hints + +**What**: Compiler hints for expected DRAM traffic intensity, enabling better prefetch scheduling. + +**Where**: `latency=N` on `ct.load()`, `ct.store()`, `ct.gather()`, `ct.scatter()`. + +| Value | Meaning | Typical Use | +|-------|---------|-------------| +| 1 | Low traffic | gather/scatter with few elements | +| 2-3 | Moderate | Standard loads, stores | +| 6 | Above average | Attention key/value loads | +| 10 | High traffic | Main input tensor loads | + +--- + +## 7. allow_tma on Store + +**What**: `ct.store(..., allow_tma=False)` disables TMA for the store operation. + +**Impact**: +10-30% for some kernels (measured +30% in rms_norm). + +**Why**: The TMA store path has overhead for certain access patterns. Disabling it falls back to a faster non-TMA store. + +**Rule**: Benchmark both `allow_tma=True` (default) and `allow_tma=False`. Keep whichever is faster. + +--- + +## 8. Flush to Zero & Approximate Math + +**What**: Trade precision for speed on math operations. + +| Parameter | Where | Effect | +|-----------|-------|--------| +| `flush_to_zero=True` | `ct.exp2`, `ct.rsqrt`, `ct.truediv`, `ct.sqrt`, `ct.add`, `ct.sub`, `ct.mul` | Skip denormal number handling | +| `rounding_mode=RoundingMode.APPROX` | `ct.truediv`, `ct.tanh` | Use HW approximation | + +**Impact**: +1-5% for math-heavy kernels (softmax, attention). + +**Caution**: May fail tight numerical tolerances. + +--- + +## 9. TF32 Guard for MMA + +**What**: Cast FP32 inputs to TF32 before `ct.mma()` to use tensor cores. + +```python +dtype = ct.tfloat32 if a.dtype == ct.float32 else a.dtype +a = ct.astype(a, dtype) +b = ct.astype(b, dtype) +acc = ct.mma(a, b, acc=acc) # Uses tensor cores instead of FP32 CUDA cores +``` + +**Impact**: ~2x for FP32 MMA operations. + +**Note**: cuTile requires explicit cast to tf32 before `ct.mma()`. + +--- + +## 10. GROUP_SIZE_M (2D Swizzling) + +**What**: Controls how 2D tiles are grouped for L2 cache locality. + +**Impact**: +5-15% for large 2D tiled kernels. + +| GROUP_SIZE_M | When to Try | +|-------------|-------------| +| 4 | Small matrices, few M tiles | +| 8 | Default — good general choice | +| 16 | Large matrices, many M tiles | + +--- + +## 11. Padding Mode + +**What**: How out-of-bounds reads are handled. + +| Mode | Value | Use Case | +|------|-------|----------| +| `ZERO` | 0 | Most ops (default) | +| `NEG_ZERO` | -0 | Signed-zero-sensitive ops | +| `NEG_INF` | -inf | Softmax max reduction | +| `POS_INF` | +inf | Min reduction | +| `NAN` | NaN | Debug: detect unintended OOB | +| `UNDETERMINED` | — | Default (let compiler decide) | + +**Note**: Using `ZERO` explicitly instead of `UNDETERMINED` can avoid unnecessary masking code. + +--- + +## Optimization Priority Summary + +### Memory-bound kernel priority: +1. TMA (2-78x) +2. Persistent scheduling (+50-300%) +3. Autotune (+10-50%) +4. allow_tma=False on store (+10-30%) +5. Tile size tuning (+5-20%) +6. Latency hints (+2-5%) +7. Flush to zero (+1-5%) + +### Compute-bound (MMA) kernel priority: +1. TF32 guard (~2x) +2. Tile size (M/N/K) tuning (+10-50%) +3. Autotune (num_ctas + occupancy) (+10-30%) +4. GROUP_SIZE_M swizzling (+5-15%) +5. Persistent scheduling (+20-100%) +6. Latency hints (+2-5%) diff --git a/.claude/skills/improve-cutile-kernel-perf/references/performance-model.md b/.claude/skills/improve-cutile-kernel-perf/references/performance-model.md new file mode 100644 index 0000000..eaf0c82 --- /dev/null +++ b/.claude/skills/improve-cutile-kernel-perf/references/performance-model.md @@ -0,0 +1,642 @@ + + + + +# GPU Performance Model + +A guide to GPU performance fundamentals for cuTile kernel optimization. + +## Contents +- [The Three Pillars](#the-three-pillars) +- [Arithmetic Intensity](#arithmetic-intensity) +- [Framework Comparison](#framework-comparison) +- [Autotune Examples](#autotune-examples) +- [Common Bottleneck Diagnosis](#common-bottleneck-diagnosis) +- [Profiling Guidance](#profiling-guidance) +- [Benchmark Template](#benchmark-template) +- [Performance Checklist](#performance-checklist) +- [Summary: Optimization Strategy](#summary-optimization-strategy) +- [cuTile Performance Optimization (Advanced)](#cutile-performance-optimization-advanced) + +## The Three Pillars + +Every GPU kernel's performance is governed by: **Memory Bandwidth**, **Compute Throughput**, and **Latency Hiding**. + +**Most ML kernels are memory-bound.** Optimize memory access first, then compute, then latency. + +--- + +## Arithmetic Intensity + +``` +AI = FLOPs / Bytes Transferred +``` + +| AI < 10 = Memory-bound (element-wise, reductions) | AI > 50 = Compute-bound (GEMM, attention) | + +--- + + +## Framework Comparison + +| Aspect | CUDA | cuTile | PyTorch | +|--------|------|--------|---------| +| **Paradigm** | Thread-based | Tile-based | Automatic | +| **Tuning** | Manual | Autotune (occupancy, num_ctas, tile sizes) | Automatic | +| **Tensor Cores** | WMMA API | `ct.mma` | Automatic | +| **Shared Memory** | Explicit | Automatic | Automatic | +| **Profiling** | Nsight | Nsight | PyTorch Profiler | +| **Control** | Maximum | High | Minimal | + +--- + +## Autotune Examples + +### cuTile Autotune + +cuTile uses **autotune** to find optimal occupancy, num_ctas, and tile sizes at runtime. +Do NOT hardcode `occupancy=` in `@ct.kernel()` — instead, let the autotuner search over it. + +```python +@ct.kernel +def optimized_kernel(input, output, n_items: ct.Constant[int], ...): + bid = ct.bid(0) + num_programs = ct.num_blocks(0) + for item_idx in range(bid, n_items, num_programs): + data = ct.load(input, index=(item_idx, 0), ...) + result = compute(data) + ct.store(output, index=(item_idx, 0), tile=result) +``` + +**cuTile Occupancy (via Autotune):** + +Occupancy controls how many thread blocks can run concurrently per SM. +The autotuner searches over occupancy values to find the best one: + +| Occupancy Range | Best For | Example Kernels | +|-----------------|----------|-----------------| +| 1-4 | Compute-bound (heavy math) | Complex transforms | +| 4-8 | Balanced (GEMM, TMA) | Matrix multiply | +| 8-16 | Memory-bound (reductions) | Softmax, LayerNorm | +| 16-32 | Very light (copies, casts) | Type conversions | +**Grid Size Calculation (with autotune):** +```python +NUM_SM = torch.cuda.get_device_properties(device).multi_processor_count +# occupancy comes from autotune config, e.g., cfg.occupancy +num_programs = min(NUM_SM * cfg.occupancy, n_items) +grid = (num_programs, 1, 1) +``` + +--- + +## Common Bottleneck Diagnosis + +### Memory-Bound Symptoms + +**Indicators:** +- Low compute utilization (<50%) +- High memory throughput (>80%) +- Nsight shows "Memory Bound" classification + +**Fixes by Framework:** + +| Framework | Solution | +|-----------|----------| +| **CUDA** | Vectorized loads (`float4`), coalesced access, shared memory tiling | +| **cuTile** | `ct.load` for aligned access (compiler uses TMA automatically), `ct.gather`/`ct.scatter` for arbitrary offsets | + +```python +# cuTile: Block-aligned access — compiler will use TMA automatically +data = ct.load(input, index=(bid, 0), shape=(TILE_M, TILE_K)) +``` + +### Compute-Bound Symptoms + +**Indicators:** +- High compute utilization (>80%) +- Low memory throughput +- Nsight shows "Compute Bound" classification + +**Fixes by Framework:** + +| Framework | Solution | +|-----------|----------| +| **CUDA** | Tensor cores (`wmma::mma_sync`), fast math intrinsics, reduced precision | +| **cuTile** | `ct.mma` with proper accumulator, mixed precision | + +```python +# cuTile: Explicit MMA +acc = ct.mma(a_tile, b_tile, acc=acc) # acc= is REQUIRED +``` + +### Latency-Bound Symptoms + +**Indicators:** +- Achieved occupancy <25% +- High register usage per thread +- Many stalls in Nsight + +**Fixes by Framework:** + +| Framework | Solution | +|-----------|----------| +| **CUDA** | `__launch_bounds__`, `--maxrregcount`, smaller tiles | +| **cuTile** | Tune occupancy via autotune, persistent scheduling | + +```python +# CUDA: Limit register usage +__global__ __launch_bounds__(256, 2) // Max threads, min blocks per SM +void kernel(...) { ... } + +# cuTile: Persistent scheduling + autotune occupancy +@ct.kernel +def kernel(...): + for item in range(bid, n_items, num_programs): # Work sharing + ... +``` + +--- + +## Profiling Guidance + +### Nsight Compute (All Frameworks) + +```bash +# Full profiling +ncu --set full -o profile_output ./my_app + +# cuTile kernel profiling +ncu --set full python my_cutile_script.py +``` + +**Key Metrics to Check:** + +| Metric | Target | Indicates | +|--------|--------|-----------| +| SM Throughput | >80% | Good compute utilization | +| Memory Throughput | >80% | Good bandwidth utilization | +| Achieved Occupancy | >50% | Adequate latency hiding | +| L1 Hit Rate | >80% | Good cache utilization | + +### cuTile-Specific Profiling + +```python +# Manual timing +torch.cuda.synchronize() +start = time.time() +ct.launch(stream, grid, kernel, args) +torch.cuda.synchronize() +elapsed = time.time() - start +print(f"Kernel time: {elapsed * 1000:.2f} ms") +``` + +**Environment Variables (cuTile framework):** +```bash +CUDA_TILE_LOGS=CUTILEIR # Show compilation IR +CUDA_TILE_ENABLE_CRASH_DUMP=1 # Enable crash dump +``` + +**Environment Variables (TileGym project convention — NOT part of cuTile):** +```bash +DISABLE_CUTILE_TUNE=1 # Disable autotuning (use fixed configs) + # This is an TileGym-specific convention used in tilegym kernels, + # not a cuTile framework feature. +``` + +--- + +## Benchmark Template + +Benchmark cuTile kernel performance: + +```python +import torch +import time + +def benchmark_cutile(fn, x, n_warmup=10, n_rep=100): + """Simple benchmark for cuTile kernels.""" + # Warmup + for _ in range(n_warmup): + fn(x) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(n_rep): + torch.cuda.synchronize() + start = time.perf_counter() + fn(x) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + times.append(elapsed * 1000) # ms + + ms = sum(times) / len(times) + + # Calculate bandwidth (read + write) + bytes_transferred = 2 * x.numel() * x.element_size() + bandwidth_gbps = bytes_transferred / ms * 1e-6 + print(f"Kernel time: {ms:.3f} ms, Bandwidth: {bandwidth_gbps:.1f} GB/s") + return ms +``` + +--- + +## Performance Checklist + +When a translated kernel is slower than expected: + +### Priority 1: Algorithmic Issues (10-100x Impact) + +- [ ] Is persistent scheduling used? (cuTile) +- [ ] Is grid size reasonable (NUM_SM * occupancy from autotune)? +- [ ] Is work distribution balanced? +- [ ] Are you using the right memory access pattern (`ct.load` vs `ct.gather`)? + +### Priority 2: Memory Access (2-10x Impact) + +- [ ] Are accesses coalesced? +- [ ] Are block sizes aligned to memory transaction sizes? +- [ ] Is shared memory used effectively? + +### Priority 3: Occupancy (1.2-2x Impact) +- [ ] Is autotune configured with a wide range of occupancy values? +- [ ] Is occupancy appropriate for workload type (see Occupancy Range table)? +- [ ] Are there register spills? + +### Priority 4: Microoptimizations (1.05-1.2x Impact) + +- [ ] Minimize type conversions +- [ ] Hoist invariants out of loops +- [ ] Avoid redundant tensor creations + +--- + +## Summary: Optimization Strategy + +``` +1. PROFILE FIRST + - Identify bottleneck (memory, compute, latency) + - Use Nsight Compute for detailed analysis + +2. OPTIMIZE THE BOTTLENECK + +-- Memory-bound -> Improve access patterns, increase reuse + +-- Compute-bound -> Use tensor cores, reduce precision + +-- Latency-bound -> Increase occupancy, add prefetching + +3. USE CUTILE FEATURES + +-- autotune (occupancy, num_ctas, tile sizes) + persistent scheduling + +4. VERIFY CORRECTNESS + - Always check numerical accuracy after optimization + - Use appropriate tolerances (1e-3 for FP32, 1e-2 for FP16) + +5. ITERATE + - Profile again after each optimization + - New bottleneck may emerge +``` + +**Key Takeaways:** +- Most kernels are memory-bound - optimize memory access first +- cuTile's autotune handles many optimizations automatically +- Profile before optimizing - don't guess at bottlenecks +- Use tensor cores (`ct.mma`) whenever possible for matrix operations + +--- + +## cuTile Performance Optimization (Advanced) + +This section covers advanced cuTile-specific optimizations discovered through production kernel development. + +### Static Persistent Scheduling (HIGHEST IMPACT) + +**Problem**: Naive 1:1 block-to-work mapping severely underutilizes GPU. + +**Bad Pattern (Poor GPU Utilization):** +```python +@ct.kernel +def naive_kernel(input, output, ...): + bid = ct.bid(0) # Each block processes ONE work item + + # Process single item + data = ct.load(input, index=(bid, 0), ...) + result = compute(data) + ct.store(output, index=(bid, 0), tile=result) + +# Launch: grid = (n_items, 1, 1) +# Problem: If n_items >> NUM_SM, thousands of blocks sit idle in queue +``` + +**Good Pattern (Static Persistent Scheduling):** +```python +@ct.kernel +def optimized_kernel(input, output, n_items: ct.Constant[int], ...): + bid = ct.bid(0) + num_programs = ct.num_blocks(0) + + # Each block processes MULTIPLE items + for item_idx in range(bid, n_items, num_programs): + data = ct.load(input, index=(item_idx, 0), ...) + result = compute(data) + ct.store(output, index=(item_idx, 0), tile=result) + +# Launch: grid = (NUM_SM * cfg.occupancy, 1, 1) +# Benefit: Fixed number of blocks, each processes ~(n_items / grid_size) items +``` + +**Grid Size Calculation:** +```python +NUM_SM = torch.cuda.get_device_properties(device).multi_processor_count +# occupancy comes from autotune config (cfg.occupancy), NOT hardcoded in @ct.kernel +occupancy = 4 # Example default; in practice, use cfg.occupancy from autotune +num_programs = min(NUM_SM * occupancy, total_work_items) +grid = (num_programs, 1, 1) +``` + +**Expected Performance Gain:** +- Softmax: **+50-300%** (2-4x faster) +- Workloads with n_items > 1000: Typically **+100-200%** +- Best for row-wise/independent operations + +**When to Use:** +- Row-wise operations (softmax, layer_norm, etc.) +- Independent work items (matmul tiles, attention blocks) +- When work_items >> NUM_SM +- NOT when work_items < NUM_SM (just use grid=(work_items,)) + +--- + +### cuTile Autotune Template + +**Step 1: Define Config Generator** + +```python +from types import SimpleNamespace +import torch + +def _my_kernel_autotune_configs(): + """ + Autotune config generator. + + IMPORTANT: Cover a WIDE RANGE of configurations! + - The autotuner will find the best combination + - Don't pre-optimize by narrowing the search space + """ + # Tile sizes: Cover from smallest expected input to largest + tile_sizes = [64, 128, 256, 512, 1024] + + # Occupancy: Range is [1, 32] + occupancies = [1, 2, 4, 8, 16] + + # num_ctas: Valid values are 1, 2, 4, 8, 16 + num_ctas_options = [1, 2, 4] + + # Generate all combinations + for tile in tile_sizes: + for occ in occupancies: + for num_ctas in num_ctas_options: + yield SimpleNamespace( + TILE_SIZE=tile, + num_ctas=num_ctas, + occupancy=occ, + ) +``` + +**Step 2: Autotune Launch Function** + +> **Note:** The recommended autotune API is `ct.tune.exhaustive_search()` (see +> [Modern API](#modern-autotune-api-recommended) below). The legacy +> `ct_experimental.autotune_launch()` shown here is **deprecated** but still +> used in existing TileGym kernels. New code should prefer `exhaustive_search`. + +```python +# --- Legacy API (deprecated, still used in TileGym) --- +import cuda.tile_experimental as ct_experimental + +def _my_kernel_autotune_base(stream, input, output, N, C): + """Autotuned kernel launch with dynamic grid and args.""" + NUM_SM = torch.cuda.get_device_properties(input.device).multi_processor_count + + def args_fn(cfg): + tile_size = min(cfg.TILE_SIZE, _next_power_of_2(C)) + return (input, output, tile_size, N) + + def grid_fn(cfg): + num_programs = min(NUM_SM * cfg.occupancy, N) + return (num_programs, 1, 1) + + ct_experimental.autotune_launch( + stream, + grid_fn=grid_fn, + kernel=_my_kernel, + args_fn=args_fn, + hints_fn=lambda cfg: { + "num_ctas": cfg.num_ctas, + "occupancy": cfg.occupancy, + }, + search_space=_my_kernel_autotune_configs, + ) +``` + +#### Modern Autotune API (Recommended) + +`ct.tune.exhaustive_search()` is the replacement for the deprecated +`autotune_launch`. Key differences: +- `search_space` must be a `Sequence` (e.g. `list`), **not** a generator or `Callable`. +- Returns a `TuningResult` with `best_config` / `best_time_us`; does **not** + launch the kernel — you call `ct.launch` yourself with the tuned config. +- No built-in caching; manage your own cache if needed. + +```python +import cuda.tile as ct + +def _my_kernel_autotune_modern(stream, input, output, N, C): + """Autotuned kernel launch using the modern ct.tune API.""" + NUM_SM = torch.cuda.get_device_properties(input.device).multi_processor_count + + # search_space must be a list (Sequence), not a generator + configs = list(_my_kernel_autotune_configs()) + + def args_fn(cfg): + tile_size = min(cfg.TILE_SIZE, _next_power_of_2(C)) + return (input, output, tile_size, N) + + def grid_fn(cfg): + num_programs = min(NUM_SM * cfg.occupancy, N) + return (num_programs, 1, 1) + + result = ct.tune.exhaustive_search( + search_space=configs, + stream=stream, + grid_fn=grid_fn, + kernel=_my_kernel, + args_fn=args_fn, + hints_fn=lambda cfg: { + "num_ctas": cfg.num_ctas, + "occupancy": cfg.occupancy, + }, + ) + + # exhaustive_search does NOT launch — launch manually with best config + best = result.best_config + kernel = _my_kernel.replace_hints( + num_ctas=best.num_ctas, occupancy=best.occupancy + ) + ct.launch(stream, grid_fn(best), kernel, args_fn(best)) +``` + +**Step 3: Conditional Autotune in Forward Pass** + +```python +import os + +class MyOpcuTile(torch.autograd.Function): + @staticmethod + def forward(ctx, x, ...): + enable_autotune = os.environ.get("DISABLE_CUTILE_TUNE", "0") != "1" + + if enable_autotune: + _my_kernel_autotune_base( + torch.cuda.current_stream(), x, output, N, C + ) + else: + # Use fixed default configs + configs = {"TILE_SIZE": 256, "num_ctas": 1, "occupancy": 4} + # ... launch with fixed configs + + return output +``` + +**Autotune Parameter Ranges:** + +| Parameter | Valid Range | Description | +|-----------|-------------|-------------| +| **occupancy** | 1 - 32 | Active warps per SM | +| **num_ctas** | 1, 2, 4, 8, 16 | CTAs to fuse (powers of 2) | +| **TILE_SIZE** | Powers of 2 | Tile dimension size | + +--- + +### `ct.load` vs `ct.gather`/`ct.scatter` Selection + +> **How TMA works in cuTile:** TMA is **not** an explicit API — the cuTile +> compiler decides whether to use TMA hardware automatically when you call +> `ct.load`/`ct.store`. The `allow_tma` parameter (default `True`) is the +> only user-facing control. Your job is to choose the right API: +> **`ct.load`** for block-aligned tile access, **`ct.gather`** for arbitrary +> element offsets. + +**CRITICAL RULE**: `ct.load` works with block-aligned tile-space indices. +Use `ct.gather`/`ct.scatter` for arbitrary element offsets. + +**`ct.load` — Block-Aligned Access (compiler may use TMA):** +```python +@ct.kernel +def gemm_kernel(...): + bid_m, bid_n = ct.bid(0), ct.bid(1) + + # Block-aligned tile-space indices — compiler will use TMA when possible + a = ct.load(a_tensor, index=(bid_m, k), shape=(TILE_M, TILE_K)) +``` + +**`ct.load` Fails for Non-Aligned Ragged Access:** +```python +# Segment starts: [0, 5504, 10656, 14424] <- 10656 % 128 = 32 (NOT aligned!) + +@ct.kernel +def ragged_kernel(...): + # m_start = 10656 (not aligned to TILE_M=128) + # ct.load tile-space indexing cannot express arbitrary byte offsets! +``` + +**Solution: Use `ct.gather`/`ct.scatter`:** +```python +@ct.kernel +def ragged_kernel(...): + # Calculate exact element indices + m_indices = m_start + bid_m * TILE_M + ct.arange(TILE_M, dtype=ct.int32) + # m_indices = [10656, 10657, ..., 10783] <- Exact rows needed! + + # Gather supports arbitrary element offsets (padding defaults to 0) + a_tile = ct.gather(a, (m_indices_2d, k_indices_2d)) +``` + +**Decision Tree:** +``` +Is data access pattern block-aligned? +├─ YES -> Use ct.load/ct.store (compiler uses TMA automatically) +│ Example: Regular GEMM, batch operations +│ +└─ NO -> Use ct.gather/ct.scatter (element-level indexing, no TMA) + Examples: Ragged BMM, paged attention, sparse ops + +Special case: Mixed approach +- Use ct.load for aligned dimensions (e.g., B matrix in ragged BMM) +- Use ct.gather/ct.scatter for ragged dimensions (e.g., A, C matrices) +``` + +--- + +### Performance Anti-Patterns + +**Anti-Pattern 1: Excessive Type Conversions** +```python +# BAD: Convert for every row in loop +for row in range(...): + row_fp32 = ct.astype(row, ct.float32) + result = compute(row_fp32) + row_fp16 = ct.astype(result, ct.float16) + +# Better: Keep in fp32 longer, batch conversions +``` + +**Anti-Pattern 2: Redundant Tensor Creation** +```python +# BAD: Create mask inside loop +for i in range(n): + mask = ct.full((tm,), True, dtype=ct.bool_) # Recreated every iteration! + +# GOOD: Create once outside loop +mask = ct.full((tm,), True, dtype=ct.bool_) +for i in range(n): + # Use mask +``` + +**Anti-Pattern 3: Column Loops for Row-Wise Ops** +```python +# BAD: Softmax with column loop +for col_tile in range(num_col_tiles): + partial = ct.load(..., index=(row, col_tile), ...) + # Partial softmax on tile -> WRONG! Need full row + +# GOOD: Load entire row +row = ct.load(..., index=(row, 0), shape=(1, TILE_SIZE_COVERS_ALL_COLS)) +``` + +--- + +### Quick Performance Fix Template + +**Add Persistent Scheduling** (30 seconds): +```python +# In kernel: change from bid to loop +- bid = ct.bid(0) ++ bid = ct.bid(0) ++ num_programs = ct.num_blocks(0) + for work_id in range(bid, total_work, num_programs): + +# In launch: change grid +- grid = (n_items, 1, 1) ++ NUM_SM = torch.cuda.get_device_properties(device).multi_processor_count ++ grid = (NUM_SM * 4, 1, 1) + +# In kernel signature: add total_work +- def kernel(input, output, ...): ++ def kernel(input, output, total_work: ct.Constant[int], ...): +``` + +**Fix Slow Kernel** (2 minutes): +1. Use `@ct.kernel` +2. Add persistent loop +3. Set up autotune with occupancy in search space +4. Update grid to use `NUM_SM * cfg.occupancy` +5. Test -> Usually 2-3x faster diff --git a/.github/scripts/check_spdx_headers.py b/.github/scripts/check_spdx_headers.py index 9b1be17..987607e 100755 --- a/.github/scripts/check_spdx_headers.py +++ b/.github/scripts/check_spdx_headers.py @@ -13,7 +13,9 @@ import argparse import os +import re import sys +from datetime import datetime from pathlib import Path from typing import Dict from typing import Iterator @@ -21,10 +23,21 @@ from typing import Optional from typing import Tuple -# SPDX header content -SPDX_COPYRIGHT = "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved." +# Year constants for copyright validation +MIN_COPYRIGHT_YEAR = 2025 # TileGym project inception year +CURRENT_YEAR = datetime.now().year + +# SPDX header content — uses the current year for newly added headers +SPDX_COPYRIGHT = ( + f"SPDX-FileCopyrightText: Copyright (c) {CURRENT_YEAR} NVIDIA CORPORATION & AFFILIATES. All rights reserved." +) SPDX_LICENSE = "SPDX-License-Identifier: MIT" +# Regex pattern to validate SPDX copyright lines with any valid year or year range +SPDX_COPYRIGHT_PATTERN = re.compile( + r"SPDX-FileCopyrightText: Copyright \(c\) (\d{4})(?:-(\d{4}))? NVIDIA CORPORATION & AFFILIATES\. All rights reserved\." +) + # Comment styles for different file types COMMENT_STYLES: Dict[str, Tuple[str, str, str]] = { @@ -156,10 +169,26 @@ def create_header(prefix: str, middle: str, suffix: str) -> List[str]: def has_spdx_header(content: str) -> bool: - """Check if content already has SPDX headers.""" - # Check for both required strings within the first 10 lines + """Check if content already has SPDX headers. + + Validates that: + - An SPDX copyright line exists in the first 10 lines + - The copyright year (or year range) is between MIN_COPYRIGHT_YEAR and CURRENT_YEAR + - An SPDX license identifier line exists in the first 10 lines + """ first_lines = "\n".join(content.split("\n")[:10]) - return SPDX_COPYRIGHT in first_lines and SPDX_LICENSE in first_lines + if SPDX_LICENSE not in first_lines: + return False + match = SPDX_COPYRIGHT_PATTERN.search(first_lines) + if not match: + return False + start_year = int(match.group(1)) + end_year = int(match.group(2)) if match.group(2) else start_year + return ( + MIN_COPYRIGHT_YEAR <= start_year <= CURRENT_YEAR + and MIN_COPYRIGHT_YEAR <= end_year <= CURRENT_YEAR + and start_year <= end_year + ) def add_header_to_file(file_path: Path, comment_style: Tuple[str, str, str]) -> bool: diff --git a/ROADMAP.md b/ROADMAP.md index a5e2414..b549b92 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -62,10 +62,12 @@ The following table tracks the support status for various models. | LLaMA-3.1-8B | ✅ Available | Tested on B200 | | DeepSeek-V2-Lite-Chat | ✅ Available | Tested on B200 | | Qwen2-7B | ✅ Available | Tested on B200 | +| Qwen3.5-7B | ✅ Available | Tested on B200 | | Gemma-3-4B-IT | ✅ Available | Tested on B200 | | GPT-OSS | ✅ Available | Tested on B200 | | Mistral-7B-Instruct-v0.3 | ✅ Available | Tested on B200 | | Phi-3-mini-4k-instruct | ✅ Available | Tested on B200 | +| OLMo-3-1025-7B | ✅ Available | Tested on B200 | | More LLM models | 🙋 Help Wanted | | ### 1.3 Kernel Library Support diff --git a/modeling/transformers/bench_olmo3.sh b/modeling/transformers/bench_olmo3.sh new file mode 100755 index 0000000..266d63a --- /dev/null +++ b/modeling/transformers/bench_olmo3.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +# Benchmark script for OLMo-3-1025-7B model +# Compares PyTorch baseline vs TileGym CUTILE backend + +set -e + +MODEL_ID="allenai/Olmo-3-1025-7B" +INPUT_FILE="sample_inputs/input_prompt_32K.txt" +OUTPUT_LENGTH=50 +LOG_DIR="${LOG_DIR:-/logs}" +SUMMARY_FILE="${LOG_DIR}/olmo3_benchmark_summary.txt" + +echo "========================================" +echo " OLMo-3-1025-7B Performance Benchmark" +echo "========================================" +echo "" +echo "Model: ${MODEL_ID}" +echo "Input: ${INPUT_FILE}" +echo "Output length: ${OUTPUT_LENGTH} tokens" +echo "" + +# Clean previous results +rm -f ${SUMMARY_FILE} + +echo "Running PyTorch baseline..." +python infer.py \ + --model_id ${MODEL_ID} \ + --profile \ + --sentence_file ${INPUT_FILE} \ + --output_length ${OUTPUT_LENGTH} \ + --log_dir ${LOG_DIR} \ + --summary_file ${SUMMARY_FILE} + +echo "" +echo "Running TileGym CUTILE backend..." +python infer.py \ + --model_id ${MODEL_ID} \ + --use_tilegym \ + --use_cutile \ + --use_attn \ + --profile \ + --sentence_file ${INPUT_FILE} \ + --output_length ${OUTPUT_LENGTH} \ + --log_dir ${LOG_DIR} \ + --summary_file ${SUMMARY_FILE} + +echo "" +echo "========================================" +echo " Benchmark Results" +echo "========================================" +if [ -f ${SUMMARY_FILE} ]; then + cat ${SUMMARY_FILE} +else + echo "Summary file not found." +fi +echo "========================================" + +echo "" +echo "========================================" +echo " TileGym Kernel Coverage" +echo "========================================" +python infer.py \ + --model_id ${MODEL_ID} \ + --use_tilegym \ + --use_cutile \ + --use_attn \ + --report_kernel_coverage \ + --sentence_file ${INPUT_FILE} \ + --output_length ${OUTPUT_LENGTH} \ + --log_dir ${LOG_DIR} +echo "========================================" diff --git a/modeling/transformers/infer.py b/modeling/transformers/infer.py index 5278e2f..27095dd 100644 --- a/modeling/transformers/infer.py +++ b/modeling/transformers/infer.py @@ -27,6 +27,7 @@ from tilegym.transformers import apply_tilegym_kernel_to_gpt_oss from tilegym.transformers import apply_tilegym_kernel_to_llama from tilegym.transformers import apply_tilegym_kernel_to_mistral +from tilegym.transformers import apply_tilegym_kernel_to_olmo3 from tilegym.transformers import apply_tilegym_kernel_to_phi3 from tilegym.transformers import apply_tilegym_kernel_to_qwen2 from tilegym.transformers import apply_tilegym_kernel_to_qwen3 @@ -278,6 +279,8 @@ def apply_tilegym_patch(model_id, use_attn=False, use_cutile=False): apply_tilegym_kernel_to_gemma3(rope=True, rms_norm=True, mlp=True, attn=use_attn, use_cutile=use_cutile) elif "phi-3" in model_name or "phi3" in model_name: apply_tilegym_kernel_to_phi3(rope=True, rms_norm=True, swiglu=True, attn=use_attn, use_cutile=use_cutile) + elif "olmo-3" in model_name or "olmo3" in model_name: + apply_tilegym_kernel_to_olmo3(rope=True, rms_norm=True, swiglu=True, attn=use_attn, use_cutile=use_cutile) else: print(f"Warning: Model {model_id} is not supported in tilegym patch. No optimizations will be applied.") @@ -324,6 +327,9 @@ def __init__(self): "_silu_and_mul_separate_kernel", "_causal_conv1d_prefill_silu_kernel", "_residual_add_rms_norm_kernel", + # Fused OLMo-3 cuTile kernels + "_rms_norm_residual_add_kernel", + "_dual_rms_norm_kernel", # Reduce kernels "splitk_reduce_kernel", # GEMM kernels diff --git a/src/tilegym/transformers/__init__.py b/src/tilegym/transformers/__init__.py index 1b378bf..0a2d183 100644 --- a/src/tilegym/transformers/__init__.py +++ b/src/tilegym/transformers/__init__.py @@ -7,6 +7,7 @@ from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_gpt_oss from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_llama from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_mistral +from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_olmo3 from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_phi3 from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_qwen2 from tilegym.transformers.monkey_patch import apply_tilegym_kernel_to_qwen3 diff --git a/src/tilegym/transformers/monkey_patch.py b/src/tilegym/transformers/monkey_patch.py index bf19db0..0a61caa 100644 --- a/src/tilegym/transformers/monkey_patch.py +++ b/src/tilegym/transformers/monkey_patch.py @@ -434,6 +434,66 @@ def apply_tilegym_kernel_to_phi3( ALL_ATTENTION_FUNCTIONS["sdpa"] = get_fmha_phi3_interface() +def apply_tilegym_kernel_to_olmo3( + rope: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + attn: bool = True, + model: PreTrainedModel = None, + use_cutile: bool = False, +) -> None: + """ + Apply TileGym kernels to replace original implementation in HuggingFace OLMo-3 models. + + OLMo-3 uses a Llama-like architecture with: + - Post-normalization (RMSNorm after residual addition) + - Q/K normalization (RMSNorm on projected Q and K before attention) + - SwiGLU MLP (separate gate/up/down projections with SiLU) + - YARN RoPE for extended context + - Mixed sliding window + full attention (3:1 pattern) + + Args: + rope (bool): Whether to apply TileGym's rotary position embedding. Default is True. + rms_norm (bool): Whether to apply TileGym's RMSNorm. Default is True. + swiglu (bool): Whether to apply TileGym's SwiGLU MLP. Default is True. + attn (bool): Whether to apply TileGym's attention. Default is True. + model (PreTrainedModel): The model instance to apply TileGym kernels to, if the model has already been + loaded. Default is None. + use_cutile (bool): Whether to apply using cutile. Default is False. + """ + logger.info("--------------------------------") + logger.info("apply_tilegym_kernel_to_olmo3") + logger.info("--------------------------------") + from transformers.models.olmo3 import modeling_olmo3 + + if use_cutile: + set_backend("cutile") + + if rope: + modeling_olmo3.apply_rotary_pos_emb = get_apply_rope_func(model="llama") + if rms_norm: + modeling_olmo3.Olmo3RMSNorm = get_rms_norm_module() + if swiglu: + modeling_olmo3.Olmo3MLP = get_fused_swiglu_module() + if attn: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + ALL_ATTENTION_FUNCTIONS["sdpa"] = get_fmha_interface() + + if use_cutile: + from tilegym.transformers.olmo3.modeling_olmo3 import FusedOlmo3MLP + from tilegym.transformers.olmo3.modeling_olmo3 import _attention_forward_tilegym + from tilegym.transformers.olmo3.modeling_olmo3 import _decoder_layer_forward_tilegym + + if swiglu: + modeling_olmo3.Olmo3MLP = FusedOlmo3MLP + logger.info("Replaced Olmo3MLP with FusedOlmo3MLP (linear_gluact_linear)") + modeling_olmo3.Olmo3Attention.forward = _attention_forward_tilegym + logger.info("Patched Olmo3Attention.forward with fused dual Q/K RMSNorm") + modeling_olmo3.Olmo3DecoderLayer.forward = _decoder_layer_forward_tilegym + logger.info("Patched Olmo3DecoderLayer.forward with fused residual_add+RMSNorm") + + MODEL_TYPE_TO_APPLY_TILEGYM_FN = { "llama": apply_tilegym_kernel_to_llama, "deepseek_v2": apply_tilegym_kernel_to_deepseek_v2, @@ -443,6 +503,7 @@ def apply_tilegym_kernel_to_phi3( "qwen3_5": apply_tilegym_kernel_to_qwen3, "gemma3": apply_tilegym_kernel_to_gemma3, "phi3": apply_tilegym_kernel_to_phi3, + "olmo3": apply_tilegym_kernel_to_olmo3, } diff --git a/src/tilegym/transformers/olmo3/__init__.py b/src/tilegym/transformers/olmo3/__init__.py new file mode 100644 index 0000000..8476bc5 --- /dev/null +++ b/src/tilegym/transformers/olmo3/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT diff --git a/src/tilegym/transformers/olmo3/modeling_olmo3.py b/src/tilegym/transformers/olmo3/modeling_olmo3.py new file mode 100644 index 0000000..537ecba --- /dev/null +++ b/src/tilegym/transformers/olmo3/modeling_olmo3.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +"""OLMo-3-specific cuTile kernel wrappers and patched decoder layer forward.""" + +import cuda.tile as ct +import torch +import torch.nn as nn + +ConstInt = ct.Constant[int] + + +@ct.kernel +def _dual_rms_norm_kernel( + q, # (N, D) — projected Q, normalized in-place + k, # (N, D) — projected K, normalized in-place + q_weight, # (D,) + k_weight, # (D,) + eps: float, + D: ConstInt, + TILE_D: ConstInt, +): + """Fused in-place RMSNorm for Q and K in a single kernel launch. + + Grid: (N,). Each block normalizes both q[bid] and k[bid] in-place. + No race conditions since each block owns a unique row pair. + """ + PAD = ct.PaddingMode.ZERO + bid = ct.bid(0) + + # ---- Normalize Q row ---- + q_h = ct.load(q, index=(bid, 0), shape=(1, TILE_D), padding_mode=PAD).reshape((TILE_D,)).astype(ct.float32) + q_w = ct.load(q_weight, index=(0,), shape=(TILE_D,), padding_mode=PAD).astype(ct.float32) + q_var = ct.sum(q_h * q_h) * ct.truediv(1.0, D) + q_normed = q_h * ct.rsqrt(q_var + eps) * q_w + ct.store(q, index=(bid, 0), tile=q_normed.reshape((1, TILE_D)).astype(q.dtype)) + + # ---- Normalize K row ---- + k_h = ct.load(k, index=(bid, 0), shape=(1, TILE_D), padding_mode=PAD).reshape((TILE_D,)).astype(ct.float32) + k_w = ct.load(k_weight, index=(0,), shape=(TILE_D,), padding_mode=PAD).astype(ct.float32) + k_var = ct.sum(k_h * k_h) * ct.truediv(1.0, D) + k_normed = k_h * ct.rsqrt(k_var + eps) * k_w + ct.store(k, index=(bid, 0), tile=k_normed.reshape((1, TILE_D)).astype(k.dtype)) + + +def dual_rms_norm_cutile( + q: torch.Tensor, + k: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused in-place RMSNorm for Q and K in a single kernel launch.""" + D = q.shape[-1] + q_flat = q.contiguous().view(-1, D) + k_flat = k.contiguous().view(-1, D) + N = q_flat.shape[0] + TILE_D = 1 << (D - 1).bit_length() + ct.launch( + torch.cuda.current_stream(), + (N,), + _dual_rms_norm_kernel, + (q_flat, k_flat, q_weight, k_weight, eps, D, TILE_D), + ) + return q, k + + +@ct.kernel +def _rms_norm_residual_add_kernel( + x, # (N, D) — branch output (attn or mlp) + residual, # (N, D) — residual from before the branch + weight, # (D,) — RMSNorm weight + out, # (N, D) — residual + rms_norm(x) + eps: float, + D: ConstInt, + TILE_D: ConstInt, +): + """Fused RMSNorm + residual add for OLMo-3 post-normalization. + + Computes: out = residual + weight * x * rsqrt(mean(x^2) + eps) + + OLMo-3 uses post-norm (norm on branch output, then add residual), + so this fuses what would be two separate ops into one kernel. + """ + bid = ct.bid(0) + offs = ct.arange(TILE_D, dtype=ct.int32) + + # Load in float32 for numerical stability + h = ct.astype(ct.gather(x, (bid, offs), padding_value=0.0, check_bounds=True), ct.float32) + r = ct.astype(ct.gather(residual, (bid, offs), padding_value=0.0, check_bounds=True), ct.float32) + w = ct.astype(ct.gather(weight, (offs,), padding_value=0.0, check_bounds=True), ct.float32) + + # RMSNorm: weight * x * rsqrt(mean(x^2) + eps) + variance = ct.sum(h * h) * ct.truediv(1.0, D) + normed = h * ct.rsqrt(variance + eps) * w + + # Residual add + result = r + normed + + ct.scatter(out, (bid, offs), ct.astype(result, out.dtype), check_bounds=True) + + +def rms_norm_residual_add_cutile( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """Fused RMSNorm + residual add. Returns residual + rms_norm(x).""" + D = x.shape[-1] + x_flat = x.contiguous().view(-1, D) + r_flat = residual.contiguous().view(-1, D) + N = x_flat.shape[0] + out = torch.empty_like(x_flat) + TILE_D = 1 << (D - 1).bit_length() + ct.launch( + torch.cuda.current_stream(), + (N,), + _rms_norm_residual_add_kernel, + (x_flat, r_flat, weight, out, eps, D, TILE_D), + ) + return out.view(x.shape) + + +class FusedOlmo3MLP(nn.Module): + """Fully fused SwiGLU MLP using linear_gluact_linear (single kernel). + + Replaces PartiallyFusedSwiGLUMLP's 3-kernel pattern (matmul + silu_and_mul + matmul) + with a single fused kernel: silu(x @ W_gate^T) * (x @ W_up^T) @ W_down^T. + """ + + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def forward(self, x): + from tilegym.ops import linear_gluact_linear + + return linear_gluact_linear( + input=x, + weight_act=self.gate_proj.weight, + weight_noact=self.up_proj.weight, + weight2=self.down_proj.weight, + act_type="silu", + ) + + +def _attention_forward_tilegym( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple, + attention_mask=None, + past_key_values=None, + cache_position=None, + **kwargs, +): + """Patched Olmo3Attention.forward with fused dual Q/K RMSNorm.""" + from transformers.models.olmo3.modeling_olmo3 import ALL_ATTENTION_FUNCTIONS + from transformers.models.olmo3.modeling_olmo3 import apply_rotary_pos_emb + from transformers.models.olmo3.modeling_olmo3 import eager_attention_forward + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Fused dual Q/K RMSNorm (single kernel instead of two) + q_norm_eps = getattr(self.q_norm, "variance_epsilon", getattr(self.q_norm, "eps", 1e-6)) + query_states, key_states = dual_rms_norm_cutile( + query_states, + key_states, + self.q_norm.weight, + self.k_norm.weight, + q_norm_eps, + ) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def _decoder_layer_forward_tilegym( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_values=None, + use_cache=None, + cache_position=None, + position_embeddings=None, + **kwargs, +) -> torch.Tensor: + """Patched Olmo3DecoderLayer.forward with fused RMSNorm + residual add.""" + from transformers.models.olmo3.modeling_olmo3 import apply_rotary_pos_emb + + # ---- Self-attention ---- + residual = hidden_states + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Fused: post_attention_layernorm(hidden_states) + residual + norm = self.post_attention_layernorm + eps = getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-6)) + hidden_states = rms_norm_residual_add_cutile(hidden_states, residual, norm.weight, eps) + + # ---- MLP ---- + residual = hidden_states + hidden_states = self.mlp(hidden_states) + + # Fused: post_feedforward_layernorm(hidden_states) + residual + norm = self.post_feedforward_layernorm + eps = getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-6)) + hidden_states = rms_norm_residual_add_cutile(hidden_states, residual, norm.weight, eps) + + return hidden_states