# 4 工具函数

初始化量化工具，方便使用。

In [4]:
import torch
from dataclasses import dataclass

In [None]:
# 定义一个QuantizationParams类用于保存量化过程中的参数
@dataclass
class QuantizationParams:
    scale: torch.Tensor
    zero_point: torch.Tensor
    q_min: int
    q_max: int

In [10]:
# 初始化一个随机Tensor
torch.manual_seed(0)
x = torch.randn(3, 3, dtype=torch.float32)
x

tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193]])

## 4.1 对称量化

定义对称量化函数用于返回对称量化过程中的QuantizationParams。

In [12]:
# 对称量化参数计算,接收一个张量，返回量化参数对象，可选通道进行量化
def get_symmetric_qparams(x: torch.Tensor,
                          per_channel: bool = False,
                          channel_dim: int = 0,
                          dtype = torch.int8,
                          eps: float = 1e-8):

    # 获取最大和最小值
    q_max = torch.iinfo(dtype).max
    q_min = torch.iinfo(dtype).min

    if per_channel:
        # 在channel_dim维度上做
        max_val = x.abs().amax(dim=tuple(d for d in range(x.dim()) if d != channel_dim), keepdim=True)
    else:
        max_val = x.abs().max()
    
    # 避免除0
    scale = max_val / max(q_max,1)
    scale = torch.clip(scale,eps)

    # 对称量化，zero_point恒为0
    zero_point = torch.zeros_like(scale)
    return QuantizationParams(scale=scale, zero_point=zero_point, q_min=q_min, q_max=q_max)

In [17]:
# 获取x的整体量化参数，x_shape=[3,3]
x_sym_params = get_symmetric_qparams(x)
x_sym_params

QuantizationParams(scale=tensor(0.0172), zero_point=tensor(0.), q_min=-128, q_max=127)

In [18]:
# 以dim=0作为通道，逐通道获取x的量化参数，每个通道维护其自身的s和z，x_shape=[3,3]
x_sym_params = get_symmetric_qparams(x,True,0)
x_sym_params

QuantizationParams(scale=tensor([[0.0172],
        [0.0110],
        [0.0066]]), zero_point=tensor([[0.],
        [0.],
        [0.]]), q_min=-128, q_max=127)

## 4.2 非对称量化
定义非对称量化函数用于返回对称量化过程中的QuantizationParams。

In [19]:
# 非对称量化参数计算,接收一个张量，返回量化参数对象，可选通道进行量化
def get_asymmetric_qparams(x: torch.Tensor,
                            per_channel: bool = False,
                            channel_dim: int = 0,
                            dtype = torch.int8,
                            eps: float = 1e-8) -> QuantizationParams:

    # 获取最大和最小值
    q_max = torch.iinfo(dtype).max
    q_min = torch.iinfo(dtype).min

    if per_channel:
        reduce_dims = tuple(d for d in range(x.dim()) if d != channel_dim)
        x_min = x.amin(dim=reduce_dims, keepdim=True)
        x_max = x.amax(dim=reduce_dims, keepdim=True)
    else:
        x_min = x.min()
        x_max = x.max()
    
    # 避免0除的情况
    scale = (x_max-x_min) / max(q_max - q_min,1)  
    scale = torch.clip(scale,eps) # 当x_max == x_min的时候会出现0

    zore_point = q_min - torch.round(x_min/scale)  # 四舍五入
    zore_point = torch.clip(zore_point,q_min,q_max)

    return QuantizationParams(scale=scale, zero_point=zore_point,q_min=q_min,q_max=q_max)

In [20]:
# 获取x的整体非量化参数，x_shape=[3,3]
x_asym_params = get_asymmetric_qparams(x)
x_asym_params

QuantizationParams(scale=tensor(0.0146), zero_point=tensor(21.), q_min=-128, q_max=127)

In [21]:
# 以dim=0作为通道，逐通道获取x的非量化参数，每个通道维护其自身的s和z，x_shape=[3,3]
x_asym_params = get_asymmetric_qparams(x,True,0)
x_asym_params

QuantizationParams(scale=tensor([[0.0146],
        [0.0077],
        [0.0061]]), zero_point=tensor([[ 21.],
        [ 53.],
        [-10.]]), q_min=-128, q_max=127)

## 4.3 量化和反量化
获取QuantizationParams之后对x进行量化

In [23]:
# 对输入的x进行量化，返回量化后的tensor （int8）
def quantize_tensor(x: torch.Tensor, qparams: QuantizationParams) -> torch.Tensor:
    """
    通用的量化函数：
        q = clip(round(x / scale + zero_point), q_min, q_max)
    支持 per-tensor 和 per-channel（通过广播）。
    """
    scale = qparams.scale
    zero_point = qparams.zero_point

    # 确保能广播
    # 如果是标量，则直接用；如果是 per-channel，应该已经带 keepdim=True
    q = x / scale + zero_point
    q = torch.round(q)
    q = torch.clip(q, qparams.q_min, qparams.q_max)
    q = q.to(torch.int8)
    return q

In [33]:
# 对x进行对称量化，输出对称量化之后的参数q_sym
q_sym = quantize_tensor(x,x_sym_params)
print(f"x: \n {x},{x.dtype} \n \n q_sym: \n {q_sym}")

x: 
 tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193]]),torch.float32 
 
 q_sym: 
 tensor([[  90,  -17, -127],
        [  52,  -98, -127],
        [  61,  127, -109]], dtype=torch.int8)


In [34]:
# 对x进行非对称量化，输出非对称量化之后的参数q_sym
q_asym = quantize_tensor(x,x_asym_params)
print(f"x: \n {x},{x.dtype} \n \n q_asym: \n {q_asym}")

x: 
 tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193]]),torch.float32 
 
 q_asym: 
 tensor([[ 127,    1, -128],
        [ 127,  -88, -128],
        [  56,  127, -128]], dtype=torch.int8)


In [35]:
# 对量化后的q进行反量化输出x_hat
def dequantize_tensor(q: torch.Tensor, qparams: QuantizationParams) -> torch.Tensor:
    """
    通用反量化：
        x_hat = (q - zero_point) * scale
    """
    q = q.to(torch.float32)
    scale = qparams.scale
    zero_point = qparams.zero_point
    x_hat = (q - zero_point) * scale
    return x_hat

In [36]:
# 对对称量化之后的x进行反量化
x_hat_sym = dequantize_tensor(q_sym,x_sym_params)
x_hat_sym

tensor([[ 1.5440, -0.2916, -2.1788],
        [ 0.5727, -1.0792, -1.3986],
        [ 0.4025,  0.8380, -0.7193]])

In [37]:
# 对非对称量化之后的x进行反量化
x_hat_asym = dequantize_tensor(q_asym,x_asym_params)
x_hat_asym

tensor([[ 1.5463, -0.2917, -2.1735],
        [ 0.5708, -1.0877, -1.3962],
        [ 0.4031,  0.8367, -0.7206]])

In [39]:
# 分别展示对称量化误差和非对称量化误差
err_x_sym = x_hat_sym - x
err_x_asym = x_hat_asym -x 
print(f"err_x_sym: \n {err_x_sym} \n \n err_x_asym: \n {err_x_asym}")

err_x_sym: 
 tensor([[ 3.0279e-03,  1.7799e-03,  0.0000e+00],
        [ 4.2220e-03,  5.2912e-03,  0.0000e+00],
        [-8.3023e-04,  0.0000e+00,  6.6161e-06]]) 
 
 err_x_asym: 
 tensor([[ 0.0053,  0.0017,  0.0053],
        [ 0.0024, -0.0031,  0.0024],
        [-0.0003, -0.0014, -0.0014]])


## 4.4 封装量化函数
将量化函数进行封装，保留一个量化入口即可。

In [40]:
# 量化+反量化函数入口
def quantize_dequant(
    x: torch.Tensor,
    per_channel: bool = False,
    channel_dim: int = 0,
    dtype=torch.int8,
    sym:bool = True,
):
    """
    一步完成：对称量化 + 反量化
    返回:
        x_hat: 反量化后的近似 x
        qparams: 量化参数，可重复使用
    """
    if sym:
        get_qparams = get_symmetric_qparams
    else:
        get_qparams = get_asymmetric_qparams
    
    # 获取量化参数
    qparams = get_qparams(
        x,
        dtype=dtype,
        per_channel=per_channel,
        channel_dim=channel_dim,
    )

    # 量化
    q = quantize_tensor(x, qparams)
    # 反量化
    x_hat = dequantize_tensor(q, qparams)
    return q, x_hat, qparams

In [None]:
# 对称量化
q_sym,x_hat_sym,x_sym_params = quantize_dequant(x)
print(f"x: \n {x}")
print(f"q_sym: \n {q_sym}")
print(f"x_hat_sym: \n {x_hat_sym}")
print(f"x_sym_params: \n {x_sym_params}")

x: 
 tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193]])
q_sym: 
 tensor([[  90,  -17, -127],
        [  33,  -63,  -82],
        [  24,   49,  -42]], dtype=torch.int8)
x_hat_sym: 
 tensor([[ 1.5440, -0.2916, -2.1788],
        [ 0.5661, -1.0808, -1.4068],
        [ 0.4117,  0.8406, -0.7205]])
x_sym_params: 
 QuantizationParams(scale=tensor(0.0172), zero_point=tensor(0.), q_min=-128, q_max=127)


In [42]:
# 非对称量化
q_asym,x_hat_asym,x_asym_params = quantize_dequant(x,sym=False)
print(f"x: \n {x}")
print(f"q_asym: \n {q_asym}")
print(f"x_hat_asym: \n {x_hat_asym}")
print(f"x_asym_params: \n {x_asym_params}")

x: 
 tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193]])
q_asym: 
 tensor([[ 127,    1, -128],
        [  60,  -53,  -75],
        [  49,   78,  -28]], dtype=torch.int8)
x_hat_asym: 
 tensor([[ 1.5463, -0.2917, -2.1735],
        [ 0.5689, -1.0795, -1.4004],
        [ 0.4084,  0.8315, -0.7148]])
x_asym_params: 
 QuantizationParams(scale=tensor(0.0146), zero_point=tensor(21.), q_min=-128, q_max=127)


**将函数工具写入 util 目录下的 quant_tool.py 中**