Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions custom_ops/gpu_ops/sparse_indexer/indexer_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2061,22 +2061,28 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
int batch_id, length;
Comment thread
chang-wenbin marked this conversation as resolved.
const IdType* block_table_pre_batch;

IdType* dst;

if (seq_len_decoder != nullptr) { // decode
batch_id = batch_id_per_token[bid / q_num_heads];
// batch_id = batch_id_per_token[bid / q_num_heads];
batch_id = bid / q_num_heads;
Comment thread
chang-wenbin marked this conversation as resolved.
if (batch_id == -1) return;
length = (seq_len_decoder[batch_id]); // for pack q k
if (length == 0) return;
if (block_tables != nullptr) {
block_table_pre_batch = block_tables + batch_id * max_block_num;
}
dst = output + aux_input[batch_id] * top_k;

} else { // prefill
// length = (lengths != nullptr) ? lengths[bid] : static_cast<int>(max_len);
length = (lengths != nullptr) ? lengths[bid / q_num_heads]
: static_cast<int>(max_len);
dst = output + bid * top_k;
}

const DType* score = input + bid * max_len;
IdType* dst = output + bid * top_k;
// IdType* dst = output + bid * top_k;

// Mode-specific setup
[[maybe_unused]] const IdType* src_page_entry = nullptr;
Expand Down Expand Up @@ -2110,8 +2116,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
? static_cast<IdType>(block_ids * 64 + block_offset)
: static_cast<IdType>(-1);
} else {
dst[i] =
(i < length) ? static_cast<IdType>(i) : static_cast<IdType>(-1);
dst[i] = (i < length) ? static_cast<IdType>(i) + offset_val
: static_cast<IdType>(-1);
}
} else { // Plain
if (i < length) {
Expand Down Expand Up @@ -2337,10 +2343,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
block_idx = idx / 64;
block_ids = block_table_pre_batch[block_idx];
block_offset = idx % 64;
dst[base] =
static_cast<IdType>(block_ids * 64 + block_offset); // + offset_val
dst[base] = static_cast<IdType>(block_ids * 64 + block_offset);
} else {
dst[base] = static_cast<IdType>(idx); //+ offset_val;
dst[base] = static_cast<IdType>(idx) + offset_val;
}

} else { // Plain
Expand Down
28 changes: 10 additions & 18 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,14 +674,13 @@ def forward(
self.indexer_cache, k_fp8_cache, k_scale_cache, forward_meta.block_tables, forward_meta.cu_seqlens_k
)

k_scale_cache = k_scale_cache.flatten()[: k.shape[0]]
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache
k_scale_cache_real = k_scale_cache.flatten()[: k.shape[0]].contiguous()
k_cache = k_fp8_cache.view(paddle.float8_e4m3fn), k_scale_cache_real

# TODO(changwenbin): Constructed using maskoffset
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
num_tokens = q_fp8.shape[0]
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
ks_topk = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.zeros(num_tokens, dtype=paddle.int32)

bsz = forward_meta.seq_lens_this_time.shape[0]
Expand All @@ -696,20 +695,13 @@ def forward(

logits = deep_gemm.fp8_mqa_logits(
q_fp8, k_cache, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False
)

# To save GPU global memory usage
assert logits.size() == (num_tokens, max_seqlen_k)
tmp = paddle.full((num_tokens, num_tokens), float("-inf"))
for i in range(num_tokens):
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
logits = tmp
).contiguous()

radix_topk_ragged_transform(
logits.contiguous(),
logits,
indexer_top_k,
ks_topk, # self.offsets,
ke - ks + 1, # mask.contiguous(),#self.lengths,
ks, # self.offsets,# 初始K方向偏移,
ke - ks, # self.lengths,# 表明当前q 关注的k有多长;
None, # forward_meta.seq_lens_decoder,
None, # forward_meta.batch_id_per_token,
None,
Expand Down Expand Up @@ -740,20 +732,20 @@ def forward(
schedule_metadata,
self.max_model_len,
clean_logits=True,
)
).contiguous()

radix_topk_ragged_transform(
logits.contiguous(),
logits,
indexer_top_k,
self.offsets, # unused
forward_meta.cu_seqlens_q,
self.lengths, # unused
cache_seqlens,
forward_meta.batch_id_per_token,
forward_meta.block_tables,
None, # self.buffer
forward_meta.block_tables.shape[1],
self.index_topk,
1, # q_head
1, # kv_head
)

return indexer_top_k
Expand Down
57 changes: 33 additions & 24 deletions tests/operators/test_radix_topk_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,40 @@ def setUp(self):
def get_reference_topk(self, input_pd, lengths_pd, offsets_pd, top_k, q_num_heads):
"""
使用 paddle.topk 生成参考结果
注意:算子输出的索引是 0-based 相对索引(不包含 offset
注意:算子输出的索引是相对于 offsets 的偏移量(0-based 相对索引)

Args:
input_pd: (num_rows, max_len)
lengths_pd: (batch_size,) - 每个batch的长度
offsets_pd: (num_rows,) - 每一行的偏移基点(未使用,仅保留参数兼容性)
offsets_pd: (num_rows,) - 每一行的偏移基点
top_k: k值
q_num_heads: query head数量

Returns:
ref_indices: (num_rows, top_k) - 参考索引(0-based 相对索引),长度不足的部分用-1填充
ref_indices: (num_rows, top_k) - 参考索引(相对于 offset 的偏移),长度不足的部分用-1填充
"""
num_rows = input_pd.shape[0]
ref_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
offsets = offsets_pd.numpy()

for row_idx in range(num_rows):
batch_idx = row_idx // q_num_heads
length = lengths_pd[batch_idx].item()
offset = offsets[row_idx]

if length == 0:
continue

row_data = input_pd[row_idx, :length]

if length <= top_k:
# 长度不足top_k,按顺序返回所有索引(0-based
ref_indices[row_idx, :length] = paddle.arange(0, length, dtype="int32")
# 长度不足top_k,按顺序返回所有索引(相对于 offset
ref_indices[row_idx, :length] = paddle.arange(offset, offset + length, dtype="int32")
else:
# 长度足够,使用 paddle.topk 获取最大的top_k个值的索引
topk_vals, topk_inds = paddle.topk(row_data, top_k)
# 直接使用 topk 返回的索引(0-based)
ref_indices[row_idx, :top_k] = topk_inds
# 加上 offset 作为基点
ref_indices[row_idx, :top_k] = topk_inds + offset

return ref_indices

Expand Down Expand Up @@ -171,41 +173,48 @@ def test_decode_mode(self):
paddle.seed(2025)

batch_size = 2
q_num_heads = 4
num_rows = batch_size * q_num_heads
kv_head = 1 # decode 模式下,每个 batch 只有一个新 token
num_rows = batch_size * kv_head # = batch_size
max_len = 1024
top_k = 8

# 使用 paddle 构造数据
input_pd = paddle.randn([num_rows, max_len], dtype="float32")
offsets_pd = paddle.arange(num_rows, dtype="int32")
lengths_pd = paddle.full([num_rows], 0, dtype="int32")
seq_len_decoder_pd = paddle.randint(16, 128, [batch_size], dtype="int32")

# 生成 batch_id_per_token
batch_id_per_token_pd = paddle.arange(num_rows, dtype="int32") // q_num_heads
# 生成 cu_seqlens_q: 每个 batch 在打平的 query 中的偏移量
# 在 decode 模式下,每个 batch 只有一个新 token,所以 cu_seqlens_q = [0, 1, 2, ..., batch_size]
cu_seqlens_q_pd = paddle.concat(
[
paddle.zeros([1], dtype="int32"),
paddle.cumsum(paddle.ones([batch_size], dtype="int32")).astype("int32"),
],
axis=0,
)

# 调用算子
lengths_pd = paddle.full([num_rows], 0, dtype="int32") # unused
seq_len_decoder_pd = paddle.randint(16, 128, [batch_size], dtype="int32")

# 调用算子(不使用 block_tables,让它按照 prefill 模式类似的逻辑工作)
output_indices = paddle.full([num_rows, top_k], -1, dtype="int32")
radix_topk_ragged_transform(
input_pd,
output_indices,
offsets_pd,
lengths_pd,
cu_seqlens_q_pd,
lengths_pd, # unused
seq_len_decoder_pd,
batch_id_per_token_pd,
None,
None,
0,
None, # batch_id_per_token
None, # block_tables
None, # buffer
0, # max_block_per_seq
top_k,
q_num_heads,
kv_head,
)

# Decode 模式下,长度 = seq_len_decoder + 1
decode_lengths = seq_len_decoder_pd + 1

# 获取参考结果
ref_indices = self.get_reference_topk(input_pd, decode_lengths, offsets_pd, top_k, q_num_heads)
# 获取参考结果(注意:num_rows = batch_size * kv_head)
ref_indices = self.get_reference_topk(input_pd, decode_lengths, cu_seqlens_q_pd, top_k, kv_head)

# 对比结果
result = self.compare_indices(output_indices, ref_indices)
Expand Down
Loading