Skip to content

tilelang Path C: review-driven fix-wave + MLX wiring + 5 review/fix loop iterations#1

Closed
apstenku123 wants to merge 30 commits intomainfrom
mlx-z3-wiring
Closed

tilelang Path C: review-driven fix-wave + MLX wiring + 5 review/fix loop iterations#1
apstenku123 wants to merge 30 commits intomainfrom
mlx-z3-wiring

Conversation

@apstenku123
Copy link
Copy Markdown
Collaborator

Summary

Brings mlx-z3-wiring into main after 5 waves of multi-provider AI review→fix loops (Grok-4 ext + Meta + GPT-5.5-pro extended) plus parallel MLX-wiring PassConfig opt-ins.

29 commits ahead of main: initial 12-fix wave (ebbb84b) + 27 follow-up fixes/migrations on top.

Key changes

Initial fix-wave (ebbb84b)

12 fixes across _msl_transform.py, fp8_vecmat_path_c.py, kernel docstrings, routing doc, tests, bench harness, _experimental.py. Addressed Grok/Meta/GPT findings: intrinsic registration P0 (kPure→kReadState for position-reads, fixed kOpaque misnaming), shape contract (1,N) macro vs (N,) fallback, routing-doc honesty (PROBE-ONLY/REDUCERS-ONLY/BROKEN labels), Path C↔B parity tests, schema gate v2, env hygiene autouse fixture, snapshot rot archive.

MLX-wiring (Z3-roadmap PassConfig opt-ins)

  • 9a668ea tl.dot4_legality opt-in in fp8_vecmat_path_c
  • 9202646 tl.intra_warp_barrier_elision opt-in in sparse_mla + topk_selector
  • 9242368 thread-safety lock on _PASS_CONFIGS_CACHE
  • 60f0005 K>0 guard + narrow exception in fp8_vecmat_path_c

Wave 1-5 review→fix loop (grok-4 extended, correctness+performance)

Wave Commit Findings
1 fix 2b9310e (DSA) + integrated warning logs on silent None, apply_simplify warn, intrinsic registration cached, vectorized_loads probe documented, _grid_for_lowering delegated to canonical helper
2 fix integrated + DSA wave-2 commits idx_scores_f -INF priming (P0), Stage-2 partial Q hoist, Metal block AH-aware, fp8_vecmat output_shape canonicalization, lock unification, vectorized_loads cache key, scale-resolve fast path
3 fix ab17d21 TOCTOU+lazy libz3 preload, narrow exceptions in tl_lower path, dispatch input/buffer count validation, _as_metal_target lru_cache, lowering result cache, bench env-prep idempotency
4 fix c66ce56 robust _parse_buffer_param_names (split/strip/extract helpers), dispatch None-kernel guard, recursive _freeze_for_hash, bounded LRU _LOWERING_CACHE (CPPMEGA_LOWERING_CACHE_SIZE)
5 fix 741fc74 multi-param test corpus + validation, reset_libz3_preload_state(), stable _freeze_for_hash (no repr()), dispatch None-guard with reason

Migration phase 1+2 (3bccf84/e3340b7/f1ca48b)

Unified tilelang.engine.lower path for fp8_amax + dsa_splitk (gated by CPPMEGA_MLX_TILELANG_ENGINE), deprecate _msl_transform shim, topk_selector engine path via dispatch_lower.

Misc

  • Round 2-7 fixes from grok review feedback (libz3 hardening security: drop /tmp from prod candidates, gate behind env, distinguish OSError vs FileNotFoundError, retry caps).
  • DSA splitK + FP8 amax integrations + their wave-2/3 self-audits.
  • 17 review reports preserved as audit trail under reports/grok/, reports/chatgpt/, reports/meta/.

False-positive findings rejected (per memory rule "verify before relaying")

  • metal_grid_for_lowering multiplication: grok flagged 3 times as P0; verified correct each time (TileLang grid = threadgroup extents, MLX wants total threads).
  • SCHEMA_VERSION "missing from __all__": verified present at bench script :58.
  • MTL_DEBUG_LAYER/etc. "not covered by conftest": verified covered by MTL_ prefix.
  • _parse_buffer_param_names "hot-path cost": verified buffer_param_names already cached on TileLangMSLLowering:446, dispatch reads cached at :321.

Test plan

  • pytest tests/test_tilelang_msl_transform.py — should be 13 passed / 5 xfailed (xfails strict, will flip green when transforms land upstream)
  • pytest tests/test_tilelang_path_c_vs_b_parity.py — Path C↔B numeric parity (xfail-gated where Path C apply not yet exposed)
  • pytest tests/test_tilelang_fp8_vecmat_path_c.py::test_fp8_e4m3_dot4_intrinsic_is_registered — intrinsic registration trip-wire
  • Bench: python scripts/bench_tilelang_fp8_path_c.py --shapes vecmat_4096 --iters 50 — confirm no regression vs Path B
  • Manual: CPPMEGA_VALIDATE_PARSE_TESTS=1 python -c "from cppmega_mlx.nn._tilelang import _msl_transform" — multi-param signature parser self-test

apstenku123 added 30 commits May 7, 2026 03:20
…ecmat)

Two MLX threads lowering the Path C TileLang kernels concurrently could race
the first-time populate of the per-module ``_PASS_CONFIGS_CACHE`` dicts in
``topk_selector`` and ``fp8_vecmat_path_c``. The probe loop runs
``tvm.transform.PassContext`` per candidate key and is not safe to enter
twice. Guard the read-or-populate region with a per-module
``threading.Lock()`` (separate locks so the kernels don't serialize each
other's first build).
- Fuse two T.Parallel loops in amax kernel: load+abs+cast in one pass.
  Eliminates the missing shared-memory barrier (data race between
  X_shared write and read) and the redundant global->shared->fragment
  hop flagged by the perf review (~30-60% amax speedup).
- Drop globals().update() side-effect in both make_fp8_amax_kernel /
  make_fp8_quantize_kernel; the PrimFunc only ever read local
  N/BLOCK/FP8_MAX, so the module-level mutation was dead code.
- Add empty-tensor fast paths to fp8_quantize_tilelang and
  fp8_pack_tilelang (clamp=True path) — n_elements>0 was implicit.
- Honor the public out= contract in fp8_quantize_tilelang: when the
  caller-supplied out is non-contiguous, write to a contiguous scratch
  and copy_back so the returned tensor is always the user's buffer.
Apply concrete fixes from the grok-4 review of integration #06 (DSA
splitK indexer-loss kernel port to TileLang):

* Remove the globals() mutation hack in both stage-1 and stage-2 builders
  -- BLOCK_*/threads were already passed explicitly through the
  lru_cache-keyed entry points, so the global side-effect was redundant
  and risked stale values across concurrent JIT compilation.
* Guard IndexMask / IndexScores reads against OOB on boundary tiles
  (ASq % BLOCK_SQ != 0 or Sk % BLOCK_SK != 0) -- the original Triton
  kernels use tl.load(..., mask=...) for the equivalent guard but the
  TileLang port did unconditional loads, which would read garbage on
  any non-aligned shape.
* Tighten Metal block overrides from 64x64x32 to 32x32x16. The earlier
  sizing only accounted for fp16 shared Q/K (16 KB) and missed the
  fp32 register fragments: stage 2 alone needed 64 KB of fragments,
  causing register spilling on M-series (32 KB threadgroup budget).
  32x32x16 keeps total footprint well under the limit.
* Replace the AB*ASq*Sk*4-byte zeros() allocation in the non-sparse
  path with empty(); the kernel never reads the tensor when SPARSE is
  False (constexpr-eliminated, plus the new bounds-guard ensures even
  boundary tiles skip the load).
* Skip redundant .contiguous() / .to(fp32) copies when tensors are
  already in the canonical layout, avoiding device-to-device copies
  on every forward pass.
…amic BLOCK_SIZE, shim type consistency

Applies the four wave-1 deferred grok review items on fp8_amax.py:

1. BLOCK_SIZE-vs-threads correctness: builders now refuse to compile when
   block_size % threads != 0 with a precise diagnostic; the strided
   T.Parallel(BLOCK) inner loop requires a clean stride or the last block
   leaves a partial tail uncovered.

2. JIT-cache thrashing: amax dispatch buckets by next-pow2 of n_elements
   (via _bucket_n), so 5 close call shapes (4097..8192) now compile a
   single shared kernel. Tail is zero-padded; amax(0) is the identity
   for max so the result is unchanged. lru_cache size bumped to 256 for
   both amax and quantize. Quantize stays per-shape because output sizing
   matters per-element. Added precompile_amax_kernel() warm-up API.

3. Dynamic BLOCK_SIZE picker: _pick_block_size(target, n) replaces the
   hard-coded 1024/128 with a per-target table -- cuda(1024,128),
   hip(1024,256), metal(256,64) -- documented at the top of the file.
   Tiny inputs shrink block to next pow2 >= threads.

4. tilelang_supports diagnostics: added tilelang_supports_with_reason()
   returning (bool, str) on every path; tilelang_supports() is now a
   thin wrapper that drops the reason for backward compat with
   cppmega/megatron/fp8_activations.py callers.
- Stage 1: hoist Q load to Q_full=(BLOCK_SQ, AD) shared once per (sq_block, h);
  inner d_tile copies from Q_full instead of re-reading Q from HBM per sk_tile.
  Saves SK_TILES-1 redundant HBM reloads of Q.
- Stage 2: pre-load M[b, h, sq] / D[b, h, sq] into M_pre/D_pre fragments of
  shape (AH, BLOCK_SQ) once per sq_block; inner sk_tile->h loop reads from
  registers instead of HBM. Stage-2 Q hoist is harder (h iterates inside
  sk_tile due to softmax_attn accumulator) — left as a TODO with concrete
  path forward.
- Both stages: trim sk_tile loop bound to
  min(SK_TILES, max(sq_block_id*BLOCK_SQ + BLOCK_SQ - 1, ASq-1) // BLOCK_SK + 1)
  so causal-mask-zero tiles are skipped entirely. On the last Q-block the
  trim is a no-op; for Q-block 0 with ASq>=BLOCK_SQ it cuts to a single
  sk_tile.

Followups still tracked:
- Stage 2 Q hoist (TODO in code).
Wave-3 self-audit (wave-2 grok review came back empty after retries).
Highest-impact item addressed in cppmega.mlx:

- fp8_pack_tilelang now raises FloatingPointError when the host-side
  amax_val is NaN/Inf instead of falling through to inv_scale = fp8_max
  / amax which would silently produce 0/NaN scales and poison every FP8
  output element. Caller gets a clear diagnostic instead of garbage
  weights downstream.
…validation

Wave-2 grok re-review came back empty for this packet (no response x 5
aspects); this self-audit pass picks 3 high-impact items off the wave-1
review and adds them on top of wave-2 commit 2b9310e.

1. Clamp _active_sk_tiles to >= 1 in both stages. When ASq=0 or
   sq_block_id*BLOCK_SQ exceeds ASq (last Q-block in non-divisible shapes,
   pathological zero-rows), _max_useful_sk goes negative and the trim
   collapses to 0, skipping the loop and leaving accumulators uninitialised.
   The clamp guarantees the loop body runs once; the per-position guards
   already produce no writes for OOB sq_idx, so output remains correct.

2. Wrapper validation for topk_indices: enforce shape (AB, ASq, *),
   dtype int32|int64 with int32→int64 promotion (PyTorch scatter requires
   int64), device parity, and contiguity. Surfaces these mismatches at the
   wrapper boundary instead of an unhelpful "expected scalar type Long"
   crash deep in the C++ scatter implementation.

3. Note: the partial Q hoist in stage 2 (wave-2 TODO) is already present
   in the committed file at line 689 -- per-(sk_tile, h) full-AD slab into
   Q_full shared, d_tile reads from shared. Full out-of-sk_tile hoist in
   stage 2 still TODO (would need accumulator restructure).

Wave-2 attribution note: the wave-2 commit 2b9310e accidentally swept in
~179 lines of unrelated fp8_vecmat_path_c.py (belongs to integration-05).
That sliver remains in 2b9310e history -- not addressed here.
Wave-1b meta findings:

- HIGH perf (M_pre/D_pre register spill): the per-(sq_block) prefetch
  fragments are AH*BLOCK_SQ fp32 each, totalling 8*AH*BLOCK_SQ bytes
  per thread block. At AH=128, BLOCK_SQ=128 (CUDA worst case) that
  reaches 128 KB combined — far above per-block register budgets and
  causes spill-to-local-mem (-30/-40% on 2k seqlen). Gate the
  prefetch behind a 32 KB combined budget; over-budget compiles fall
  back to per-(sk_tile, h) HBM reads (the original pre-Wave-2 path).

- MED sec (scatter_ bounds + NaN guard): topk_indices were fed into
  scatter_ without an upper-bound check. PyTorch's CUDA scatter_ does
  NOT validate upper bound in release builds, so an OOB index would
  silently corrupt adjacent memory. Validate range [0, Sk) up-front,
  raise ValueError on violation. Additionally, when a row has zero
  in-range indices, every IndexMask slot stays -inf and the
  downstream softmax produces NaN that poisons the loss. Detect
  fully-masked rows and patch slot 0 to a safe sentinel.

Verified citations (cite line numbers were off but findings real):
- M_pre/D_pre fragment alloc actually at lines 638-639 (cite said
  575-576).
- scatter_ actually in dsa_splitk_indexer_loss.py:1065 (cite said
  topk_*.py).

NOTE: Path C agent (a1f7b2a8) is concurrently touching topk_selector
/ sparse_mla / fp8_vecmat — this fix is scoped to
dsa_splitk_indexer_loss.py to avoid conflict.
…tal)

- _msl_transform.py: preload libz3.dylib on Darwin so dev-build
  libtilelang.dylib's bare-basename dlopen succeeds (otherwise every
  Path C kernel silently no-ops via its dispatch try/except).
- sparse_mla_path_c.py + topk_selector.py: replace -T.infinity('float32')
  with T.float32(-1.0e38) sentinel; Metal MSL codegen emits inf as
  HUGE_VALF/INFINITY which trips the canonicalizer/algebraic rewrites.
- bench/tilelang_ports/topk_selector.json: refreshed receipt under the
  fixed dispatch path.
…base hoists

z3-final's CSE/algebraic-simplify passes hoist address-base computations
(int h, b, g, q_row_base, kv_b_base, idx_base, out_row, d_out_row,
dkv_partial_base) to the top of the kernel body — *above* the
'float sm_scale = sm_scale_buf[0];' marker that the canonicalize passes
key off when injecting 'int gid = int(blockIdx.x);'. The hoisted
expressions then reference an undeclared 'gid' and collide (with
conflicting int vs uint types) with the unsigned address basics that
_canonicalize_fwd/bwd_base_indexing re-emit after the marker.

This patch:
- Adds _strip_z3_hoisted_address_decls() and runs it at the top of
  _canonicalize_fwd_lane_indexing and _canonicalize_bwd_lane_indexing
  to delete the hoisted decls before injection. The canonicalization
  passes already re-emit equivalent uint versions further down.
- Drops tautological lane-loop aliases ('int k = k;', 'int kd = ...',
  'int d_N = d;', 'int kd = ((kd - tid) + tid);') that the lane-loop
  rewrite leaves behind once the inner-body index expressions collapse.
- Promotes bare 'gather_idx = indices[...]' assignments to declarations
  in the fwd output-projection / bwd dq/output paths so the symbol is
  in scope after we strip the original 'thread int gather_idx[1]'.
- Normalizes the '-1.000000e+38f' sentinel back to '-INFINITY' at the
  very top of _postprocess_lowered_msl. The TIR uses the finite
  sentinel because '-T.infinity' tripped z3-final's canonicalizer; the
  rest of the canonicalization regexes (and the all-masked fast-return)
  match against the historical '-INFINITY' token.

Result: sparse_mla_path_c forward/backward parity + dispatch tests go
from 23 failed / 23 passed to 9 failed / 37 passed (14 recovered, 0
regressions). The 9 remaining failures are pre-existing dump_lowered_msl
text-inspection tests that assert kernel-signature strings the dump
helper doesn't emit; unrelated to this fix.
After the hoist-aware canonicalize fix in 7555124 the Path C kernel
dispatches and compiles for fwd+bwd across all bench shapes. Numbers
captured at warmup=3, iters=10:

  shape                fwd_b    fwd_c   fwd_ratio   bwd_b    bwd_c   bwd_ratio
  B2_S128_H8_D64       0.209    0.200      0.96     0.297    0.355      1.20
  B4_S512_H8_D64       0.332    0.349      1.05     0.959    0.939      0.98
  B4_S1024_H8_D64      0.624    0.646      1.04     3.136    3.172      1.01

Forward Path C >= Path B on the smallest shape and within 5% on the
larger shapes; backward Path C is at parity (1.01x / 0.98x) on the
larger shapes and ~20% slower on the small shape (where dispatch
overhead dominates the kernel time).
…ion)

The Path C topk_selector local-heap insertion sort carried an
``elif (K<=8) or (K>=64): break`` heuristic that was a tuning artifact:
it *skipped* the early break for K in {16, 32}, so once the loop walked
past the insertion point it kept shifting still-larger neighbours, leaving
the local top-K list corrupted and producing wrong indices for K=16/32.

Replaces the heuristic with the standard insertion-sort early exit
``else: break`` (works for all K -- once value <= local_vals[p] the rest
of the ascending list is >= value and no further shifts are needed).

Side effect: K=32 at (B=1, T=512) is now ~1.3x Path B (the prior tuning
win was driven by the bug). Drops (1, 512, 32, "float32") from
``_PATH_C_AUTO_PROFITABLE_RECEIPTS`` so AUTO routes via Path B for that
shape; explicit Path C still works for correctness coverage. Updates the
``test_path_c_lowering_keeps_k32_insertion_scan_unbroken_for_perf`` test
(it was pinning the buggy lowering) and relaxes the K=32 bench smoke
to ``max_ratio=1.5``.
This rewrite is regex/string surgery on TileLang's emitted MSL body, so
any change in TileLang's emission shape (whitespace, names, brace style,
barrier order) silently no-ops the rewrite. We degrade to TileLang's
conservative lowering rather than miscompile, but we lose the perf win
without warning. Documents three durable fixes (TIR-level pass, regex
hit-or-fail assertion, or dropping the rewrite once
``tl.intra_warp_barrier_elision`` covers the same ground) and links the
upstream blockers tracker.

No behaviour change.
_preload_libz3_for_dev_tilelang only set _done on success. On failure it
left both flags unset, so every Path C dispatch re-walked the candidate
list and re-stat'd every path, eating cycles when libz3 is genuinely
missing. Add a _failed_attempts counter that bails after 3 full sweeps;
_done stays gated on actual dlopen success so the env can still recover
if it's pointed at a valid lib later (within the cap).
Previously a no-match silently returned the original lowering, masking
TileLang emission drift with no perf regression signal. Replace the
short-circuit with an assert so a future TileLang version that changes
whitespace / variable names / barrier ordering forces a deliberate regex
update.
…BZ3 env

The libz3 preload helper used `/tmp/tl_apache_tvm_swap/build/lib/libz3.dylib`
as an unconditional last-resort candidate. /tmp is world-writable on Unix,
so an attacker who can write there could plant a malicious libz3.dylib that
gets dlopen'd into any process running cppmega_mlx (HIGH security: ARE).

Gate that single candidate behind ``CPPMEGA_ALLOW_UNSAFE_LIBZ3=1`` (default
off) and emit a warnings.warn when the opt-in is exercised. Production code
inherits the secure default; the env-rooted (TILELANG_DEV_BUILD_ROOT,
TILELANG_ROOT) and Homebrew (/opt/homebrew, root-owned) candidates remain
unconditional.

The dev test workflow has no TILELANG_ROOT in scope (conftest.py strips it
for hermetic env isolation), so set the opt-in at conftest module-import
time -- before pytest collects any test that imports the _msl_transform
module and triggers the eager Darwin preload. Also re-set via monkeypatch
in the autouse fixture so test bodies that re-scrub env keep dispatch
working.

Also documents the exists()-then-dlopen TOCTOU race as accepted (finding 4):
the gated /tmp path is the only world-writable candidate; a real fix would
need fd-based dlopen, which macOS does not offer.
The `except OSError: continue` block in `_preload_libz3_for_dev_tilelang`
swallowed every dlopen failure equally -- a missing file (FileNotFoundError,
which is a subclass of OSError) and a *broken* libz3 (wrong arch, corrupt
dylib, missing transitive dependency) were silently retried as if both
meant "try the next candidate" (HIGH correctness).

Split the handler:
- ``FileNotFoundError`` is the benign TOCTOU case (file vanished between
  exists() and dlopen) -- continue silently.
- ``OSError`` proper (everything else: arch mismatch, signature failure,
  missing dependency) -- log a warning naming the candidate path and the
  underlying error so a developer staring at "Path C didn't dispatch" has
  a real signal instead of a silent retry.

No behavioural change for the success path or for the missing-file case.
… injection point

Wave-5 audit flagged that the prior CPPMEGA_ALLOW_UNSAFE_LIBZ3 env-var
gate was inverted by tests/conftest.py (which set it via os.environ
.setdefault), so production processes that inherited the env from a
parent (a stray shell export, a CI parent job, etc.) would silently
dlopen a world-writable /tmp dylib.

Round-7 changes:
- Remove the /tmp/tl_apache_tvm_swap candidate from production code
  entirely. No env var can re-enable it.
- Expose `_LIBZ3_DEV_CANDIDATES` (empty in production) for tests to
  inject path candidates directly. The conftest companion change
  registers /tmp via this seam instead of through env.
- Add a defensive sys.platform != "darwin" early return inside the
  preload helper (the module-level call site is already platform-
  gated; this catches a future direct caller, e.g. a unit test that
  resets the helper on Linux).
Companion to the _msl_transform change: tests now register the in-tree
/tmp/tl_apache_tvm_swap libz3 path through the new
``_LIBZ3_DEV_CANDIDATES`` injection seam instead of forcing
``CPPMEGA_ALLOW_UNSAFE_LIBZ3=1`` via os.environ.setdefault.

The env-based opt-in was a security hole: any process that inherited
the env from a parent picked up the unsafe /tmp candidate, even
production. The new seam is process-local Python state that production
code never touches.

Conftest also resets the helper's idempotency flag and re-runs the
preload after injecting the candidate, since the module-level preload
fires at import time (before our injection can take effect).
…r_fp8_vecmat_msl

The bench script had a stale import name for the MSL source emitter; the
public API is lower_fp8_vecmat_msl. Aliasing in _require_bench_deps keeps
the rest of the bench (which calls fp8_vecmat_runtime_msl_source) unchanged.
- bench/tilelang_ports/{fp8_path_c_vs_path_b,mamba3,mamba3_path_c,
  sparse_mla,topk_selector}.json: regenerated by recent bench runs
  on top of fix-rounds 2-7 + libz3 hardening + sparse_mla Path C
  hoist-aware canonicalize.
- docs/tilelang_ports/mamba3_path_{b_vs_c.diff,c_lowered.metal}:
  refreshed with latest lowering output post-canonicalize.
- reports/grok/2026050{6,7}T*/ and reports/chatgpt/2026050{6,7}T*/:
  review artifacts from waves 1-2 review→fix loops (grok-4 ext +
  meta + gpt-5.5-pro extended). Audit trail of findings and
  applied fixes across fp8_vecmat_path_c.py + dsa_splitk_indexer_loss.py.
- uv.lock: project lock file.
Apply grok-4 Wave 3 findings (correctness P1/P2 + performance P2).

cppmega_mlx/nn/_tilelang/_msl_transform.py
- TOCTOU + lazy preload (:43-50, :121-152, :157-191): replaced
  candidate.exists()-then-CDLL with direct CDLL + OSError-message
  heuristic for missing-vs-broken; module-level eager preload
  replaced with _LIBZ3_PRELOAD_ATTEMPTED guard + _maybe_preload_libz3
  invoked lazily from lower_tilelang_to_msl_inline. Public
  ensure_libz3_preloaded() exposed for bench eager use. Compatible
  with conftest mutation flow.
- Narrow exceptions (:617-627, :639-655, :879-890): tilelang import
  catch is (ImportError, ModuleNotFoundError); tl_lower wraps non-
  MSLDispatchUnsupported errors with `from exc` chain; intrinsic
  registration auto-call narrows to (ImportError, ModuleNotFoundError,
  AttributeError) and warns instead of silent pass.
- Input/buffer count validation in dispatch() (:248-265): asserts
  len(inputs) == buffer_param_names - len(output_dtypes); raises
  MSLDispatchUnsupported with full diagnostic.
- _as_metal_target lru_cache (:830-845): split into uncached/cached
  variants; non-string targets bypass cache.
- Lowering result cache (:573-608, :626-636, :691-696): manual dict
  keyed on (id(prim_func), frozenset(pass_configs.items()),
  target_str) with keepalive list pinning prim_func refs; unhashable
  pass_configs values fall through to uncached path.

scripts/bench_tilelang_fp8_path_c.py
- _IMPORT_ENV_LOCK + double-checked locking on _IMPORT_ENV_READY
  (:24, :113, :302-349): existing single-bool guard made thread-
  safe with lock-free fast path.
- reset_env_preparation() helper (:352-368): test-only; clears
  flag + cached resolved roots.

Skipped per memory rule "verify before acting":
- metal_grid_for_lowering math grok flagged as P0. Wave 1 fix
  agent already investigated and confirmed it's correct: TileLang
  emits grid in threadgroup extents, MLX wants total thread grid,
  multiplication is right (docstring at :488-491 documents the
  contract). False alarm, not changed.

Verify: pytest test_tilelang_msl_transform.py 13 passed / 5 xfailed
(unchanged from baseline). py_compile passes both files.
… dsa_splitk (gated by CPPMEGA_MLX_TILELANG_ENGINE)

* New module cppmega_mlx/nn/_tilelang/_engine_dispatch.py: dispatch_lower(prim, target)
  routes through tilelang.compile() (engine path) or
  _msl_transform.lower_tilelang_to_msl_inline() (legacy MSL shim) based on the
  CPPMEGA_MLX_TILELANG_ENGINE env flag (auto / engine / shim).
* fp8_amax._amax_kernel_for + _quantize_kernel_for: route through dispatch_lower.
* dsa_splitk_indexer_loss._stage1_kernel_for + _stage2_kernel_for: same.
* __init__.py: re-export dispatch_lower + tilelang_engine_mode + _engine_dispatch.

Engine results carry _tilelang_engine_target (set by _engine_compile); shim
results are TileLangMSLLowering with msl_text. auto mode falls back to shim
on ImportError/ModuleNotFoundError with a one-shot UserWarning; other engine
errors propagate so real bugs are not masked. Set CPPMEGA_MLX_TILELANG_ENGINE
explicitly to engine or shim to bypass auto.

Wave-3 commits 76e187a (fp8_amax NaN guard) and a2ffcc1 (dsa_splitk SK clamp)
preserved.
…uard

Apply grok-4 Wave 4 findings (correctness P1/P2 + performance P2).

cppmega_mlx/nn/_tilelang/_msl_transform.py
- Robust _parse_buffer_param_names (~:484-690): replaced fragile
  single-regex with helper-based parsing — _split_signature_decls
  (top-level comma split respecting () and [[...]] nesting),
  _strip_attribute_markers (balanced [[...]] removal at any depth),
  _extract_param_identifier (drops *, &, trailing [N], takes last
  identifier). Buffer recognition matches \bdevice\b OR \bconstant\b
  (Metal address spaces); skips \bthreadgroup\b. Excludes
  _METAL_BUILTIN_PARAM_NAMES = {blockIdx, threadIdx, gridDim, blockDim}.
  Test corpus _TEST_PARSE_SIGNATURES with 12 cases. TVM prim_func
  introspection rejected — TileLang transforms params before emit,
  emitted MSL is the only authoritative source.
- dispatch() None-kernel guard (:286-296): if kernel is None raise
  MSLDispatchUnsupported with clear message instead of TypeError on
  None call. Type hint updated to MetalKernel | None.
- Robust _lowering_cache_key (:670-720): _freeze_for_hash recursive
  handler for nested dicts/lists/sets; repr() fallback for
  unhashable leaves. Cache key always succeeds — no more silent
  cache bypass on Z3 nested configs.
- Bounded _LOWERING_CACHE (:660-740): OrderedDict for both cache
  and keepalive, LRU eviction honoring CPPMEGA_LOWERING_CACHE_SIZE
  env (default 128). Cache hits move_to_end. Public
  clear_lowering_cache() added to __all__.

Skipped per CRITICAL note: metal_grid_for_lowering math (verified
correct twice — TileLang grid is threadgroup extents, MLX wants
total threads; multiplication is right).

Verify: pytest 13 passed / 5 xfailed (unchanged from baseline).
Smoke test confirmed parser handles _TEST_PARSE_SIGNATURES + real
mamba3 fwd_kernel signature.
…bz3 reset

Final wave grok-4 findings on _msl_transform.py.

cppmega_mlx/nn/_tilelang/_msl_transform.py
- Multi-param test corpus (:583-672): _TEST_PARSE_FULL_SIGNATURES
  with 5 realistic full signatures (mamba3 4-buffer, FP8 with
  constant uint*, const device ordering with __restrict, comma-
  inside-[[...]] attribute, single-buffer degenerate).
  _validate_test_full_signatures() self-test + env-gated import-
  time validation (CPPMEGA_VALIDATE_PARSE_TESTS=1, :1320-1334).
- libz3 preload retry: reset_libz3_preload_state() (:201-218)
  clears _LIBZ3_PRELOAD_ATTEMPTED + _done/_failed_attempts; added
  to __all__ (:1366). clear_lowering_cache() now cascades to it
  (:890-915) — clearing lowering cache implies "reset all cached".
- Stable _freeze_for_hash (:918-1003): _FREEZE_UNSTABLE sentinel;
  rewritten as recurse → hashable → _serialize_to_str()/to_json()
  TVM idiom → json.dumps(sort_keys=True, default=str) →
  _FREEZE_UNSTABLE. _lowering_cache_key returns None when unstable.
  lower_tilelang_to_msl_inline skips lookup/store on None key
  (:1062-1080). Trade-off documented: correctness > hit rate.
- dispatch None-guard with reason (:319-326, :257-292):
  _LAST_MAKE_METAL_KERNEL_REASON module global captured on each
  None-return path of make_metal_kernel (default_device,
  metal_available, "constructor unavailable on this MLX build");
  dispatch surfaces "metal_kernel is None — {reason}".
- import json at module top (:42) for json.dumps fallback.

Skipped per CRITICAL note + verification:
- metal_grid_for_lowering math (verified correct 3 times).
- _parse_buffer_param_names hot-path cost: VERIFIED grok wrong —
  buffer_param_names already cached on TileLangMSLLowering:446,
  dispatch reads cached at :321, no re-parse.
- LRU weakref refactor (P3 nice-to-have).
- _strip_attribute_markers edge cases (corpus passes; mark TODO
  if future signature exposes a real failure).

Verify: pytest 13 passed / 5 xfailed (unchanged baseline).
_validate_test_full_signatures() returns [] (all 5 pass).
… to dispatch_lower

- _msl_transform.lower_tilelang_to_msl_inline now emits a one-shot
  DeprecationWarning pointing callers at
  cppmega_mlx.nn._tilelang.dispatch_lower(prim, target). Behaviour
  unchanged. Module-level _DEPRECATION_WARNED flag gates the warn.
- __init__.py docstring updated: _msl_transform marked DEPRECATED;
  _engine_dispatch documented as the preferred entry point. dispatch_lower
  was already re-exported (predecessor 3bccf84 from phase-1.C).

Caller flip status: NONE flipped this round. All 17 callers consume
TileLangMSLLowering.body/.header strings for mx.fast.metal_kernel — the
engine path returns a tilelang.compile artifact (runtime callable), not
an MSL string. A flip needs an adapter that either (a) extracts MSL out
of the engine artifact for the metal_kernel path, or (b) replaces the
metal_kernel call with the artifact's __call__. Tracked as a Phase-3
blocker; see updated docstring in __init__.py.

Predecessors preserved: 76e187a (fp8_amax NaN guard), a2ffcc1 (dsa SK
clamp), 3bccf84 (phase-1.C dispatcher). Sibling-fork dirty paths
(sparse_mla_blockscaled_path_c.py, topk_selector.py) left untouched.
Route _path_c_kernel_for through dispatch_lower so Path-C topk lowering
honors CPPMEGA_MLX_TILELANG_ENGINE. Default (auto/shim) keeps the
existing TileLangMSLLowering + mx.fast.metal_kernel flow unchanged;
engine mode returns the tilelang.compile artifact directly and the
caller invokes it with (scores, starts, ends, indices).

Note: the docstring's reference to "T.sync_threads(3, RADIX) partial
barriers" describes the upstream CUDA radix-select algorithm, NOT the
current Metal-friendly Path-C kernel — Path-C uses only full
T.sync_threads(). No partial-barrier replacement is needed here; the
new tl.sync_threads_partial primitive (Phase-1.A) remains available
for kernels that do mirror the upstream CUDA layout.
Copilot AI review requested due to automatic review settings May 7, 2026 09:29
@apstenku123
Copy link
Copy Markdown
Collaborator Author

Direct merge instead of PR review

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR continues the TileLang “Path C” integration by tightening correctness/perf guardrails in tests, hardening the dev-time libz3 preload behavior in pytest, and introducing a phase-1 “engine vs shim” lowering dispatcher that can route kernel lowering via a unified TileLang engine path or the legacy MSL-inline shim.

Changes:

  • Updated TopK Path C tests to assert the corrected insertion-sort early-exit behavior and relaxed the perf smoke threshold post-correctness fix.
  • Made the FP8 Path C bench harness environment preparation once-per-process and thread-safe; adjusted vecmat lowering import to lower_fp8_vecmat_msl.
  • Added _engine_dispatch.dispatch_lower() and wired some lowering helpers to route through it; added audit-trail review reports and refreshed bench JSON artifacts.

Reviewed changes

Copilot reviewed 91 out of 92 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tests/test_tilelang_topk.py Pins K=32 lowering correctness (break;) and relaxes K=32 perf smoke threshold to reflect correctness-first behavior.
tests/conftest.py Injects a dev libz3 candidate for tests and scrubs unsafe preload env var.
scripts/bench_tilelang_fp8_path_c.py Adds a lock + once-per-process guard for import-environment prep; switches vecmat lowering import to lower_fp8_vecmat_msl.
cppmega_mlx/nn/_tilelang/_engine_dispatch.py Introduces env-driven dispatcher for engine vs shim lowering.
cppmega_mlx/nn/_tilelang/sparse_mla_blockscaled_path_c.py Routes lowering-through-source helpers via dispatch_lower; adds engine artifact helpers/cache.
cppmega_mlx/nn/_tilelang/init.py Exposes dispatch_lower/tilelang_engine_mode and updates package docs/exports.
docs/tilelang_ports/mamba3_path_b_vs_c.diff Updates Path B vs Path C source comparison artifact.
bench/tilelang_ports/mamba3.json Refreshes mamba3 benchmark results/metadata.
bench/tilelang_ports/mamba3_path_c.json Refreshes mamba3 Path C benchmark results/metadata.
reports/meta/20260507T040203/meta__performance__20260507T040439.md Adds/updates review report (audit trail).
reports/meta/20260507T040203/meta__design__20260507T040339.md Adds/updates review report (audit trail).
reports/meta/20260507T040203/meta__correctness__20260507T040250.md Adds/updates review report (audit trail).
reports/meta/20260507T033015/meta__security__20260507T033244.md Adds/updates review report (audit trail).
reports/meta/20260507T033015/meta__performance__20260507T033204.md Adds/updates review report (audit trail).
reports/meta/20260507T033015/meta__design__20260507T033114.md Adds/updates review report (audit trail).
reports/meta/20260507T033015/meta__correctness__20260507T033058.md Adds/updates review report (audit trail).
reports/meta/20260507T030544/meta__performance__20260507T031049.md Adds/updates review report (audit trail).
reports/meta/20260507T030544/meta__design__20260507T030946.md Adds/updates review report (audit trail).
reports/meta/20260507T024602/meta__security__20260507T025003.md Adds/updates review report (audit trail).
reports/meta/20260507T024602/meta__design__20260507T024805.md Adds/updates review report (audit trail).
reports/meta/20260507T022229/meta__design__20260507T022402.md Adds/updates review report (audit trail).
reports/meta/20260507T022229/meta__correctness__20260507T022258.md Adds/updates review report (audit trail).
reports/meta/20260507T021404/meta__security__20260507T021410.md Adds/updates review report (audit trail).
reports/meta/20260507T021404/meta__performance__20260507T021409.md Adds/updates review report (audit trail).
reports/meta/20260507T021404/meta__design__20260507T021407.md Adds/updates review report (audit trail).
reports/meta/20260507T021404/meta__correctness__20260507T021405.md Adds/updates review report (audit trail).
reports/meta/20260507T021321/meta__tests__20260507T021328.md Adds/updates review report (audit trail).
reports/meta/20260507T021321/meta__security__20260507T021323.md Adds/updates review report (audit trail).
reports/meta/20260507T021321/meta__performance__20260507T021324.md Adds/updates review report (audit trail).
reports/meta/20260507T021321/meta__design__20260507T021327.md Adds/updates review report (audit trail).
reports/meta/20260507T021321/meta__correctness__20260507T021325.md Adds/updates review report (audit trail).
reports/grok/20260507T091652/grok__performance__20260507T091851.md Adds/updates review report (audit trail).
reports/grok/20260507T091652/grok__correctness__20260507T091803.md Adds/updates review report (audit trail).
reports/grok/20260507T091041/grok__performance__20260507T091155.md Adds/updates review report (audit trail).
reports/grok/20260507T091041/grok__correctness__20260507T091126.md Adds/updates review report (audit trail).
reports/grok/20260507T090316/grok__performance__20260507T090505.md Adds/updates review report (audit trail).
reports/grok/20260507T090316/grok__correctness__20260507T090411.md Adds/updates review report (audit trail).
reports/grok/20260507T040203/grok__security__20260507T040501.md Adds/updates review report (audit trail).
reports/grok/20260507T040203/grok__performance__20260507T040425.md Adds/updates review report (audit trail).
reports/grok/20260507T040203/grok__design__20260507T040358.md Adds/updates review report (audit trail).
reports/grok/20260507T040203/grok__correctness__20260507T040245.md Adds/updates review report (audit trail).
reports/grok/20260507T033015/grok__security__20260507T033309.md Adds/updates review report (audit trail).
reports/grok/20260507T033015/grok__performance__20260507T033235.md Adds/updates review report (audit trail).
reports/grok/20260507T033015/grok__correctness__20260507T033054.md Adds/updates review report (audit trail).
reports/grok/20260507T030544/grok__security__20260507T030852.md Adds/updates review report (audit trail).
reports/grok/20260507T030544/grok__correctness__20260507T030619.md Adds/updates review report (audit trail).
reports/grok/20260507T024602/grok__security__20260507T024828.md Adds/updates review report (audit trail).
reports/grok/20260507T024602/grok__performance__20260507T024804.md Adds/updates review report (audit trail).
reports/grok/20260507T024602/grok__design__20260507T024739.md Adds/updates review report (audit trail).
reports/grok/20260507T022009/grok__performance__20260507T022150.md Adds/updates review report (audit trail).
reports/grok/20260507T021404/grok__performance__20260507T021517.md Adds/updates review report (audit trail).
reports/grok/20260507T021404/grok__correctness__20260507T021422.md Adds/updates review report (audit trail).
reports/grok/20260507T012047/grok__security__20260507T012139.md Adds/updates review report (audit trail).
reports/grok/20260507T012047/grok__design__20260507T012618.md Adds/updates review report (audit trail).
reports/chatgpt/20260507T012045/chatgpt__security__20260507T012223.md Adds/updates review report (audit trail).
reports/chatgpt/20260507T011618/chatgpt__security__20260507T012245.md Adds/updates review report (audit trail).
reports/chatgpt/20260507T011618/chatgpt__performance__20260507T012614.md Adds/updates review report (audit trail).
reports/chatgpt/20260507T011618/chatgpt__correctness__20260507T011950.md Adds/updates review report (audit trail).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +70 to +80
"""Lower via the legacy MSL-string shim. Always targets metal."""

from cppmega_mlx.nn._tilelang._msl_transform import lower_tilelang_to_msl_inline

if target != "metal":
warnings.warn(
f"_engine_dispatch: shim mode is metal-only; ignoring target={target!r}.",
UserWarning,
stacklevel=2,
)
return lower_tilelang_to_msl_inline(prim_func, target="metal")
Comment on lines +54 to +67
def _engine_compile(prim_func: Any, target: str) -> Any:
"""Run ``tilelang.compile`` and stamp the result with the target tag."""

import tilelang # noqa: F401 - intentional eager import for ImportError surfacing

artifact = tilelang.compile(prim_func, target=target, out_idx=None)
try:
setattr(artifact, "_tilelang_engine_target", target)
except (AttributeError, TypeError):
# Some builds wrap the artifact in a frozen / __slots__ object; preserve
# the artifact unchanged if we cannot stamp it.
pass
return artifact

Comment thread tests/conftest.py
Comment on lines +39 to +52
# The module-level preload at the bottom of _msl_transform.py runs at
# import time -- BEFORE we set the candidate list above. So the first
# preload attempt only saw the default empty list (plus the brew
# fallback). Reset the idempotency flag and re-run the preload now that
# the in-tree dev candidate is registered, so libtilelang.dylib's
# basename libz3 reference resolves to the matching dev-build z3.
try:
if hasattr(_msl._preload_libz3_for_dev_tilelang, "_done"):
delattr(_msl._preload_libz3_for_dev_tilelang, "_done")
if hasattr(_msl._preload_libz3_for_dev_tilelang, "_failed_attempts"):
delattr(_msl._preload_libz3_for_dev_tilelang, "_failed_attempts")
_msl._preload_libz3_for_dev_tilelang()
except Exception: # pragma: no cover - best-effort
pass
Comment on lines +9 to +18
- _msl_transform.py: legacy TileLang->MSL inline lowering shim. **DEPRECATED.**
New callers should route through ``dispatch_lower(prim, target)`` which
prefers ``tilelang.engine.lower(target=...)`` (engine path) and only falls
back to this shim when the engine is unavailable. See
``_engine_dispatch.py`` and ``MIGRATION_PLAN.md``. Existing callers that
consume ``TileLangMSLLowering.body``/``.header`` strings for
``mx.fast.metal_kernel`` need an adapter layer (Phase 3) before they can
flip — the engine artifact is a runtime callable, not an MSL string.
- _engine_dispatch.py: ``dispatch_lower(prim, target)`` — phase-1 dispatcher
that flips between engine and shim based on ``$CPPMEGA_MLX_TILELANG_ENGINE``.
apstenku123 added a commit that referenced this pull request May 7, 2026
… hoist, head-0 alloc gate

Apply the 3 HIGH-severity hot-path findings from grok wave-3 review of
the DSA split-K indexer-loss port:

1. sparse_loss path GPU->CPU sync (issue #1): the topk_idx64.max()/min()
   /.item() bounds check + (index_mask==0).all() NaN-poisoning detect
   forced full GPU->CPU syncs + extra reduction kernels on every sparse
   forward pass. Gate behind CPPMEGA_MLX_DSA_DEBUG env var. Production
   training paths now skip both validation passes; CI / first-run
   regressions enable via the env var.

2. Stage-2 partial Q hoist rationale (issue #2): document why the
   current per-(sk_tile, h) hoist is the local optimum and what the
   structural blocker is for full hoist (heads-summed softmax_attn
   accumulator -> persisting it across heads costs 2 MB shared, well
   over Metal/CUDA budgets; alternatives include online cross-head
   softmax recurrence which is the right wave-5 fix).

3. Stage-1 fragment over-allocation (issue #3): split the kernel into
   two specialised variants:
   - make_dsa_splitk_stage1_kernel(compute_index_path: bool = True):
     when False, idx_scores_f / m1_i / d1_i / m1_i_prev shrink to (1,)
     stubs and the `if h == 0` index-softmax block is Python-guarded
     out, eliminating register pressure on AH-1 attn-only blocks.
   - make_dsa_splitk_stage1_idx_kernel: NEW dedicated kernel that
     computes only M1/D1 from IndexScores+IndexMask. Grid is (AB,
     NUM_SQ_BLOCKS) -- no AH axis. No Q/K/matmul; just the per-(b,
     sq) online softmax over IndexScores.

   Wrapper splits the stage-1 launch when AH > 1: attn-only kernel for
   all AH heads + idx kernel for the single h=0-equivalent. AH==1 still
   uses the unified kernel (backward compat for single-head models).
apstenku123 added a commit that referenced this pull request May 7, 2026
…kernel_for

Wave-2 (44f4f88) introduced N=n_elements / BLOCK=block_size local rebinds
inside make_fp8_amax_kernel + make_fp8_quantize_kernel so the inner
@T.prim_func annotations resolve at decoration time, but missed the same
treatment for in_dtype. Under some tilelang parser paths CPython does not
materialise a closure cell for an outer-scope name that is read but never
written in the enclosing function, so 'T.Tensor((N,), in_dtype)' raised
NameError at parse time and burned down the 6 fp8_amax tests reported in
docs/research/runtime_test_matrix.md.

Adds 'DTYPE = in_dtype' next to the existing 'N = n_elements' rebind in
both kernel builders and switches the annotation to use DTYPE. Wave-3 NaN
guard (76e187a) preserved untouched; py_compile clean.

Fixes runtime_test_matrix Bug #1.
apstenku123 added a commit that referenced this pull request May 7, 2026
…als__

Wave-7 #1 (a439df0) restored DTYPE/N/BLOCK as closure rebinds, but
typing.get_type_hints() — which tilelang's @T.prim_func parser invokes
to resolve the T.Tensor((N,), DTYPE) annotations — only inspects
__globals__ and an explicit localns; it never walks __closure__.
Test #4 (commit 20398296) found this empirically.

Adds _expose_to_globals helper that mutates the factory's module
__globals__ in place to expose DTYPE/N/BLOCK/FP8_MAX before
@T.prim_func runs. Per-shape kernel construction is synchronous and
@lru_cache-wrapped so successive calls overwrite each other's values
without races.

Empirical smoke (M4 Max, MLX venv): make_fp8_amax_kernel(128, fp16)
+ make_fp8_quantize_kernel(128, fp16) + make_fp8_amax_kernel(256, bf16)
all return PrimFunc; was raising NameError before this fix.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants