Skip to content

Conversation

@sufubao
Copy link
Collaborator

@sufubao sufubao commented Aug 4, 2025

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue causing illegal memory access errors when the total number of tokens (max_total_token_num) becomes very large. The changes ensure that memory indexing operations within various Triton kernels correctly handle large token counts, improving the stability and reliability of the system under high-load conditions.

Highlights

  • Memory Safety Enhancement: Implemented explicit type casting to tl.int64 for token indices loaded within various Triton kernels. This prevents potential integer overflow issues that could lead to incorrect memory addressing and illegal memory access errors when max_total_token_num is excessively large.
  • Broad Kernel Coverage: The fix has been applied consistently across multiple Triton kernel files responsible for context flash attention, KV cache management, and decoding stages in various model architectures, including Bloom, Deepseek2, Llama, and Phi3, ensuring robustness for diverse large language models.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug that could cause illegal memory access when max_total_token_num is large. The fix involves consistently casting loaded token indices to tl.int64 across various Triton kernels for Bloom, DeepSeek2, Llama, and Phi3 models. This prevents integer overflow, ensuring correct memory offset calculations and improving the stability of the models with large inputs.

mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 prevents potential integer overflows when max_total_token_num is large, ensuring correct memory access.

mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Explicitly casting to tl.int64 prevents potential integer overflows, mitigating a critical bug related to large max_total_token_num.

mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 is essential for handling large token spaces, preventing integer overflow and ensuring correct memory access.

offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)

dest_index = tl.load(Dest_loc + cur_index)
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting dest_index to tl.int64 prevents potential integer overflows, ensuring correct memory write locations.

offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)

dest_index = tl.load(Dest_loc + cur_index)
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting dest_index to tl.int64 avoids potential integer overflows with large token buffers, ensuring correct memory writes.

mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting to tl.int64 prevents integer overflow, ensuring correct calculation and preventing potential illegal memory access.

mask=(start_n + offs_n) < cur_batch_seq_len,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This cast to tl.int64 prevents overflow and ensures the subsequent memory access for k is valid.

mask=cur_chunk_mask,
other=0.0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting cur_kv_loc to tl.int64 avoids illegal memory access with large token buffers, ensuring correct calculations.

kv_loc = tl.load(
req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Casting kv_loc to tl.int64 handles large token index values, preventing potential integer overflow and ensuring correct memory access.

mask=(start_n + offs_n) < block_end_loc,
other=0,
)
).to(tl.int64)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This cast to tl.int64 prevents illegal memory access when max_total_token_num is large, ensuring memory safety.

@shihaobai shihaobai merged commit 5b3e319 into main Aug 12, 2025
1 check passed
@shihaobai shihaobai deleted the fix_quant branch August 12, 2025 08:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants