# Learn quantization with hands on

## Asymmetric Quantization
非对称量化
  ![alt text](resources/asym_quant.png "Title")

In [1]:
import torch
from torch import Tensor

In [2]:
def linear_q_with_scale_and_zero_points(
    tensor: Tensor, scale, zero_point, dtype=torch.int8
) -> Tensor:
    """
    Quantizes a tensor using the provided scale and zero point.

    Args:
        tensor (Tensor): The input tensor to quantize.
        scale (float): The scale factor for quantization.
        zero_point (int): The zero point for quantization.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.
    """
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    qmin = torch.iinfo(dtype).min
    qmax = torch.iinfo(dtype).max

    return rounded_tensor.clamp(qmin, qmax).to(dtype)


def get_q_scale_and_zero_point(tensor: Tensor, dtype=torch.int8) -> (float, int):
    """
    Computes the scale and zero point for quantization.

    Args:
        tensor (Tensor): The input tensor to compute scale and zero point.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.

    Returns:
        tuple: A tuple containing the scale and zero point.
    """
    qmin = torch.iinfo(dtype).min
    qmax = torch.iinfo(dtype).max
    min_val = tensor.min().item()
    max_val = tensor.max().item()
    if min_val == max_val:
        # If all values are the same, we can set scale to value and zero point to 0
        return min_val, 0

    scale = (max_val - min_val) / (qmax - qmin)
    zero_point = int(round(qmin - min_val / scale))
    zero_point = max(qmin, min(zero_point, qmax))

    return scale, zero_point


def linear_quantization(tensor: Tensor, dtype=torch.int8):
    """
    Quantizes a tensor to the specified data type using linear quantization.

    Args:
        tensor (Tensor): The input tensor to quantize.
        dtype (torch.dtype, optional): The desired data type of the output tensor. Defaults to torch.int8.

    Returns:
        Tensor: The quantized tensor.
    """
    scale, zero_point = get_q_scale_and_zero_point(tensor, dtype)
    return (
        linear_q_with_scale_and_zero_points(tensor, scale, zero_point, dtype),
        scale,
        zero_point,
    )


def linear_dequantization(quantized_tensor: Tensor, scale, zero_point) -> Tensor:
    return scale * (quantized_tensor.float() - zero_point)

In [11]:
### a dummy tensor to test the implementation
test_tensor = torch.randn(3, 3) * 10
quatization_type = torch.int8
quantized_tensor, scale, zero_point = linear_quantization(
    test_tensor, dtype=quatization_type
)
print(f"Scale: {scale}, Zero Point: {zero_point}")
print("Original Tensor:\n", test_tensor)
print("Quantized Tensor:\n", quantized_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
print("Dequantized Tensor:\n", dequantized_tensor)
print("Error", dequantized_tensor - test_tensor)
print("Relative error:\n", (dequantized_tensor - test_tensor) / (test_tensor + 1e-9))
print("MSE:\n", (dequantized_tensor - test_tensor).square().mean())

Scale: 0.07328061122520298, Zero Point: 31
Original Tensor:
 tensor([[ -5.5340,  -6.4383,  -8.9854],
        [ -0.9105, -11.6636,   7.0229],
        [ -6.8562,  -3.9043,   1.1225]])
Quantized Tensor:
 tensor([[ -45,  -57,  -92],
        [  19, -128,  127],
        [ -63,  -22,   46]], dtype=torch.int8)
Dequantized Tensor:
 tensor([[ -5.5693,  -6.4487,  -9.0135],
        [ -0.8794, -11.6516,   7.0349],
        [ -6.8884,  -3.8839,   1.0992]])
Error tensor([[-0.0353, -0.0104, -0.0281],
        [ 0.0311,  0.0120,  0.0120],
        [-0.0321,  0.0204, -0.0233]])
Relative error:
 tensor([[ 0.0064,  0.0016,  0.0031],
        [-0.0342, -0.0010,  0.0017],
        [ 0.0047, -0.0052, -0.0207]])
MSE:
 tensor(0.0006)
