From 7306122e931fa4e1f0a91739dcc182663730a32d Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 4 Aug 2025 17:44:42 +0800 Subject: [PATCH] Fix error illegal memory access when max_total_token_num is too large --- .../models/bloom/triton_kernel/context_flashattention_nopad.py | 2 +- .../deepseek2/triton_kernel/context_flashattention_nopad.py | 2 +- .../deepseek2/triton_kernel/context_flashattention_nopad_fp8.py | 2 +- lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py | 2 +- .../models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py | 2 +- .../models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py | 2 +- .../deepseek2/triton_kernel/gqa_flash_decoding_stage1_fp8.py | 2 +- lightllm/models/deepseek2/triton_kernel/sample_kv.py | 2 +- .../models/llama/triton_kernel/context_flashattention_nopad.py | 2 +- .../llama/triton_kernel/gqa_decode_flashattention_nopad.py | 2 +- lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py | 2 +- lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py | 2 +- .../models/phi3/triton_kernel/context_flashattention_nopad.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py b/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py index a5dafe327..03ed86447 100644 --- a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py @@ -76,7 +76,7 @@ def _fwd_kernel( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), mask=(start_n + offs_n) < block_end_loc, other=0, - ) + ).to(tl.int64) off_k = kv_loc[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py index 6a42d913f..fdecfc72f 100644 --- a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py @@ -91,7 +91,7 @@ def _fwd_kernel( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), mask=(start_n + offs_n) < block_end_loc, other=0, - ) + ).to(tl.int64) off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d off_kv_rope = ( kv_loc[None, :] * stride_kv_rope_bs diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py index 700706a2d..499c14084 100644 --- a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py @@ -96,7 +96,7 @@ def _fwd_kernel_fp8( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), mask=(start_n + offs_n) < block_end_loc, other=0, - ) + ).to(tl.int64) off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d off_kv_rope = ( kv_loc[None, :] * stride_kv_rope_bs diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py index 5b922604d..39deb1b6f 100644 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py @@ -34,7 +34,7 @@ def _fwd_kernel_destindex_copy_kv( offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) - dest_index = tl.load(Dest_loc + cur_index) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py index d0f676a93..b765cc4f8 100644 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py @@ -36,7 +36,7 @@ def _fwd_kernel_destindex_copy_kv_fp8( offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) - dest_index = tl.load(Dest_loc + cur_index) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py index 9191de6d1..f5909fffd 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py @@ -105,7 +105,7 @@ def _fwd_kernel_flash_decode_stage1_padding( req_to_tokens_ptr + offs_n_new, mask=seq_n_mask, other=0, - ) + ).to(tl.int64) off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) att_value = tl.dot(q, kv) diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1_fp8.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1_fp8.py index 262dabec6..6ec4e09b2 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1_fp8.py @@ -108,7 +108,7 @@ def _fwd_kernel_flash_decode_stage1_padding_fp8( req_to_tokens_ptr + offs_n_new, mask=seq_n_mask, other=0, - ) + ).to(tl.int64) off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index 6259c3ccd..af0aaa2f6 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -44,7 +44,7 @@ def _sample_kv_kernel( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, mask=offs_m < block_end_loc, other=0, - ) + ).to(tl.int64) off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) diff --git a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py index f2a45a67f..e36c51b39 100644 --- a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py @@ -84,7 +84,7 @@ def _fwd_kernel( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), mask=(start_n + offs_n) < block_end_loc, other=0, - ) + ).to(tl.int64) off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) qk = tl.dot(q, k) diff --git a/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py b/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py index 95bd232a2..afc67a84e 100644 --- a/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py +++ b/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py @@ -70,7 +70,7 @@ def _fwd_kernel( Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0, - ) + ).to(tl.int64) k = tl.load( k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0 ) diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py index dbee931a9..850d4185c 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py @@ -198,7 +198,7 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1( + cur_chunk_range * stride_req_to_token_seq, mask=cur_chunk_mask, other=0.0, - ) + ).to(tl.int64) k_off = ( cur_kv_loc[None, :] * stride_k_bs + cur_kv_head_idx * stride_k_h + d_off[:, None] diff --git a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py index 7126e1d74..a285004b5 100644 --- a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py +++ b/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py @@ -136,7 +136,7 @@ def _fwd_kernel_destindex_copy_dequantize_kv( kv_loc = tl.load( req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len - ) + ).to(tl.int64) offs_kv = kv_loc[:, None] * stride_kv_b + cur_head * stride_kv_h + cur_group * stride_kv_g + offs_d[None, :] src_data = tl.load( diff --git a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py index 5281bfdc2..ee04c3367 100644 --- a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py @@ -80,7 +80,7 @@ def _fwd_kernel( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), mask=(start_n + offs_n) < block_end_loc, other=0, - ) + ).to(tl.int64) off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd k = tl.load( K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0