add sage3 sparse and fp4 communication#917
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily introduces and integrates Sparse SAGE Attention (spas_sage_attn3) and FP4 communication support into the system. These enhancements aim to improve the efficiency and flexibility of distributed attention mechanisms, particularly for models utilizing sequence parallelism. The changes involve adding new attention types, extending communication protocols to handle FP4 quantization, and updating relevant model configurations and inference logic to leverage these new capabilities. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces sparse attention (spas_sage_attn3) and FP4 communication capabilities, primarily affecting attention mechanisms and their configurations. The changes involve updating configuration files, adding new parameters to attention functions, and implementing the logic for FP4 quantization and dequantization within the ulysses_attn.py module. The addition of sageattn3_sparse_blackwell and related FP4 quantization/dequantization functions from sageattn3_sparse aligns with the PR's objective to enhance communication efficiency. However, some areas could benefit from improved clarity and consistency, particularly regarding the handling of FP4 communication in RingAttn and the use of magic numbers/strings in ulysses_attn.py and sage_attn.py.
| torch.Tensor: 计算得到的注意力结果 | ||
| """ | ||
| assert not enable_head_parallel, "RingAttn can't support head parallel mode." | ||
| assert not use_fp4_comm, "RingAttn don't support use_fp4_comm now." |
There was a problem hiding this comment.
The assertion assert not use_fp4_comm, "RingAttn don't support use_fp4_comm now." indicates that RingAttn does not support FP4 communication. However, use_fp4_comm was added as a parameter in this PR. This creates a confusing situation where a feature is introduced but immediately disabled. If RingAttn truly cannot support FP4 communication, consider removing the use_fp4_comm parameter from its apply method to avoid misleading usage or potential dead code. Alternatively, if support is planned, this assert should be removed once implemented.
There was a problem hiding this comment.
ring和ulysses需要使用相同的调用参数。
| q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) | ||
| elif len(q.shape) == 4: | ||
| bs = q.shape[0] | ||
| x = ( |
There was a problem hiding this comment.
The parameters per_block_mean = False and topk=0.2 are magic numbers/boolean literals. It would improve readability and maintainability to define these as named constants or configurable parameters, especially topk which represents a specific threshold.
| x = ( | |
| x = sageattn3_sparse_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=False, is_sparse=True, per_block_mean=False, topk=0.2).transpose(1, 2).reshape(bs * max_seqlen_q, -1) |
| img_qkv_quant = img_qkv_quant.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims) | ||
| img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1) | ||
| else: | ||
| img_qkv_quant, img_qkv_scale = quant_fp4_sage3(img_qkv.reshape(world_size, -1, shard_heads, hidden_dims)) |
There was a problem hiding this comment.
The quant_fp4_sage3 call here is missing the in_tensor_layout and out_tensor_layout arguments, which were present in other quant_fp4_sage3 calls (e.g., lines 129 and 166-168). This inconsistency could lead to unexpected behavior or make the code harder to understand and maintain. Please ensure all calls to quant_fp4_sage3 are consistent with their arguments, or provide a clear reason for the difference.
| img_qkv_quant, img_qkv_scale = quant_fp4_sage3(img_qkv.reshape(world_size, -1, shard_heads, hidden_dims)) | |
| img_qkv_quant, img_qkv_scale = quant_fp4_sage3(img_qkv.reshape(world_size, -1, shard_heads, hidden_dims), in_tensor_layout="HND", out_tensor_layout="HND") |
There was a problem hiding this comment.
in_tensor_layout和out_tensor_layout默认值为NHD
| img_q_quant, img_q_scale = quant_fp4_sage3(img_q) | ||
| img_k_quant, img_k_scale = quant_fp4_sage3(img_k) | ||
| img_v_quant, img_v_scale = quant_fp4_sage3(img_v) |
There was a problem hiding this comment.
Similar to the previous comment, these quant_fp4_sage3 calls are missing the in_tensor_layout and out_tensor_layout arguments. Please ensure consistency across all quant_fp4_sage3 calls or document the reason for any intentional differences.
| img_q_quant, img_q_scale = quant_fp4_sage3(img_q) | |
| img_k_quant, img_k_scale = quant_fp4_sage3(img_k) | |
| img_v_quant, img_v_scale = quant_fp4_sage3(img_v) | |
| img_q_quant, img_q_scale = quant_fp4_sage3(img_q, in_tensor_layout="HND", out_tensor_layout="HND") | |
| img_k_quant, img_k_scale = quant_fp4_sage3(img_k, in_tensor_layout="HND", out_tensor_layout="HND") | |
| img_v_quant, img_v_scale = quant_fp4_sage3(img_v, in_tensor_layout="HND", out_tensor_layout="HND") |
There was a problem hiding this comment.
in_tensor_layout和out_tensor_layout默认值为NHD
No description provided.