Quantization

Ref: https://github.com/hkproj/quantization-notes/tree/main

In [1]:
import torch
x = torch.randn(3, 4)
x

tensor([[-2.1104,  1.0370, -0.1524,  1.5283],
        [-0.5803,  0.4747, -0.0349, -0.4250],
        [ 1.0839,  0.2665, -0.3881,  1.3272]])

In [2]:
# quantization Types
# 1. Symmetric Quantization
# 2. Asymmetric Quantization

def clamp(x, min_val, max_val):
    x[x<min_val] = min_val  
    x[x>max_val] = max_val
    return x

def asymmetric_quantization(x, num_bits):
    # min-max quantization
    alpha = torch.max(torch.abs(x))
    beta = 0
    scale = (alpha - beta) / (2**num_bits - 1)
    zero_point = -1 * torch.round(beta / scale)
    quantized = clamp(torch.round(x / scale), -2**(num_bits-1), 2**(num_bits-1) - 1)
    return quantized, scale, zero_point

def asymmetric_dequantization(x, scale, zero_point):
    return x * scale + zero_point

In [3]:
print('x:', x)
x_q, scale, zero_point = asymmetric_quantization(x, 8)
print('x_quantized:', x_q)
x_dequantized = asymmetric_dequantization(x_q, scale, zero_point)
print('x_dequantized:', x_dequantized)
# error SSE
error = torch.sum((x - x_dequantized)**2)
print('asymmetric quantization error: ', error)


x: tensor([[-2.1104,  1.0370, -0.1524,  1.5283],
        [-0.5803,  0.4747, -0.0349, -0.4250],
        [ 1.0839,  0.2665, -0.3881,  1.3272]])
x_quantized: tensor([[-128.,  125.,  -18.,  127.],
        [ -70.,   57.,   -4.,  -51.],
        [ 127.,   32.,  -47.,  127.]])
x_dequantized: tensor([[-1.0593,  1.0345, -0.1490,  1.0510],
        [-0.5793,  0.4717, -0.0331, -0.4221],
        [ 1.0510,  0.2648, -0.3890,  1.0510]])
asymmetric quantization error:  tensor(1.4099)


In [4]:
def symmetric_quantization(x, num_bits):
    alpha = torch.max(torch.abs(x))
    scale = alpha / (2**(num_bits-1) - 1)
    quantized = torch.round(x / scale)
    return quantized, scale

def symmetric_dequantization(x, scale):
    return x * scale


In [5]:
print('x:', x)
x_q, scale = symmetric_quantization(x, 8)
print('x_quantized:', x_q)
x_dequantized = symmetric_dequantization(x_q, scale)
print('x_dequantized:', x_dequantized)
# error SSE
error = torch.sum((x - x_dequantized)**2)
print('symmetric quantization error: ', error)

x: tensor([[-2.1104,  1.0370, -0.1524,  1.5283],
        [-0.5803,  0.4747, -0.0349, -0.4250],
        [ 1.0839,  0.2665, -0.3881,  1.3272]])
x_quantized: tensor([[-127.,   62.,   -9.,   92.],
        [ -35.,   29.,   -2.,  -26.],
        [  65.,   16.,  -23.,   80.]])
x_dequantized: tensor([[-2.1104,  1.0303, -0.1496,  1.5288],
        [-0.5816,  0.4819, -0.0332, -0.4320],
        [ 1.0801,  0.2659, -0.3822,  1.3294]])
symmetric quantization error:  tensor(0.0002)


In [24]:
import numpy as np
def quantize_percentile(x, num_bits, percentile=99.9):
    alpha = np.percentile(x.numpy(), percentile)
    scale = alpha / (2**(num_bits-1) - 1)
    quantized = torch.round(x / scale)
    return quantized, scale

def dequantize_percentile(x, scale):
    return x * scale

In [25]:
print('x:', x)
x_q, scale = quantize_percentile(x, 8)
print('x_quantized:', x_q)
x_dequantized = dequantize_percentile(x_q, scale)
print('x_dequantized:', x_dequantized)
# error SSE
error = torch.sum((x - x_dequantized)**2)
print('symmetric quantization error: ', error)

x: tensor([[-2.1104,  1.0370, -0.1524,  1.5283],
        [-0.5803,  0.4747, -0.0349, -0.4250],
        [ 1.0839,  0.2665, -0.3881,  1.3272]])
x_quantized: tensor([[-176.,   86.,  -13.,  127.],
        [ -48.,   40.,   -3.,  -35.],
        [  90.,   22.,  -32.,  110.]])
x_dequantized: tensor([[-2.1149,  1.0334, -0.1562,  1.5261],
        [-0.5768,  0.4807, -0.0360, -0.4206],
        [ 1.0815,  0.2644, -0.3845,  1.3218]])
symmetric quantization error:  tensor(0.0002)


In [26]:
# get SSE 
def comapre_quantization_methods(x, num_bits):
    x_q, scale, zero_point = asymmetric_quantization(x, num_bits)
    x_dequantized = asymmetric_dequantization(x_q, scale, zero_point)
    asyn_error = torch.sum((x - x_dequantized)**2)
    print('asymmetric quantization error: ', asyn_error)
    
    x_q, scale = symmetric_quantization(x, num_bits)
    x_dequantized = symmetric_dequantization(x_q, scale)
    sym_error = torch.sum((x - x_dequantized)**2)
    print('symmetric quantization error: ', sym_error)
    
    x_q, scale = quantize_percentile(x, num_bits)
    x_dequantized = dequantize_percentile(x_q, scale)
    perc_error = torch.sum((x - x_dequantized)**2)
    print('percentile quantization error: ', perc_error)    
    return asyn_error, sym_error, perc_error



In [27]:
x
comapre_quantization_methods(x, 8)

asymmetric quantization error:  tensor(1.4099)
symmetric quantization error:  tensor(0.0002)
percentile quantization error:  tensor(0.0002)


(tensor(1.4099), tensor(0.0002), tensor(0.0002))