[Feat] Add TensorAdaptor.mark_static for shape-aware JIT cache keys#554
Merged
Conversation
When a kernel's compiled IR bakes concrete shape values into the
generated code (e.g. buffer-resource num_records derived from the
static layout cosize), reusing a single compiled artifact across
different tensor shapes leads to silent OOB or zero-padded results.
This adds an opt-in mark_static(dims=None) API that pins the listed
dims' shape/stride values into the JIT cache key, so different shapes
get distinct cache entries. The default cache key remains lightweight
(dtype + align + 32-bit-stride flag + rank) so existing kernels keep
their single-compilation behavior.
Usage:
flyc.from_dlpack(t).mark_static() # all dims static
flyc.from_dlpack(t).mark_static(dims=[1]) # only dim 1 static
Cache key format (when mark_static is used):
(dtype, align, use_32bit, rank, ((dim, shape, stride), ...))
raw_cache_signature is updated to match the new default __cache_signature__
format so raw torch.Tensor inputs and from_dlpack-wrapped inputs share
cache entries.
Test: tests/unit/test_tensor_cache_signature.py (12 cases)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
4 tasks
4393426 to
4f05c62
Compare
Root cause of the PR #551 fp8gemm silent-OOB regression: ``fx.rocdl.make_buffer_tensor(t, max_size=False)`` in ``python/flydsl/expr/rocdl/universal.py`` auto-derived ``num_records_bytes = cosize(layout) * elem_bytes`` from the static memref shape and baked the result into IR. The JIT cache key does not include shape by default, so a kernel compiled for shape A was silently reused for shape B with A's ``num_records`` baked in -- truncating OOB reads (returns 0) and dropping OOB writes. mark_static (PR #554) is a useful opt-in for callers who *want* shape in the cache key, but it only fixes the silent OOB if every caller remembers to use it. The robust fix is to make the unsafe pattern impossible at compile time: require an explicit ``num_records_bytes`` whenever ``max_size=False``. * Add ``num_records_bytes`` parameter to ``make_buffer_tensor``. When ``max_size=False`` and ``num_records_bytes`` is not supplied, raise ``ValueError`` with an actionable message pointing callers to either pass an explicit byte count (computed from runtime tensor extents) or use ``max_size=True`` for the safe coarse path. * Update all in-tree callers to be explicit: - ``examples/04-preshuffle_gemm.py`` (3 callsites, fp16 GEMM) - ``kernels/fp8_gemm_utils.py`` ``make_fp8_buffer_tensor`` helper and ``StoreC`` class (4 callsites, fp8 GEMM + scales) - ``kernels/fp8_gemm_8wave.py`` / ``fp8_gemm_4wave.py`` callers of ``make_fp8_buffer_tensor`` (4 callsites) * Cherry-pick of the parallel ``buffer_ops.create_buffer_resource`` fix from commit 202793d (``[Fix] Align buffer-resource num_records with logical tensor sizes``) covering the legacy ``create_buffer_resource (max_size=False)`` path in 4 MoE/GEMM files (18 callsites). * New test ``tests/unit/test_make_buffer_tensor_guard.py`` (3 cases) validates the guard at trace time under ``COMPILE_ONLY=1``. Verification: * tests/unit/test_tensor_cache_signature.py 12 passed * tests/unit/test_make_buffer_tensor_guard.py 3 passed * tests/unit/ 464 passed * tests/system/ 20 passed * tests/kernels/test_preshuffle_gemm.py + test_moe_gemm.py + test_moe_blockscale.py 456 passed * bash scripts/check_python_style.sh clean Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Demonstrates the safe-by-construction dynamic-shape pattern: a single
compiled kernel reused across different M values, with no silent OOB.
* ``gemm_kernel`` and the ``preshuffle_gemm`` JIT launcher both take
``M: fx.Int32`` as a runtime parameter. ``num_records_bytes = M * K * 2``
is now a runtime expression that adapts the buffer-resource descriptor
to the actual tensor extent.
* N and K stay compile-time: layout B's preshuffle factors depend on N
and K (used in ``fx.make_layout(..., N // 16, ..., K // 32, ...)``)
and the K-loop is constexpr-unrolled in ``run_pipeline_stage``.
* ``make_buffer_tensor`` now always materialises ``num_records_bytes``
as Int64 (previously fx.Int32 expressions slipped through and tripped
the ROCDL op verifier which requires i64 num_records -- found by the
refactor). ``num_records_bytes.to(Int64)`` covers fx-typed inputs;
Python ints continue through the existing ``Int64(...)`` wrap.
Verification:
* Direct run produces correct result (max_violation < 0)
* Cache reuse confirmed: M in {2048, 4096, 8192} -> 1 compiled artifact
* tests/unit/test_tensor_cache_signature.py 12 passed
* tests/unit/test_make_buffer_tensor_guard.py 3 passed
* tests/unit/ + tests/system/ 484 passed, 0 failed
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Re-enable the historical ``cosize(layout) * elem_bytes`` auto-derive path in ``make_buffer_tensor(max_size=False)`` -- but only when the tensor's layout has dynamic shape. ``cosize(layout)`` is then a runtime expression that adapts to the actual tensor extent, so the descriptor stays correct across calls with different shapes. For static-shape layouts the auto-derive would still bake the first call's shape into IR and silently OOB on cache reuse (the PR #551 fp8gemm regression mode), so it remains an error with the same actionable message pointing callers to mark_layout_dynamic / num_records_bytes / max_size=True. This matches the user's mental model: "if static you can auto-derive, dynamic raises" -- inverted to "if dynamic you can auto-derive, static raises", which is the safe formulation (static shape == unsafe to bake). Existing kernels that already call mark_layout_dynamic on their tensor adaptors keep their ``make_buffer_tensor(max_size=False)`` callsites unchanged. ``test_max_size_false_with_dynamic_layout_auto_derives`` locks in this behavior. Also add an ``FLYDSL_RUNTIME_ENABLE_CACHE=0`` fixture so the guard re-fires across test runs (the JIT disk cache would otherwise short-circuit re-trace and mask the guard on repeated runs). Verification: * tests/unit/test_make_buffer_tensor_guard.py 4 passed (incl. new) * tests/unit/ + tests/system/ 485 passed, 0 failed * tests/kernels/test_preshuffle_gemm.py + test_moe_gemm.py + test_moe_blockscale.py 456 passed, 0 failed * examples/04-preshuffle_gemm.py correct result Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Make TensorAdaptor default to layout-dynamic memref in __init__ so the generated IR is shape-independent and a single compiled kernel serves calls with different sizes by construction. This eliminates the silent- OOB failure mode that affected PR #551 fp8gemm (static-shape cosize baked into IR + shape-agnostic JIT cache key reusing across shapes), and makes the workarounds added in PR #554 (mark_static API, make_buffer_tensor guard, explicit num_records_bytes on 18+ callsites) all unnecessary -- the same kernel code is now safe by default. Extends the per-call CallState fast path to support the dynamic-memref ABI which carries an extra layout-buffer struct after the data pointer (packed ``[i32 * rank | i64 * (rank-1)]`` matching the C++ ``buildMemRefDesc`` layout). ``_reusable_slot_spec`` may now return a list of slot tuples; each slot is either: * a scalar slot ``(ctype, extract_returning_value)`` -- the existing protocol (e.g. data pointer); * a buffer slot ``(ctype_array, extract_in_place(arg, storage))`` -- new protocol for fixed-size byte buffers, mutated in place via ``struct.Struct.pack_into`` on a ``memoryview`` of the storage. CallState detects the protocol per slot at init time (``hasattr(storage, 'value')``) and dispatches accordingly on the hot path. No allocation per call; the layout buffer is packed with shape / stride values read directly from torch.Tensor's ``.shape`` / ``.stride()`` (no DLPack export). Verification: * tests/unit/ + tests/system/ 469 passed, 0 failed * tests/kernels/test_vec_add + test_preshuffle_gemm + test_moe_gemm + test_moe_blockscale 457 passed, 0 failed * Cache reuse confirmed: examples/04 unchanged kernel runs across M in {2048, 4096, 8192} with a single compiled artifact Follow-up: PR #554's workarounds (mark_static API, make_buffer_tensor guard, kernel-side explicit num_records_bytes) can now be reverted as a separate cleanup since this commit makes them redundant. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous commit makes TensorAdaptor default to layout-dynamic memref, so ``cosize(layout)`` is always a runtime expression that adapts to the actual tensor extent. The trace-time ``raise ValueError`` for "static layout + max_size=False + no num_records_bytes" can no longer fire from any real call site (the static-layout branch is unreachable through the default TensorAdaptor entry point), so drop it. Kernels keep their explicit ``num_records_bytes`` callsites: when the exact size is known at compile time (e.g. ``M * N * elem_bytes`` from constexpr extents in the outer compile-function closure), passing it explicitly folds the descriptor to a constant in IR and avoids the runtime ``cosize`` multiplication. When omitted, the runtime cosize path is correct and safe. Files: - python/flydsl/expr/rocdl/universal.py : simplify make_buffer_tensor; drop the raise + the LayoutType static-shape check; docstring updated to recommend explicit num_records_bytes only as a perf hint, not a safety requirement. - tests/unit/test_make_buffer_tensor_guard.py : removed (test premise no longer reachable). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This reverts commit e5f41ee.
With default-dynamic TensorAdaptor + multi-slot fast path, the ``make_buffer_tensor(max_size=False)`` calls auto-derive ``num_records`` from runtime ``cosize(layout) * elem_bytes`` -- which adapts to the actual tensor extent and is safe across cache reuse. The explicit ``num_records_bytes = M * K * 2`` style was needed by an earlier interim version of this PR where the auto-derive path could silently bake compile-time shape; that path no longer exists. Revert to the simpler form (matches main) so the example shows the cleanest default-dynamic pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Net -46 lines of prose, no behavior change. Tests unchanged: 481 passed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
C++ markLayoutDynamic accumulates dynamic flags across calls without resetting -- a 2nd call with a different leading_dim leaves the previous call's stride[leading] dynamic, and the Python-cached _dyn_leading_dim (used by _reusable_slot_spec to lay out the layout buffer) diverges from the C++ ABI, silently corrupting the layout buffer. Raise NotImplementedError on the 2nd call with a different leading_dim than the one auto-detected in __init__. Same leading_dim (or None, which means auto-detect = same) is still allowed so callers can update divisibility. Real fix path (TODO): C++ should reset all shape/stride dynamic flags at the start of markLayoutDynamic before re-marking per the new leading_dim. All in-tree call sites today pass leading_dim equal to the auto-detected one (contiguous tensors), so the guard never fires. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a JitFunction is called first with a raw ``torch.Tensor`` and then with a ``flyc.from_dlpack(t)`` wrapper of the same tensor (same cache key), the cached CallState reuses ``_extract_data_ptr`` which previously assumed the arg is always a raw torch.Tensor. Calling ``adaptor.data_ptr()`` on a TensorAdaptor raised AttributeError. Walk to the keepalive tensor when the arg is a TensorAdaptor. Both wrappings now hit the same fast-path entry correctly. Also narrow the __init__ ``try/except`` from bare ``Exception`` to ``(RuntimeError, StopIteration)`` -- the two real failure modes for the layout-dynamic auto-detect (C++ binding throws RuntimeError for ambiguous / non-unit leading stride; ``next(...)`` raises StopIteration when no stride-1 dim exists). Avoids swallowing bugs. Verified: raw torch.Tensor x4 shapes + from_dlpack wrapper x2 shapes all hit ONE compiled artifact (compiles=1, call_states=1) on preshuffle_gemm. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous ``hasattr(num_records_bytes, "ir_value")`` check passed
through any fx Numeric type unchanged -- including ``fx.Int32`` and
``fx.Int32 * int`` arithmetic results. ROCDL ``make.buffer.rsrc``
requires an i64 ``num_records`` operand, so feeding it an i32 fx value
fails the op verifier with:
operand #2 must be 64-bit signless integer, but got 'i32'
Always wrap in ``Int64(...)``, which is idempotent for already-Int64
inputs and handles Python int / other Integer / raw ir.Value (i32,
index, float) by emitting the appropriate extension / cast (same
logic as ``Integer.__init__``).
In-tree callers all pass Python ints today so the bug is latent, but
the API allows fx-typed inputs; this prevents silent invalid IR for
future runtime-shape kernels.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Make TensorAdaptor default to layout-dynamic memref so the generated IR
is shape-independent and a single compiled kernel serves calls with
different sizes by construction — eliminating the silent-OOB failure mode
that affected PR #551 fp8gemm (static-shape
cosizebaked into IR + ashape-agnostic JIT cache key reusing across shapes).
🤖 Generated with Claude Code