Conversation
fuheaven
commented
Dec 15, 2025
- rename dcu to hygon_dcu
- fix flash attention bug
Summary of ChangesHello @fuheaven, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily focuses on enhancing the platform's support for Hygon DCU devices. It systematically renames all Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request primarily renames the dcu platform to hygon_dcu and updates the flash attention implementation for it. While the renaming changes are applied consistently, I've identified several critical issues with the new flash attention logic. The implementation incorrectly handles variable-length sequences, which will cause runtime errors in both the primary execution path and the SDPA fallback. Additionally, it appears that support for matrix multiplication (mm) operations for the hygon_dcu platform may have been unintentionally removed during the refactoring. These issues should be addressed before merging.
| ) | ||
|
|
||
| # Reshape to [B*max_seqlen_q, num_heads * head_dim] | ||
| bs = cu_seqlens_q.shape[0] - 1 |
There was a problem hiding this comment.
The output tensor from flash_attn_varlen_func is being reshaped incorrectly. flash_attn_varlen_func returns a packed tensor of shape (total_tokens, num_heads, head_dim), where total_tokens is the sum of sequence lengths in the batch. Reshaping it to (bs * max_seqlen_q, -1) is only correct if all sequences have length max_seqlen_q (i.e., no padding). For variable length sequences, this will raise a RuntimeError because the number of elements will not match.
The output should be reshaped based on the total number of tokens, which is the first dimension of the output tensor.
| bs = cu_seqlens_q.shape[0] - 1 | |
| output = output.reshape(-1, output.shape[-2] * output.shape[-1]) |
| # Reshape q, k, v to [B, L, Nq, C] | ||
| q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1]) | ||
| k = k.reshape(bs, max_seqlen_q, k.shape[-2], k.shape[-1]) |
There was a problem hiding this comment.
The reshape operation here is incorrect for variable-length sequences. The input tensors q, k, and v are packed (flattened), with a shape like (total_tokens, num_heads, head_dim). Reshaping them to (bs, max_seqlen_q, ...) will fail if there's any padding in the batch, because total_tokens will not be equal to bs * max_seqlen_q.
To use torch.nn.functional.scaled_dot_product_attention, you must first convert the packed tensors into padded, batched tensors. This requires more than a simple reshape. After the attention call, the resulting padded tensor must then be converted back to a packed format. The current implementation of this fallback function is broken for variable-length inputs.
| from .mm.dcu import * | ||
| # Check if running on Hygon DCU platform | ||
| if os.getenv("PLATFORM") == "hygon_dcu": | ||
| from .attn.hygon_dcu import * |
There was a problem hiding this comment.
The refactoring from dcu to hygon_dcu appears to have removed the import for mm (matrix multiplication) operations. The previous dcu platform implementation imported from .mm.dcu, but this has been omitted for hygon_dcu.
If hygon_dcu is intended to support these mm operations, this change constitutes a feature regression. Was this removal intentional? If not, please ensure the necessary mm operations are implemented and imported for the hygon_dcu platform.
1. rename dcu to hygon_dcu 2. fix flash attention bug