Skip to content

feat(cpu): implement RadixAttnSwaSink with sliding window attention support#516

Merged
chenghuaWang merged 2 commits intoUbiquitousLearning:v2from
chenghuaWang:v2
Nov 12, 2025
Merged

feat(cpu): implement RadixAttnSwaSink with sliding window attention support#516
chenghuaWang merged 2 commits intoUbiquitousLearning:v2from
chenghuaWang:v2

Conversation

@chenghuaWang
Copy link
Copy Markdown
Collaborator

@chenghuaWang chenghuaWang commented Nov 12, 2025

  • Add new CPU operator RadixAttnSwaSink for radix attention with sink tokens
  • Implement forward kernel fwd_bshd for BSHD tensor layout with architecture-specific optimizations
  • Support both ARM64 and x86 architectures with SIMD instructions
  • Integrate with prefix cache system for efficient KV cache management
  • Add comprehensive test cases covering prefill, decode and append modes
  • Fix tensor rank validation and output shape calculation
  • Update library loading logic to support Linux platform correctly

Summary by CodeRabbit

  • New Features

    • Added sliding-window attention with sink token integration for efficient KV sequence handling.
    • Introduced prefix cache system for CPU-based key-value memory management supporting full and sliding-window modes.
  • Tests

    • Added comprehensive test suite for sliding-window attention with various cache configurations.
  • Improvements

    • Optimized architecture-specific kernels for ARM64 and X86 platforms.
    • Improved platform detection for CPU extension loading on Linux systems.

chenghuaWang and others added 2 commits November 12, 2025 19:45
…upport

- Add new CPU operator RadixAttnSwaSink for radix attention with sink tokens
- Implement forward kernel `fwd_bshd` for BSHD tensor layout with architecture-specific optimizations
- Support both ARM64 and x86 architectures with SIMD instructions
- Integrate with prefix cache system for efficient KV cache management
- Add comprehensive test cases covering prefill, decode and append modes
- Fix tensor rank validation and output shape calculation
- Update library loading logic to support Linux platform correctly
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 12, 2025

Walkthrough

Introduces a sliding-window radix attention implementation (RadixAttnSwaSink) with architecture-specific forward kernels (fwd_bshd), a CPU-based prefix cache system for KV management, test infrastructure, and updates kernel function naming conventions and platform-specific extension loading for Linux.

Changes

Cohort / File(s) Summary
Radix Attention Kernel Implementation
mllm-ext-opset/cpu/radix_swa_sink/radix_swa_sink_fwd_bshd.hpp
New forward pass kernel for batched head-varied sliding-window attention with per-head KV token mapping, dot-product softmax scaling, and sink token integration via s_aux.
Radix Attention Operation Integration
mllm-ext-opset/cpu/radix_swa_sink/RadixAttnSwaSink.cpp, mllm-ext-opset/cpu/radix_swa_sink/tests/main.cpp
Replaces TODOs with architecture-specific fwd_bshd kernel calls, updates input validation for rank-1 K/V tensors, refines output tensor shape allocation, and adds comprehensive multi-scenario test suite with reference eager implementation.
Radix Attention Kernel Dispatch
mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp, mllm/backends/cpu/ops/RadixAttnOp.cpp
Renames fwd_bhsd to fwd_bshd and updates dispatcher calls across ARM64/X86 architectures.
CPU Prefix Cache Framework
mllm/nn/lmcache/PrefixCache.hpp, mllm/nn/lmcache/PrefixCache.cpp
Introduces PrefixCache base class and CpuPrefixCache implementation with options-based configuration, per-layer memory allocation, validation hooks, and dual-mode initialization for full-attention and sliding-window scenarios.
Platform Extension Handling
mllm/mllm.cpp
Updates loadExtensionOpset platform check to consolidate Android and Linux under ".so" extension loading.

Sequence Diagram

sequenceDiagram
    participant Test as Test Suite
    participant RadixOp as RadixAttnSwaSink Op
    participant Kernel as fwd_bshd Kernel
    participant Cache as CPU Prefix Cache
    
    Test->>Cache: Initialize (full-attention or sliding-window mode)
    activate Cache
    Cache->>Cache: Setup per-layer allocators<br/>(HiCPUAllocator)
    deactivate Cache
    
    Test->>Test: Generate Q, K_idx, V_idx, s_aux tensors
    Test->>Cache: Allocate K/V storage per layer
    Test->>Cache: Store K/V via physical addresses
    
    Test->>RadixOp: Submit RadixAttnSwaSink operation<br/>(Q, K_idx, V_idx, s_aux, options)
    activate RadixOp
    RadixOp->>RadixOp: Validate input tensor ranks<br/>(Q: rank-4, K/V: rank-1)
    RadixOp->>Kernel: fwd_bshd(B, H_Q, H_KV, S_Q, S_KV,<br/>D_QK, D_V, window, seq_len,<br/>Q, K_ptrs, V_ptrs, s_aux, output)
    activate Kernel
    Kernel->>Kernel: Per Q-head parallelization
    Kernel->>Kernel: Map KV/V tokens via indices
    Kernel->>Kernel: Compute Q @ K with sliding window
    Kernel->>Kernel: Apply per-head softmax scaling
    Kernel->>Kernel: Integrate sink token from s_aux
    Kernel->>Kernel: Accumulate to output
    deactivate Kernel
    RadixOp->>RadixOp: Allocate output with<br/>shape [B, H_Q, D_V]
    deactivate RadixOp
    
    Test->>Test: Compare output vs. reference<br/>(allClose validation)
    alt Match
        Test->>Test: Pass
    else Mismatch
        Test->>Test: Print diagnostics
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • fwd_bshd kernel logic (radix_swa_sink_fwd_bshd.hpp): Dense parallelization, per-head token mapping, softmax scaling, and sink token integration require careful validation of correctness and numerical stability.
  • PrefixCache memory management (PrefixCache.cpp/hpp): Dual-mode initialization (full-attention vs. sliding-window), per-layer allocator setup, and address translation logic need thorough review for consistency and correctness.
  • Test infrastructure complexity: Multiple test scenarios with dynamic mask construction, cache-based K/V handling, and cross-validation against reference eager implementation.
  • Input validation refinements: Rank constraints on K/V inputs and output shape computation merit attention for edge cases.

Possibly related PRs

Suggested reviewers

  • yirongjie
  • oreomaker
  • liang1232018

Poem

🐰 A radix tree hops through cached tokens bright,
Sliding windows catching queries mid-flight,
Per-head attention sinks with graceful care,
Prefix caches blooming everywhere,
All kernels aligned—left and right! 🌟

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main feature: implementing RadixAttnSwaSink with sliding window attention support, which aligns with the PR's core objective.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@chenghuaWang
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 12, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@chenghuaWang chenghuaWang merged commit 0a97dfd into UbiquitousLearning:v2 Nov 12, 2025
3 checks passed
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
mllm-ext-opset/cpu/radix_swa_sink/RadixAttnSwaSink.cpp (2)

23-60: Guard optional S_AUX before indexing inputs[3].

forward unconditionally reads inputs[3], but reshape explicitly allows configurations with only three tensors whenever options_.s_aux_enable is false. In that scenario we dereference past the end of the inputs vector, leading to UB/crash before the kernel launches. Please either require the op to run with s_aux_enable == true (and assert that a fourth tensor is present) or handle the disabled case explicitly before dereferencing. Example fix to enforce the invariant:

+  MLLM_RT_ASSERT(options_.s_aux_enable);
+  MLLM_RT_ASSERT_GE(inputs.size(), 4);
   auto& S_AUX = inputs[3];

39-87: Fix K/V shape usage after switching them to rank-1.

reshape now enforces k.rank() == 1 and v.rank() == 1, yet forward still indexes K.shape()[1] / V.shape()[1] and even reads V.shape()[3]. That pulls garbage (or faults) for rank-1 tensors, breaking the sliding-window kernel immediately. Use the 0th dimension for the sequence length and drop the nonexistent head/dimension lookups—options_ already carries D_QK/D_V. For instance:

-  auto D_V = V.shape()[3];
-  MLLM_RT_ASSERT_EQ(H_Q, options_.q_head);
-  auto S_KV = K.shape()[1];
-  MLLM_RT_ASSERT_EQ(S_KV, V.shape()[1]);
+  MLLM_RT_ASSERT_EQ(H_Q, options_.q_head);
+  MLLM_RT_ASSERT_EQ(K.rank(), 1);
+  MLLM_RT_ASSERT_EQ(V.rank(), 1);
+  auto S_KV = K.shape()[0];
+  MLLM_RT_ASSERT_EQ(S_KV, V.shape()[0]);

Without this fix the op miscomputes S_KV (often reading random memory) and the kernel walks the wrong number of cache slots.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7a16b55 and ad13d07.

📒 Files selected for processing (8)
  • mllm-ext-opset/cpu/radix_swa_sink/RadixAttnSwaSink.cpp (3 hunks)
  • mllm-ext-opset/cpu/radix_swa_sink/radix_swa_sink_fwd_bshd.hpp (1 hunks)
  • mllm-ext-opset/cpu/radix_swa_sink/tests/main.cpp (1 hunks)
  • mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp (1 hunks)
  • mllm/backends/cpu/ops/RadixAttnOp.cpp (1 hunks)
  • mllm/mllm.cpp (1 hunks)
  • mllm/nn/lmcache/PrefixCache.cpp (1 hunks)
  • mllm/nn/lmcache/PrefixCache.hpp (1 hunks)
🧰 Additional context used
🪛 Clang (14.0.6)
mllm-ext-opset/cpu/radix_swa_sink/radix_swa_sink_fwd_bshd.hpp

[error] 6-6: 'cmath' file not found

(clang-diagnostic-error)


[error] 29-29: declaration uses identifier '__ArchTag', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: declaration uses identifier '__QDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: declaration uses identifier '__KDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: declaration uses identifier '__VDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: declaration uses identifier '__ODType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 29-29: declaration uses identifier '__AccDType', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 31-31: 9 adjacent parameters of 'fwd_bshd' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 31-31: parameter name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 32-32: declaration uses identifier '__q', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 32-32: declaration uses identifier '__k', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 32-32: declaration uses identifier '__v', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 33-33: declaration uses identifier '__s_aux', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 33-33: declaration uses identifier '__out', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)


[error] 34-34: variable 'head_repeat_times' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)

mllm/backends/cpu/kernels/common/radix_attn/fwd_bshd.hpp

[error] 33-33: 6 adjacent parameters of 'fwd_bshd' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 33-33: parameter name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 33-33: parameter name 'D' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 33-33: declaration uses identifier '__q', which is a reserved identifier

(bugprone-reserved-identifier,-warnings-as-errors)

mllm-ext-opset/cpu/radix_swa_sink/tests/main.cpp

[error] 1-1: 'cmath' file not found

(clang-diagnostic-error)


[error] 8-8: 2 adjacent parameters of 'radixAttnSWAwSink' of similar type ('const int &') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 8-8: 9 adjacent parameters of 'radixAttnSWAwSink' of similar type are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 12-12: variable name 'id' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 27-27: 5 adjacent parameters of 'eagerSWAwSink' of similar type are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 30-30: variable name 'B' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 30-30: invalid case style for variable 'B'

(readability-identifier-naming,-warnings-as-errors)


[error] 31-31: invalid case style for variable 'S_Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 32-32: invalid case style for variable 'H_Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 33-33: invalid case style for variable 'S_KV'

(readability-identifier-naming,-warnings-as-errors)


[error] 34-34: invalid case style for variable 'H_KV'

(readability-identifier-naming,-warnings-as-errors)


[error] 35-35: invalid case style for variable 'D_QK'

(readability-identifier-naming,-warnings-as-errors)


[error] 47-47: variable 'scale' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 66-66: 7 adjacent parameters of 'testCase' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 66-66: 2 adjacent parameters of 'testCase' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 73-73: variable name 'Q' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 73-73: invalid case style for variable 'Q'

(readability-identifier-naming,-warnings-as-errors)


[error] 74-74: variable name 'K' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 74-74: invalid case style for variable 'K'

(readability-identifier-naming,-warnings-as-errors)


[error] 75-75: variable name 'V' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 75-75: invalid case style for variable 'V'

(readability-identifier-naming,-warnings-as-errors)


[error] 104-104: invalid case style for variable 'O_ref'

(readability-identifier-naming,-warnings-as-errors)


[error] 134-134: variable name 'O' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 134-134: invalid case style for variable 'O'

(readability-identifier-naming,-warnings-as-errors)


[error] 140-140: variable 'ret' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 141-141: variable 'O_ref' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 141-141: invalid case style for variable 'O_ref'

(readability-identifier-naming,-warnings-as-errors)


[error] 142-142: variable 'O' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 142-142: variable name 'O' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 142-142: invalid case style for variable 'O'

(readability-identifier-naming,-warnings-as-errors)


[error] 148-148: variable 'MLLM_MAIN' is non-const and globally accessible, consider making it const

(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)


[error] 148-148: invalid case style for variable 'MLLM_MAIN'

(readability-identifier-naming,-warnings-as-errors)

mllm/nn/lmcache/PrefixCache.hpp

[error] 21-21: 'memory' file not found

(clang-diagnostic-error)


[error] 59-59: class 'PrefixCache' defines a non-default destructor but does not define a copy constructor, a copy assignment operator, a move constructor or a move assignment operator

(cppcoreguidelines-special-member-functions,-warnings-as-errors)


[error] 68-68: parameter 2 is const-qualified in the function declaration; const-qualification of parameters only has an effect in function definitions

(readability-avoid-const-params-in-decls,-warnings-as-errors)


[error] 99-99: parameter name 'fp' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 102-102: member variable 'options_' has protected visibility

(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)


[error] 103-103: member variable 'tree_' has protected visibility

(cppcoreguidelines-non-private-member-variables-in-classes,-warnings-as-errors)


[error] 111-111: parameter 2 is const-qualified in the function declaration; const-qualification of parameters only has an effect in function definitions

(readability-avoid-const-params-in-decls,-warnings-as-errors)

mllm/nn/lmcache/PrefixCache.cpp

[error] 4-4: 'cmath' file not found

(clang-diagnostic-error)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-macos
  • GitHub Check: build-android
🔇 Additional comments (1)
mllm/mllm.cpp (1)

109-109: LGTM! Correctly extends .so extension support to Linux.

The change properly adds Linux and UnknownPlatform to the .so extension path, which is correct since Linux uses .so for shared libraries. The UnknownPlatform fallback is also reasonable.

Comment on lines +57 to +96
// Compute should be done in [full_attention_start, full_attention_end]
auto full_attention_start = std::max(0, s_q_idx + cur_kv_seq_len - S_Q - left_sliding_window);
auto full_attention_end = std::min(cur_kv_seq_len, s_q_idx + cur_kv_seq_len - S_Q);

// We then map full attention to local attention position;
auto local_attention_start = 0;

auto local_attention_end = full_attention_end - full_attention_start;
if (local_attention_end < left_sliding_window) local_attention_end++;

for (int s_kv_idx = local_attention_start; s_kv_idx < local_attention_end; ++s_kv_idx) {
// k_token and v_token shape is [B, 1, H, D]
__KDType* k_token = __k[s_kv_idx];
__VDType* v_token = __v[s_kv_idx];

// Offset to one head.
// k_token and v_token shape is [D]
k_token = k_token + b_idx * H_KV * D_QK + h_kv_id * D_QK;
v_token = v_token + b_idx * H_KV * D_V + h_kv_id * D_V;

// 1. MMA0. Q @ K -> A_i
__AccDType acc_s;
mllm::cpu::radix_attn::details::VectorDotProduct<__ArchTag, __QDType, __KDType, __AccDType>::run(q_token, k_token,
&acc_s, D_QK);

// 2. Do softmax stuff.
scores_max_prev = scores_max;
scores_max = std::max(scores_max_prev, acc_s);
scores_scale = std::exp2(scores_max_prev * scale - scores_max * scale);
acc_s = std::exp2(acc_s * scale - scores_max * scale);
scores_sum = acc_s;
logsum = logsum * scores_scale + scores_sum;

// 3. Scale
mllm::cpu::radix_attn::details::MulFromConst<__ArchTag, __AccDType, __AccDType>::run(acc_o, scores_scale, D_V);

// 4. MMA1.
mllm::cpu::radix_attn::details::FMAConstArray<__ArchTag, __AccDType, __AccDType, __AccDType>::run(acc_o, acc_s,
v_token, D_V);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Apply the computed sliding-window offset when reading KV tokens.

full_attention_start is calculated for each query, but the inner loop still uses __k[s_kv_idx] / __v[s_kv_idx]. As soon as the window slides (i.e., full_attention_start > 0), we start from slot 0 instead of the correct global index, so the kernel attends to stale tokens. Offset the lookup by full_attention_start (and reuse cur_kv_seq_len bounds) before the dot-product:

-        for (int s_kv_idx = local_attention_start; s_kv_idx < local_attention_end; ++s_kv_idx) {
-          __KDType* k_token = __k[s_kv_idx];
-          __VDType* v_token = __v[s_kv_idx];
+        for (int s_kv_idx = local_attention_start; s_kv_idx < local_attention_end; ++s_kv_idx) {
+          const int kv_global_idx = full_attention_start + s_kv_idx;
+          __KDType* k_token = __k[kv_global_idx];
+          __VDType* v_token = __v[kv_global_idx];

Without this adjustment the sink operator returns incorrect outputs for any nontrivial sliding-window step.

Comment on lines +134 to +136
auto O =
radixAttnSWAwSink(Q, k_cache_indices, v_cache_indices, s_aux, 1, H_Q, H_KV, D_QK, D_V, left_sliding_window, cur_seq_len);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the requested batch size when invoking the op.

testCase accepts B, but the radixAttnSWAwSink call always hardcodes batch = 1. Any scenario that exercises B > 1 will configure the kernel incorrectly and drop work for the extra batch entries. Please pass B through to the options.

-  auto O =
-      radixAttnSWAwSink(Q, k_cache_indices, v_cache_indices, s_aux, 1, H_Q, H_KV, D_QK, D_V, left_sliding_window, cur_seq_len);
+  auto O = radixAttnSWAwSink(Q, k_cache_indices, v_cache_indices, s_aux, B, H_Q, H_KV, D_QK, D_V, left_sliding_window,
+                             cur_seq_len);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto O =
radixAttnSWAwSink(Q, k_cache_indices, v_cache_indices, s_aux, 1, H_Q, H_KV, D_QK, D_V, left_sliding_window, cur_seq_len);
auto O = radixAttnSWAwSink(Q, k_cache_indices, v_cache_indices, s_aux, B, H_Q, H_KV, D_QK, D_V, left_sliding_window,
cur_seq_len);
🧰 Tools
🪛 Clang (14.0.6)

[error] 134-134: variable name 'O' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 134-134: invalid case style for variable 'O'

(readability-identifier-naming,-warnings-as-errors)

🤖 Prompt for AI Agents
In mllm-ext-opset/cpu/radix_swa_sink/tests/main.cpp around lines 134 to 136, the
call to radixAttnSWAwSink hardcodes batch=1 which ignores the testCase parameter
B; update the call to pass the actual batch size B (or variable representing B)
into the radixAttnSWAwSink options so the kernel is configured with the
requested batch size and will process all batch entries correctly.

Comment on lines +21 to +27
#include <memory>
#include <vector>
#include <optional>
#include "mllm/core/Tensor.hpp"
#include "mllm/engine/prefix_cache/TLB.hpp"
#include "mllm/engine/prefix_cache/Cache.hpp" // IWYU pragma: export

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add the missing <string> include.

PrefixCacheOptions exposes a std::string field, but this header never includes <string>. Any translation unit that includes PrefixCache.hpp before another header that happens to pull in <string> will fail to compile. Relying on transitive includes is explicitly discouraged—always include the headers you depend on.(stackoverflow.com)

 #include <memory>
 #include <vector>
+#include <string>
 #include <optional>
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <memory>
#include <vector>
#include <optional>
#include "mllm/core/Tensor.hpp"
#include "mllm/engine/prefix_cache/TLB.hpp"
#include "mllm/engine/prefix_cache/Cache.hpp" // IWYU pragma: export
#include <memory>
#include <vector>
#include <string>
#include <optional>
#include "mllm/core/Tensor.hpp"
#include "mllm/engine/prefix_cache/TLB.hpp"
#include "mllm/engine/prefix_cache/Cache.hpp" // IWYU pragma: export
🧰 Tools
🪛 Clang (14.0.6)

[error] 21-21: 'memory' file not found

(clang-diagnostic-error)

🤖 Prompt for AI Agents
In mllm/nn/lmcache/PrefixCache.hpp around lines 21 to 27, the header declares
PrefixCacheOptions with a std::string field but never includes <string>; add
#include <string> (place it with the other standard includes near the top of the
file) so the header directly satisfies its dependency and no longer relies on
transitive includes.

Comment on lines +106 to +144
class CpuPrefixCache final : protected PrefixCache {
public:
explicit CpuPrefixCache(const PrefixCacheOptions& options);

void promote(const std::vector<int64_t>& token_ids,
const std::vector<std::vector<prefix_cache::vp_addr_t>>& key_cache_addresses,
const std::vector<std::vector<prefix_cache::vp_addr_t>>& value_cache_addresses, int64_t extra_key) override;

prefix_cache::vp_addr_t allocKey(int layer_idx) override;

prefix_cache::vp_addr_t allocValue(int layer_idx) override;

void freeKey(int layer_idx, prefix_cache::vp_addr_t addr) override;

void freeValue(int layer_idx, prefix_cache::vp_addr_t addr) override;

char* physicalAddrKey(int layer_idx, prefix_cache::vp_addr_t addr) override;

char* physicalAddrValue(int layer_idx, prefix_cache::vp_addr_t addr) override;

void prefetchKey(int layer_idx, prefix_cache::vp_addr_t addr) override;

void prefetchValue(int layer_idx, prefix_cache::vp_addr_t addr) override;

void purgeKey(int layer_idx, prefix_cache::vp_addr_t addr) override;

void purgeValue(int layer_idx, prefix_cache::vp_addr_t addr) override;

void _initFullAttention() override;

void _initSlidingWindowAttention() override;

void _validateKeyTokenShape(Tensor& key) override;

void _validateValueTokenShape(Tensor& value) override;

private:
// Each Layer each allocator for memory contiguous on CPU platform.
std::vector<std::pair<prefix_cache::_AllocatorImpl::ptr_t, prefix_cache::_AllocatorImpl::ptr_t>> caches_;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Switch to public inheritance for CpuPrefixCache.

CpuPrefixCache currently derives protected from PrefixCache, which blocks implicit upcasts to the base type. That defeats the polymorphic interface you just introduced—callers that work with PrefixCache& or PrefixCache* cannot use CpuPrefixCache at all. Make the inheritance public so the concrete implementation can substitute for the base.

-class CpuPrefixCache final : protected PrefixCache {
+class CpuPrefixCache final : public PrefixCache {
🧰 Tools
🪛 Clang (14.0.6)

[error] 111-111: parameter 2 is const-qualified in the function declaration; const-qualification of parameters only has an effect in function definitions

(readability-avoid-const-params-in-decls,-warnings-as-errors)

🤖 Prompt for AI Agents
In mllm/nn/lmcache/PrefixCache.hpp around lines 106 to 144, CpuPrefixCache
currently inherits from PrefixCache using protected inheritance which prevents
implicit upcasts to PrefixCache and breaks polymorphism; change the inheritance
specifier from protected to public (i.e., make CpuPrefixCache publicly inherit
PrefixCache) so callers can use CpuPrefixCache where a PrefixCache* or
PrefixCache& is expected, then rebuild and run tests to ensure no access
violations or visibility issues occur.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant