[None][feat] Checkpointing variant of replay for MTP for mamba models#14203
[None][feat] Checkpointing variant of replay for MTP for mamba models#14203hnover-nv wants to merge 78 commits into
Conversation
c1be203 to
32ebf3f
Compare
|
/bot run |
|
PR_Github #48673 [ run ] triggered by Bot. Commit: |
|
PR_Github #48673 [ run ] completed with state
|
32ebf3f to
1eae71e
Compare
a1f2532 to
bab26f2
Compare
|
/bot run |
|
PR_Github #48787 [ run ] triggered by Bot. Commit: |
|
PR_Github #48787 [ run ] completed with state |
8f1d0aa to
24f3dda
Compare
|
/bot run |
|
PR_Github #49541 [ run ] triggered by Bot. Commit: |
|
PR_Github #49541 [ run ] completed with state
|
724fdf8 to
5c44835
Compare
|
/bot run |
|
PR_Github #49775 [ run ] triggered by Bot. Commit: |
|
PR_Github #49775 [ run ] completed with state
|
5c44835 to
5394fc9
Compare
|
/bot run |
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
The replay-write code path had a same-thread write-then-read race on the
old_x cache. Inside _persistent_main_impl, the kernel loads old_x[0..PNAT)
into registers (line ~1276) for the state += dot(old_x_all, dB_scaled)
recurrence, then later stores the fresh x_all into old_x[write_offset..+T)
(line ~1466). In the write_checkpoint=True case, write_offset=0 so the
store range [0, T) overlaps the load range [0, PNAT) on the SAME buffer
(old_x was single-buffered, unlike old_dt/old_dA_cumsum/old_B). Triton's
alias analysis does not unify the offset expressions (offs_window vs
(write_offset + offs_t)), so the compiler is free to issue the store before
the load is fully consumed, corrupting the dot's input and producing
per-head cumulative output errors from some t onward — flaky failure rate
~40-45% on persistent_main-write-T55-fp16-16-64-128-1 in tight loops.
Fix: double-buffer old_x, matching the existing pattern for old_B / old_dt /
old_dA_cumsum. Read from active_buf, write to write_buf (= 1 - active_buf
when is_write; = active_buf otherwise — same convention as the others).
Eliminates the same-address race entirely.
Changes:
- Add stride_old_x_dbuf to all 3 kernel signatures (_persistent_main_impl,
_persistent_rectangle_impl, _persistent_main_kernel)
- Split old_x_base into old_x_read_base / old_x_write_base in both main
impls, indexed by active_buf and write_buf respectively
- Wrapper: pass 5-stride tuple to all dispatcher sites; update old_x
shape assertion and docstring (cache, max_window, nheads, dim) -> (cache,
2, max_window, nheads, dim); max_window now derived from shape[2]
- Tests: update old_x allocation in all 5 test functions to add dbuf dim;
update per-slot fills to write to the active buffer (matching old_B/
old_dt/old_dA_cumsum); update inspection slicing in
test_replay_selective_state_update and test_replay_heads_per_block_multistep
to assert against the per-buffer expected values
- Bench: same alloc shape update
Tests: 1215/1215 pass for the full mamba slim test suite; 1200/1200 pass
for the previously-flaky persistent_main-write-T55-fp16 cell across 200
reps x 6 dtypes (vs ~40-45% failure rate without the fix).
Perf (fp16/SR, all batches 1..1024, best-(HPB,pW) per cell):
- Write path (prev_k=16): consistent +0.4 to +1.8% slower (max +1.79%
at b=1024) - the real DB cost from doubled old_x memory footprint and
different read/write buffers.
- Nowrite path (prev_k=10): mixed direction, mostly ~0% with a few
-4 to +3% outliers driven by autotune winners shifting around stride
changes; not a real cost.
Compared to the alternative false-math-dependency workaround tried earlier,
this is roughly half the write-path cost (~+1% vs ~+2%) and is robust to
future Triton compiler changes.
Cross-module follow-up: mamba_cache_manager.py:1603 still allocates old_x
without a dbuf dim - that needs to gain a '2' axis in a coordinated update
with this commit (production caller will break otherwise).
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Replace triton helper softplus with numerically stable version. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
a4f552c to
3a55712
Compare
|
/bot run |
|
PR_Github #50880 [ run ] triggered by Bot. Commit: |
|
PR_Github #50880 [ run ] completed with state |
<!-- .github/pull_request_template.md -->
## 📌 Description
Fixes a per-launch non-determinism bug in the CUDA `checkpointing_ssu`
kernel's `must_checkpoint=True` code path, caused by a missing
cross-warp barrier between `load_state_per_warp` (Phase 0) and
`replay_state_mma` (Phase 1).
**Root cause.** `load_state_per_warp` partitions M=DIM across warps —
warp W loads rows `[W*D_PER_CTA/4, (W+1)*D_PER_CTA/4)` of `smem.state`
via cp.async. `replay_state_mma` uses a `Layout<_1, _4>` tiled MMA (1
warp on M, 4 warps on N), so every warp reads the **full M=DIM extent**
of `smem.state` when forming its `frag_h` initial value. The `load_data`
tail used only `__syncwarp()` + `__pipeline_wait_prior(0)`, which
establishes visibility *within* a single warp but not *across* warps.
Result: warp 0's replay reads rows that warps 1/2/3 may not have
committed yet, picking up partial/stale smem and producing different
output every launch.
**Symptom (pre-fix).** Hashing the post-kernel `state`, `state_scale`,
and `out` across 5 launches with bit-identical inputs gave 5 distinct
hashes per config. `state_diff` (max abs delta between launches) up to
~13 in fp16 at `batch=99`, scaling roughly with how much `smem.state`
each warp's MMA touched. Sub-ULP for small batches/heads where the race
rarely fires before the consumer thread arrives — i.e. genuine race
timing, not arithmetic noise.
**Fix.** Hoist a single `__syncthreads()` from inside `ssu_nocheckpoint`
to the dispatch site in `checkpointing_ssu_kernel`, right before the `if
(must_checkpoint)` branch. One barrier now covers cross-warp visibility
for everything both branches consume:
- (a) load_data's per-warp-partitioned `smem.state`
- (b) `smem.x` (warp 2-loaded), `smem.z` (warp 3-loaded)
- (c) `compute_CB_scaled_2warp` writes (warps 0,1)
- (d) `compute_CB_old_2warp` writes (warps 2,3, no-checkpoint path)
Net barrier-count delta is **zero** — the `__syncthreads()` previously
inside `ssu_nocheckpoint` is just relocated, and `ssu_checkpoint` was
missing one. No new compilation flags, no smem layout changes, no
perf-relevant code path touched.
**Scope.** Only the generic kernel (`kernel_checkpointing_ssu.cuh`) is
affected. The 8-bit kernel (`kernel_checkpointing_ssu_8bit.cuh`, used
for `int8`/`fp8_e4m3fn` state) goes through a separate code path and was
already deterministic — verified empirically across `mw ∈ {8, 16}, np ∈
{8, 16}`, `philox ∈ {0, 5}`, all `prev_k` values.
## 🔍 Related Issues
Discovered while investigating intermittent failures in the batch-sweep
parity test added in #3431. The race is in the existing kernel — not
introduced by that PR — but the new sweep stresses it across enough
`(batch, heads_per_group)` configurations to expose it reliably. Landing
this hotfix should let #3431's CUDA-vs-Triton parity test go green.
Adjacent prior art: NVIDIA/TensorRT-LLM#14203 fixes a structurally
similar (but mechanically distinct) bug in the Triton replay kernel —
Triton alias analysis reordering writes ahead of reads on a
single-buffered `old_x`. Our CUDA kernel doesn't have that issue because
cp.async + explicit barriers preserve in-thread ordering; our bug is
purely cross-warp visibility.
## 🚀 Pull Request Checklist
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
## 🧪 Tests
- [x] Tests have been added:
`tests/mamba/test_checkpointing_ssu.py::test_checkpointing_ssu_determinism_across_launches`
— runs the kernel 5× with bit-identical inputs and asserts that the
`state`, `state_scale` (quantized path only), and `out` tensors hash to
a single value across all launches. Four parametrizations:
- `fp16-no_checkpoint` (`prev_k=4`, no state write)
- `fp16-checkpoint` (`prev_k=12`, state writeback)
- `fp8-no_checkpoint`
- `fp8-checkpoint`
- [x] All tests are passing (`unittest`, etc.). *(Local: full
`tests/mamba/test_checkpointing_ssu.py` — please trigger CI.)*
## Reviewer Notes
- **Why hash-equality instead of `assert_close(rtol, atol)`?** For
*self*-determinism (one kernel against itself, same inputs, same
compiled binary) there should be zero numerical noise — any difference
indicates a real race / uninitialized read / nondeterministic atomic. A
tolerance-based check would silently swallow sub-ULP variance that still
represents a correctness bug (some pre-fix configs had `state_diff ≈
1e-3` — small in magnitude but the symptom of the same race). See the
test docstring for details.
- **Performance impact**: a single extra `__syncthreads()` per CTA on
the dispatch path. Branch divergence is unchanged — `must_checkpoint` is
CTA-uniform (derived from broadcast `prev_k` + compile-time `NPREDICTED`
+ `MAX_WINDOW`).
- **Follow-up opportunity (not in this PR)**: the redundant per-warp
cp.async loads of `old_x` / `old_B` / scalar `old_dt` / `old_cumAdt` in
`load_data` were a defensive measure for the missing-sync regime. With
the dispatch-site barrier in place, these could be partitioned across
all 128 threads to reclaim the 3× redundant cp.async issue cost. Out of
scope for the hotfix.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Unified GPU kernel synchronization to ensure correct cross-warp
shared-memory visibility, preventing inconsistent outcomes between
checkpointing and non-checkpointing executions.
* **Tests**
* Added a determinism regression test that validates bit-exact,
launch-to-launch consistency across multiple runs, data types, and
checkpointing configurations.
<!-- review_stack_entry_start -->
[](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3439?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)
<!-- review_stack_entry_end -->
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md -->
## 📌 Description
Fixes a per-launch non-determinism bug in the CUDA `checkpointing_ssu`
kernel's `must_checkpoint=True` code path, caused by a missing
cross-warp barrier between `load_state_per_warp` (Phase 0) and
`replay_state_mma` (Phase 1).
**Root cause.** `load_state_per_warp` partitions M=DIM across warps —
warp W loads rows `[W*D_PER_CTA/4, (W+1)*D_PER_CTA/4)` of `smem.state`
via cp.async. `replay_state_mma` uses a `Layout<_1, _4>` tiled MMA (1
warp on M, 4 warps on N), so every warp reads the **full M=DIM extent**
of `smem.state` when forming its `frag_h` initial value. The `load_data`
tail used only `__syncwarp()` + `__pipeline_wait_prior(0)`, which
establishes visibility *within* a single warp but not *across* warps.
Result: warp 0's replay reads rows that warps 1/2/3 may not have
committed yet, picking up partial/stale smem and producing different
output every launch.
**Symptom (pre-fix).** Hashing the post-kernel `state`, `state_scale`,
and `out` across 5 launches with bit-identical inputs gave 5 distinct
hashes per config. `state_diff` (max abs delta between launches) up to
~13 in fp16 at `batch=99`, scaling roughly with how much `smem.state`
each warp's MMA touched. Sub-ULP for small batches/heads where the race
rarely fires before the consumer thread arrives — i.e. genuine race
timing, not arithmetic noise.
**Fix.** Hoist a single `__syncthreads()` from inside `ssu_nocheckpoint`
to the dispatch site in `checkpointing_ssu_kernel`, right before the `if
(must_checkpoint)` branch. One barrier now covers cross-warp visibility
for everything both branches consume:
- (a) load_data's per-warp-partitioned `smem.state`
- (b) `smem.x` (warp 2-loaded), `smem.z` (warp 3-loaded)
- (c) `compute_CB_scaled_2warp` writes (warps 0,1)
- (d) `compute_CB_old_2warp` writes (warps 2,3, no-checkpoint path)
Net barrier-count delta is **zero** — the `__syncthreads()` previously
inside `ssu_nocheckpoint` is just relocated, and `ssu_checkpoint` was
missing one. No new compilation flags, no smem layout changes, no
perf-relevant code path touched.
**Scope.** Only the generic kernel (`kernel_checkpointing_ssu.cuh`) is
affected. The 8-bit kernel (`kernel_checkpointing_ssu_8bit.cuh`, used
for `int8`/`fp8_e4m3fn` state) goes through a separate code path and was
already deterministic — verified empirically across `mw ∈ {8, 16}, np ∈
{8, 16}`, `philox ∈ {0, 5}`, all `prev_k` values.
## 🔍 Related Issues
Discovered while investigating intermittent failures in the batch-sweep
parity test added in #3431. The race is in the existing kernel — not
introduced by that PR — but the new sweep stresses it across enough
`(batch, heads_per_group)` configurations to expose it reliably. Landing
this hotfix should let #3431's CUDA-vs-Triton parity test go green.
Adjacent prior art: NVIDIA/TensorRT-LLM#14203 fixes a structurally
similar (but mechanically distinct) bug in the Triton replay kernel —
Triton alias analysis reordering writes ahead of reads on a
single-buffered `old_x`. Our CUDA kernel doesn't have that issue because
cp.async + explicit barriers preserve in-thread ordering; our bug is
purely cross-warp visibility.
## 🚀 Pull Request Checklist
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
## 🧪 Tests
- [x] Tests have been added:
`tests/mamba/test_checkpointing_ssu.py::test_checkpointing_ssu_determinism_across_launches`
— runs the kernel 5× with bit-identical inputs and asserts that the
`state`, `state_scale` (quantized path only), and `out` tensors hash to
a single value across all launches. Four parametrizations:
- `fp16-no_checkpoint` (`prev_k=4`, no state write)
- `fp16-checkpoint` (`prev_k=12`, state writeback)
- `fp8-no_checkpoint`
- `fp8-checkpoint`
- [x] All tests are passing (`unittest`, etc.). *(Local: full
`tests/mamba/test_checkpointing_ssu.py` — please trigger CI.)*
## Reviewer Notes
- **Why hash-equality instead of `assert_close(rtol, atol)`?** For
*self*-determinism (one kernel against itself, same inputs, same
compiled binary) there should be zero numerical noise — any difference
indicates a real race / uninitialized read / nondeterministic atomic. A
tolerance-based check would silently swallow sub-ULP variance that still
represents a correctness bug (some pre-fix configs had `state_diff ≈
1e-3` — small in magnitude but the symptom of the same race). See the
test docstring for details.
- **Performance impact**: a single extra `__syncthreads()` per CTA on
the dispatch path. Branch divergence is unchanged — `must_checkpoint` is
CTA-uniform (derived from broadcast `prev_k` + compile-time `NPREDICTED`
+ `MAX_WINDOW`).
- **Follow-up opportunity (not in this PR)**: the redundant per-warp
cp.async loads of `old_x` / `old_B` / scalar `old_dt` / `old_cumAdt` in
`load_data` were a defensive measure for the missing-sync regime. With
the dispatch-site barrier in place, these could be partitioned across
all 128 threads to reclaim the 3× redundant cp.async issue cost. Out of
scope for the hotfix.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Unified GPU kernel synchronization to ensure correct cross-warp
shared-memory visibility, preventing inconsistent outcomes between
checkpointing and non-checkpointing executions.
* **Tests**
* Added a determinism regression test that validates bit-exact,
launch-to-launch consistency across multiple runs, data types, and
checkpointing configurations.
<!-- review_stack_entry_start -->
[](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3439?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)
<!-- review_stack_entry_end -->
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Description
Summary of Changes
Quick summary, see ideas below for more detail.
We replace old per-step replay with a "checkpointing" replay, where we don't write out mamba state every step. Main changes:
Timing results
These are the time from the start of conv1d to the end of the last replay kernel. Replay state update is broken up into 2 or 3 kernels, and conv1d PDLs into the first one. It's a kernel microbenchmark that tries to simulate hot and cold inputs based on what is coming out of in_proj in the mamba2_mixer and what is not.
Ideas
Building on #13453, here we take the replay idea a step further, to a "checkpointing" replay. This replaces the old replay version in our code.
This section walks through the ideas, skip below to the tl;dr summary.
To summarize "old" replay, we would save a mamba state every step. It just lagged what would be the true state by one, because we don't know how many to accept. Compared to the prior method of saving every possible state at each step, we saved substantially on memory traffic in the kernel and at step-end when we'd otherwise copy the "winning" state over. And needing to only materialize one state let us save further by skipping a lot of computation, and accelerating what remained with tensor cores.
But there's no reason we have to write even one state every step. We can have an ssm state checkpoint arbitrarily far back, store all the needed inputs since, and reconstruct the state. When we don't accept tokens, we just overwrite the unaccepted ones with the next step's new ones. Of course, in the limit, that just becomes linear-ish attention where you kept the whole KV cache.
How often should we checkpoint? One could try to balance the memory traffic of the replay window with the state. But we're actually not just memory bound. The easiest choice is that our tensor core usage requires us to pad all our dimension T (draft length + 1) operations to at least 16, so that's a natural history size. In the prior PR you can see that the old kernels' runtimes were quite flat with respect to T up to 16, and that had larger T for both state replay and output generation. Here we replay up to 16 from history but only output T.
The algorithm, then is every step, if the number of previously accepted tokens (PNAT)+ T > window size, we can't store the T new tokens so we need to write a checkpoint and restart the history buffer. Otherwise we have room to append.
Already this approach saves time. We don't issue instructions for saving the state, and for quantized, stochastically rounded, or block scaled SSM dstates, we also save on that pre-writing computation. Also, by quantizing less often, we accumulate less quantization error, which means we could quantize more aggressively. For example, perhaps int8 w/ stochastic rounding is viable if it only happens every 8 tokens, but not every token. Speculative decoding already gives some of that benefit, but it depends on acceptance rate. Here, we guarantee 16 - T tokens between checkpoints, and it can be as high as 16.
This approach also unlocks a number of additional speedups:
The Rectangle
Ignoring dA/dt scaling factors, the original replay algorithm is:
last_step_state = very_old_state + (old_X^t @ old_B)
outputs = new_C @ one_back_state^t + (new_C @ new_B^t) @ new_X. # bypass all the draft states, straight to output tokens. That C @ B^t we call the CB matrix, it's TxT square lower triangular, padded into 16x16.
Here old_* are last step's new_*, except we only take the first k of them, the acceptance rate.
All the matrix multiplies have some dimension at T, padded to 16.
But if we aren't writing, then just as we go straight from one_back_state to output tokens, we can actually go straight from very_old_state. If you just plug last_step_state in and expand, you get:
outputs = new_C @ very_old_state^t + ( new_C @ (old_B cat new_B)^t) @ (old_X + new_X)
Now the CB matrix is a rectangular lower triangular (PNAT + T) x T, and this fits in our window size (16) exactly when PNAT + T <= window_size. Oh look, our non-checkpointing condition!
The old state computation is one of the largest matmuls, so this is a substantial savings. But due to implementation details it is not always faster at small batch sizes.
constexpr-specialized Persistent Kernels
For a given request, sometimes we now write a checkpoint and sometimes don't. And when we don't we may even use a different algorithm. We can specialize nowrite algorithm choice as a constexpr we tune based on batch/dtype/rounding mode, but not write vs nowrite. Or can we? If we do specialize them, the compiler can take a lot of advantage. Nowrite should need less live state, even more so when it uses rectangle or writing uses stochastic rounding or quantizing. The compiler generates very different code. We also build on our prior PR where we tune parameters like M (slice of a head_dim done by a block), number of warps, etc. Separate kernels lets us tune these knobs for each case.
Ideally we'd launch one correctly sized grid for all the write requests, and one for the nowrite. But we're using cuda graphs in pytorch, so we can't vary that. Instead, we need to launch enough CTAs to cover all requests as write, and all as nowrite. Although the kernel can early out on requests that aren't its type, that overhead made specialized kernels uncompetative.
Except, we can use persistent kernels. As long as we have enough work to fill all the SMs a few times over, then there's no real overhead. And if we add some secondary data structures to present the requests partitioned by write/nowrite, the kernels can just loop over their portion. If we pack other needed data into that structure, it's not even added latency.
This makes specialized kernels a huge win at medium and large batches. At smaller batches a "dynamic"(i.e. unspecialized, non-constexpr) kernel still wins, probably because we can't easily optimally PDL in PyTorch between the precomputation kernel and the two main kernels. And we didn't have time to explore emulating this with atomics. But the dynamic kernel also wins by being persistent. Similarly constrained by time, the precompute kernel is non-persistent dynamic, whereas at least one of persistent or specialized would surely help.
Other tweaks
Tunings
We implemented a huge number of tunable knobs, and can pick which knobs to use based on batch size, dtype, and whether we are doing stochastic rounding. But how to pick the best? We have a mode choice of a single dynamic or a pair of specialized main kernels. Then each of these offers many tuning knobs. When we pick the specialized kernels, ther are a total of of 17 active knobs (more are in the code, but a few either don't compile or weren't tuned as they seemed dead), a mixture of boolean and integer. This resulted in over 61 billion possible combinations. While we'd hoped to be able to drop a lot of them, exhaustive searching of select subspaces showed none of the 17 were always dead. Obviously we can't exhaustively search the whole space for even one input cofiguration, let alone for all supported batch/dtype/rounding combinations. So we implemented a multi-process block coordinate descent optimization, with some tweaks to deal with cross-GPU performance drift. We also had to invest considerable effort into the actual benchmark runner: in-process CUPTI timing vs cuda events, multi-process for compilation and cupti analysis, optimization of cuda graph instantiation vs lauch overhead due to GIL contention.
These are honestly giant vibe coded messes, so we're removing the benchmarker from the repo and not attempting to submit the optimizer. No other human should have to review such things. But there's no arguing with the actual timing results.
Test Coverage
I've updated the replay unit tests to the new paradigm and added some more coverage scenarios, including better coverage of the rounding. Also added unit tests to cache manager for modified PNAT/double buffer behavior, and mamba metadata to test its new functionality of geneating the replay work items, needed for efficient persistent specialized kernel pairs.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Release Notes
New Features
Tests
Chores