# L3-B - Linear Quantization II: Finer Granularity for more Precision

In [1]:
import torch

## 一些必要的函数

In [2]:
def linear_q_symmetric(tensor, dtype = torch.int8):
    scale = get_q_scale_symmetric(tensor, dtype)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale, zero_point = 0, dtype=dtype)
    return quantized_tensor, scale

In [3]:
def get_q_scale_symmetric(tensor, dtype = torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max
    return r_max/q_max

In [9]:
def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

In [10]:
def linear_q_with_scale_and_zero_point(
    tensor, scale, zero_point, dtype=torch.int8):
    scaled_and_shifted_tensor = tensor / scale + zero_point
    
    rounded_tensor = torch.round(scaled_and_shifted_tensor)
    #最后一步是，确保我们的舍入张量在最小量化值和最大量化值之间
    q_min = torch.iinfo(dtype).min#iInfo方法获取最小值和最大值
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min, q_max).to(dtype)#定义量化张量（使用to()函数转换为我们想要的量化数据类型）
    return q_tensor

In [11]:
def quantization_error(original_tensor, dequantized_tensor, error_type="mse"):
    """
    计算量化误差的通用函数（支持 PyTorch 和 NumPy 张量）
    
    参数：
        original_tensor      : 原始浮点张量 (torch.Tensor/np.ndarray)
        dequantized_tensor   : 反量化后的重建张量 (与原始张量同类型同形状)
        error_type          : 误差计算方式，可选 "mse"(默认) 或 "mae"
    
    返回：
        量化误差值 (float)
    """
    # 校验输入类型一致性
    if type(original_tensor) != type(dequantized_tensor):
        raise TypeError("原始张量与反量化张量类型必须一致")

    # 校验形状一致性
    if original_tensor.shape != dequantized_tensor.shape:
        raise ValueError("张量形状不匹配")

    # 自动检测计算框架
    if isinstance(original_tensor, torch.Tensor):
        lib = torch
    elif isinstance(original_tensor, np.ndarray):
        lib = np
    else:
        raise TypeError("仅支持 PyTorch 或 NumPy 张量")

    # 计算误差
    diff = original_tensor - dequantized_tensor
    if error_type == "mse":
        error = lib.mean(diff ** 2)
    elif error_type == "mae":
        error = lib.mean(lib.abs(diff))
    else:
        raise ValueError("error_type 必须为 'mse' 或 'mae'")

    # 返回标量值
    return error.item() if lib == torch else float(error)

## Per Tensor
- Perform `Per Tensor` Symmetric Quantization.

In [12]:
# test tensor
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [15]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

In [16]:
quantized_tensor

tensor([[ 33,  -2, 127],
        [ 16,  52, -32],
        [  0, 119,  43]], dtype=torch.int8)

In [17]:
scale

5.737007681779035

In [14]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)

In [18]:
dequantized_tensor

tensor([[ 189.3213,  -11.4740,  728.6000],
        [  91.7921,  298.3244, -183.5842],
        [   0.0000,  682.7039,  246.6913]])

In [21]:
# 未实现的函数（TODO）
plot_quantization_errors(test_tensor, quantized_tensor,
                         dequantized_tensor)

NameError: name 'plot_quantization_errors' is not defined

In [22]:
print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor)}""")

Quantization Error : 2.5091912746429443


## Per Channel
- Implement `Per Channel` Symmetric Quantization
- `dim` parameter decides if it needs to be along the rows or columns

In [23]:
def linear_q_symmetric_per_channel(tensor,dim,dtype=torch.int8):



    return quantized_tensor, scale

In [24]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

- `dim = 0`, along the rows
- `dim = 1`, along the columns

In [25]:
dim=0
output_dim = test_tensor.shape[dim]

In [26]:
output_dim

3

In [27]:
scale = torch.zeros(output_dim)

In [28]:
scale

tensor([0., 0., 0.])

- Iterate through each row to calculate its `scale`.

In [30]:
for index in range(output_dim):
    sub_tensor = test_tensor.select(dim,index)
    print(sub_tensor)
    scale[index] = get_q_scale_symmetric(sub_tensor)

tensor([191.6000, -13.5000, 728.6000])
tensor([  92.1400,  295.5000, -184.0000])
tensor([  0.0000, 684.6000, 245.5000])


In [31]:
scale

tensor([5.7370, 2.3268, 5.3906])

In [32]:
scale_shape = [1] * test_tensor.dim()

In [33]:
scale_shape

[1, 1]

In [34]:
scale_shape[dim] = -1

In [35]:
scale_shape

[-1, 1]

In [36]:
scale = scale.view(scale_shape)

In [37]:
# copy to be used later
copy_scale = scale

scale

tensor([[5.7370],
        [2.3268],
        [5.3906]])

#### Understanding tensor by tensor division using `view` function

In [38]:
m = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])

In [39]:
m

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

In [40]:
s = torch.tensor([1,5,10])

In [41]:
s

tensor([ 1,  5, 10])

In [42]:
s.shape

torch.Size([3])

In [43]:
s.view(1, 3).shape

torch.Size([1, 3])

In [44]:
# alternate way
s.view(1, -1).shape

torch.Size([1, 3])

In [45]:
s.view(-1,1).shape

torch.Size([3, 1])

##### Along the row division

In [46]:
scale = torch.tensor([[1], [5], [10]])

In [47]:
scale.shape

torch.Size([3, 1])

In [48]:
m / scale

tensor([[1.0000, 2.0000, 3.0000],
        [0.8000, 1.0000, 1.2000],
        [0.7000, 0.8000, 0.9000]])

##### Along the column division

In [49]:
scale = torch.tensor([[1, 5, 10]])

In [50]:
scale.shape

torch.Size([1, 3])

In [51]:
m / scale

tensor([[1.0000, 0.4000, 0.3000],
        [4.0000, 1.0000, 0.6000],
        [7.0000, 1.6000, 0.9000]])

#### Coming back to quantizing the tensor

In [52]:
# the scale you got earlier
scale = copy_scale

scale

tensor([[5.7370],
        [2.3268],
        [5.3906]])

In [53]:
scale.shape

torch.Size([3, 1])

In [54]:
quantized_tensor = linear_q_with_scale_and_zero_point(
    test_tensor, scale=scale, zero_point=0)

In [55]:
quantized_tensor

tensor([[ 33,  -2, 127],
        [ 40, 127, -79],
        [  0, 127,  46]], dtype=torch.int8)

- Now, put all this in `linear_q_symmetric_per_channel` function defined earlier.

In [56]:
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
    
    output_dim = r_tensor.shape[dim]
    # store the scales
    scale = torch.zeros(output_dim)

    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
   
    return quantized_tensor, scale

In [57]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [58]:
### along the rows (dim = 0)
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(
    test_tensor, dim=0)

### along the columns (dim = 1)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(
    test_tensor, dim=1)

- Plot the quantization error for along the rows.

In [60]:
dequantized_tensor_0 = linear_dequantization(
    quantized_tensor_0, scale_0, 0)

#plot_quantization_errors(test_tensor, quantized_tensor_0, dequantized_tensor_0)

In [61]:
print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor_0)}""")

Quantization Error : 1.8084441423416138


- Plot the quantization error for along the columns.

In [64]:
dequantized_tensor_1 = linear_dequantization(
    quantized_tensor_1, scale_1, 0)

#plot_quantization_errors(test_tensor, quantized_tensor_1, dequantized_tensor_1, n_bits=8)

print(f"""Quantization Error : \
{quantization_error(test_tensor, dequantized_tensor_1)}""")

Quantization Error : 1.0781488418579102
