tilelang Path C: review-driven fix-wave + MLX wiring + 5 review/fix loop iterations#1
Closed
apstenku123 wants to merge 30 commits intomainfrom
Closed
tilelang Path C: review-driven fix-wave + MLX wiring + 5 review/fix loop iterations#1apstenku123 wants to merge 30 commits intomainfrom
apstenku123 wants to merge 30 commits intomainfrom
Conversation
…ecmat) Two MLX threads lowering the Path C TileLang kernels concurrently could race the first-time populate of the per-module ``_PASS_CONFIGS_CACHE`` dicts in ``topk_selector`` and ``fp8_vecmat_path_c``. The probe loop runs ``tvm.transform.PassContext`` per candidate key and is not safe to enter twice. Guard the read-or-populate region with a per-module ``threading.Lock()`` (separate locks so the kernels don't serialize each other's first build).
- Fuse two T.Parallel loops in amax kernel: load+abs+cast in one pass. Eliminates the missing shared-memory barrier (data race between X_shared write and read) and the redundant global->shared->fragment hop flagged by the perf review (~30-60% amax speedup). - Drop globals().update() side-effect in both make_fp8_amax_kernel / make_fp8_quantize_kernel; the PrimFunc only ever read local N/BLOCK/FP8_MAX, so the module-level mutation was dead code. - Add empty-tensor fast paths to fp8_quantize_tilelang and fp8_pack_tilelang (clamp=True path) — n_elements>0 was implicit. - Honor the public out= contract in fp8_quantize_tilelang: when the caller-supplied out is non-contiguous, write to a contiguous scratch and copy_back so the returned tensor is always the user's buffer.
Apply concrete fixes from the grok-4 review of integration #06 (DSA splitK indexer-loss kernel port to TileLang): * Remove the globals() mutation hack in both stage-1 and stage-2 builders -- BLOCK_*/threads were already passed explicitly through the lru_cache-keyed entry points, so the global side-effect was redundant and risked stale values across concurrent JIT compilation. * Guard IndexMask / IndexScores reads against OOB on boundary tiles (ASq % BLOCK_SQ != 0 or Sk % BLOCK_SK != 0) -- the original Triton kernels use tl.load(..., mask=...) for the equivalent guard but the TileLang port did unconditional loads, which would read garbage on any non-aligned shape. * Tighten Metal block overrides from 64x64x32 to 32x32x16. The earlier sizing only accounted for fp16 shared Q/K (16 KB) and missed the fp32 register fragments: stage 2 alone needed 64 KB of fragments, causing register spilling on M-series (32 KB threadgroup budget). 32x32x16 keeps total footprint well under the limit. * Replace the AB*ASq*Sk*4-byte zeros() allocation in the non-sparse path with empty(); the kernel never reads the tensor when SPARSE is False (constexpr-eliminated, plus the new bounds-guard ensures even boundary tiles skip the load). * Skip redundant .contiguous() / .to(fp32) copies when tensors are already in the canonical layout, avoiding device-to-device copies on every forward pass.
…amic BLOCK_SIZE, shim type consistency Applies the four wave-1 deferred grok review items on fp8_amax.py: 1. BLOCK_SIZE-vs-threads correctness: builders now refuse to compile when block_size % threads != 0 with a precise diagnostic; the strided T.Parallel(BLOCK) inner loop requires a clean stride or the last block leaves a partial tail uncovered. 2. JIT-cache thrashing: amax dispatch buckets by next-pow2 of n_elements (via _bucket_n), so 5 close call shapes (4097..8192) now compile a single shared kernel. Tail is zero-padded; amax(0) is the identity for max so the result is unchanged. lru_cache size bumped to 256 for both amax and quantize. Quantize stays per-shape because output sizing matters per-element. Added precompile_amax_kernel() warm-up API. 3. Dynamic BLOCK_SIZE picker: _pick_block_size(target, n) replaces the hard-coded 1024/128 with a per-target table -- cuda(1024,128), hip(1024,256), metal(256,64) -- documented at the top of the file. Tiny inputs shrink block to next pow2 >= threads. 4. tilelang_supports diagnostics: added tilelang_supports_with_reason() returning (bool, str) on every path; tilelang_supports() is now a thin wrapper that drops the reason for backward compat with cppmega/megatron/fp8_activations.py callers.
- Stage 1: hoist Q load to Q_full=(BLOCK_SQ, AD) shared once per (sq_block, h); inner d_tile copies from Q_full instead of re-reading Q from HBM per sk_tile. Saves SK_TILES-1 redundant HBM reloads of Q. - Stage 2: pre-load M[b, h, sq] / D[b, h, sq] into M_pre/D_pre fragments of shape (AH, BLOCK_SQ) once per sq_block; inner sk_tile->h loop reads from registers instead of HBM. Stage-2 Q hoist is harder (h iterates inside sk_tile due to softmax_attn accumulator) — left as a TODO with concrete path forward. - Both stages: trim sk_tile loop bound to min(SK_TILES, max(sq_block_id*BLOCK_SQ + BLOCK_SQ - 1, ASq-1) // BLOCK_SK + 1) so causal-mask-zero tiles are skipped entirely. On the last Q-block the trim is a no-op; for Q-block 0 with ASq>=BLOCK_SQ it cuts to a single sk_tile. Followups still tracked: - Stage 2 Q hoist (TODO in code).
Wave-3 self-audit (wave-2 grok review came back empty after retries). Highest-impact item addressed in cppmega.mlx: - fp8_pack_tilelang now raises FloatingPointError when the host-side amax_val is NaN/Inf instead of falling through to inv_scale = fp8_max / amax which would silently produce 0/NaN scales and poison every FP8 output element. Caller gets a clear diagnostic instead of garbage weights downstream.
…validation Wave-2 grok re-review came back empty for this packet (no response x 5 aspects); this self-audit pass picks 3 high-impact items off the wave-1 review and adds them on top of wave-2 commit 2b9310e. 1. Clamp _active_sk_tiles to >= 1 in both stages. When ASq=0 or sq_block_id*BLOCK_SQ exceeds ASq (last Q-block in non-divisible shapes, pathological zero-rows), _max_useful_sk goes negative and the trim collapses to 0, skipping the loop and leaving accumulators uninitialised. The clamp guarantees the loop body runs once; the per-position guards already produce no writes for OOB sq_idx, so output remains correct. 2. Wrapper validation for topk_indices: enforce shape (AB, ASq, *), dtype int32|int64 with int32→int64 promotion (PyTorch scatter requires int64), device parity, and contiguity. Surfaces these mismatches at the wrapper boundary instead of an unhelpful "expected scalar type Long" crash deep in the C++ scatter implementation. 3. Note: the partial Q hoist in stage 2 (wave-2 TODO) is already present in the committed file at line 689 -- per-(sk_tile, h) full-AD slab into Q_full shared, d_tile reads from shared. Full out-of-sk_tile hoist in stage 2 still TODO (would need accumulator restructure). Wave-2 attribution note: the wave-2 commit 2b9310e accidentally swept in ~179 lines of unrelated fp8_vecmat_path_c.py (belongs to integration-05). That sliver remains in 2b9310e history -- not addressed here.
Wave-1b meta findings: - HIGH perf (M_pre/D_pre register spill): the per-(sq_block) prefetch fragments are AH*BLOCK_SQ fp32 each, totalling 8*AH*BLOCK_SQ bytes per thread block. At AH=128, BLOCK_SQ=128 (CUDA worst case) that reaches 128 KB combined — far above per-block register budgets and causes spill-to-local-mem (-30/-40% on 2k seqlen). Gate the prefetch behind a 32 KB combined budget; over-budget compiles fall back to per-(sk_tile, h) HBM reads (the original pre-Wave-2 path). - MED sec (scatter_ bounds + NaN guard): topk_indices were fed into scatter_ without an upper-bound check. PyTorch's CUDA scatter_ does NOT validate upper bound in release builds, so an OOB index would silently corrupt adjacent memory. Validate range [0, Sk) up-front, raise ValueError on violation. Additionally, when a row has zero in-range indices, every IndexMask slot stays -inf and the downstream softmax produces NaN that poisons the loss. Detect fully-masked rows and patch slot 0 to a safe sentinel. Verified citations (cite line numbers were off but findings real): - M_pre/D_pre fragment alloc actually at lines 638-639 (cite said 575-576). - scatter_ actually in dsa_splitk_indexer_loss.py:1065 (cite said topk_*.py). NOTE: Path C agent (a1f7b2a8) is concurrently touching topk_selector / sparse_mla / fp8_vecmat — this fix is scoped to dsa_splitk_indexer_loss.py to avoid conflict.
…tal)
- _msl_transform.py: preload libz3.dylib on Darwin so dev-build
libtilelang.dylib's bare-basename dlopen succeeds (otherwise every
Path C kernel silently no-ops via its dispatch try/except).
- sparse_mla_path_c.py + topk_selector.py: replace -T.infinity('float32')
with T.float32(-1.0e38) sentinel; Metal MSL codegen emits inf as
HUGE_VALF/INFINITY which trips the canonicalizer/algebraic rewrites.
- bench/tilelang_ports/topk_selector.json: refreshed receipt under the
fixed dispatch path.
…base hoists
z3-final's CSE/algebraic-simplify passes hoist address-base computations
(int h, b, g, q_row_base, kv_b_base, idx_base, out_row, d_out_row,
dkv_partial_base) to the top of the kernel body — *above* the
'float sm_scale = sm_scale_buf[0];' marker that the canonicalize passes
key off when injecting 'int gid = int(blockIdx.x);'. The hoisted
expressions then reference an undeclared 'gid' and collide (with
conflicting int vs uint types) with the unsigned address basics that
_canonicalize_fwd/bwd_base_indexing re-emit after the marker.
This patch:
- Adds _strip_z3_hoisted_address_decls() and runs it at the top of
_canonicalize_fwd_lane_indexing and _canonicalize_bwd_lane_indexing
to delete the hoisted decls before injection. The canonicalization
passes already re-emit equivalent uint versions further down.
- Drops tautological lane-loop aliases ('int k = k;', 'int kd = ...',
'int d_N = d;', 'int kd = ((kd - tid) + tid);') that the lane-loop
rewrite leaves behind once the inner-body index expressions collapse.
- Promotes bare 'gather_idx = indices[...]' assignments to declarations
in the fwd output-projection / bwd dq/output paths so the symbol is
in scope after we strip the original 'thread int gather_idx[1]'.
- Normalizes the '-1.000000e+38f' sentinel back to '-INFINITY' at the
very top of _postprocess_lowered_msl. The TIR uses the finite
sentinel because '-T.infinity' tripped z3-final's canonicalizer; the
rest of the canonicalization regexes (and the all-masked fast-return)
match against the historical '-INFINITY' token.
Result: sparse_mla_path_c forward/backward parity + dispatch tests go
from 23 failed / 23 passed to 9 failed / 37 passed (14 recovered, 0
regressions). The 9 remaining failures are pre-existing dump_lowered_msl
text-inspection tests that assert kernel-signature strings the dump
helper doesn't emit; unrelated to this fix.
After the hoist-aware canonicalize fix in 7555124 the Path C kernel dispatches and compiles for fwd+bwd across all bench shapes. Numbers captured at warmup=3, iters=10: shape fwd_b fwd_c fwd_ratio bwd_b bwd_c bwd_ratio B2_S128_H8_D64 0.209 0.200 0.96 0.297 0.355 1.20 B4_S512_H8_D64 0.332 0.349 1.05 0.959 0.939 0.98 B4_S1024_H8_D64 0.624 0.646 1.04 3.136 3.172 1.01 Forward Path C >= Path B on the smallest shape and within 5% on the larger shapes; backward Path C is at parity (1.01x / 0.98x) on the larger shapes and ~20% slower on the small shape (where dispatch overhead dominates the kernel time).
…ion)
The Path C topk_selector local-heap insertion sort carried an
``elif (K<=8) or (K>=64): break`` heuristic that was a tuning artifact:
it *skipped* the early break for K in {16, 32}, so once the loop walked
past the insertion point it kept shifting still-larger neighbours, leaving
the local top-K list corrupted and producing wrong indices for K=16/32.
Replaces the heuristic with the standard insertion-sort early exit
``else: break`` (works for all K -- once value <= local_vals[p] the rest
of the ascending list is >= value and no further shifts are needed).
Side effect: K=32 at (B=1, T=512) is now ~1.3x Path B (the prior tuning
win was driven by the bug). Drops (1, 512, 32, "float32") from
``_PATH_C_AUTO_PROFITABLE_RECEIPTS`` so AUTO routes via Path B for that
shape; explicit Path C still works for correctness coverage. Updates the
``test_path_c_lowering_keeps_k32_insertion_scan_unbroken_for_perf`` test
(it was pinning the buggy lowering) and relaxes the K=32 bench smoke
to ``max_ratio=1.5``.
This rewrite is regex/string surgery on TileLang's emitted MSL body, so any change in TileLang's emission shape (whitespace, names, brace style, barrier order) silently no-ops the rewrite. We degrade to TileLang's conservative lowering rather than miscompile, but we lose the perf win without warning. Documents three durable fixes (TIR-level pass, regex hit-or-fail assertion, or dropping the rewrite once ``tl.intra_warp_barrier_elision`` covers the same ground) and links the upstream blockers tracker. No behaviour change.
_preload_libz3_for_dev_tilelang only set _done on success. On failure it left both flags unset, so every Path C dispatch re-walked the candidate list and re-stat'd every path, eating cycles when libz3 is genuinely missing. Add a _failed_attempts counter that bails after 3 full sweeps; _done stays gated on actual dlopen success so the env can still recover if it's pointed at a valid lib later (within the cap).
Previously a no-match silently returned the original lowering, masking TileLang emission drift with no perf regression signal. Replace the short-circuit with an assert so a future TileLang version that changes whitespace / variable names / barrier ordering forces a deliberate regex update.
…BZ3 env The libz3 preload helper used `/tmp/tl_apache_tvm_swap/build/lib/libz3.dylib` as an unconditional last-resort candidate. /tmp is world-writable on Unix, so an attacker who can write there could plant a malicious libz3.dylib that gets dlopen'd into any process running cppmega_mlx (HIGH security: ARE). Gate that single candidate behind ``CPPMEGA_ALLOW_UNSAFE_LIBZ3=1`` (default off) and emit a warnings.warn when the opt-in is exercised. Production code inherits the secure default; the env-rooted (TILELANG_DEV_BUILD_ROOT, TILELANG_ROOT) and Homebrew (/opt/homebrew, root-owned) candidates remain unconditional. The dev test workflow has no TILELANG_ROOT in scope (conftest.py strips it for hermetic env isolation), so set the opt-in at conftest module-import time -- before pytest collects any test that imports the _msl_transform module and triggers the eager Darwin preload. Also re-set via monkeypatch in the autouse fixture so test bodies that re-scrub env keep dispatch working. Also documents the exists()-then-dlopen TOCTOU race as accepted (finding 4): the gated /tmp path is the only world-writable candidate; a real fix would need fd-based dlopen, which macOS does not offer.
The `except OSError: continue` block in `_preload_libz3_for_dev_tilelang` swallowed every dlopen failure equally -- a missing file (FileNotFoundError, which is a subclass of OSError) and a *broken* libz3 (wrong arch, corrupt dylib, missing transitive dependency) were silently retried as if both meant "try the next candidate" (HIGH correctness). Split the handler: - ``FileNotFoundError`` is the benign TOCTOU case (file vanished between exists() and dlopen) -- continue silently. - ``OSError`` proper (everything else: arch mismatch, signature failure, missing dependency) -- log a warning naming the candidate path and the underlying error so a developer staring at "Path C didn't dispatch" has a real signal instead of a silent retry. No behavioural change for the success path or for the missing-file case.
… injection point Wave-5 audit flagged that the prior CPPMEGA_ALLOW_UNSAFE_LIBZ3 env-var gate was inverted by tests/conftest.py (which set it via os.environ .setdefault), so production processes that inherited the env from a parent (a stray shell export, a CI parent job, etc.) would silently dlopen a world-writable /tmp dylib. Round-7 changes: - Remove the /tmp/tl_apache_tvm_swap candidate from production code entirely. No env var can re-enable it. - Expose `_LIBZ3_DEV_CANDIDATES` (empty in production) for tests to inject path candidates directly. The conftest companion change registers /tmp via this seam instead of through env. - Add a defensive sys.platform != "darwin" early return inside the preload helper (the module-level call site is already platform- gated; this catches a future direct caller, e.g. a unit test that resets the helper on Linux).
Companion to the _msl_transform change: tests now register the in-tree /tmp/tl_apache_tvm_swap libz3 path through the new ``_LIBZ3_DEV_CANDIDATES`` injection seam instead of forcing ``CPPMEGA_ALLOW_UNSAFE_LIBZ3=1`` via os.environ.setdefault. The env-based opt-in was a security hole: any process that inherited the env from a parent picked up the unsafe /tmp candidate, even production. The new seam is process-local Python state that production code never touches. Conftest also resets the helper's idempotency flag and re-runs the preload after injecting the candidate, since the module-level preload fires at import time (before our injection can take effect).
…r_fp8_vecmat_msl The bench script had a stale import name for the MSL source emitter; the public API is lower_fp8_vecmat_msl. Aliasing in _require_bench_deps keeps the rest of the bench (which calls fp8_vecmat_runtime_msl_source) unchanged.
- bench/tilelang_ports/{fp8_path_c_vs_path_b,mamba3,mamba3_path_c,
sparse_mla,topk_selector}.json: regenerated by recent bench runs
on top of fix-rounds 2-7 + libz3 hardening + sparse_mla Path C
hoist-aware canonicalize.
- docs/tilelang_ports/mamba3_path_{b_vs_c.diff,c_lowered.metal}:
refreshed with latest lowering output post-canonicalize.
- reports/grok/2026050{6,7}T*/ and reports/chatgpt/2026050{6,7}T*/:
review artifacts from waves 1-2 review→fix loops (grok-4 ext +
meta + gpt-5.5-pro extended). Audit trail of findings and
applied fixes across fp8_vecmat_path_c.py + dsa_splitk_indexer_loss.py.
- uv.lock: project lock file.
Apply grok-4 Wave 3 findings (correctness P1/P2 + performance P2). cppmega_mlx/nn/_tilelang/_msl_transform.py - TOCTOU + lazy preload (:43-50, :121-152, :157-191): replaced candidate.exists()-then-CDLL with direct CDLL + OSError-message heuristic for missing-vs-broken; module-level eager preload replaced with _LIBZ3_PRELOAD_ATTEMPTED guard + _maybe_preload_libz3 invoked lazily from lower_tilelang_to_msl_inline. Public ensure_libz3_preloaded() exposed for bench eager use. Compatible with conftest mutation flow. - Narrow exceptions (:617-627, :639-655, :879-890): tilelang import catch is (ImportError, ModuleNotFoundError); tl_lower wraps non- MSLDispatchUnsupported errors with `from exc` chain; intrinsic registration auto-call narrows to (ImportError, ModuleNotFoundError, AttributeError) and warns instead of silent pass. - Input/buffer count validation in dispatch() (:248-265): asserts len(inputs) == buffer_param_names - len(output_dtypes); raises MSLDispatchUnsupported with full diagnostic. - _as_metal_target lru_cache (:830-845): split into uncached/cached variants; non-string targets bypass cache. - Lowering result cache (:573-608, :626-636, :691-696): manual dict keyed on (id(prim_func), frozenset(pass_configs.items()), target_str) with keepalive list pinning prim_func refs; unhashable pass_configs values fall through to uncached path. scripts/bench_tilelang_fp8_path_c.py - _IMPORT_ENV_LOCK + double-checked locking on _IMPORT_ENV_READY (:24, :113, :302-349): existing single-bool guard made thread- safe with lock-free fast path. - reset_env_preparation() helper (:352-368): test-only; clears flag + cached resolved roots. Skipped per memory rule "verify before acting": - metal_grid_for_lowering math grok flagged as P0. Wave 1 fix agent already investigated and confirmed it's correct: TileLang emits grid in threadgroup extents, MLX wants total thread grid, multiplication is right (docstring at :488-491 documents the contract). False alarm, not changed. Verify: pytest test_tilelang_msl_transform.py 13 passed / 5 xfailed (unchanged from baseline). py_compile passes both files.
… dsa_splitk (gated by CPPMEGA_MLX_TILELANG_ENGINE) * New module cppmega_mlx/nn/_tilelang/_engine_dispatch.py: dispatch_lower(prim, target) routes through tilelang.compile() (engine path) or _msl_transform.lower_tilelang_to_msl_inline() (legacy MSL shim) based on the CPPMEGA_MLX_TILELANG_ENGINE env flag (auto / engine / shim). * fp8_amax._amax_kernel_for + _quantize_kernel_for: route through dispatch_lower. * dsa_splitk_indexer_loss._stage1_kernel_for + _stage2_kernel_for: same. * __init__.py: re-export dispatch_lower + tilelang_engine_mode + _engine_dispatch. Engine results carry _tilelang_engine_target (set by _engine_compile); shim results are TileLangMSLLowering with msl_text. auto mode falls back to shim on ImportError/ModuleNotFoundError with a one-shot UserWarning; other engine errors propagate so real bugs are not masked. Set CPPMEGA_MLX_TILELANG_ENGINE explicitly to engine or shim to bypass auto. Wave-3 commits 76e187a (fp8_amax NaN guard) and a2ffcc1 (dsa_splitk SK clamp) preserved.
…uard
Apply grok-4 Wave 4 findings (correctness P1/P2 + performance P2).
cppmega_mlx/nn/_tilelang/_msl_transform.py
- Robust _parse_buffer_param_names (~:484-690): replaced fragile
single-regex with helper-based parsing — _split_signature_decls
(top-level comma split respecting () and [[...]] nesting),
_strip_attribute_markers (balanced [[...]] removal at any depth),
_extract_param_identifier (drops *, &, trailing [N], takes last
identifier). Buffer recognition matches \bdevice\b OR \bconstant\b
(Metal address spaces); skips \bthreadgroup\b. Excludes
_METAL_BUILTIN_PARAM_NAMES = {blockIdx, threadIdx, gridDim, blockDim}.
Test corpus _TEST_PARSE_SIGNATURES with 12 cases. TVM prim_func
introspection rejected — TileLang transforms params before emit,
emitted MSL is the only authoritative source.
- dispatch() None-kernel guard (:286-296): if kernel is None raise
MSLDispatchUnsupported with clear message instead of TypeError on
None call. Type hint updated to MetalKernel | None.
- Robust _lowering_cache_key (:670-720): _freeze_for_hash recursive
handler for nested dicts/lists/sets; repr() fallback for
unhashable leaves. Cache key always succeeds — no more silent
cache bypass on Z3 nested configs.
- Bounded _LOWERING_CACHE (:660-740): OrderedDict for both cache
and keepalive, LRU eviction honoring CPPMEGA_LOWERING_CACHE_SIZE
env (default 128). Cache hits move_to_end. Public
clear_lowering_cache() added to __all__.
Skipped per CRITICAL note: metal_grid_for_lowering math (verified
correct twice — TileLang grid is threadgroup extents, MLX wants
total threads; multiplication is right).
Verify: pytest 13 passed / 5 xfailed (unchanged from baseline).
Smoke test confirmed parser handles _TEST_PARSE_SIGNATURES + real
mamba3 fwd_kernel signature.
…bz3 reset
Final wave grok-4 findings on _msl_transform.py.
cppmega_mlx/nn/_tilelang/_msl_transform.py
- Multi-param test corpus (:583-672): _TEST_PARSE_FULL_SIGNATURES
with 5 realistic full signatures (mamba3 4-buffer, FP8 with
constant uint*, const device ordering with __restrict, comma-
inside-[[...]] attribute, single-buffer degenerate).
_validate_test_full_signatures() self-test + env-gated import-
time validation (CPPMEGA_VALIDATE_PARSE_TESTS=1, :1320-1334).
- libz3 preload retry: reset_libz3_preload_state() (:201-218)
clears _LIBZ3_PRELOAD_ATTEMPTED + _done/_failed_attempts; added
to __all__ (:1366). clear_lowering_cache() now cascades to it
(:890-915) — clearing lowering cache implies "reset all cached".
- Stable _freeze_for_hash (:918-1003): _FREEZE_UNSTABLE sentinel;
rewritten as recurse → hashable → _serialize_to_str()/to_json()
TVM idiom → json.dumps(sort_keys=True, default=str) →
_FREEZE_UNSTABLE. _lowering_cache_key returns None when unstable.
lower_tilelang_to_msl_inline skips lookup/store on None key
(:1062-1080). Trade-off documented: correctness > hit rate.
- dispatch None-guard with reason (:319-326, :257-292):
_LAST_MAKE_METAL_KERNEL_REASON module global captured on each
None-return path of make_metal_kernel (default_device,
metal_available, "constructor unavailable on this MLX build");
dispatch surfaces "metal_kernel is None — {reason}".
- import json at module top (:42) for json.dumps fallback.
Skipped per CRITICAL note + verification:
- metal_grid_for_lowering math (verified correct 3 times).
- _parse_buffer_param_names hot-path cost: VERIFIED grok wrong —
buffer_param_names already cached on TileLangMSLLowering:446,
dispatch reads cached at :321, no re-parse.
- LRU weakref refactor (P3 nice-to-have).
- _strip_attribute_markers edge cases (corpus passes; mark TODO
if future signature exposes a real failure).
Verify: pytest 13 passed / 5 xfailed (unchanged baseline).
_validate_test_full_signatures() returns [] (all 5 pass).
… to dispatch_lower - _msl_transform.lower_tilelang_to_msl_inline now emits a one-shot DeprecationWarning pointing callers at cppmega_mlx.nn._tilelang.dispatch_lower(prim, target). Behaviour unchanged. Module-level _DEPRECATION_WARNED flag gates the warn. - __init__.py docstring updated: _msl_transform marked DEPRECATED; _engine_dispatch documented as the preferred entry point. dispatch_lower was already re-exported (predecessor 3bccf84 from phase-1.C). Caller flip status: NONE flipped this round. All 17 callers consume TileLangMSLLowering.body/.header strings for mx.fast.metal_kernel — the engine path returns a tilelang.compile artifact (runtime callable), not an MSL string. A flip needs an adapter that either (a) extracts MSL out of the engine artifact for the metal_kernel path, or (b) replaces the metal_kernel call with the artifact's __call__. Tracked as a Phase-3 blocker; see updated docstring in __init__.py. Predecessors preserved: 76e187a (fp8_amax NaN guard), a2ffcc1 (dsa SK clamp), 3bccf84 (phase-1.C dispatcher). Sibling-fork dirty paths (sparse_mla_blockscaled_path_c.py, topk_selector.py) left untouched.
Route _path_c_kernel_for through dispatch_lower so Path-C topk lowering honors CPPMEGA_MLX_TILELANG_ENGINE. Default (auto/shim) keeps the existing TileLangMSLLowering + mx.fast.metal_kernel flow unchanged; engine mode returns the tilelang.compile artifact directly and the caller invokes it with (scores, starts, ends, indices). Note: the docstring's reference to "T.sync_threads(3, RADIX) partial barriers" describes the upstream CUDA radix-select algorithm, NOT the current Metal-friendly Path-C kernel — Path-C uses only full T.sync_threads(). No partial-barrier replacement is needed here; the new tl.sync_threads_partial primitive (Phase-1.A) remains available for kernels that do mirror the upstream CUDA layout.
Collaborator
Author
|
Direct merge instead of PR review |
There was a problem hiding this comment.
Pull request overview
This PR continues the TileLang “Path C” integration by tightening correctness/perf guardrails in tests, hardening the dev-time libz3 preload behavior in pytest, and introducing a phase-1 “engine vs shim” lowering dispatcher that can route kernel lowering via a unified TileLang engine path or the legacy MSL-inline shim.
Changes:
- Updated TopK Path C tests to assert the corrected insertion-sort early-exit behavior and relaxed the perf smoke threshold post-correctness fix.
- Made the FP8 Path C bench harness environment preparation once-per-process and thread-safe; adjusted vecmat lowering import to
lower_fp8_vecmat_msl. - Added
_engine_dispatch.dispatch_lower()and wired some lowering helpers to route through it; added audit-trail review reports and refreshed bench JSON artifacts.
Reviewed changes
Copilot reviewed 91 out of 92 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_tilelang_topk.py | Pins K=32 lowering correctness (break;) and relaxes K=32 perf smoke threshold to reflect correctness-first behavior. |
| tests/conftest.py | Injects a dev libz3 candidate for tests and scrubs unsafe preload env var. |
| scripts/bench_tilelang_fp8_path_c.py | Adds a lock + once-per-process guard for import-environment prep; switches vecmat lowering import to lower_fp8_vecmat_msl. |
| cppmega_mlx/nn/_tilelang/_engine_dispatch.py | Introduces env-driven dispatcher for engine vs shim lowering. |
| cppmega_mlx/nn/_tilelang/sparse_mla_blockscaled_path_c.py | Routes lowering-through-source helpers via dispatch_lower; adds engine artifact helpers/cache. |
| cppmega_mlx/nn/_tilelang/init.py | Exposes dispatch_lower/tilelang_engine_mode and updates package docs/exports. |
| docs/tilelang_ports/mamba3_path_b_vs_c.diff | Updates Path B vs Path C source comparison artifact. |
| bench/tilelang_ports/mamba3.json | Refreshes mamba3 benchmark results/metadata. |
| bench/tilelang_ports/mamba3_path_c.json | Refreshes mamba3 Path C benchmark results/metadata. |
| reports/meta/20260507T040203/meta__performance__20260507T040439.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T040203/meta__design__20260507T040339.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T040203/meta__correctness__20260507T040250.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T033015/meta__security__20260507T033244.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T033015/meta__performance__20260507T033204.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T033015/meta__design__20260507T033114.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T033015/meta__correctness__20260507T033058.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T030544/meta__performance__20260507T031049.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T030544/meta__design__20260507T030946.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T024602/meta__security__20260507T025003.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T024602/meta__design__20260507T024805.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T022229/meta__design__20260507T022402.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T022229/meta__correctness__20260507T022258.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021404/meta__security__20260507T021410.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021404/meta__performance__20260507T021409.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021404/meta__design__20260507T021407.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021404/meta__correctness__20260507T021405.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021321/meta__tests__20260507T021328.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021321/meta__security__20260507T021323.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021321/meta__performance__20260507T021324.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021321/meta__design__20260507T021327.md | Adds/updates review report (audit trail). |
| reports/meta/20260507T021321/meta__correctness__20260507T021325.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T091652/grok__performance__20260507T091851.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T091652/grok__correctness__20260507T091803.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T091041/grok__performance__20260507T091155.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T091041/grok__correctness__20260507T091126.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T090316/grok__performance__20260507T090505.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T090316/grok__correctness__20260507T090411.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T040203/grok__security__20260507T040501.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T040203/grok__performance__20260507T040425.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T040203/grok__design__20260507T040358.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T040203/grok__correctness__20260507T040245.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T033015/grok__security__20260507T033309.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T033015/grok__performance__20260507T033235.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T033015/grok__correctness__20260507T033054.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T030544/grok__security__20260507T030852.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T030544/grok__correctness__20260507T030619.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T024602/grok__security__20260507T024828.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T024602/grok__performance__20260507T024804.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T024602/grok__design__20260507T024739.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T022009/grok__performance__20260507T022150.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T021404/grok__performance__20260507T021517.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T021404/grok__correctness__20260507T021422.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T012047/grok__security__20260507T012139.md | Adds/updates review report (audit trail). |
| reports/grok/20260507T012047/grok__design__20260507T012618.md | Adds/updates review report (audit trail). |
| reports/chatgpt/20260507T012045/chatgpt__security__20260507T012223.md | Adds/updates review report (audit trail). |
| reports/chatgpt/20260507T011618/chatgpt__security__20260507T012245.md | Adds/updates review report (audit trail). |
| reports/chatgpt/20260507T011618/chatgpt__performance__20260507T012614.md | Adds/updates review report (audit trail). |
| reports/chatgpt/20260507T011618/chatgpt__correctness__20260507T011950.md | Adds/updates review report (audit trail). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+70
to
+80
| """Lower via the legacy MSL-string shim. Always targets metal.""" | ||
|
|
||
| from cppmega_mlx.nn._tilelang._msl_transform import lower_tilelang_to_msl_inline | ||
|
|
||
| if target != "metal": | ||
| warnings.warn( | ||
| f"_engine_dispatch: shim mode is metal-only; ignoring target={target!r}.", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
| return lower_tilelang_to_msl_inline(prim_func, target="metal") |
Comment on lines
+54
to
+67
| def _engine_compile(prim_func: Any, target: str) -> Any: | ||
| """Run ``tilelang.compile`` and stamp the result with the target tag.""" | ||
|
|
||
| import tilelang # noqa: F401 - intentional eager import for ImportError surfacing | ||
|
|
||
| artifact = tilelang.compile(prim_func, target=target, out_idx=None) | ||
| try: | ||
| setattr(artifact, "_tilelang_engine_target", target) | ||
| except (AttributeError, TypeError): | ||
| # Some builds wrap the artifact in a frozen / __slots__ object; preserve | ||
| # the artifact unchanged if we cannot stamp it. | ||
| pass | ||
| return artifact | ||
|
|
Comment on lines
+39
to
+52
| # The module-level preload at the bottom of _msl_transform.py runs at | ||
| # import time -- BEFORE we set the candidate list above. So the first | ||
| # preload attempt only saw the default empty list (plus the brew | ||
| # fallback). Reset the idempotency flag and re-run the preload now that | ||
| # the in-tree dev candidate is registered, so libtilelang.dylib's | ||
| # basename libz3 reference resolves to the matching dev-build z3. | ||
| try: | ||
| if hasattr(_msl._preload_libz3_for_dev_tilelang, "_done"): | ||
| delattr(_msl._preload_libz3_for_dev_tilelang, "_done") | ||
| if hasattr(_msl._preload_libz3_for_dev_tilelang, "_failed_attempts"): | ||
| delattr(_msl._preload_libz3_for_dev_tilelang, "_failed_attempts") | ||
| _msl._preload_libz3_for_dev_tilelang() | ||
| except Exception: # pragma: no cover - best-effort | ||
| pass |
Comment on lines
+9
to
+18
| - _msl_transform.py: legacy TileLang->MSL inline lowering shim. **DEPRECATED.** | ||
| New callers should route through ``dispatch_lower(prim, target)`` which | ||
| prefers ``tilelang.engine.lower(target=...)`` (engine path) and only falls | ||
| back to this shim when the engine is unavailable. See | ||
| ``_engine_dispatch.py`` and ``MIGRATION_PLAN.md``. Existing callers that | ||
| consume ``TileLangMSLLowering.body``/``.header`` strings for | ||
| ``mx.fast.metal_kernel`` need an adapter layer (Phase 3) before they can | ||
| flip — the engine artifact is a runtime callable, not an MSL string. | ||
| - _engine_dispatch.py: ``dispatch_lower(prim, target)`` — phase-1 dispatcher | ||
| that flips between engine and shim based on ``$CPPMEGA_MLX_TILELANG_ENGINE``. |
apstenku123
added a commit
that referenced
this pull request
May 7, 2026
… hoist, head-0 alloc gate Apply the 3 HIGH-severity hot-path findings from grok wave-3 review of the DSA split-K indexer-loss port: 1. sparse_loss path GPU->CPU sync (issue #1): the topk_idx64.max()/min() /.item() bounds check + (index_mask==0).all() NaN-poisoning detect forced full GPU->CPU syncs + extra reduction kernels on every sparse forward pass. Gate behind CPPMEGA_MLX_DSA_DEBUG env var. Production training paths now skip both validation passes; CI / first-run regressions enable via the env var. 2. Stage-2 partial Q hoist rationale (issue #2): document why the current per-(sk_tile, h) hoist is the local optimum and what the structural blocker is for full hoist (heads-summed softmax_attn accumulator -> persisting it across heads costs 2 MB shared, well over Metal/CUDA budgets; alternatives include online cross-head softmax recurrence which is the right wave-5 fix). 3. Stage-1 fragment over-allocation (issue #3): split the kernel into two specialised variants: - make_dsa_splitk_stage1_kernel(compute_index_path: bool = True): when False, idx_scores_f / m1_i / d1_i / m1_i_prev shrink to (1,) stubs and the `if h == 0` index-softmax block is Python-guarded out, eliminating register pressure on AH-1 attn-only blocks. - make_dsa_splitk_stage1_idx_kernel: NEW dedicated kernel that computes only M1/D1 from IndexScores+IndexMask. Grid is (AB, NUM_SQ_BLOCKS) -- no AH axis. No Q/K/matmul; just the per-(b, sq) online softmax over IndexScores. Wrapper splits the stage-1 launch when AH > 1: attn-only kernel for all AH heads + idx kernel for the single h=0-equivalent. AH==1 still uses the unified kernel (backward compat for single-head models).
apstenku123
added a commit
that referenced
this pull request
May 7, 2026
…kernel_for Wave-2 (44f4f88) introduced N=n_elements / BLOCK=block_size local rebinds inside make_fp8_amax_kernel + make_fp8_quantize_kernel so the inner @T.prim_func annotations resolve at decoration time, but missed the same treatment for in_dtype. Under some tilelang parser paths CPython does not materialise a closure cell for an outer-scope name that is read but never written in the enclosing function, so 'T.Tensor((N,), in_dtype)' raised NameError at parse time and burned down the 6 fp8_amax tests reported in docs/research/runtime_test_matrix.md. Adds 'DTYPE = in_dtype' next to the existing 'N = n_elements' rebind in both kernel builders and switches the annotation to use DTYPE. Wave-3 NaN guard (76e187a) preserved untouched; py_compile clean. Fixes runtime_test_matrix Bug #1.
apstenku123
added a commit
that referenced
this pull request
May 7, 2026
…als__ Wave-7 #1 (a439df0) restored DTYPE/N/BLOCK as closure rebinds, but typing.get_type_hints() — which tilelang's @T.prim_func parser invokes to resolve the T.Tensor((N,), DTYPE) annotations — only inspects __globals__ and an explicit localns; it never walks __closure__. Test #4 (commit 20398296) found this empirically. Adds _expose_to_globals helper that mutates the factory's module __globals__ in place to expose DTYPE/N/BLOCK/FP8_MAX before @T.prim_func runs. Per-shape kernel construction is synchronous and @lru_cache-wrapped so successive calls overwrite each other's values without races. Empirical smoke (M4 Max, MLX venv): make_fp8_amax_kernel(128, fp16) + make_fp8_quantize_kernel(128, fp16) + make_fp8_amax_kernel(256, bf16) all return PrimFunc; was raising NameError before this fix.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Brings
mlx-z3-wiringintomainafter 5 waves of multi-provider AI review→fix loops (Grok-4 ext + Meta + GPT-5.5-pro extended) plus parallel MLX-wiring PassConfig opt-ins.29 commits ahead of main: initial 12-fix wave (
ebbb84b) + 27 follow-up fixes/migrations on top.Key changes
Initial fix-wave (
ebbb84b)12 fixes across
_msl_transform.py,fp8_vecmat_path_c.py, kernel docstrings, routing doc, tests, bench harness,_experimental.py. Addressed Grok/Meta/GPT findings: intrinsic registration P0 (kPure→kReadState for position-reads, fixed kOpaque misnaming), shape contract(1,N)macro vs(N,)fallback, routing-doc honesty (PROBE-ONLY/REDUCERS-ONLY/BROKEN labels), Path C↔B parity tests, schema gate v2, env hygiene autouse fixture, snapshot rot archive.MLX-wiring (Z3-roadmap PassConfig opt-ins)
9a668eatl.dot4_legalityopt-in in fp8_vecmat_path_c9202646tl.intra_warp_barrier_elisionopt-in in sparse_mla + topk_selector9242368thread-safety lock on_PASS_CONFIGS_CACHE60f0005K>0 guard + narrow exception in fp8_vecmat_path_cWave 1-5 review→fix loop (grok-4 extended, correctness+performance)
2b9310e(DSA) + integratedidx_scores_f-INFpriming (P0), Stage-2 partial Q hoist, Metal block AH-aware, fp8_vecmat output_shape canonicalization, lock unification, vectorized_loads cache key, scale-resolve fast pathab17d21_as_metal_targetlru_cache, lowering result cache, bench env-prep idempotencyc66ce56_parse_buffer_param_names(split/strip/extract helpers), dispatch None-kernel guard, recursive_freeze_for_hash, bounded LRU_LOWERING_CACHE(CPPMEGA_LOWERING_CACHE_SIZE)741fc74reset_libz3_preload_state(), stable_freeze_for_hash(norepr()), dispatch None-guard with reasonMigration phase 1+2 (
3bccf84/e3340b7/f1ca48b)Unified
tilelang.engine.lowerpath for fp8_amax + dsa_splitk (gated byCPPMEGA_MLX_TILELANG_ENGINE), deprecate_msl_transformshim, topk_selector engine path viadispatch_lower.Misc
/tmpfrom prod candidates, gate behind env, distinguish OSError vs FileNotFoundError, retry caps).reports/grok/,reports/chatgpt/,reports/meta/.False-positive findings rejected (per memory rule "verify before relaying")
metal_grid_for_loweringmultiplication: grok flagged 3 times as P0; verified correct each time (TileLang grid = threadgroup extents, MLX wants total threads).SCHEMA_VERSION"missing from__all__": verified present at bench script:58.MTL_DEBUG_LAYER/etc. "not covered by conftest": verified covered byMTL_prefix._parse_buffer_param_names"hot-path cost": verifiedbuffer_param_namesalready cached onTileLangMSLLowering:446, dispatch reads cached at:321.Test plan
pytest tests/test_tilelang_msl_transform.py— should be 13 passed / 5 xfailed (xfails strict, will flip green when transforms land upstream)pytest tests/test_tilelang_path_c_vs_b_parity.py— Path C↔B numeric parity (xfail-gated where Path C apply not yet exposed)pytest tests/test_tilelang_fp8_vecmat_path_c.py::test_fp8_e4m3_dot4_intrinsic_is_registered— intrinsic registration trip-wirepython scripts/bench_tilelang_fp8_path_c.py --shapes vecmat_4096 --iters 50— confirm no regression vs Path BCPPMEGA_VALIDATE_PARSE_TESTS=1 python -c "from cppmega_mlx.nn._tilelang import _msl_transform"— multi-param signature parser self-test