# Learn quantization with hands on

## Asymmetric Quantization
非对称量化是一种神经网络量化方法，将浮点权重或激活值映射到整数范围时，允许零点（zero-point）偏移，即量化区间的最小值不一定对应浮点值中的0。这使得量化区间更灵活，能更好适应非对称分布的数据（如包含负数和正数），从而减少量化误差，提升模型精度。

  ![alt text](resources/asym_quant.png "Title")

In [1]:
import torch
from torch import Tensor
import torch.nn as nn

In [2]:
def linear_q_with_scale_and_zero_points(
    tensor: Tensor, scale, zero_point, dtype=torch.int8
) -> Tensor:
    """
    Quantizes a tensor using the provided scale and zero point.

    Args:
        tensor (Tensor): The input tensor to quantize.
        scale (float): The scale factor for quantization.
        zero_point (int): The zero point for quantization.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.
    """
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    qmin = torch.iinfo(dtype).min
    qmax = torch.iinfo(dtype).max

    return rounded_tensor.clamp(qmin, qmax).to(dtype)


def get_q_scale_and_zero_point(tensor: Tensor, dtype=torch.int8) -> (float, int):
    """
    Computes the scale and zero point for quantization.

    Args:
        tensor (Tensor): The input tensor to compute scale and zero point.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.

    Returns:
        tuple: A tuple containing the scale and zero point.
    """
    qmin = torch.iinfo(dtype).min
    qmax = torch.iinfo(dtype).max
    min_val = tensor.min().item()
    max_val = tensor.max().item()
    if min_val == max_val:
        # If all values are the same, we can set scale to value and zero point to 0
        return min_val, 0

    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = int(round(qmin - min_val / scale))
    zero_point = max(qmin, min(zero_point, qmax))

    return scale, zero_point


def linear_quantization(tensor: Tensor, dtype=torch.int8):
    """
    Quantizes a tensor to the specified data type using linear quantization.

    Args:
        tensor (Tensor): The input tensor to quantize.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.

    Returns:
        Tensor: The quantized tensor.
    """
    scale, zero_point = get_q_scale_and_zero_point(tensor, dtype)
    return (
        linear_q_with_scale_and_zero_points(tensor, scale, zero_point, dtype),
        scale,
        zero_point,
    )


def linear_dequantization(quantized_tensor: Tensor, scale, zero_point) -> Tensor:
    return scale * (quantized_tensor.float() - zero_point)

In [3]:
### a dummy tensor to test the implementation
test_tensor = torch.randn(3, 3) * 10
quatization_type = torch.int8
quantized_tensor, scale, zero_point = linear_quantization(
    test_tensor, dtype=quatization_type
)
print(f"Scale: {scale}, Zero Point: {zero_point}")
print("Original Tensor:\n", test_tensor)
print("Quantized Tensor:\n", quantized_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
print("Dequantized Tensor:\n", dequantized_tensor)
print("Error", dequantized_tensor - test_tensor)
print("Relative error:\n", (dequantized_tensor - test_tensor) / (test_tensor + 1e-9))
print("MSE:\n", (dequantized_tensor - test_tensor).square().mean())

Scale: 0.12882560281192554, Zero Point: 3
Original Tensor:
 tensor([[ -6.4771,   0.7486,  -5.5430],
        [-16.8492,  14.8348,  10.8760],
        [  2.0805,   5.5973,  16.0013]])
Quantized Tensor:
 tensor([[ -47,    9,  -40],
        [-128,  118,   87],
        [  19,   46,  127]], dtype=torch.int8)
Dequantized Tensor:
 tensor([[ -6.4413,   0.7730,  -5.5395],
        [-16.8762,  14.8149,  10.8214],
        [  2.0612,   5.5395,  15.9744]])
Error tensor([[ 0.0359,  0.0243,  0.0035],
        [-0.0269, -0.0199, -0.0546],
        [-0.0193, -0.0578, -0.0269]])
Relative error:
 tensor([[-0.0055,  0.0325, -0.0006],
        [ 0.0016, -0.0013, -0.0050],
        [-0.0093, -0.0103, -0.0017]])
MSE:
 tensor(0.0012)


## Symmetric Quantization
对称量化是一种神经网络量化方法，将浮点值对称地映射到整数范围，要求量化区间的最小值和最大值关于零对称（如[-a, a]），且浮点零映射到整数零点。它不引入零点偏移，计算更简单高效，但适用于数据分布接近对称的情况，常用于权重量化。

  ![alt text](resources/sym_quant.png "Title")

In [4]:
def get_q_scale_symmetric(tensor: Tensor, dtype=torch.int8) -> float:
    """
    Computes the symmetric scale for quantization.

    Args:
        tensor (Tensor): The input tensor to compute scale.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.

    Returns:
        float: The symmetric scale.
    """
    tensor_max = tensor.abs().max().item()
    qmax = torch.iinfo(dtype).max
    return tensor_max / qmax if tensor_max != 0 else 1.0


def linear_q_symmetric(tensor: Tensor, dtype=torch.int8):
    """
    Quantizes a tensor using symmetric quantization.

    Args:
        tensor (Tensor): The input tensor to quantize.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.

    Returns:
        Tensor: The quantized tensor.
    """
    scale = get_q_scale_symmetric(tensor, dtype)
    quantized_tensor = linear_q_with_scale_and_zero_points(
        tensor, scale=scale, zero_point=0, dtype=dtype
    )
    return quantized_tensor, scale

In [5]:
test_tensor = torch.randn(4, 4) * 10
quantized_tensor, scale = linear_q_symmetric(test_tensor, dtype=torch.int8)
print(f"Symmetric Scale: {scale}")
print("Original Tensor:\n", test_tensor)
print("Quantized Tensor:\n", quantized_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
print("Dequantized Tensor:\n", dequantized_tensor)
print("Error", dequantized_tensor - test_tensor)
print("Relative error:\n", (dequantized_tensor - test_tensor) / (test_tensor + 1e-9))
print("MSE:\n", (dequantized_tensor - test_tensor).square().mean())
print("Max absolute error:", (dequantized_tensor - test_tensor).abs().max())

Symmetric Scale: 0.17191350741649236
Original Tensor:
 tensor([[  4.3260,  -6.3765,  -5.5382,  -4.1744],
        [  3.3907,   2.1706,   3.6630,  -9.7095],
        [ 21.2197,  -3.7216, -14.5955, -12.1297],
        [-21.8330,  -3.5494, -14.7311, -14.0608]])
Quantized Tensor:
 tensor([[  25,  -37,  -32,  -24],
        [  20,   13,   21,  -56],
        [ 123,  -22,  -85,  -71],
        [-127,  -21,  -86,  -82]], dtype=torch.int8)
Dequantized Tensor:
 tensor([[  4.2978,  -6.3608,  -5.5012,  -4.1259],
        [  3.4383,   2.2349,   3.6102,  -9.6272],
        [ 21.1454,  -3.7821, -14.6126, -12.2059],
        [-21.8330,  -3.6102, -14.7846, -14.0969]])
Error tensor([[-0.0281,  0.0157,  0.0369,  0.0485],
        [ 0.0476,  0.0643, -0.0528,  0.0824],
        [-0.0744, -0.0605, -0.0171, -0.0762],
        [ 0.0000, -0.0608, -0.0535, -0.0361]])
Relative error:
 tensor([[-0.0065, -0.0025, -0.0067, -0.0116],
        [ 0.0140,  0.0296, -0.0144, -0.0085],
        [-0.0035,  0.0163,  0.0012,  0.0063],
  

## Per Channel Quantization

In [6]:
from timm.layers import linear


def linear_q_symmetric_per_channel(tensor: Tensor, dim: int, dtype=torch.int8):
    """Do symmetric quantization per channel.

    Args:
        tensor: input Tensor.
        dim: Dim to quantize over.
        dtype: _description_. Defaults to torch.int8.
    """
    output_dim = tensor.shape[dim]
    scales = torch.zeros(output_dim, dtype=torch.float32)
    for index in range(output_dim):
        sub_tensor = tensor.select(dim, index)
        scales[index] = get_q_scale_symmetric(sub_tensor, dtype)
    scale_shape = [1] * tensor.dim()
    scale_shape[dim] = output_dim
    scales = scales.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_points(
        tensor, scales, zero_point=0, dtype=dtype
    )
    return quantized_tensor, scales

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
test_tensor = torch.tensor(
    [
        [191.6, -13.5, 728.6],
        [92.14, 295.5, -184],
        [0, 684.6, 245.5],
    ]
)
quantized_tensor, scale = linear_q_symmetric_per_channel(
    test_tensor, dim=0, dtype=torch.int8
)
print(f"Symmetric Per Channel Scale: {scale}")
print("Original Tensor:\n", test_tensor)
print("Quantized Tensor:\n", quantized_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
print("Dequantized Tensor:\n", dequantized_tensor)
print("Error", dequantized_tensor - test_tensor)
print("Relative error:\n", (dequantized_tensor - test_tensor) / (test_tensor + 1e-9))
print("MSE:\n", (dequantized_tensor - test_tensor).square().mean())
print("Max absolute error:", (dequantized_tensor - test_tensor).abs().max())

Symmetric Per Channel Scale: tensor([[5.7370],
        [2.3268],
        [5.3906]])
Original Tensor:
 tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])
Quantized Tensor:
 tensor([[ 33,  -2, 127],
        [ 40, 127, -79],
        [  0, 127,  46]], dtype=torch.int8)
Dequantized Tensor:
 tensor([[ 189.3213,  -11.4740,  728.6000],
        [  93.0709,  295.5000, -183.8150],
        [   0.0000,  684.6000,  247.9653]])
Error tensor([[-2.2787,  2.0260,  0.0000],
        [ 0.9309,  0.0000,  0.1850],
        [ 0.0000,  0.0000,  2.4653]])
Relative error:
 tensor([[-0.0119, -0.1501,  0.0000],
        [ 0.0101,  0.0000, -0.0010],
        [ 0.0000,  0.0000,  0.0100]])
MSE:
 tensor(1.8084)
Max absolute error: tensor(2.4653)


In [8]:
quantized_tensor, scale = linear_q_symmetric_per_channel(
    test_tensor, dim=1, dtype=torch.int8
)
print(f"Symmetric Per Channel Scale: {scale}")
print("Original Tensor:\n", test_tensor)
print("Quantized Tensor:\n", quantized_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
print("Dequantized Tensor:\n", dequantized_tensor)
print("Error", dequantized_tensor - test_tensor)
print("Relative error:\n", (dequantized_tensor - test_tensor) / (test_tensor + 1e-9))
print("MSE:\n", (dequantized_tensor - test_tensor).square().mean())
print("Max absolute error:", (dequantized_tensor - test_tensor).abs().max())

Symmetric Per Channel Scale: tensor([[1.5087, 5.3906, 5.7370]])
Original Tensor:
 tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])
Quantized Tensor:
 tensor([[127,  -3, 127],
        [ 61,  55, -32],
        [  0, 127,  43]], dtype=torch.int8)
Dequantized Tensor:
 tensor([[ 191.6000,  -16.1717,  728.6000],
        [  92.0284,  296.4803, -183.5842],
        [   0.0000,  684.6000,  246.6913]])
Error tensor([[ 0.0000, -2.6717,  0.0000],
        [-0.1116,  0.9803,  0.4158],
        [ 0.0000,  0.0000,  1.1913]])
Relative error:
 tensor([[ 0.0000,  0.1979,  0.0000],
        [-0.0012,  0.0033, -0.0023],
        [ 0.0000,  0.0000,  0.0049]])
MSE:
 tensor(1.0781)
Max absolute error: tensor(2.6717)


## Per Group Quantization

In [9]:
from click import group


def linear_q_symmetric_per_group(tensor: Tensor, group_size: int, dtype=torch.int8):
    """Do symmetric quantization per group.

    Args:
        tensor: input Tensor.
        group_size: Size of the group to quantize over.
        dtype: _description_. Defaults to torch.int8.
    """
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0, "Tensor size must be divisible by group_size"
    assert tensor.dim() == 2, "Tensor must be 2D for group quantization"

    tensor = tensor.view(-1, group_size)
    quantized_tensor, scale = linear_q_symmetric_per_channel(tensor, dim=0, dtype=dtype)
    quantized_tensor = quantized_tensor.view(t_shape)

    return quantized_tensor, scale


def linear_dequantization_per_group(
    quantized_tensor: Tensor, scale, group_size: int
) -> Tensor:
    """Dequantizes a tensor that was quantized per group.

    Args:
        quantized_tensor: The quantized tensor to dequantize.
        scale: The scale used for quantization.
        group_size: Size of the group used for quantization.
    """
    t_shape = quantized_tensor.shape
    assert t_shape[1] % group_size == 0, "Tensor size must be divisible by group_size"
    assert quantized_tensor.dim() == 2, "Tensor must be 2D for group dequantization"
    quantized_tensor = quantized_tensor.view(-1, group_size)
    dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
    dequantized_tensor = dequantized_tensor.view(t_shape)

    return dequantized_tensor

In [10]:
test_tensor = torch.randn(6, 6) * 10
group_size = 3
quantized_tensor, scale = linear_q_symmetric_per_group(
    test_tensor, group_size=group_size, dtype=torch.int8
)
print(f"Symmetric Per Group Scale: {scale}")
print("Original Tensor:\n", test_tensor)
print("Quantized Tensor:\n", quantized_tensor)
dequantized_tensor = linear_dequantization_per_group(
    quantized_tensor, scale, group_size
)
print("Dequantized Tensor:\n", dequantized_tensor)
print("Error", dequantized_tensor - test_tensor)
print("Relative error:\n", (dequantized_tensor - test_tensor) / (test_tensor + 1e-9))
print("MSE:\n", (dequantized_tensor - test_tensor).square().mean())
print("Max absolute error:", (dequantized_tensor - test_tensor).abs().max())

Symmetric Per Group Scale: tensor([[0.1412],
        [0.1113],
        [0.0864],
        [0.0295],
        [0.1109],
        [0.1820],
        [0.1905],
        [0.0663],
        [0.1533],
        [0.0806],
        [0.0367],
        [0.1105]])
Original Tensor:
 tensor([[ 17.4175,  17.3003, -17.9289, -10.8549,  -3.3364,  14.1369],
        [ 10.9709,   8.9511,   3.1999,   3.7452,   3.1249,  -0.2184],
        [-12.7733,  10.0035,  14.0848,  -3.9805, -23.1181,  11.8203],
        [  1.1131, -24.1891,  -9.8270,  -7.5528,   8.4201,  -6.2769],
        [-19.4720, -11.6513,  -3.5473, -10.2308,  -2.6917,  -0.8038],
        [  0.3371,   4.6565,   1.7423,  -8.6143,   5.0310, -14.0398]])
Quantized Tensor:
 tensor([[ 123,  123, -127,  -98,  -30,  127],
        [ 127,  104,   37,  127,  106,   -7],
        [-115,   90,  127,  -22, -127,   65],
        [   6, -127,  -52, -114,  127,  -95],
        [-127,  -76,  -23, -127,  -33,  -10],
        [   9,  127,   48,  -78,   46, -127]], dtype=torch.int8)
Deq

## Inference Linear Quantization W8A32

Weight only quantization and activation is performed with float32.

In [11]:
def quantized_linear_W8A32_without_bias(input: Tensor, q_w, s_w, z_w) -> Tensor:
    """
    Computes the output of a linear layer with weight quantization.

    Args:
        input (Tensor): The input tensor.
        q_w (Tensor): The quantized weight tensor.
        s_w (float): The scale for the weights.
        z_w (int): The zero point for the weights.

    Returns:
        Tensor: The output tensor after applying the quantized linear transformation.
    """
    assert (
        input.dtype == torch.float32
    ), "Input tensor must be float32 for linear operation"
    assert q_w.dtype == torch.int8, "Quantized weights must be int8"

    # Dequantize weights
    w = linear_dequantization(q_w, s_w, z_w)
    # Perform linear operation
    return torch.nn.functional.linear(input, w, bias=None)


input = torch.tensor([1, 2, 3], dtype=torch.float32)
weight = torch.tensor([[-2, -1.13, 0.42], [-1.51, 0.25, 1.62], [0.23, 1.35, 2.15]])
q_w, s_w = linear_q_symmetric(weight)
output_quant = quantized_linear_W8A32_without_bias(input, q_w, s_w, 0)
output_origin = torch.nn.functional.linear(input, weight, bias=None)
print("Output with quantized weights:\n", output_quant)
print("Output with original weights:\n", output_origin)
print("Difference:\n", output_quant - output_origin)
print("Relative error:\n", (output_quant - output_origin) / (output_origin + 1e-9))
print("MSE:\n", (output_quant - output_origin).square().mean())
print("Max absolute error:", (output_quant - output_origin).abs().max())

Output with quantized weights:
 tensor([-2.9965,  3.8768,  9.3957])
Output with original weights:
 tensor([-3.0000,  3.8500,  9.3800])
Difference:
 tensor([0.0035, 0.0268, 0.0157])
Relative error:
 tensor([-0.0012,  0.0070,  0.0017])
MSE:
 tensor(0.0003)
Max absolute error: tensor(0.0268)
