Skip to content

Support dynamic shapes in FP8 GEMMs#551

Merged
coderfeli merged 2 commits into
ROCm:mainfrom
amd-cgilli:fp8gemm_dyn_shapes
May 25, 2026
Merged

Support dynamic shapes in FP8 GEMMs#551
coderfeli merged 2 commits into
ROCm:mainfrom
amd-cgilli:fp8gemm_dyn_shapes

Conversation

@amd-cgilli
Copy link
Copy Markdown
Contributor

Pass M/N as runtime parameters, keep only K as compile-time.

coderfeli
coderfeli previously approved these changes May 20, 2026
coderfeli added a commit that referenced this pull request May 24, 2026
The previous mark_static API was opt-in: callers had to remember to call
.mark_static() on every TensorAdaptor whose shape the compiled IR depends
on, otherwise the JIT cache silently reused a stale artifact across
different shapes (the original PR #551 fp8gemm OOB bug).  Forgetting the
call brought back the silent-OOB failure mode -- a footgun.

Invert the invariant so the safe behavior is the default:

  * __cache_signature__ now includes the full (shape, stride) of every
    static dim by default.  Two calls with different shapes get distinct
    cache entries unless the caller explicitly opts out.

  * mark_dynamic(dims=None) is the new opt-out: declare that the kernel's
    IR is shape-independent on the listed dims, dropping them from the
    cache key.

  * mark_layout_dynamic auto-mirrors all shape dims into the dynamic set,
    preserving the invariant:
        dim dynamic in memref type  <=>  dim excluded from cache key
    so kernels that legitimately compile-once-run-many-shapes (those that
    already use mark_layout_dynamic) still reuse one compiled artifact.

  * mark_static is retained as a backward-compatible no-op / inverse of
    mark_dynamic so existing call sites keep working.

  * raw_cache_signature is aligned with the new default so raw
    torch.Tensor inputs and from_dlpack-wrapped inputs share entries.

Test: tests/unit/test_tensor_cache_signature.py (21 cases) covers
default per-shape behavior, mark_dynamic per-dim / accumulate / OOB,
mark_static back-compat undo semantics, and the mark_layout_dynamic
invariant.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
coderfeli added a commit that referenced this pull request May 24, 2026
Root cause of the PR #551 fp8gemm silent-OOB regression:
``fx.rocdl.make_buffer_tensor(t, max_size=False)`` in
``python/flydsl/expr/rocdl/universal.py`` auto-derived ``num_records_bytes
= cosize(layout) * elem_bytes`` from the static memref shape and baked the
result into IR.  The JIT cache key does not include shape by default, so a
kernel compiled for shape A was silently reused for shape B with A's
``num_records`` baked in -- truncating OOB reads (returns 0) and dropping
OOB writes.

mark_static (PR #554) is a useful opt-in for callers who *want* shape in
the cache key, but it only fixes the silent OOB if every caller remembers
to use it.  The robust fix is to make the unsafe pattern impossible at
compile time: require an explicit ``num_records_bytes`` whenever
``max_size=False``.

* Add ``num_records_bytes`` parameter to ``make_buffer_tensor``.  When
  ``max_size=False`` and ``num_records_bytes`` is not supplied, raise
  ``ValueError`` with an actionable message pointing callers to either
  pass an explicit byte count (computed from runtime tensor extents) or
  use ``max_size=True`` for the safe coarse path.

* Update all in-tree callers to be explicit:
  - ``examples/04-preshuffle_gemm.py`` (3 callsites, fp16 GEMM)
  - ``kernels/fp8_gemm_utils.py`` ``make_fp8_buffer_tensor`` helper and
    ``StoreC`` class (4 callsites, fp8 GEMM + scales)
  - ``kernels/fp8_gemm_8wave.py`` / ``fp8_gemm_4wave.py`` callers of
    ``make_fp8_buffer_tensor`` (4 callsites)

* Cherry-pick of the parallel ``buffer_ops.create_buffer_resource`` fix
  from commit 202793d (``[Fix] Align buffer-resource num_records with
  logical tensor sizes``) covering the legacy ``create_buffer_resource
  (max_size=False)`` path in 4 MoE/GEMM files (18 callsites).

* New test ``tests/unit/test_make_buffer_tensor_guard.py`` (3 cases)
  validates the guard at trace time under ``COMPILE_ONLY=1``.

Verification:
* tests/unit/test_tensor_cache_signature.py     12 passed
* tests/unit/test_make_buffer_tensor_guard.py    3 passed
* tests/unit/                                  464 passed
* tests/system/                                 20 passed
* tests/kernels/test_preshuffle_gemm.py + test_moe_gemm.py +
  test_moe_blockscale.py                       456 passed
* bash scripts/check_python_style.sh           clean

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
coderfeli added a commit that referenced this pull request May 24, 2026
Re-enable the historical ``cosize(layout) * elem_bytes`` auto-derive path
in ``make_buffer_tensor(max_size=False)`` -- but only when the tensor's
layout has dynamic shape.  ``cosize(layout)`` is then a runtime expression
that adapts to the actual tensor extent, so the descriptor stays correct
across calls with different shapes.

For static-shape layouts the auto-derive would still bake the first
call's shape into IR and silently OOB on cache reuse (the PR #551 fp8gemm
regression mode), so it remains an error with the same actionable message
pointing callers to mark_layout_dynamic / num_records_bytes / max_size=True.

This matches the user's mental model: "if static you can auto-derive,
dynamic raises" -- inverted to "if dynamic you can auto-derive, static
raises", which is the safe formulation (static shape == unsafe to bake).

Existing kernels that already call mark_layout_dynamic on their tensor
adaptors keep their ``make_buffer_tensor(max_size=False)`` callsites
unchanged.  ``test_max_size_false_with_dynamic_layout_auto_derives``
locks in this behavior.

Also add an ``FLYDSL_RUNTIME_ENABLE_CACHE=0`` fixture so the guard
re-fires across test runs (the JIT disk cache would otherwise
short-circuit re-trace and mask the guard on repeated runs).

Verification:
* tests/unit/test_make_buffer_tensor_guard.py        4 passed (incl. new)
* tests/unit/ + tests/system/                      485 passed, 0 failed
* tests/kernels/test_preshuffle_gemm.py + test_moe_gemm.py +
  test_moe_blockscale.py                          456 passed, 0 failed
* examples/04-preshuffle_gemm.py                   correct result

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
coderfeli added a commit that referenced this pull request May 24, 2026
Make TensorAdaptor default to layout-dynamic memref in __init__ so the
generated IR is shape-independent and a single compiled kernel serves
calls with different sizes by construction.  This eliminates the silent-
OOB failure mode that affected PR #551 fp8gemm (static-shape cosize
baked into IR + shape-agnostic JIT cache key reusing across shapes), and
makes the workarounds added in PR #554 (mark_static API,
make_buffer_tensor guard, explicit num_records_bytes on 18+ callsites)
all unnecessary -- the same kernel code is now safe by default.

Extends the per-call CallState fast path to support the dynamic-memref
ABI which carries an extra layout-buffer struct after the data pointer
(packed ``[i32 * rank | i64 * (rank-1)]`` matching the C++
``buildMemRefDesc`` layout).  ``_reusable_slot_spec`` may now return a
list of slot tuples; each slot is either:

  * a scalar slot ``(ctype, extract_returning_value)`` -- the existing
    protocol (e.g. data pointer);
  * a buffer slot ``(ctype_array, extract_in_place(arg, storage))`` --
    new protocol for fixed-size byte buffers, mutated in place via
    ``struct.Struct.pack_into`` on a ``memoryview`` of the storage.

CallState detects the protocol per slot at init time
(``hasattr(storage, 'value')``) and dispatches accordingly on the hot
path.  No allocation per call; the layout buffer is packed with shape /
stride values read directly from torch.Tensor's ``.shape`` /
``.stride()`` (no DLPack export).

Verification:
* tests/unit/ + tests/system/         469 passed, 0 failed
* tests/kernels/test_vec_add + test_preshuffle_gemm + test_moe_gemm +
  test_moe_blockscale                 457 passed, 0 failed
* Cache reuse confirmed: examples/04 unchanged kernel runs across
  M in {2048, 4096, 8192} with a single compiled artifact

Follow-up: PR #554's workarounds (mark_static API, make_buffer_tensor
guard, kernel-side explicit num_records_bytes) can now be reverted as a
separate cleanup since this commit makes them redundant.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderfeli coderfeli merged commit 816490e into ROCm:main May 25, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants