Skip to content

[Feat] Add TensorAdaptor.mark_static for shape-aware JIT cache keys#554

Merged
coderfeli merged 12 commits into
mainfrom
feat/tensor-mark-static
May 24, 2026
Merged

[Feat] Add TensorAdaptor.mark_static for shape-aware JIT cache keys#554
coderfeli merged 12 commits into
mainfrom
feat/tensor-mark-static

Conversation

@coderfeli
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli commented May 23, 2026

Summary

Make TensorAdaptor default to layout-dynamic memref so the generated IR
is shape-independent and a single compiled kernel serves calls with
different sizes by construction — eliminating the silent-OOB failure mode
that affected PR #551 fp8gemm (static-shape cosize baked into IR + a
shape-agnostic JIT cache key reusing across shapes).

🤖 Generated with Claude Code

When a kernel's compiled IR bakes concrete shape values into the
generated code (e.g. buffer-resource num_records derived from the
static layout cosize), reusing a single compiled artifact across
different tensor shapes leads to silent OOB or zero-padded results.

This adds an opt-in mark_static(dims=None) API that pins the listed
dims' shape/stride values into the JIT cache key, so different shapes
get distinct cache entries.  The default cache key remains lightweight
(dtype + align + 32-bit-stride flag + rank) so existing kernels keep
their single-compilation behavior.

Usage:
    flyc.from_dlpack(t).mark_static()          # all dims static
    flyc.from_dlpack(t).mark_static(dims=[1])  # only dim 1 static

Cache key format (when mark_static is used):
    (dtype, align, use_32bit, rank, ((dim, shape, stride), ...))

raw_cache_signature is updated to match the new default __cache_signature__
format so raw torch.Tensor inputs and from_dlpack-wrapped inputs share
cache entries.

Test: tests/unit/test_tensor_cache_signature.py (12 cases)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
coderfeli and others added 11 commits May 24, 2026 04:01
Root cause of the PR #551 fp8gemm silent-OOB regression:
``fx.rocdl.make_buffer_tensor(t, max_size=False)`` in
``python/flydsl/expr/rocdl/universal.py`` auto-derived ``num_records_bytes
= cosize(layout) * elem_bytes`` from the static memref shape and baked the
result into IR.  The JIT cache key does not include shape by default, so a
kernel compiled for shape A was silently reused for shape B with A's
``num_records`` baked in -- truncating OOB reads (returns 0) and dropping
OOB writes.

mark_static (PR #554) is a useful opt-in for callers who *want* shape in
the cache key, but it only fixes the silent OOB if every caller remembers
to use it.  The robust fix is to make the unsafe pattern impossible at
compile time: require an explicit ``num_records_bytes`` whenever
``max_size=False``.

* Add ``num_records_bytes`` parameter to ``make_buffer_tensor``.  When
  ``max_size=False`` and ``num_records_bytes`` is not supplied, raise
  ``ValueError`` with an actionable message pointing callers to either
  pass an explicit byte count (computed from runtime tensor extents) or
  use ``max_size=True`` for the safe coarse path.

* Update all in-tree callers to be explicit:
  - ``examples/04-preshuffle_gemm.py`` (3 callsites, fp16 GEMM)
  - ``kernels/fp8_gemm_utils.py`` ``make_fp8_buffer_tensor`` helper and
    ``StoreC`` class (4 callsites, fp8 GEMM + scales)
  - ``kernels/fp8_gemm_8wave.py`` / ``fp8_gemm_4wave.py`` callers of
    ``make_fp8_buffer_tensor`` (4 callsites)

* Cherry-pick of the parallel ``buffer_ops.create_buffer_resource`` fix
  from commit 202793d (``[Fix] Align buffer-resource num_records with
  logical tensor sizes``) covering the legacy ``create_buffer_resource
  (max_size=False)`` path in 4 MoE/GEMM files (18 callsites).

* New test ``tests/unit/test_make_buffer_tensor_guard.py`` (3 cases)
  validates the guard at trace time under ``COMPILE_ONLY=1``.

Verification:
* tests/unit/test_tensor_cache_signature.py     12 passed
* tests/unit/test_make_buffer_tensor_guard.py    3 passed
* tests/unit/                                  464 passed
* tests/system/                                 20 passed
* tests/kernels/test_preshuffle_gemm.py + test_moe_gemm.py +
  test_moe_blockscale.py                       456 passed
* bash scripts/check_python_style.sh           clean

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Demonstrates the safe-by-construction dynamic-shape pattern: a single
compiled kernel reused across different M values, with no silent OOB.

* ``gemm_kernel`` and the ``preshuffle_gemm`` JIT launcher both take
  ``M: fx.Int32`` as a runtime parameter.  ``num_records_bytes = M * K * 2``
  is now a runtime expression that adapts the buffer-resource descriptor
  to the actual tensor extent.

* N and K stay compile-time: layout B's preshuffle factors depend on N
  and K (used in ``fx.make_layout(..., N // 16, ..., K // 32, ...)``)
  and the K-loop is constexpr-unrolled in ``run_pipeline_stage``.

* ``make_buffer_tensor`` now always materialises ``num_records_bytes``
  as Int64 (previously fx.Int32 expressions slipped through and tripped
  the ROCDL op verifier which requires i64 num_records -- found by the
  refactor).  ``num_records_bytes.to(Int64)`` covers fx-typed inputs;
  Python ints continue through the existing ``Int64(...)`` wrap.

Verification:
* Direct run produces correct result (max_violation < 0)
* Cache reuse confirmed: M in {2048, 4096, 8192} -> 1 compiled artifact
* tests/unit/test_tensor_cache_signature.py     12 passed
* tests/unit/test_make_buffer_tensor_guard.py    3 passed
* tests/unit/ + tests/system/                  484 passed, 0 failed

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Re-enable the historical ``cosize(layout) * elem_bytes`` auto-derive path
in ``make_buffer_tensor(max_size=False)`` -- but only when the tensor's
layout has dynamic shape.  ``cosize(layout)`` is then a runtime expression
that adapts to the actual tensor extent, so the descriptor stays correct
across calls with different shapes.

For static-shape layouts the auto-derive would still bake the first
call's shape into IR and silently OOB on cache reuse (the PR #551 fp8gemm
regression mode), so it remains an error with the same actionable message
pointing callers to mark_layout_dynamic / num_records_bytes / max_size=True.

This matches the user's mental model: "if static you can auto-derive,
dynamic raises" -- inverted to "if dynamic you can auto-derive, static
raises", which is the safe formulation (static shape == unsafe to bake).

Existing kernels that already call mark_layout_dynamic on their tensor
adaptors keep their ``make_buffer_tensor(max_size=False)`` callsites
unchanged.  ``test_max_size_false_with_dynamic_layout_auto_derives``
locks in this behavior.

Also add an ``FLYDSL_RUNTIME_ENABLE_CACHE=0`` fixture so the guard
re-fires across test runs (the JIT disk cache would otherwise
short-circuit re-trace and mask the guard on repeated runs).

Verification:
* tests/unit/test_make_buffer_tensor_guard.py        4 passed (incl. new)
* tests/unit/ + tests/system/                      485 passed, 0 failed
* tests/kernels/test_preshuffle_gemm.py + test_moe_gemm.py +
  test_moe_blockscale.py                          456 passed, 0 failed
* examples/04-preshuffle_gemm.py                   correct result

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Make TensorAdaptor default to layout-dynamic memref in __init__ so the
generated IR is shape-independent and a single compiled kernel serves
calls with different sizes by construction.  This eliminates the silent-
OOB failure mode that affected PR #551 fp8gemm (static-shape cosize
baked into IR + shape-agnostic JIT cache key reusing across shapes), and
makes the workarounds added in PR #554 (mark_static API,
make_buffer_tensor guard, explicit num_records_bytes on 18+ callsites)
all unnecessary -- the same kernel code is now safe by default.

Extends the per-call CallState fast path to support the dynamic-memref
ABI which carries an extra layout-buffer struct after the data pointer
(packed ``[i32 * rank | i64 * (rank-1)]`` matching the C++
``buildMemRefDesc`` layout).  ``_reusable_slot_spec`` may now return a
list of slot tuples; each slot is either:

  * a scalar slot ``(ctype, extract_returning_value)`` -- the existing
    protocol (e.g. data pointer);
  * a buffer slot ``(ctype_array, extract_in_place(arg, storage))`` --
    new protocol for fixed-size byte buffers, mutated in place via
    ``struct.Struct.pack_into`` on a ``memoryview`` of the storage.

CallState detects the protocol per slot at init time
(``hasattr(storage, 'value')``) and dispatches accordingly on the hot
path.  No allocation per call; the layout buffer is packed with shape /
stride values read directly from torch.Tensor's ``.shape`` /
``.stride()`` (no DLPack export).

Verification:
* tests/unit/ + tests/system/         469 passed, 0 failed
* tests/kernels/test_vec_add + test_preshuffle_gemm + test_moe_gemm +
  test_moe_blockscale                 457 passed, 0 failed
* Cache reuse confirmed: examples/04 unchanged kernel runs across
  M in {2048, 4096, 8192} with a single compiled artifact

Follow-up: PR #554's workarounds (mark_static API, make_buffer_tensor
guard, kernel-side explicit num_records_bytes) can now be reverted as a
separate cleanup since this commit makes them redundant.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous commit makes TensorAdaptor default to layout-dynamic memref,
so ``cosize(layout)`` is always a runtime expression that adapts to the
actual tensor extent.  The trace-time ``raise ValueError`` for "static
layout + max_size=False + no num_records_bytes" can no longer fire from
any real call site (the static-layout branch is unreachable through the
default TensorAdaptor entry point), so drop it.

Kernels keep their explicit ``num_records_bytes`` callsites: when the
exact size is known at compile time (e.g. ``M * N * elem_bytes`` from
constexpr extents in the outer compile-function closure), passing it
explicitly folds the descriptor to a constant in IR and avoids the
runtime ``cosize`` multiplication.  When omitted, the runtime cosize
path is correct and safe.

Files:
- python/flydsl/expr/rocdl/universal.py : simplify make_buffer_tensor;
  drop the raise + the LayoutType static-shape check; docstring updated
  to recommend explicit num_records_bytes only as a perf hint, not a
  safety requirement.
- tests/unit/test_make_buffer_tensor_guard.py : removed (test premise
  no longer reachable).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
With default-dynamic TensorAdaptor + multi-slot fast path, the
``make_buffer_tensor(max_size=False)`` calls auto-derive
``num_records`` from runtime ``cosize(layout) * elem_bytes`` -- which
adapts to the actual tensor extent and is safe across cache reuse.

The explicit ``num_records_bytes = M * K * 2`` style was needed by an
earlier interim version of this PR where the auto-derive path could
silently bake compile-time shape; that path no longer exists.  Revert
to the simpler form (matches main) so the example shows the cleanest
default-dynamic pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Net -46 lines of prose, no behavior change.  Tests unchanged: 481 passed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
C++ markLayoutDynamic accumulates dynamic flags across calls without
resetting -- a 2nd call with a different leading_dim leaves the previous
call's stride[leading] dynamic, and the Python-cached _dyn_leading_dim
(used by _reusable_slot_spec to lay out the layout buffer) diverges from
the C++ ABI, silently corrupting the layout buffer.

Raise NotImplementedError on the 2nd call with a different leading_dim
than the one auto-detected in __init__.  Same leading_dim (or None,
which means auto-detect = same) is still allowed so callers can update
divisibility.

Real fix path (TODO): C++ should reset all shape/stride dynamic flags
at the start of markLayoutDynamic before re-marking per the new
leading_dim.  All in-tree call sites today pass leading_dim equal to
the auto-detected one (contiguous tensors), so the guard never fires.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a JitFunction is called first with a raw ``torch.Tensor`` and then
with a ``flyc.from_dlpack(t)`` wrapper of the same tensor (same cache
key), the cached CallState reuses ``_extract_data_ptr`` which previously
assumed the arg is always a raw torch.Tensor.  Calling
``adaptor.data_ptr()`` on a TensorAdaptor raised AttributeError.

Walk to the keepalive tensor when the arg is a TensorAdaptor.  Both
wrappings now hit the same fast-path entry correctly.

Also narrow the __init__ ``try/except`` from bare ``Exception`` to
``(RuntimeError, StopIteration)`` -- the two real failure modes for
the layout-dynamic auto-detect (C++ binding throws RuntimeError for
ambiguous / non-unit leading stride; ``next(...)`` raises
StopIteration when no stride-1 dim exists).  Avoids swallowing bugs.

Verified: raw torch.Tensor x4 shapes + from_dlpack wrapper x2 shapes
all hit ONE compiled artifact (compiles=1, call_states=1) on
preshuffle_gemm.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous ``hasattr(num_records_bytes, "ir_value")`` check passed
through any fx Numeric type unchanged -- including ``fx.Int32`` and
``fx.Int32 * int`` arithmetic results.  ROCDL ``make.buffer.rsrc``
requires an i64 ``num_records`` operand, so feeding it an i32 fx value
fails the op verifier with:

    operand #2 must be 64-bit signless integer, but got 'i32'

Always wrap in ``Int64(...)``, which is idempotent for already-Int64
inputs and handles Python int / other Integer / raw ir.Value (i32,
index, float) by emitting the appropriate extension / cast (same
logic as ``Integer.__init__``).

In-tree callers all pass Python ints today so the bug is latent, but
the API allows fx-typed inputs; this prevents silent invalid IR for
future runtime-shape kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderfeli coderfeli merged commit 250c11d into main May 24, 2026
13 checks passed
@coderfeli coderfeli deleted the feat/tensor-mark-static branch May 24, 2026 14:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant