-
Notifications
You must be signed in to change notification settings - Fork 281
Fix error illegal memory access when max_total_token_num is too large #998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,7 +91,7 @@ def _fwd_kernel( | |
| Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), | ||
| mask=(start_n + offs_n) < block_end_loc, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d | ||
| off_kv_rope = ( | ||
| kv_loc[None, :] * stride_kv_rope_bs | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,7 +96,7 @@ def _fwd_kernel_fp8( | |
| Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), | ||
| mask=(start_n + offs_n) < block_end_loc, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None] * stride_kv_d | ||
| off_kv_rope = ( | ||
| kv_loc[None, :] * stride_kv_rope_bs | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,7 @@ def _fwd_kernel_destindex_copy_kv( | |
| offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) | ||
| offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) | ||
|
|
||
| dest_index = tl.load(Dest_loc + cur_index) | ||
| dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] | ||
| kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,7 @@ def _fwd_kernel_destindex_copy_kv_fp8( | |
| offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) | ||
| offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) | ||
|
|
||
| dest_index = tl.load(Dest_loc + cur_index) | ||
| dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] | ||
| kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,7 +105,7 @@ def _fwd_kernel_flash_decode_stage1_padding( | |
| req_to_tokens_ptr + offs_n_new, | ||
| mask=seq_n_mask, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] | ||
| kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) | ||
| att_value = tl.dot(q, kv) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,7 +108,7 @@ def _fwd_kernel_flash_decode_stage1_padding_fp8( | |
| req_to_tokens_ptr + offs_n_new, | ||
| mask=seq_n_mask, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] | ||
| kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) | ||
| off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,7 @@ def _sample_kv_kernel( | |
| Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, | ||
| mask=offs_m < block_end_loc, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] | ||
| off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] | ||
| kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,7 +84,7 @@ def _fwd_kernel( | |
| Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), | ||
| mask=(start_n + offs_n) < block_end_loc, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd | ||
| k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) | ||
| qk = tl.dot(q, k) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,7 +70,7 @@ def _fwd_kernel( | |
| Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n, | ||
| mask=(start_n + offs_n) < cur_batch_seq_len, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| k = tl.load( | ||
| k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0 | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -198,7 +198,7 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1( | |
| + cur_chunk_range * stride_req_to_token_seq, | ||
| mask=cur_chunk_mask, | ||
| other=0.0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| k_off = ( | ||
| cur_kv_loc[None, :] * stride_k_bs + cur_kv_head_idx * stride_k_h + d_off[:, None] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,7 +136,7 @@ def _fwd_kernel_destindex_copy_dequantize_kv( | |
|
|
||
| kv_loc = tl.load( | ||
| req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| offs_kv = kv_loc[:, None] * stride_kv_b + cur_head * stride_kv_h + cur_group * stride_kv_g + offs_d[None, :] | ||
|
|
||
| src_data = tl.load( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,7 +80,7 @@ def _fwd_kernel( | |
| Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), | ||
| mask=(start_n + offs_n) < block_end_loc, | ||
| other=0, | ||
| ) | ||
| ).to(tl.int64) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd | ||
| k = tl.load( | ||
| K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Casting to
tl.int64prevents potential integer overflows whenmax_total_token_numis large, ensuring correct memory access.