From 722e0e94a7f989461e33bad45d324cd7a534fb13 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 13 Nov 2025 02:32:10 +0000 Subject: [PATCH] fix awq slicer --- .../meta_weights/mm_weight/colmm_weight.py | 6 +-- .../meta_weights/mm_weight/mm_slicer.py | 54 +++++++++++-------- .../meta_weights/mm_weight/rowmm_weight.py | 8 ++- 3 files changed, 38 insertions(+), 30 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 09a08c687..281f30f02 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -8,7 +8,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin +from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin class StandardCOLMMWeight(MMWeightTpl): @@ -72,9 +72,7 @@ def __init__( tp_world_size=tp_world_size, ) # 注意这里不是错误,因为awq的weight是按inxout存的 - self.param_slicer = QuantizedRowSliceMixin( - tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=True - ) + self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQMARLINCOLMMWeight(AWQCOLMMWeight): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index 94f494d29..e3ef5b0ea 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -7,10 +7,9 @@ class SliceMixinBase(ABC): """切片操作的Mixin基类""" - def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() - self.bias_div_world_size_ = bias_div_world_size @abstractmethod def _slice_weight(self, weight: torch.Tensor): @@ -22,8 +21,8 @@ def _slice_bias(self, bias): class SliceMixinTpl(SliceMixinBase): - def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): - super().__init__(tp_rank, tp_world_size, bias_div_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError("slice_weight must implement this method") @@ -41,27 +40,25 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 # 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 class RowSliceMixin(SliceMixinTpl): - def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): - super().__init__(tp_rank, tp_world_size, bias_div_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" tp_size = weight.shape[0] // self.tp_world_size_ return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - def _slice_bias(self, bias) -> torch.Tensor: + def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" tp_size = bias.shape[0] // self.tp_world_size_ - if self.bias_div_world_size_: - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_ return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] # 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。 # 后续按需要,扩展per-tensor、per-channel的量化方式。 class QuantizedRowSliceMixin(RowSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False): - super().__init__(tp_rank, tp_world_size, bias_div_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( @@ -83,25 +80,21 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten class ColSliceMixin(SliceMixinTpl): - def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True): - super().__init__(tp_rank, tp_world_size, bias_div_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" tp_size = weight.shape[1] // self.tp_world_size_ return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - def _slice_bias(self, bias) -> torch.Tensor: - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - if self.bias_div_world_size_: - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias / self.tp_world_size_ class QuantizedColSliceMixin(ColSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True): - super().__init__(tp_rank, tp_world_size, bias_div_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( @@ -120,3 +113,22 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten zero_point_start = tp_size * self.tp_rank_ zero_point_end = tp_size * (self.tp_rank_ + 1) return weight_zero_point[:, zero_point_start:zero_point_end] + + +# awq 的量化权重是inxout存储格式,需要定制实现。 +class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" + tp_size = bias.shape[0] // self.tp_world_size_ + return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + +class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias / self.tp_world_size_ diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 070b5105f..0eebdc74d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -9,7 +9,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin +from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin class StandardROWMMWeight(MMWeightTpl): @@ -94,10 +94,8 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - # 注意这里不是错误,因为awq的weight是按inxout存的 - self.param_slicer = QuantizedColSliceMixin( - tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=False - ) + + self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQMARLINROWMMWeight(AWQROWMMWeight):