Skip to content

[None][feat] Checkpointing variant of replay for MTP for mamba models#14203

Open
hnover-nv wants to merge 78 commits into
NVIDIA:mainfrom
hnover-nv:mamba_checkpointing_submit
Open

[None][feat] Checkpointing variant of replay for MTP for mamba models#14203
hnover-nv wants to merge 78 commits into
NVIDIA:mainfrom
hnover-nv:mamba_checkpointing_submit

Conversation

@hnover-nv
Copy link
Copy Markdown
Collaborator

@hnover-nv hnover-nv commented May 16, 2026

Description

Summary of Changes
Quick summary, see ideas below for more detail.

We replace old per-step replay with a "checkpointing" replay, where we don't write out mamba state every step. Main changes:

  • Families of tunable kernels for the actual state update, tuned for Nemo v3 Super on B200.
  • Framework adjustments to support the new conditions on mamba cache management reflecting when we write state. Unfortunate switch to double-buffered old_x as Triton can't tell two derived memory addresses are actually the same, leading to a read hazard.
  • New fields on MambaMetadata that track the number of requests in a batch that will want to write checkpoints this step, plus a tensor that stores per-request data for the batch, sorted into writes and nowrites.
  • Removal of benchmark script. It go insanely complicated to support efficient tuning runs. A version may come back some day.

Timing results
These are the time from the start of conv1d to the end of the last replay kernel. Replay state update is broken up into 2 or 3 kernels, and conv1d PDLs into the first one. It's a kernel microbenchmark that tries to simulate hot and cold inputs based on what is coming out of in_proj in the mamba2_mixer and what is not.

batch kernel fp32/RN fp16/SR fp8/SR int8/SR
1 replay 5.85 6.43 6.21 6.79
old replay 7.00 7.32
+16.4% +12.2%
2 replay 6.80 6.93 7.00 7.22
old replay 7.35 7.36
+7.5% +5.8%
4 replay 7.23 7.02 7.08 7.72
old replay 7.82 7.69
+7.5% +8.6%
8 replay 8.09 7.76 7.56 8.29
old replay 9.12 8.88
+11.3% +12.6%
16 replay 9.59 9.17 8.71 10.02
old replay 12.67 11.78
+24.3% +22.2%
32 replay 13.39 12.50 10.83 13.03
old replay 22.15 17.59
+39.5% +28.9%
64 replay 19.16 17.43 15.76 17.84
old replay 26.69 25.20
+28.2% +30.8%
128 replay 29.65 25.50 23.67 27.58
old replay 44.49 40.80
+33.3% +37.5%
256 replay 49.64 42.44 38.77 45.11
old replay 84.85 76.35
+41.5% +44.4%
512 replay 88.46 72.10 68.62 78.06
old replay 162.78 146.12
+45.7% +50.7%
1024 replay 168.68 131.89 123.81 142.73
old replay 320.07 285.84
+47.3% +53.9%

Ideas

Building on #13453, here we take the replay idea a step further, to a "checkpointing" replay. This replaces the old replay version in our code.

This section walks through the ideas, skip below to the tl;dr summary.

To summarize "old" replay, we would save a mamba state every step. It just lagged what would be the true state by one, because we don't know how many to accept. Compared to the prior method of saving every possible state at each step, we saved substantially on memory traffic in the kernel and at step-end when we'd otherwise copy the "winning" state over. And needing to only materialize one state let us save further by skipping a lot of computation, and accelerating what remained with tensor cores.

But there's no reason we have to write even one state every step. We can have an ssm state checkpoint arbitrarily far back, store all the needed inputs since, and reconstruct the state. When we don't accept tokens, we just overwrite the unaccepted ones with the next step's new ones. Of course, in the limit, that just becomes linear-ish attention where you kept the whole KV cache.

How often should we checkpoint? One could try to balance the memory traffic of the replay window with the state. But we're actually not just memory bound. The easiest choice is that our tensor core usage requires us to pad all our dimension T (draft length + 1) operations to at least 16, so that's a natural history size. In the prior PR you can see that the old kernels' runtimes were quite flat with respect to T up to 16, and that had larger T for both state replay and output generation. Here we replay up to 16 from history but only output T.

The algorithm, then is every step, if the number of previously accepted tokens (PNAT)+ T > window size, we can't store the T new tokens so we need to write a checkpoint and restart the history buffer. Otherwise we have room to append.

Already this approach saves time. We don't issue instructions for saving the state, and for quantized, stochastically rounded, or block scaled SSM dstates, we also save on that pre-writing computation. Also, by quantizing less often, we accumulate less quantization error, which means we could quantize more aggressively. For example, perhaps int8 w/ stochastic rounding is viable if it only happens every 8 tokens, but not every token. Speculative decoding already gives some of that benefit, but it depends on acceptance rate. Here, we guarantee 16 - T tokens between checkpoints, and it can be as high as 16.

This approach also unlocks a number of additional speedups:

The Rectangle
Ignoring dA/dt scaling factors, the original replay algorithm is:
last_step_state = very_old_state + (old_X^t @ old_B)
outputs = new_C @ one_back_state^t + (new_C @ new_B^t) @ new_X. # bypass all the draft states, straight to output tokens. That C @ B^t we call the CB matrix, it's TxT square lower triangular, padded into 16x16.
Here old_* are last step's new_*, except we only take the first k of them, the acceptance rate.

All the matrix multiplies have some dimension at T, padded to 16.
But if we aren't writing, then just as we go straight from one_back_state to output tokens, we can actually go straight from very_old_state. If you just plug last_step_state in and expand, you get:
outputs = new_C @ very_old_state^t + ( new_C @ (old_B cat new_B)^t) @ (old_X + new_X)
Now the CB matrix is a rectangular lower triangular (PNAT + T) x T, and this fits in our window size (16) exactly when PNAT + T <= window_size. Oh look, our non-checkpointing condition!

The old state computation is one of the largest matmuls, so this is a substantial savings. But due to implementation details it is not always faster at small batch sizes.

constexpr-specialized Persistent Kernels
For a given request, sometimes we now write a checkpoint and sometimes don't. And when we don't we may even use a different algorithm. We can specialize nowrite algorithm choice as a constexpr we tune based on batch/dtype/rounding mode, but not write vs nowrite. Or can we? If we do specialize them, the compiler can take a lot of advantage. Nowrite should need less live state, even more so when it uses rectangle or writing uses stochastic rounding or quantizing. The compiler generates very different code. We also build on our prior PR where we tune parameters like M (slice of a head_dim done by a block), number of warps, etc. Separate kernels lets us tune these knobs for each case.

Ideally we'd launch one correctly sized grid for all the write requests, and one for the nowrite. But we're using cuda graphs in pytorch, so we can't vary that. Instead, we need to launch enough CTAs to cover all requests as write, and all as nowrite. Although the kernel can early out on requests that aren't its type, that overhead made specialized kernels uncompetative.

Except, we can use persistent kernels. As long as we have enough work to fill all the SMs a few times over, then there's no real overhead. And if we add some secondary data structures to present the requests partitioned by write/nowrite, the kernels can just loop over their portion. If we pack other needed data into that structure, it's not even added latency.

This makes specialized kernels a huge win at medium and large batches. At smaller batches a "dynamic"(i.e. unspecialized, non-constexpr) kernel still wins, probably because we can't easily optimally PDL in PyTorch between the precomputation kernel and the two main kernels. And we didn't have time to explore emulating this with atomics. But the dynamic kernel also wins by being persistent. Similarly constrained by time, the precompute kernel is non-persistent dynamic, whereas at least one of persistent or specialized would surely help.

Other tweaks

  • Optionally support TMA loads and stores of mamba state, as a specializable knob per kernel and scenario.
  • Use of tensor-ization in triton to replace some for loops, notably in the heads per block part of precompute kernel. This both removed some undefined behavior and improves performance.
  • Sadly, we had to double-buffer old_x. In the main kernel, we read old x from the history buffer and new x from inputs from conv1d. Then on nowrite requests we append new x to old x, but on write requests it puts new x at the start of the history buffer, on top of old x. This is the same block and should be even threads doing the work, so should be safe, but triton does not realize they're the same, so we have to double buffer. In replay, where old_x and new_x are the same size, it was presumably more clear to the compiler. Or a latent bug, hard to know! If space is an issue, we could replace this with a circular buffer of size window + T, instead of today's 2 * window.

Tunings

We implemented a huge number of tunable knobs, and can pick which knobs to use based on batch size, dtype, and whether we are doing stochastic rounding. But how to pick the best? We have a mode choice of a single dynamic or a pair of specialized main kernels. Then each of these offers many tuning knobs. When we pick the specialized kernels, ther are a total of of 17 active knobs (more are in the code, but a few either don't compile or weren't tuned as they seemed dead), a mixture of boolean and integer. This resulted in over 61 billion possible combinations. While we'd hoped to be able to drop a lot of them, exhaustive searching of select subspaces showed none of the 17 were always dead. Obviously we can't exhaustively search the whole space for even one input cofiguration, let alone for all supported batch/dtype/rounding combinations. So we implemented a multi-process block coordinate descent optimization, with some tweaks to deal with cross-GPU performance drift. We also had to invest considerable effort into the actual benchmark runner: in-process CUPTI timing vs cuda events, multi-process for compilation and cupti analysis, optimization of cuda graph instantiation vs lauch overhead due to GIL contention.

These are honestly giant vibe coded messes, so we're removing the benchmarker from the repo and not attempting to submit the optimizer. No other human should have to review such things. But there's no arguing with the actual timing results.

Test Coverage

I've updated the replay unit tests to the new paradigm and added some more coverage scenarios, including better coverage of the rounding. Also added unit tests to cache manager for modified PNAT/double buffer behavior, and mamba metadata to test its new functionality of geneating the replay work items, needed for efficient persistent specialized kernel pairs.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced replay state update support for improved speculative decoding performance.
    • Added replay work indexing and metadata tracking for state management.
  • Tests

    • Expanded test coverage for replay state updates and quantization-aware validation.
    • Added comprehensive testing for stochastic rounding correctness and multi-head scenarios.
  • Chores

    • Removed benchmark utility script.

Review Change Stack

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from c1be203 to 32ebf3f Compare May 16, 2026 05:43
@hnover-nv hnover-nv changed the title Mamba checkpointing submit @coderabbit May 16, 2026
@hnover-nv hnover-nv changed the title @coderabbit @coderabbitai May 16, 2026
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48673 [ run ] triggered by Bot. Commit: 32ebf3f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48673 [ run ] completed with state FAILURE. Commit: 32ebf3f
/LLM/main/L0_MergeRequest_PR pipeline #38453 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from 32ebf3f to 1eae71e Compare May 17, 2026 01:44
@hnover-nv hnover-nv changed the title @coderabbitai [None][feat] Checkpointing variant of replay for MTP for mamba models May 17, 2026
@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch 2 times, most recently from a1f2532 to bab26f2 Compare May 17, 2026 21:09
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48787 [ run ] triggered by Bot. Commit: bab26f2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48787 [ run ] completed with state SUCCESS. Commit: bab26f2
/LLM/main/L0_MergeRequest_PR pipeline #38551 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch 4 times, most recently from 8f1d0aa to 24f3dda Compare May 21, 2026 02:06
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49541 [ run ] triggered by Bot. Commit: 24f3dda Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49541 [ run ] completed with state SUCCESS. Commit: 24f3dda
/LLM/main/L0_MergeRequest_PR pipeline #39171 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch 4 times, most recently from 724fdf8 to 5c44835 Compare May 21, 2026 18:10
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49775 [ run ] triggered by Bot. Commit: 5c44835 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49775 [ run ] completed with state SUCCESS. Commit: 5c44835
/LLM/main/L0_MergeRequest_PR pipeline #39373 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from 5c44835 to 5394fc9 Compare May 22, 2026 04:17
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

hnover-nv added 22 commits May 27, 2026 16:37
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
The replay-write code path had a same-thread write-then-read race on the
old_x cache.  Inside _persistent_main_impl, the kernel loads old_x[0..PNAT)
into registers (line ~1276) for the state += dot(old_x_all, dB_scaled)
recurrence, then later stores the fresh x_all into old_x[write_offset..+T)
(line ~1466).  In the write_checkpoint=True case, write_offset=0 so the
store range [0, T) overlaps the load range [0, PNAT) on the SAME buffer
(old_x was single-buffered, unlike old_dt/old_dA_cumsum/old_B).  Triton's
alias analysis does not unify the offset expressions (offs_window vs
(write_offset + offs_t)), so the compiler is free to issue the store before
the load is fully consumed, corrupting the dot's input and producing
per-head cumulative output errors from some t onward — flaky failure rate
~40-45% on persistent_main-write-T55-fp16-16-64-128-1 in tight loops.

Fix: double-buffer old_x, matching the existing pattern for old_B / old_dt /
old_dA_cumsum.  Read from active_buf, write to write_buf (= 1 - active_buf
when is_write; = active_buf otherwise — same convention as the others).
Eliminates the same-address race entirely.

Changes:
  - Add stride_old_x_dbuf to all 3 kernel signatures (_persistent_main_impl,
    _persistent_rectangle_impl, _persistent_main_kernel)
  - Split old_x_base into old_x_read_base / old_x_write_base in both main
    impls, indexed by active_buf and write_buf respectively
  - Wrapper: pass 5-stride tuple to all dispatcher sites; update old_x
    shape assertion and docstring (cache, max_window, nheads, dim) -> (cache,
    2, max_window, nheads, dim); max_window now derived from shape[2]
  - Tests: update old_x allocation in all 5 test functions to add dbuf dim;
    update per-slot fills to write to the active buffer (matching old_B/
    old_dt/old_dA_cumsum); update inspection slicing in
    test_replay_selective_state_update and test_replay_heads_per_block_multistep
    to assert against the per-buffer expected values
  - Bench: same alloc shape update

Tests: 1215/1215 pass for the full mamba slim test suite; 1200/1200 pass
for the previously-flaky persistent_main-write-T55-fp16 cell across 200
reps x 6 dtypes (vs ~40-45% failure rate without the fix).

Perf (fp16/SR, all batches 1..1024, best-(HPB,pW) per cell):
  - Write path (prev_k=16): consistent +0.4 to +1.8% slower (max +1.79%
    at b=1024) - the real DB cost from doubled old_x memory footprint and
    different read/write buffers.
  - Nowrite path (prev_k=10): mixed direction, mostly ~0% with a few
    -4 to +3% outliers driven by autotune winners shifting around stride
    changes; not a real cost.

Compared to the alternative false-math-dependency workaround tried earlier,
this is roughly half the write-path cost (~+1% vs ~+2%) and is robust to
future Triton compiler changes.

Cross-module follow-up: mamba_cache_manager.py:1603 still allocates old_x
without a dbuf dim - that needs to gain a '2' axis in a coordinated update
with this commit (production caller will break otherwise).


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Replace triton helper softplus with numerically stable version.

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
@hnover-nv hnover-nv force-pushed the mamba_checkpointing_submit branch from a4f552c to 3a55712 Compare May 28, 2026 20:11
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50880 [ run ] triggered by Bot. Commit: 3a55712 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #50880 [ run ] completed with state SUCCESS. Commit: 3a55712
/LLM/main/L0_MergeRequest_PR pipeline #40347 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

aleozlx pushed a commit to flashinfer-ai/flashinfer that referenced this pull request May 29, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Fixes a per-launch non-determinism bug in the CUDA `checkpointing_ssu`
kernel's `must_checkpoint=True` code path, caused by a missing
cross-warp barrier between `load_state_per_warp` (Phase 0) and
`replay_state_mma` (Phase 1).

**Root cause.** `load_state_per_warp` partitions M=DIM across warps —
warp W loads rows `[W*D_PER_CTA/4, (W+1)*D_PER_CTA/4)` of `smem.state`
via cp.async. `replay_state_mma` uses a `Layout<_1, _4>` tiled MMA (1
warp on M, 4 warps on N), so every warp reads the **full M=DIM extent**
of `smem.state` when forming its `frag_h` initial value. The `load_data`
tail used only `__syncwarp()` + `__pipeline_wait_prior(0)`, which
establishes visibility *within* a single warp but not *across* warps.
Result: warp 0's replay reads rows that warps 1/2/3 may not have
committed yet, picking up partial/stale smem and producing different
output every launch.

**Symptom (pre-fix).** Hashing the post-kernel `state`, `state_scale`,
and `out` across 5 launches with bit-identical inputs gave 5 distinct
hashes per config. `state_diff` (max abs delta between launches) up to
~13 in fp16 at `batch=99`, scaling roughly with how much `smem.state`
each warp's MMA touched. Sub-ULP for small batches/heads where the race
rarely fires before the consumer thread arrives — i.e. genuine race
timing, not arithmetic noise.

**Fix.** Hoist a single `__syncthreads()` from inside `ssu_nocheckpoint`
to the dispatch site in `checkpointing_ssu_kernel`, right before the `if
(must_checkpoint)` branch. One barrier now covers cross-warp visibility
for everything both branches consume:
- (a) load_data's per-warp-partitioned `smem.state`
- (b) `smem.x` (warp 2-loaded), `smem.z` (warp 3-loaded)
- (c) `compute_CB_scaled_2warp` writes (warps 0,1)
- (d) `compute_CB_old_2warp` writes (warps 2,3, no-checkpoint path)

Net barrier-count delta is **zero** — the `__syncthreads()` previously
inside `ssu_nocheckpoint` is just relocated, and `ssu_checkpoint` was
missing one. No new compilation flags, no smem layout changes, no
perf-relevant code path touched.

**Scope.** Only the generic kernel (`kernel_checkpointing_ssu.cuh`) is
affected. The 8-bit kernel (`kernel_checkpointing_ssu_8bit.cuh`, used
for `int8`/`fp8_e4m3fn` state) goes through a separate code path and was
already deterministic — verified empirically across `mw ∈ {8, 16}, np ∈
{8, 16}`, `philox ∈ {0, 5}`, all `prev_k` values.

## 🔍 Related Issues

Discovered while investigating intermittent failures in the batch-sweep
parity test added in #3431. The race is in the existing kernel — not
introduced by that PR — but the new sweep stresses it across enough
`(batch, heads_per_group)` configurations to expose it reliably. Landing
this hotfix should let #3431's CUDA-vs-Triton parity test go green.

Adjacent prior art: NVIDIA/TensorRT-LLM#14203 fixes a structurally
similar (but mechanically distinct) bug in the Triton replay kernel —
Triton alias analysis reordering writes ahead of reads on a
single-buffered `old_x`. Our CUDA kernel doesn't have that issue because
cp.async + explicit barriers preserve in-thread ordering; our bug is
purely cross-warp visibility.

## 🚀 Pull Request Checklist

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

## 🧪 Tests

- [x] Tests have been added:
`tests/mamba/test_checkpointing_ssu.py::test_checkpointing_ssu_determinism_across_launches`
— runs the kernel 5× with bit-identical inputs and asserts that the
`state`, `state_scale` (quantized path only), and `out` tensors hash to
a single value across all launches. Four parametrizations:
  - `fp16-no_checkpoint` (`prev_k=4`, no state write)
  - `fp16-checkpoint` (`prev_k=12`, state writeback)
  - `fp8-no_checkpoint`
  - `fp8-checkpoint`
- [x] All tests are passing (`unittest`, etc.). *(Local: full
`tests/mamba/test_checkpointing_ssu.py` — please trigger CI.)*

## Reviewer Notes

- **Why hash-equality instead of `assert_close(rtol, atol)`?** For
*self*-determinism (one kernel against itself, same inputs, same
compiled binary) there should be zero numerical noise — any difference
indicates a real race / uninitialized read / nondeterministic atomic. A
tolerance-based check would silently swallow sub-ULP variance that still
represents a correctness bug (some pre-fix configs had `state_diff ≈
1e-3` — small in magnitude but the symptom of the same race). See the
test docstring for details.
- **Performance impact**: a single extra `__syncthreads()` per CTA on
the dispatch path. Branch divergence is unchanged — `must_checkpoint` is
CTA-uniform (derived from broadcast `prev_k` + compile-time `NPREDICTED`
+ `MAX_WINDOW`).
- **Follow-up opportunity (not in this PR)**: the redundant per-warp
cp.async loads of `old_x` / `old_B` / scalar `old_dt` / `old_cumAdt` in
`load_data` were a defensive measure for the missing-sync regime. With
the dispatch-site barrier in place, these could be partitioned across
all 128 threads to reclaim the 3× redundant cp.async issue cost. Out of
scope for the hotfix.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Unified GPU kernel synchronization to ensure correct cross-warp
shared-memory visibility, preventing inconsistent outcomes between
checkpointing and non-checkpointing executions.

* **Tests**
* Added a determinism regression test that validates bit-exact,
launch-to-launch consistency across multiple runs, data types, and
checkpointing configurations.

<!-- review_stack_entry_start -->

[![Review Change
Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3439?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)

<!-- review_stack_entry_end -->
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
aleozlx pushed a commit to flashinfer-ai/flashinfer that referenced this pull request May 29, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Fixes a per-launch non-determinism bug in the CUDA `checkpointing_ssu`
kernel's `must_checkpoint=True` code path, caused by a missing
cross-warp barrier between `load_state_per_warp` (Phase 0) and
`replay_state_mma` (Phase 1).

**Root cause.** `load_state_per_warp` partitions M=DIM across warps —
warp W loads rows `[W*D_PER_CTA/4, (W+1)*D_PER_CTA/4)` of `smem.state`
via cp.async. `replay_state_mma` uses a `Layout<_1, _4>` tiled MMA (1
warp on M, 4 warps on N), so every warp reads the **full M=DIM extent**
of `smem.state` when forming its `frag_h` initial value. The `load_data`
tail used only `__syncwarp()` + `__pipeline_wait_prior(0)`, which
establishes visibility *within* a single warp but not *across* warps.
Result: warp 0's replay reads rows that warps 1/2/3 may not have
committed yet, picking up partial/stale smem and producing different
output every launch.

**Symptom (pre-fix).** Hashing the post-kernel `state`, `state_scale`,
and `out` across 5 launches with bit-identical inputs gave 5 distinct
hashes per config. `state_diff` (max abs delta between launches) up to
~13 in fp16 at `batch=99`, scaling roughly with how much `smem.state`
each warp's MMA touched. Sub-ULP for small batches/heads where the race
rarely fires before the consumer thread arrives — i.e. genuine race
timing, not arithmetic noise.

**Fix.** Hoist a single `__syncthreads()` from inside `ssu_nocheckpoint`
to the dispatch site in `checkpointing_ssu_kernel`, right before the `if
(must_checkpoint)` branch. One barrier now covers cross-warp visibility
for everything both branches consume:
- (a) load_data's per-warp-partitioned `smem.state`
- (b) `smem.x` (warp 2-loaded), `smem.z` (warp 3-loaded)
- (c) `compute_CB_scaled_2warp` writes (warps 0,1)
- (d) `compute_CB_old_2warp` writes (warps 2,3, no-checkpoint path)

Net barrier-count delta is **zero** — the `__syncthreads()` previously
inside `ssu_nocheckpoint` is just relocated, and `ssu_checkpoint` was
missing one. No new compilation flags, no smem layout changes, no
perf-relevant code path touched.

**Scope.** Only the generic kernel (`kernel_checkpointing_ssu.cuh`) is
affected. The 8-bit kernel (`kernel_checkpointing_ssu_8bit.cuh`, used
for `int8`/`fp8_e4m3fn` state) goes through a separate code path and was
already deterministic — verified empirically across `mw ∈ {8, 16}, np ∈
{8, 16}`, `philox ∈ {0, 5}`, all `prev_k` values.

## 🔍 Related Issues

Discovered while investigating intermittent failures in the batch-sweep
parity test added in #3431. The race is in the existing kernel — not
introduced by that PR — but the new sweep stresses it across enough
`(batch, heads_per_group)` configurations to expose it reliably. Landing
this hotfix should let #3431's CUDA-vs-Triton parity test go green.

Adjacent prior art: NVIDIA/TensorRT-LLM#14203 fixes a structurally
similar (but mechanically distinct) bug in the Triton replay kernel —
Triton alias analysis reordering writes ahead of reads on a
single-buffered `old_x`. Our CUDA kernel doesn't have that issue because
cp.async + explicit barriers preserve in-thread ordering; our bug is
purely cross-warp visibility.

## 🚀 Pull Request Checklist

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

## 🧪 Tests

- [x] Tests have been added:
`tests/mamba/test_checkpointing_ssu.py::test_checkpointing_ssu_determinism_across_launches`
— runs the kernel 5× with bit-identical inputs and asserts that the
`state`, `state_scale` (quantized path only), and `out` tensors hash to
a single value across all launches. Four parametrizations:
  - `fp16-no_checkpoint` (`prev_k=4`, no state write)
  - `fp16-checkpoint` (`prev_k=12`, state writeback)
  - `fp8-no_checkpoint`
  - `fp8-checkpoint`
- [x] All tests are passing (`unittest`, etc.). *(Local: full
`tests/mamba/test_checkpointing_ssu.py` — please trigger CI.)*

## Reviewer Notes

- **Why hash-equality instead of `assert_close(rtol, atol)`?** For
*self*-determinism (one kernel against itself, same inputs, same
compiled binary) there should be zero numerical noise — any difference
indicates a real race / uninitialized read / nondeterministic atomic. A
tolerance-based check would silently swallow sub-ULP variance that still
represents a correctness bug (some pre-fix configs had `state_diff ≈
1e-3` — small in magnitude but the symptom of the same race). See the
test docstring for details.
- **Performance impact**: a single extra `__syncthreads()` per CTA on
the dispatch path. Branch divergence is unchanged — `must_checkpoint` is
CTA-uniform (derived from broadcast `prev_k` + compile-time `NPREDICTED`
+ `MAX_WINDOW`).
- **Follow-up opportunity (not in this PR)**: the redundant per-warp
cp.async loads of `old_x` / `old_B` / scalar `old_dt` / `old_cumAdt` in
`load_data` were a defensive measure for the missing-sync regime. With
the dispatch-site barrier in place, these could be partitioned across
all 128 threads to reclaim the 3× redundant cp.async issue cost. Out of
scope for the hotfix.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Unified GPU kernel synchronization to ensure correct cross-warp
shared-memory visibility, preventing inconsistent outcomes between
checkpointing and non-checkpointing executions.

* **Tests**
* Added a determinism regression test that validates bit-exact,
launch-to-launch consistency across multiple runs, data types, and
checkpointing configurations.

<!-- review_stack_entry_start -->

[![Review Change
Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3439?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)

<!-- review_stack_entry_end -->
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants