# 5 线性层量化
在前面的过程中我们已经学习了对称量化和非对称量化，现在我们要尝试对模型中的线性层进行量化。

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import util.quant_tool as quant_tool

## 5.1 自定义量化模块

In [None]:
# 定义权重量化模块（per-channel 对称 int8 量化）
class QuantLinear(nn.Module):
    """
    简化版的权重量化线性层：
    - 只量化 weight（symmetric per-channel int8）
    - bias 保持 FP32
    - 前向时：先临时反量化，再用 F.linear
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # # qweight: int8，形状 [out_features, in_features]
        self.register_buffer(
            "qweight",
            torch.empty(out_features, in_features, dtype=torch.int8),
        )
        # scale: per-output-channel，形状 [out_features, 1]
        self.register_buffer(
            "scale",
            torch.ones(out_features, 1, dtype=torch.float32),
        )
        # 对称量化 zero_point 固定 0，这里留个占位方便扩展
        self.register_buffer(
            "zero_point",
            torch.zeros(out_features, 1, dtype=torch.float32),
        )
        
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.float32))
        else:
            self.bias = None

    @classmethod
    def from_linear(cls, linear: nn.Linear, per_channel: bool=False, is_symmetric: bool=True) -> "QuantLinear":
        """
        给定一个 nn.Linear，构造对应的 QuantLinear 并完成权重量化。
        """
        qlinear = cls(
            in_features=linear.in_features,
            out_features=linear.out_features,
            bias=linear.bias is not None,
        )

        with torch.no_grad():
            # 获取线性层的权重
            weight = linear.weight.data.detach().float().cpu()

            if is_symmetric:
                # 对线性层进行对称量化
                qparams = quant_tool.get_symmetric_qparams(weight,per_channel)
                qweight = quant_tool.quantize_tensor(weight,qparams)
            else:
                # 对线性层进行非对称量化
                qparams = quant_tool.get_asymmetric_qparams(weight,per_channel)
                qweight = quant_tool.quantize_tensor(weight,qparams)

            # 储存缩放信息
            qlinear.qweight.copy_(qweight)
            qlinear.scale.copy_(qparams.scale)
            qlinear.zero_point.copy_(qparams.zero_point)

            # 对偏置项不做处理
            if linear.bias is not None:
                qlinear.bias.data.copy_(linear.bias.data.detach().float().cpu())
        return qlinear

    # 前向传播的时候需要进行反量化
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 反量化得到近似权重：w_hat = (q - z) * scale
        # qweight: [out, in], scale: [out, 1]
        w_hat = (self.qweight.float()) * self.scale
        return F.linear(x, w_hat, self.bias)

## 5.2 定义模型量化入口
输入一个模型，输出量化后的模型

In [24]:
from typing import Optional, List
# 对权重进行量化
def quantize_model_weights(
    model: nn.Module,
    modules_to_exclude: Optional[List[str]] = None,  # 可选参数，排除不需要量化的层
) -> nn.Module:
    """
    递归遍历模型，遇到 nn.Linear 就替换成 QuantLinear（权重量化）。
    可以通过 modules_to_exclude 按模块名排除不想量化的层。
    """
    if modules_to_exclude is None:
        modules_to_exclude = []

    for name, child in list(model.named_children()):
        full_name = name

        if isinstance(child, nn.Linear) and full_name not in modules_to_exclude:
            setattr(model, name, QuantLinear.from_linear(child))
        else:
            quantize_model_weights(child, modules_to_exclude=modules_to_exclude)
    return model

## 5.3 对模型进行量化

In [44]:
# 定义一个多层线性层用于验证
class FourLayerModel(nn.Module):
    def __init__(self, input_size=64, hidden_size1=64, hidden_size2=128, hidden_size3=128, output_size=256):
        super(FourLayerModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size1)
        self.layer2 = nn.Linear(hidden_size1, hidden_size2)
        self.layer3 = nn.Linear(hidden_size2, hidden_size3)
        self.layer4 = nn.Linear(hidden_size3, output_size)
    
    def forward(self, x):
        x = torch.relu(self.layer1(x))  
        x = torch.relu(self.layer2(x))  
        x = torch.relu(self.layer3(x))  
        x = self.layer4(x) 
        return x

In [45]:
model = FourLayerModel()
model

FourLayerModel(
  (layer1): Linear(in_features=64, out_features=64, bias=True)
  (layer2): Linear(in_features=64, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=128, bias=True)
  (layer4): Linear(in_features=128, out_features=256, bias=True)
)

In [48]:
import copy
base_model = copy.deepcopy(model)
# 对模型进行量化,不排除层
q_model = quantize_model_weights(base_model)
q_model

FourLayerModel(
  (layer1): QuantLinear()
  (layer2): QuantLinear()
  (layer3): QuantLinear()
  (layer4): QuantLinear()
)

In [49]:
# 对模型进行量化,排除['layer1']层
base_model = copy.deepcopy(model)
q_model = quantize_model_weights(model=base_model,modules_to_exclude=['layer1'])
q_model

FourLayerModel(
  (layer1): Linear(in_features=64, out_features=64, bias=True)
  (layer2): QuantLinear()
  (layer3): QuantLinear()
  (layer4): QuantLinear()
)

## 5.3 加载模型

In [16]:
# 安装transformers
!pip install transformers==4.35.0 accelerate==0.26.1 seaborn==0.13.1


Collecting accelerate==0.26.1
  Downloading accelerate-0.26.1-py3-none-any.whl.metadata (18 kB)
Collecting seaborn==0.13.1
  Downloading seaborn-0.13.1-py3-none-any.whl.metadata (5.4 kB)
Collecting pandas>=1.2 (from seaborn==0.13.1)
  Using cached pandas-2.3.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting matplotlib!=3.6.1,>=3.4 (from seaborn==0.13.1)
  Using cached matplotlib-3.10.7-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib!=3.6.1,>=3.4->seaborn==0.13.1)
  Using cached contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib!=3.6.1,>=3.4->seaborn==0.13.1)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib!=3.6.1,>=3.4->seaborn==0.13.1)
  Using cached fonttools-4.61.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadat

In [17]:
import transformers

ImportError: cannot import name 'is_jieba_available' from 'transformers.utils.import_utils' (/home/lihao/.conda/envs/fquant/lib/python3.10/site-packages/transformers/utils/import_utils.py)