In [1]:
import torch

In [103]:
def find_scale_and_zero_point(tensor: torch.Tensor):
    r_max, r_min = tensor.max().item(), tensor.min().item()
    
    q_max, q_min = torch.iinfo(torch.int8).max, torch.iinfo(torch.int8).min
    
    scale = (r_max - r_min)/(q_max - q_min)
    
    zero_point = q_min - (r_min / scale)
    
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        zero_point = int(round(zero_point))
        
    return scale, zero_point

def find_scale_symmetric(tensor: torch.Tensor):
    r_max = tensor.abs().max().item()
    
    q_max = torch.iinfo(torch.int8).max
    
    scale = r_max / q_max
        
    return scale

def linear_quantization_with_scale_and_zero_point(tensor: torch.Tensor, scale, zero_point = 0):
    q_max, q_min = torch.iinfo(torch.int8).max, torch.iinfo(torch.int8).min
    
    q_tensor = torch.round((tensor / scale) + zero_point)
    q_tensor = torch.clamp(q_tensor, q_min, q_max).to(torch.int8)
    
    return q_tensor

def linear_quantization(tensor: torch.Tensor):
    scale, zero_point = find_scale_and_zero_point(tensor)
    
    return linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point), scale, zero_point

def linear_symmetric_quantization(tensor: torch.Tensor):
    scale = find_scale_symmetric(tensor)
    return linear_quantization_with_scale_and_zero_point(tensor, scale, 0), scale

def linear_symmetric_quantization_per_channel(tensor: torch.Tensor, dim=0):
    output_dim = tensor.shape[dim]
    
    # Prepare storing scales
    scale = torch.zeros(output_dim)
    
    for i in range(output_dim):
        sub_tensor = tensor.select(dim, i)
        scale[i] = find_scale_symmetric(sub_tensor)
    
    scale_shape = [1] * tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    
    quantized_tensor = linear_quantization_with_scale_and_zero_point(tensor, scale, 0)
    
    return quantized_tensor, scale

def dequantize_tensor(q_tensor: torch.Tensor, scale, zero_point = 0):
    return scale * (q_tensor.to(torch.int16) - zero_point)

In [45]:
a = torch.Tensor([[1, 2, 3],
                  [2, 3, 1],
                  [2, 1, 3]])
a

tensor([[1., 2., 3.],
        [2., 3., 1.],
        [2., 1., 3.]])

In [24]:
test = torch.Tensor([[191.6, -13.5, 728.6],
                    [92.14, 295.5,  -184],
                    [0,     684.6, 245.5]])
test

tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])

In [59]:
quant_a, scale = linear_symmetric_quantization_per_channel(a, 0)
quant_a, scale

(tensor([[ 42,  85, 127],
         [ 85, 127,  42],
         [ 85,  42, 127]], dtype=torch.int8),
 tensor([[0.0236],
         [0.0236],
         [0.0236]]))

In [55]:
dequant_a = dequantize_tensor(quant_a, scale, 0)
dequant_a

tensor([[1.0079, 2.0079, 3.0000],
        [2.0000, 3.0000, 0.9921],
        [2.0000, 0.9921, 3.0000]])

In [56]:
(dequant_a - a).square().mean()

tensor(2.7556e-05)

In [94]:
b = torch.Tensor([-200, 20, 300])
b

tensor([-200.,   20.,  300.])

In [95]:
quant_s_b, scale_s_b = linear_symmetric_quantization(b)
quant_s_b, scale_s_b

(tensor([-85,   8, 127], dtype=torch.int8), 2.3622047244094486)

In [96]:
dequant_s_b = dequantize_tensor(quant_s_b, scale_s_b, 0)
dequant_s_b

tensor([-200.7874,   18.8976,  300.0000])

In [97]:
quant_b, scale_b, zero_point = linear_quantization(b)
quant_b, scale_b, zero_point

(tensor([-128,  -16,  127], dtype=torch.int8), 1.9607843137254901, -26)

In [104]:
dequant_b = dequantize_tensor(quant_b, scale_b, zero_point)
dequant_b

tensor([-200.0000,   19.6078,  300.0000])

In [99]:
(dequant_s_b - b).square().mean()

tensor(0.6117)

In [105]:
(dequant_b - b).square().mean()

tensor(0.0513)