Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.common.quantization.quantize_method import QuantizationMethod
from typing import Dict, List, Optional, Union
from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin


class StandardCOLMMWeight(MMWeightTpl):
Expand Down Expand Up @@ -72,9 +72,7 @@ def __init__(
tp_world_size=tp_world_size,
)
# 注意这里不是错误,因为awq的weight是按inxout存的
self.param_slicer = QuantizedRowSliceMixin(
tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=True
)
self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a mix-up in the slicer class used here. AWQCOLMMWeight implements a column-parallel linear layer. For AWQ's in x out weight layout, this requires slicing along dimension 1 (the out dimension). The correct slicer for this is AwqQuantizedRowSliceMixin, which correctly handles slicing for weights, scales, zero-points, and bias for column-parallelism with AWQ weights.

However, AwqQuantizedColSliceMixin is used, which is designed for row-parallelism. This will result in incorrect tensor slicing and will break the model. You'll also need to update the import on line 11 to bring in AwqQuantizedRowSliceMixin.

Suggested change
self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)



class AWQMARLINCOLMMWeight(AWQCOLMMWeight):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
class SliceMixinBase(ABC):
"""切片操作的Mixin基类"""

def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
self.bias_div_world_size_ = bias_div_world_size

@abstractmethod
def _slice_weight(self, weight: torch.Tensor):
Expand All @@ -22,8 +21,8 @@ def _slice_bias(self, bias):


class SliceMixinTpl(SliceMixinBase):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("slice_weight must implement this method")
Expand All @@ -41,27 +40,25 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
# 默认weight 的shape是 outxin,这也是目前最通用的约定。
# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
class RowSliceMixin(SliceMixinTpl):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}"
tp_size = weight.shape[0] // self.tp_world_size_
return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]

def _slice_bias(self, bias) -> torch.Tensor:
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
tp_size = bias.shape[0] // self.tp_world_size_
if self.bias_div_world_size_:
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]


# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。
# 后续按需要,扩展per-tensor、per-channel的量化方式。
class QuantizedRowSliceMixin(RowSliceMixin):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
Expand All @@ -83,25 +80,21 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten


class ColSliceMixin(SliceMixinTpl):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True):
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}"
tp_size = weight.shape[1] // self.tp_world_size_
return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]

def _slice_bias(self, bias) -> torch.Tensor:
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
tp_size = bias.shape[0] // self.tp_world_size_
if self.bias_div_world_size_:
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias / self.tp_world_size_


class QuantizedColSliceMixin(ColSliceMixin):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True):
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
Expand All @@ -120,3 +113,22 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
zero_point_start = tp_size * self.tp_rank_
zero_point_end = tp_size * (self.tp_rank_ + 1)
return weight_zero_point[:, zero_point_start:zero_point_end]


# awq 的量化权重是inxout存储格式,需要定制实现。
class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin):
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
tp_size = bias.shape[0] // self.tp_world_size_
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]


class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin):
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
super().__init__(tp_rank, tp_world_size)

def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias / self.tp_world_size_
Comment on lines +118 to +134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The naming of these new AWQ slicer classes, and the base slicer classes they inherit from, is counter-intuitive and has likely led to the bugs in colmm_weight.py and rowmm_weight.py.

  • AwqQuantizedRowSliceMixin is for column-parallelism.
  • AwqQuantizedColSliceMixin is for row-parallelism.

This is because they are named based on the dimension they slice for a standard out x in weight, but for AWQ's in x out weights, the slicing dimension is swapped for a given parallelism type. This makes the code very difficult to reason about.

To improve clarity and prevent future bugs, I strongly recommend renaming all slicer classes to reflect the parallelism strategy they implement, rather than the dimension they slice. For example:

  • RowSliceMixin -> ColumnParallelSliceMixin
  • ColSliceMixin -> RowParallelSliceMixin
  • AwqQuantizedRowSliceMixin -> AwqColumnParallelSliceMixin
  • AwqQuantizedColSliceMixin -> AwqRowParallelSliceMixin

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.common.quantization.quantize_method import QuantizationMethod
from typing import Dict, List, Optional, Union
from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin


class StandardROWMMWeight(MMWeightTpl):
Expand Down Expand Up @@ -94,10 +94,8 @@ def __init__(
tp_rank=tp_rank,
tp_world_size=tp_world_size,
)
# 注意这里不是错误,因为awq的weight是按inxout存的
self.param_slicer = QuantizedColSliceMixin(
tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=False
)

self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a mix-up in the slicer class used here. AWQROWMMWeight implements a row-parallel linear layer. For AWQ's in x out weight layout, this requires slicing along dimension 0 (the in dimension). The correct slicer for this is AwqQuantizedColSliceMixin, which correctly handles slicing for weights, scales, zero-points, and bias for row-parallelism with AWQ weights.

However, AwqQuantizedRowSliceMixin is used, which is designed for column-parallelism. This will result in incorrect tensor slicing and will break the model. You'll also need to update the import on line 12 to bring in AwqQuantizedColSliceMixin.

Suggested change
self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)



class AWQMARLINROWMMWeight(AWQROWMMWeight):
Expand Down