### Quantize and Dequantiza a Vector

In [1]:
import torch

In [2]:
def linear_q_with_scale_and_zero_point(
    tensor,
    scale,
    zero_point,
    dtype=torch.int8
):
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)
    
    # make sure the value is between min and max quantize value
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    
    q_tensor = rounded_tensor.clamp(q_min, q_max).to(dtype)
    return q_tensor

In [3]:
test_tensor = torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5, -184],
     [0,     684.6, 245.5]]
)

In [4]:
# hardcoded scale and zero point
scale = 3.5
zero_point = -70

In [5]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
    scale, zero_point)

In [6]:
quantized_tensor

tensor([[ -15,  -74,  127],
        [ -44,   14, -123],
        [ -70,  126,    0]], dtype=torch.int8)

In [7]:
dequantized_tensor = scale * (quantized_tensor.float() - zero_point)

In [8]:
dequantized_tensor

tensor([[ 192.5000,  -14.0000,  689.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000,  686.0000,  245.0000]])

In [9]:
# problems if not casting to float
scale * (quantized_tensor - zero_point)

tensor([[ 192.5000,  -14.0000, -206.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000, -210.0000,  245.0000]])

In [10]:
def linear_dequantization(q, scale, zero_point):
    return scale * (q.float() - zero_point)

In [11]:
# quantization error
(dequantized_tensor - test_tensor).square().mean()

tensor(170.8753)

### Calculate scale and zero point

In [12]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max

print(q_min, q_max)

-128 127


In [13]:
r_min = test_tensor.min().item()
r_min

-184.0

In [14]:
r_max = test_tensor.max().item()
r_max

728.5999755859375

In [15]:
scale = (r_max-r_min)/(q_max-q_min)
scale

3.578823433670343

In [16]:
zero_point = int(round(q_min - (r_min/scale)))
zero_point

-77

In [17]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    q_min = torch.iinfo(torch.int8).min
    q_max = torch.iinfo(torch.int8).max
    r_min = tensor.min().item()
    r_max = tensor.max().item()
    
    scale = (r_max-r_min)/(q_max-q_min)
    zero_point = q_min - (r_min/scale)
    
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else: zero_point = int(round(zero_point))
    
    return scale, zero_point

In [18]:
new_scale, new_zero_point = get_q_scale_and_zero_point(test_tensor)
print(new_scale, new_zero_point)

3.578823433670343 -77


In [19]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
        new_scale, new_zero_point)

In [20]:
dequantized_tensor = linear_dequantization(quantized_tensor, new_scale,
                                            new_zero_point)

In [21]:
(dequantized_tensor - test_tensor).square().mean()

tensor(1.5730)

In [22]:
def linear_quantization(tensor, dtype=torch.int8):
    scale, zero_point = get_q_scale_and_zero_point(tensor, dtype=dtype)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
            scale, zero_point, dtype=dtype)
    
    return quantized_tensor, scale, zero_point

In [23]:
r_tensor = torch.randn((4, 4))
r_tensor

tensor([[-0.8507, -0.2863, -0.0977, -0.6233],
        [-0.4456,  0.1131,  1.0762,  0.6507],
        [ 0.3937, -0.5725,  0.2385, -1.1075],
        [-1.4498,  1.4015,  0.9169,  1.2282]])

In [24]:
quantized_tensor, scale, zero_point = linear_quantization(r_tensor)

In [25]:
quantized_tensor, scale, zero_point

(tensor([[ -74,  -24,   -7,  -54],
         [ -38,   12,   98,   60],
         [  37,  -49,   23,  -97],
         [-128,  127,   84,  112]], dtype=torch.int8),
 0.011181414828580968,
 2)

In [26]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
dequantized_tensor

tensor([[-0.8498, -0.2907, -0.1006, -0.6262],
        [-0.4473,  0.1118,  1.0734,  0.6485],
        [ 0.3913, -0.5703,  0.2348, -1.1070],
        [-1.4536,  1.3977,  0.9169,  1.2300]])

In [27]:
(dequantized_tensor - r_tensor).square().mean()

tensor(6.9450e-06)

### Symmetric mode

In [28]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max
    
    return r_max / q_max

In [29]:
test_tensor = torch.randn((4, 4))
test_tensor

tensor([[ 1.7235,  0.4838,  0.1597,  1.2745],
        [-0.3974,  0.0629, -1.5674, -0.4455],
        [ 1.3539,  1.3741,  0.8218, -1.8537],
        [ 0.3220, -1.9573,  0.9014,  0.7699]])

In [30]:
get_q_scale_symmetric(test_tensor)

0.015412120368537001

In [31]:
def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale,
        zero_point=0, dtype=dtype)
    return quantized_tensor, scale

In [32]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

In [33]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)

In [34]:
(dequantized_tensor - test_tensor).square().mean().item()

1.5442274161614478e-05

### Per channel quantization with symmetric mode

In [35]:
test_tensor = torch.tensor([
    [191.6, -13.5, 728.6],
    [92.14, 295.5, -184],
    [0,     684.6, 245.5]
])

In [36]:
dim = 0
output_dim = test_tensor.shape[dim]
output_dim

3

In [37]:
scale = torch.zeros(output_dim)
scale

tensor([0., 0., 0.])

In [38]:
for index in range(output_dim):
    sub_tensor = test_tensor.select(dim, index)
    scale[index] = get_q_scale_symmetric(sub_tensor)

In [39]:
scale

tensor([5.7370, 2.3268, 5.3906])

In [40]:
scale_shape = [1] * test_tensor.dim()
scale_shape

[1, 1]

In [41]:
scale_shape[dim] = -1
scale_shape

[-1, 1]

In [42]:
scale = scale.view(scale_shape)
scale

tensor([[5.7370],
        [2.3268],
        [5.3906]])

In [43]:
scale.shape

torch.Size([3, 1])

In [44]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
                            scale=scale, zero_point=0)
quantized_tensor

tensor([[ 33,  -2, 127],
        [ 40, 127, -79],
        [  0, 127,  46]], dtype=torch.int8)

Merge everything above in 1 function

In [45]:
def linear_q_symmetric_per_channel(tensor, dim, dtype=torch.int8):
    output_dim = tensor.shape[dim]
    scale = torch.zeros(output_dim)
    
    for index in range(output_dim):
        sub_tensor = tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor)
        
    # reshape scale
    scale_shape = [1] * tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(
        tensor, scale=scale, zero_point=0, dtype=dtype
    )

    return quantized_tensor, scale

In [46]:
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(
    test_tensor, dim=0
)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(
    test_tensor, dim=1
)

In [47]:
dequantized_tensor_0 = linear_dequantization(quantized_tensor_0, scale_0, 0)
(dequantized_tensor_0 - test_tensor).square().mean().item()

1.8084441423416138

In [48]:
dequantized_tensor_1 = linear_dequantization(quantized_tensor_1, scale_1, 0)
(dequantized_tensor_1 - test_tensor).square().mean().item()

1.0781488418579102

### Per group quantization

In [49]:
def linear_q_symmetric_per_group(tensor, group_size, dtype=torch.int8):
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0
    assert tensor.dim() == 2
    
    tensor = tensor.view(-1, group_size)
    quantized_tensor, scale = linear_q_symmetric_per_channel(
        tensor, dim=0, dtype=dtype
    )
    quantized_tensor = quantized_tensor.view(t_shape)
    return quantized_tensor, scale

In [50]:
def linear_dequantization_per_group(
    quantized_tensor, scale, group_size
):
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
    dequantized_tensor = dequantized_tensor.view(q_shape)
    
    return dequantized_tensor

In [51]:
test_tensor = torch.rand((6, 6))
group_size = 3
quantized_tensor, scale = linear_q_symmetric_per_group(
    test_tensor, group_size=group_size
)
dequantized_tensor = linear_dequantization_per_group(quantized_tensor, scale, group_size)

In [52]:
(dequantized_tensor - test_tensor).square().mean().item()

2.218512008766993e-06

### Inference Linear Quantization

weights to 8 bits, activations remains in 32 bits.

In [53]:
def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8
    
    dequantized_weight = q_w.to(torch.float32) * s_w + z_w
    output = torch.nn.functional.linear(input, dequantized_weight)
    return output

In [54]:
input = torch.tensor([1, 2, 3], dtype=torch.float32)

In [55]:
weight = torch.tensor([
    [-2, -1.13, 0.42],
    [-1.51, 0.25, 1.62],
    [0.23, 1.35, 2.15]
])

In [57]:
q_w, s_w = linear_q_symmetric(weight)
q_w, s_w

(tensor([[-118,  -67,   25],
         [ -89,   15,   96],
         [  14,   80,  127]], dtype=torch.int8),
 0.016929134609192376)

In [58]:
output = quantized_linear_W8A32_without_bias(input, q_w, s_w, 0)
output

tensor([-2.9965,  3.8768,  9.3957])

In [59]:
f32_output = torch.nn.functional.linear(input, weight)
f32_output

tensor([-3.0000,  3.8500,  9.3800])