optimize ulysses-4090: balanced a2a and merged qkv#590
Conversation
Summary of ChangesHello @STwangyingrui, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance of the Ulysses-4090 attention mechanism in distributed settings. It introduces a more efficient communication strategy by consolidating QKV tensors and attention outputs into single communication calls, leveraging a newly implemented load-balanced all-to-all primitive. These changes are complemented by configuration updates that activate the optimized attention type and enable CPU offloading for various model sub-components, aiming for overall improved resource utilization and speed. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant optimizations to the Ulysses attention mechanism, specifically for the ulysses-4090 variant. The key changes include implementing a balanced all-to-all communication strategy using a round-robin schedule, and merging Q, K, and V tensors (and their FP8 scales) to reduce the number of distributed communication calls. These changes should improve performance and reduce communication overhead.
My review includes a suggestion to further optimize the partner lookup within the new all-to-all implementation for better performance and a minor suggestion to improve code clarity by renaming a variable. Overall, the changes are well-implemented and a great improvement.
| partner = None | ||
| is_smaller_in_pair = False | ||
|
|
||
| for smaller, larger in round_pairs: | ||
| if cur_rank == smaller: | ||
| partner = larger | ||
| is_smaller_in_pair = True | ||
| break | ||
| elif cur_rank == larger: | ||
| partner = smaller | ||
| is_smaller_in_pair = False | ||
| break |
There was a problem hiding this comment.
This block finds the communication partner by iterating through round_pairs, which has a time complexity of O(world_size). You can optimize this to an O(1) lookup by changing the data structure of self.rounds.
I recommend modifying generate_round_robin_pairs to store the schedule as a list of dictionaries. Each dictionary would map a rank to its (partner, is_smaller) role for that round.
Example change in generate_round_robin_pairs:
# In generate_round_robin_pairs
...
self.rounds = []
teams = list(range(world_size))
for _ in range(world_size - 1):
round_schedule = {}
for i in range(world_size // 2):
team1, team2 = teams[i], teams[world_size - 1 - i]
smaller, larger = min(team1, team2), max(team1, team2)
round_schedule[smaller] = (larger, True)
round_schedule[larger] = (smaller, False)
self.rounds.append(round_schedule)
teams = [teams[0]] + [teams[-1]] + teams[1:-1]With self.rounds being a list of dictionaries, this block can be simplified to a direct and more efficient lookup. Note that the print_pairing_schedule method would also need to be updated to work with the new dictionary structure.
round_schedule = round_pairs
if cur_rank in round_schedule:
partner, is_smaller_in_pair = round_schedule[cur_rank]
else:
partner = None
is_smaller_in_pair = FalseThere was a problem hiding this comment.
updated in new commit.
| recvq_req.wait() | ||
| recvk_req.wait() | ||
| recvv_req.wait() | ||
| gathered_qkv_byte_tensors = self.load_balanced_all_to_all(qkv_shards, seq_p_group) |
There was a problem hiding this comment.
In this else block (for non-FP8 communication), the variable gathered_qkv_byte_tensors is misleading as the tensors it holds are of qkv_dtype, not byte tensors. Renaming it to gathered_qkv_shards would make the code easier to understand and maintain. You would also need to update its usage on line 332.
gathered_qkv_shards = self.load_balanced_all_to_all(qkv_shards, seq_p_group)There was a problem hiding this comment.
gathered_qkv_byte_tensors holds uint8 bytes rather than qkv_dtype, thus it should not be changed here.
Co-authored-by: root <root@pt-de4c35727a1b4d1b9f27f422f06026ec-worker-0.pt-de4c35727a1b4d1b9f27f422f06026ec.ns-devsft-3460edd0.svc.cluster.local> Co-authored-by: root <root@pt-9b2035a55fe647eeb007584b238e5077-worker-0.pt-9b2035a55fe647eeb007584b238e5077.ns-devsft-3460edd0.svc.cluster.local>
No description provided.