Conversation
Summary of ChangesHello @WANDY666, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the tensor parallelism implementation within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for repeating key and value heads in QKV weight loading, which is essential for models using Grouped-Query Attention (GQA) with tensor parallelism. The implementation correctly adjusts the logic for sharding heads. My review focuses on improving code maintainability. I've pointed out areas with significant code duplication that could be refactored, a method with non-obvious side effects, and a docstring that should be translated for consistency. Addressing these points will make the code cleaner and easier to maintain.
lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py
Show resolved
Hide resolved
| def _load_weight( | ||
| self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int | ||
| ) -> None: | ||
| # 复用基类逻辑,但根据 sub_child_index 选择对应的 slicer | ||
| quanted_param_name = self.quanted_weight_names[sub_child_index] | ||
| if quanted_param_name in weights: | ||
| param_name = quanted_param_name | ||
| if param_name in weights: | ||
| slicer = self._get_param_slicer(sub_child_index) | ||
| weight = slicer._slice_weight(weights[param_name]) | ||
| self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) | ||
| return | ||
|
|
||
| def _load_bias( | ||
| self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int | ||
| ) -> None: | ||
| if param_name in weights: | ||
| slicer = self._get_param_slicer(sub_child_index) | ||
| bias = slicer._slice_bias(weights[param_name]) | ||
| self.bias_list[sub_child_index].copy_(bias) | ||
| self.bias_list[sub_child_index].load_ok = True | ||
| return | ||
|
|
||
| def _load_weight_scale( | ||
| self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int | ||
| ) -> None: | ||
| if param_name in weights: | ||
| slicer = self._get_param_slicer(sub_child_index) | ||
| weight_scale = slicer._slice_weight_scale(weights[param_name]) | ||
| self.quant_method.load_weight_scale(weight_scale, self.mm_param_list[sub_child_index]) | ||
| return | ||
|
|
||
| def _load_weight_zero_point( | ||
| self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int | ||
| ) -> None: | ||
| if param_name in weights: | ||
| slicer = self._get_param_slicer(sub_child_index) | ||
| weight_zero_point = slicer._slice_weight_zero_point(weights[param_name]) | ||
| self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param_list[sub_child_index]) | ||
| return |
There was a problem hiding this comment.
These _load_* methods (_load_weight, _load_bias, _load_weight_scale, _load_weight_zero_point) are almost identical to their counterparts in the base class MMWeightTpl, with the only difference being how the slicer is obtained. This introduces significant code duplication and makes the code harder to maintain.
A better approach would be to modify the base class to use a method like _get_param_slicer, which this class can then override. This would follow the Template Method design pattern and avoid overriding all four loading methods.
If modifying the base class is not an option, the duplication within this class should still be addressed, perhaps with a private helper method.
| def _get_tp_padded_head_num(self, head_num: int): | ||
| if head_num % self.tp_world_size_ == 0: | ||
| return head_num // self.tp_world_size_ | ||
| elif self.tp_world_size_ % head_num == 0: | ||
| self.kv_repeat_times = self.tp_world_size_ // head_num | ||
| return self.kv_repeat_times * head_num // self.tp_world_size_ | ||
| else: | ||
| raise ValueError( | ||
| f"head_num must be divisible by tp_world_size_ or " | ||
| f"tp_world_size_ must be divisible by head_num, " | ||
| f"but found: {head_num} % {self.tp_world_size_}" | ||
| ) |
There was a problem hiding this comment.
This method has a side effect: it modifies self.kv_repeat_times on line 205. Methods with get in their name are generally expected to be free of side effects, which can be surprising for future maintainers. Consider refactoring this logic to make the side effect more explicit. For example, you could move this logic into the __init__ method, or rename the method to something like _setup_kv_sharding_and_get_heads to make its behavior clearer.
20f2e63 to
c77c837
Compare
No description provided.