-
Notifications
You must be signed in to change notification settings - Fork 282
fix mm slicer #1104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix mm slicer #1104
Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug in the handling of AWQ (Activation-aware Weight Quantization) weights within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the tensor-parallel slicing logic, particularly to add support for AWQ quantized weights which have a different memory layout. The changes simplify bias handling by removing the bias_div_world_size flag, which is a good improvement. However, I've found two critical issues where the new AWQ slicer classes are swapped between the column-parallel and row-parallel weight implementations, which will cause the model to fail. Additionally, the naming convention for all slicer classes is counter-intuitive and is likely the root cause of these bugs; I've suggested a refactoring to improve clarity and maintainability.
| self.param_slicer = QuantizedRowSliceMixin( | ||
| tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=True | ||
| ) | ||
| self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a mix-up in the slicer class used here. AWQCOLMMWeight implements a column-parallel linear layer. For AWQ's in x out weight layout, this requires slicing along dimension 1 (the out dimension). The correct slicer for this is AwqQuantizedRowSliceMixin, which correctly handles slicing for weights, scales, zero-points, and bias for column-parallelism with AWQ weights.
However, AwqQuantizedColSliceMixin is used, which is designed for row-parallelism. This will result in incorrect tensor slicing and will break the model. You'll also need to update the import on line 11 to bring in AwqQuantizedRowSliceMixin.
| self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) | |
| self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) |
| tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=False | ||
| ) | ||
|
|
||
| self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a mix-up in the slicer class used here. AWQROWMMWeight implements a row-parallel linear layer. For AWQ's in x out weight layout, this requires slicing along dimension 0 (the in dimension). The correct slicer for this is AwqQuantizedColSliceMixin, which correctly handles slicing for weights, scales, zero-points, and bias for row-parallelism with AWQ weights.
However, AwqQuantizedRowSliceMixin is used, which is designed for column-parallelism. This will result in incorrect tensor slicing and will break the model. You'll also need to update the import on line 12 to bring in AwqQuantizedColSliceMixin.
| self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) | |
| self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) |
| # awq 的量化权重是inxout存储格式,需要定制实现。 | ||
| class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: | ||
| assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" | ||
| tp_size = bias.shape[0] // self.tp_world_size_ | ||
| return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] | ||
|
|
||
|
|
||
| class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: | ||
| return bias / self.tp_world_size_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming of these new AWQ slicer classes, and the base slicer classes they inherit from, is counter-intuitive and has likely led to the bugs in colmm_weight.py and rowmm_weight.py.
AwqQuantizedRowSliceMixinis for column-parallelism.AwqQuantizedColSliceMixinis for row-parallelism.
This is because they are named based on the dimension they slice for a standard out x in weight, but for AWQ's in x out weights, the slicing dimension is swapped for a given parallelism type. This makes the code very difficult to reason about.
To improve clarity and prevent future bugs, I strongly recommend renaming all slicer classes to reflect the parallelism strategy they implement, rather than the dimension they slice. For example:
RowSliceMixin->ColumnParallelSliceMixinColSliceMixin->RowParallelSliceMixinAwqQuantizedRowSliceMixin->AwqColumnParallelSliceMixinAwqQuantizedColSliceMixin->AwqRowParallelSliceMixin
No description provided.