-
Notifications
You must be signed in to change notification settings - Fork 282
fix mm slicer #1104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix mm slicer #1104
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,9 @@ | |
| class SliceMixinBase(ABC): | ||
| """切片操作的Mixin基类""" | ||
|
|
||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() | ||
| self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() | ||
| self.bias_div_world_size_ = bias_div_world_size | ||
|
|
||
| @abstractmethod | ||
| def _slice_weight(self, weight: torch.Tensor): | ||
|
|
@@ -22,8 +21,8 @@ def _slice_bias(self, bias): | |
|
|
||
|
|
||
| class SliceMixinTpl(SliceMixinBase): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): | ||
| super().__init__(tp_rank, tp_world_size, bias_div_world_size) | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: | ||
| raise NotImplementedError("slice_weight must implement this method") | ||
|
|
@@ -41,27 +40,25 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten | |
| # 默认weight 的shape是 outxin,这也是目前最通用的约定。 | ||
| # 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 | ||
| class RowSliceMixin(SliceMixinTpl): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): | ||
| super().__init__(tp_rank, tp_world_size, bias_div_world_size) | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: | ||
| assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" | ||
| tp_size = weight.shape[0] // self.tp_world_size_ | ||
| return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] | ||
|
|
||
| def _slice_bias(self, bias) -> torch.Tensor: | ||
| def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: | ||
| assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" | ||
| tp_size = bias.shape[0] // self.tp_world_size_ | ||
| if self.bias_div_world_size_: | ||
| return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_ | ||
| return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] | ||
|
|
||
|
|
||
| # 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。 | ||
| # 后续按需要,扩展per-tensor、per-channel的量化方式。 | ||
| class QuantizedRowSliceMixin(RowSliceMixin): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): | ||
| super().__init__(tp_rank, tp_world_size, bias_div_world_size) | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: | ||
| assert ( | ||
|
|
@@ -83,25 +80,21 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten | |
|
|
||
|
|
||
| class ColSliceMixin(SliceMixinTpl): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True): | ||
| super().__init__(tp_rank, tp_world_size, bias_div_world_size) | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: | ||
| assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" | ||
| tp_size = weight.shape[1] // self.tp_world_size_ | ||
| return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] | ||
|
|
||
| def _slice_bias(self, bias) -> torch.Tensor: | ||
| assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" | ||
| tp_size = bias.shape[0] // self.tp_world_size_ | ||
| if self.bias_div_world_size_: | ||
| return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_ | ||
| return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] | ||
| def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: | ||
| return bias / self.tp_world_size_ | ||
|
|
||
|
|
||
| class QuantizedColSliceMixin(ColSliceMixin): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True): | ||
| super().__init__(tp_rank, tp_world_size, bias_div_world_size) | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: | ||
| assert ( | ||
|
|
@@ -120,3 +113,22 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten | |
| zero_point_start = tp_size * self.tp_rank_ | ||
| zero_point_end = tp_size * (self.tp_rank_ + 1) | ||
| return weight_zero_point[:, zero_point_start:zero_point_end] | ||
|
|
||
|
|
||
| # awq 的量化权重是inxout存储格式,需要定制实现。 | ||
| class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: | ||
| assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" | ||
| tp_size = bias.shape[0] // self.tp_world_size_ | ||
| return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] | ||
|
|
||
|
|
||
| class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin): | ||
| def __init__(self, tp_rank: int = None, tp_world_size: int = None): | ||
| super().__init__(tp_rank, tp_world_size) | ||
|
|
||
| def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: | ||
| return bias / self.tp_world_size_ | ||
|
Comment on lines
+118
to
+134
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The naming of these new AWQ slicer classes, and the base slicer classes they inherit from, is counter-intuitive and has likely led to the bugs in
This is because they are named based on the dimension they slice for a standard To improve clarity and prevent future bugs, I strongly recommend renaming all slicer classes to reflect the parallelism strategy they implement, rather than the dimension they slice. For example:
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |||||
| from lightllm.utils.dist_utils import get_current_device_id | ||||||
| from lightllm.common.quantization.quantize_method import QuantizationMethod | ||||||
| from typing import Dict, List, Optional, Union | ||||||
| from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin | ||||||
| from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin | ||||||
|
|
||||||
|
|
||||||
| class StandardROWMMWeight(MMWeightTpl): | ||||||
|
|
@@ -94,10 +94,8 @@ def __init__( | |||||
| tp_rank=tp_rank, | ||||||
| tp_world_size=tp_world_size, | ||||||
| ) | ||||||
| # 注意这里不是错误,因为awq的weight是按inxout存的 | ||||||
| self.param_slicer = QuantizedColSliceMixin( | ||||||
| tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=False | ||||||
| ) | ||||||
|
|
||||||
| self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There appears to be a mix-up in the slicer class used here. However,
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| class AWQMARLINROWMMWeight(AWQROWMMWeight): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a mix-up in the slicer class used here.
AWQCOLMMWeightimplements a column-parallel linear layer. For AWQ'sin x outweight layout, this requires slicing along dimension 1 (theoutdimension). The correct slicer for this isAwqQuantizedRowSliceMixin, which correctly handles slicing for weights, scales, zero-points, and bias for column-parallelism with AWQ weights.However,
AwqQuantizedColSliceMixinis used, which is designed for row-parallelism. This will result in incorrect tensor slicing and will break the model. You'll also need to update the import on line 11 to bring inAwqQuantizedRowSliceMixin.