Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _fwd_kernel(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 prevents potential integer overflows when max_total_token_num is large, ensuring correct memory access.

off_k = kv_loc[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _fwd_kernel(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Explicitly casting to tl.int64 prevents potential integer overflows, mitigating a critical bug related to large max_total_token_num.

off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d
off_kv_rope = (
kv_loc[None, :] * stride_kv_rope_bs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _fwd_kernel_fp8(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 is essential for handling large token spaces, preventing integer overflow and ensuring correct memory access.

off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d
off_kv_rope = (
kv_loc[None, :] * stride_kv_rope_bs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _fwd_kernel_destindex_copy_kv(
offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)
offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)

dest_index = tl.load(Dest_loc + cur_index)
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting dest_index to tl.int64 prevents potential integer overflows, ensuring correct memory write locations.


kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :]
kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _fwd_kernel_destindex_copy_kv_fp8(
offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)
offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)

dest_index = tl.load(Dest_loc + cur_index)
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting dest_index to tl.int64 avoids potential integer overflows with large token buffers, ensuring correct memory writes.


kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :]
kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _fwd_kernel_flash_decode_stage1_padding(
req_to_tokens_ptr + offs_n_new,
mask=seq_n_mask,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The cast to tl.int64 prevents integer overflow on token indices, ensuring valid memory access.

off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None]
kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0)
att_value = tl.dot(q, kv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _fwd_kernel_flash_decode_stage1_padding_fp8(
req_to_tokens_ptr + offs_n_new,
mask=seq_n_mask,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 ensures indices are handled correctly, preventing a critical bug related to integer overflows.

off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None]
kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0)
off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None]
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/deepseek2/triton_kernel/sample_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _sample_kv_kernel(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m,
mask=offs_m < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The explicit cast to tl.int64 prevents integer overflow, ensuring correct offset calculations and avoiding illegal memory access.

off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :]
off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :]
kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _fwd_kernel(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 prevents integer overflow, ensuring correct calculation and preventing potential illegal memory access.

off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)
qk = tl.dot(q, k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _fwd_kernel(
Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n,
mask=(start_n + offs_n) < cur_batch_seq_len,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This cast to tl.int64 prevents overflow and ensures the subsequent memory access for k is valid.

k = tl.load(
k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
+ cur_chunk_range * stride_req_to_token_seq,
mask=cur_chunk_mask,
other=0.0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting cur_kv_loc to tl.int64 avoids illegal memory access with large token buffers, ensuring correct calculations.


k_off = (
cur_kv_loc[None, :] * stride_k_bs + cur_kv_head_idx * stride_k_h + d_off[:, None]
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _fwd_kernel_destindex_copy_dequantize_kv(

kv_loc = tl.load(
req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting kv_loc to tl.int64 handles large token index values, preventing potential integer overflow and ensuring correct memory access.

offs_kv = kv_loc[:, None] * stride_kv_b + cur_head * stride_kv_h + cur_group * stride_kv_g + offs_d[None, :]

src_data = tl.load(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _fwd_kernel(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),
mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This cast to tl.int64 prevents illegal memory access when max_total_token_num is large, ensuring memory safety.

off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
k = tl.load(
K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0
Expand Down