### Quantize and Dequantiza a Vector

In [1]:
import torch

In [2]:
def linear_q_with_scale_and_zero_point(
    tensor,
    scale,
    zero_point,
    dtype=torch.int8
):
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)
    
    # make sure the value is between min and max quantize value
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    
    q_tensor = rounded_tensor.clamp(q_min, q_max).to(dtype)
    return q_tensor

In [3]:
test_tensor = torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5, -184],
     [0,     684.6, 245.5]]
)

In [4]:
# hardcoded scale and zero point
scale = 3.5
zero_point = -70

In [5]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
    scale, zero_point)

In [6]:
quantized_tensor

tensor([[ -15,  -74,  127],
        [ -44,   14, -123],
        [ -70,  126,    0]], dtype=torch.int8)

In [7]:
dequantized_tensor = scale * (quantized_tensor.float() - zero_point)

In [8]:
dequantized_tensor

tensor([[ 192.5000,  -14.0000,  689.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000,  686.0000,  245.0000]])

In [9]:
# problems if not casting to float
scale * (quantized_tensor - zero_point)

tensor([[ 192.5000,  -14.0000, -206.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000, -210.0000,  245.0000]])

In [10]:
def linear_dequantization(q, scale, zero_point):
    return scale * (q.float() - zero_point)

In [11]:
# quantization error
(dequantized_tensor - test_tensor).square().mean()

tensor(170.8753)

### Calculate scale and zero point

In [12]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max

print(q_min, q_max)

-128 127


In [13]:
r_min = test_tensor.min().item()
r_min

-184.0

In [14]:
r_max = test_tensor.max().item()
r_max

728.5999755859375

In [15]:
scale = (r_max-r_min)/(q_max-q_min)
scale

3.578823433670343

In [16]:
zero_point = int(round(q_min - (r_min/scale)))
zero_point

-77

In [17]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    q_min = torch.iinfo(torch.int8).min
    q_max = torch.iinfo(torch.int8).max
    r_min = tensor.min().item()
    r_max = tensor.max().item()
    
    scale = (r_max-r_min)/(q_max-q_min)
    zero_point = q_min - (r_min/scale)
    
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else: zero_point = int(round(zero_point))
    
    return scale, zero_point

In [18]:
new_scale, new_zero_point = get_q_scale_and_zero_point(test_tensor)
print(new_scale, new_zero_point)

3.578823433670343 -77


In [19]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
        new_scale, new_zero_point)

In [20]:
dequantized_tensor = linear_dequantization(quantized_tensor, new_scale,
                                            new_zero_point)

In [21]:
(dequantized_tensor - test_tensor).square().mean()

tensor(1.5730)

In [22]:
def linear_quantization(tensor, dtype=torch.int8):
    scale, zero_point = get_q_scale_and_zero_point(tensor, dtype=dtype)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
            scale, zero_point, dtype=dtype)
    
    return quantized_tensor, scale, zero_point

In [23]:
r_tensor = torch.randn((4, 4))
r_tensor

tensor([[-1.5573,  1.0115,  0.0716, -0.0499],
        [ 0.1786,  0.6451, -0.6927,  1.6553],
        [-1.5536, -0.7909,  0.8260, -0.4221],
        [-0.5926, -0.4059, -0.8752,  0.3946]])

In [24]:
quantized_tensor, scale, zero_point = linear_quantization(r_tensor)

In [25]:
quantized_tensor, scale, zero_point

(tensor([[-128,   76,    2,   -8],
         [  10,   47,  -59,  127],
         [-127,  -67,   62,  -38],
         [ -51,  -36,  -73,   27]], dtype=torch.int8),
 0.012598263048658184,
 -4)

In [26]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
dequantized_tensor

tensor([[-1.5622,  1.0079,  0.0756, -0.0504],
        [ 0.1764,  0.6425, -0.6929,  1.6504],
        [-1.5496, -0.7937,  0.8315, -0.4283],
        [-0.5921, -0.4031, -0.8693,  0.3905]])

In [27]:
(dequantized_tensor - r_tensor).square().mean()

tensor(1.5065e-05)

### Symmetric mode

In [28]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max
    
    return r_max / q_max

In [29]:
test_tensor = torch.randn((4, 4))
test_tensor

tensor([[-1.6478, -0.7517, -0.3615, -1.4731],
        [-1.4333, -2.3147, -0.3521,  0.0091],
        [-0.2756,  0.3796,  0.4729,  0.9226],
        [-0.9986,  1.0786, -0.2269,  0.1351]])

In [30]:
get_q_scale_symmetric(test_tensor)

0.018226180489607682

In [31]:
def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale,
        zero_point=0, dtype=dtype)
    return quantized_tensor, scale

In [32]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

In [34]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)

In [36]:
(dequantized_tensor - test_tensor).square().mean().item()

2.9107723094057292e-05

### Per channel quantization with symmetric mode

In [37]:
test_tensor = torch.tensor([
    [191.6, -13.5, 728.6],
    [92.14, 295.5, -184],
    [0,     684.6, 245.5]
])

In [40]:
dim = 0
output_dim = test_tensor.shape[dim]
output_dim

3

In [42]:
scale = torch.zeros(output_dim)
scale

tensor([0., 0., 0.])

In [44]:
for index in range(output_dim):
    sub_tensor = test_tensor.select(dim, index)
    scale[index] = get_q_scale_symmetric(sub_tensor)

In [45]:
scale

tensor([5.7370, 2.3268, 5.3906])

In [47]:
scale_shape = [1] * test_tensor.dim()
scale_shape

[1, 1]

In [48]:
scale_shape[dim] = -1
scale_shape

[-1, 1]

In [50]:
scale = scale.view(scale_shape)
scale

tensor([[5.7370],
        [2.3268],
        [5.3906]])

In [51]:
scale.shape

torch.Size([3, 1])

In [53]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
                            scale=scale, zero_point=0)
quantized_tensor

tensor([[ 33,  -2, 127],
        [ 40, 127, -79],
        [  0, 127,  46]], dtype=torch.int8)

Merge everything above in 1 function

In [54]:
def linear_q_symmetric_per_channel(tensor, dim, dtype=torch.int8):
    output_dim = tensor.shape[dim]
    scale = torch.zeros(output_dim)
    
    for index in range(output_dim):
        sub_tensor = tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor)
        
    # reshape scale
    scale_shape = [1] * tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(
        tensor, scale=scale, zero_point=0, dtype=dtype
    )

    return quantized_tensor, scale

In [55]:
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(
    test_tensor, dim=0
)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(
    test_tensor, dim=1
)

In [57]:
dequantized_tensor_0 = linear_dequantization(quantized_tensor_0, scale_0, 0)
(dequantized_tensor_0 - test_tensor).square().mean().item()

1.8084441423416138

In [58]:
dequantized_tensor_1 = linear_dequantization(quantized_tensor_1, scale_1, 0)
(dequantized_tensor_1 - test_tensor).square().mean().item()

1.0781488418579102

### Per group quantization

In [60]:
def linear_q_symmetric_per_group(tensor, group_size, dtype=torch.int8):
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0
    assert tensor.dim() == 2
    
    tensor = tensor.view(-1, group_size)
    quantized_tensor, scale = linear_q_symmetric_per_channel(
        tensor, dim=0, dtype=dtype
    )
    quantized_tensor = quantized_tensor.view(t_shape)
    return quantized_tensor, scale

In [61]:
def linear_dequantization_per_group(
    quantized_tensor, scale, group_size
):
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
    dequantized_tensor = dequantized_tensor.view(q_shape)
    
    return dequantized_tensor

In [62]:
test_tensor = torch.rand((6, 6))
group_size = 3
quantized_tensor, scale = linear_q_symmetric_per_group(
    test_tensor, group_size=group_size
)
dequantized_tensor = linear_dequantization_per_group(quantized_tensor, scale, group_size)

In [63]:
(dequantized_tensor - test_tensor).square().mean().item()

1.7965680854103994e-06