In [4]:
import torch

In [5]:
def linear_q_with_scale_and_zero_point(
    tensor, scale, zero_point, dtype = torch.int8):

    scaled_and_shifted_tensor = tensor / scale + zero_point

    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor

In [6]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [7]:
scale = 3.5
zero_point = -70

quantized_tensor = linear_q_with_scale_and_zero_point(
    test_tensor, scale, zero_point)

quantized_tensor

tensor([[ -15,  -74,  127],
        [ -44,   14, -123],
        [ -70,  126,    0]], dtype=torch.int8)

In [9]:
def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

In [10]:
dequantized_tensor = linear_dequantization(
    quantized_tensor, scale, zero_point)

dequantized_tensor

tensor([[ 192.5000,  -14.0000,  689.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000,  686.0000,  245.0000]])

In [11]:
(dequantized_tensor - test_tensor).square().mean()

tensor(170.8753)

In [13]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max

print( q_min,q_max)

-128 127


In [14]:
r_max = test_tensor.max().item()
r_min=test_tensor.min().item()

print(r_max,r_min)

728.5999755859375 -184.0


In [15]:
scale = (r_max - r_min) / (q_max - q_min)

scale

3.578823433670343

In [17]:
zero_point = int(round(q_min - (r_min / scale)))

zero_point

-77

In [18]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    r_min, r_max = tensor.min().item(), tensor.max().item()

    scale = (r_max - r_min) / (q_max - q_min)

    zero_point = q_min - (r_min / scale)

    # clip the zero_point to fall in [quantized_min, quantized_max]
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        # round and cast to int
        zero_point = int(round(zero_point))
    
    return scale, zero_point

In [19]:
new_scale, new_zero_point = get_q_scale_and_zero_point(
    test_tensor)

print(new_scale,new_zero_point)

3.578823433670343 -77


In [22]:
quantized_tensor = linear_q_with_scale_and_zero_point(
    test_tensor, new_scale, new_zero_point)


dequantized_tensor = linear_dequantization(quantized_tensor,
                                           new_scale, new_zero_point)

In [23]:
(dequantized_tensor-test_tensor).square().mean()

tensor(1.5730)

In [24]:
def linear_quantization(tensor, dtype=torch.int8):
    scale, zero_point = get_q_scale_and_zero_point(tensor, 
                                                   dtype=dtype)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                          scale, 
                                                          zero_point, 
                                                          dtype=dtype)
    
    return quantized_tensor, scale , zero_point

In [25]:
r_tensor = torch.randn((4, 4))
r_tensor

tensor([[-0.6258,  0.5426, -0.1811, -0.3141],
        [ 2.3216, -0.2641,  0.6738, -0.5274],
        [-1.4387, -2.0974,  1.2796,  0.7958],
        [-0.2209, -0.2677,  0.0214,  1.4747]])

In [26]:
quantized_tensor, scale, zero_point = linear_quantization(r_tensor)

In [27]:
quantized_tensor,scale,zero_point

(tensor([[ -43,   24,  -17,  -25],
         [ 127,  -22,   32,  -37],
         [ -90, -128,   67,   39],
         [ -20,  -22,   -6,   78]], dtype=torch.int8),
 0.01732910567638921,
 -7)

In [29]:
dequantized_tensor = linear_dequantization(quantized_tensor,
                                           scale, zero_point)

dequantized_tensor

tensor([[-0.6238,  0.5372, -0.1733, -0.3119],
        [ 2.3221, -0.2599,  0.6758, -0.5199],
        [-1.4383, -2.0968,  1.2824,  0.7971],
        [-0.2253, -0.2599,  0.0173,  1.4730]])

In [30]:
(dequantized_tensor-r_tensor).square().mean()

tensor(1.7941e-05)

In [31]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max

    # return the scale
    return r_max/q_max

In [32]:
test_tensor = torch.randn((4, 4))

test_tensor

tensor([[-0.9488,  1.7684,  0.3275, -1.6473],
        [-0.6891,  1.6827,  0.0492, -1.3059],
        [-0.8931,  0.0564,  0.2176,  0.7586],
        [ 0.8153, -1.8107, -0.3588, -1.4923]])

In [33]:
get_q_scale_symmetric(test_tensor)

0.014257444171454962

In [35]:
def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                     scale=scale,
                   # in symmetric quantization zero point is = 0    
                                                    zero_point=0,
                                                      dtype=dtype)
    
    return quantized_tensor, scale

In [36]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

quantized_tensor

tensor([[ -67,  124,   23, -116],
        [ -48,  118,    3,  -92],
        [ -63,    4,   15,   53],
        [  57, -127,  -25, -105]], dtype=torch.int8)

In [37]:
dequantized_tensor = linear_dequantization(quantized_tensor,scale,0)

dequantized_tensor

tensor([[-0.9552,  1.7679,  0.3279, -1.6539],
        [-0.6844,  1.6824,  0.0428, -1.3117],
        [-0.8982,  0.0570,  0.2139,  0.7556],
        [ 0.8127, -1.8107, -0.3564, -1.4970]])

In [40]:
(dequantized_tensor-test_tensor).square().mean()

tensor(1.6668e-05)

In [41]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [43]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

quantized_tensor

tensor([[ 33,  -2, 127],
        [ 16,  52, -32],
        [  0, 119,  43]], dtype=torch.int8)

In [44]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)

dequantized_tensor

tensor([[ 189.3213,  -11.4740,  728.6000],
        [  91.7921,  298.3244, -183.5842],
        [   0.0000,  682.7039,  246.6913]])

In [45]:
(dequantized_tensor-test_tensor).square().mean()

tensor(2.5092)

In [46]:
def linear_q_symmetric_per_channel(tensor,dim,dtype=torch.int8):



    return quantized_tensor, scale

In [47]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [48]:
dim=0
output_dim = test_tensor.shape[dim]

output_dim

3

In [49]:
scale = torch.zeros(output_dim)

scale

tensor([0., 0., 0.])

In [51]:
for index in range(output_dim):
    sub_tensor = test_tensor.select(dim,index)
    # print(sub_tensor)
    scale[index] = get_q_scale_symmetric(sub_tensor)

scale

tensor([5.7370, 2.3268, 5.3906])

In [53]:
scale_shape = [1] * test_tensor.dim()

scale_shape

[1, 1]

In [55]:
scale_shape[dim] = -1

scale_shape

[-1, 1]

In [57]:
scale = scale.view(scale_shape)

copy_scale = scale

scale

tensor([[5.7370],
        [2.3268],
        [5.3906]])

In [58]:
m = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
s = torch.tensor([1,5,10])

In [59]:
s.view(-1,1).shape

torch.Size([3, 1])

In [61]:
scale = torch.tensor([[1], [5], [10]])

scale.shape

torch.Size([3, 1])

In [62]:
m/scale

tensor([[1.0000, 2.0000, 3.0000],
        [0.8000, 1.0000, 1.2000],
        [0.7000, 0.8000, 0.9000]])

In [64]:
scale = torch.tensor([[1, 5, 10]])

scale.shape

torch.Size([1, 3])

In [65]:
m/scale

tensor([[1.0000, 0.4000, 0.3000],
        [4.0000, 1.0000, 0.6000],
        [7.0000, 1.6000, 0.9000]])

In [66]:
scale = copy_scale

scale,scale.shape

(tensor([[5.7370],
         [2.3268],
         [5.3906]]),
 torch.Size([3, 1]))

In [67]:
quantized_tensor = linear_q_with_scale_and_zero_point(
    test_tensor, scale=scale, zero_point=0)

quantized_tensor

tensor([[ 33,  -2, 127],
        [ 40, 127, -79],
        [  0, 127,  46]], dtype=torch.int8)

In [68]:
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
    
    output_dim = r_tensor.shape[dim]
    # store the scales
    scale = torch.zeros(output_dim)

    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
   
    return quantized_tensor, scale

In [69]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [70]:
### along the rows (dim = 0)
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(
    test_tensor, dim=0)

### along the columns (dim = 1)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(
    test_tensor, dim=1)

In [72]:
dequantized_tensor_0 = linear_dequantization(
    quantized_tensor_0, scale_0, 0)

dequantized_tensor

tensor([[ 189.3213,  -11.4740,  728.6000],
        [  91.7921,  298.3244, -183.5842],
        [   0.0000,  682.7039,  246.6913]])

In [74]:
(dequantized_tensor_0-test_tensor).square().mean()

tensor(1.8084)

In [75]:
dequantized_tensor_1 = linear_dequantization(
    quantized_tensor_1, scale_1, 0)
 
dequantized_tensor_1

tensor([[ 191.6000,  -16.1717,  728.6000],
        [  92.0284,  296.4803, -183.5842],
        [   0.0000,  684.6000,  246.6913]])

In [76]:
(dequantized_tensor_1-test_tensor).square().mean()

tensor(1.0781)

In [77]:
def linear_q_symmetric_per_group(tensor, group_size,
                                 dtype=torch.int8):
    
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0
    assert tensor.dim() == 2
    
    tensor = tensor.view(-1, group_size)
    
    quantized_tensor, scale = linear_q_symmetric_per_channel(
                                tensor, dim=0, dtype=dtype)
    
    quantized_tensor = quantized_tensor.view(t_shape)
    
    return quantized_tensor, scale

In [78]:
def linear_dequantization_per_group(quantized_tensor, scale, 
                                    group_size):
    
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    
    dequantized_tensor = linear_dequantization(quantized_tensor, 
                                               scale, 0)
    
    dequantized_tensor = dequantized_tensor.view(q_shape)
    
    return dequantized_tensor

In [80]:
test_tensor = torch.rand((6, 6))
test_tensor

tensor([[0.8597, 0.4113, 0.0419, 0.7496, 0.7314, 0.7832],
        [0.9978, 0.5952, 0.3123, 0.5696, 0.5188, 0.3179],
        [0.3016, 0.0920, 0.8069, 0.8225, 0.8090, 0.1946],
        [0.4224, 0.9490, 0.3638, 0.9033, 0.6693, 0.4078],
        [0.8337, 0.5264, 0.2299, 0.9628, 0.4466, 0.1192],
        [0.5687, 0.9650, 0.4831, 0.2920, 0.2206, 0.0101]])

In [81]:
group_size = 3
quantized_tensor, scale = linear_q_symmetric_per_group(
    test_tensor, group_size=group_size)

dequantized_tensor = linear_dequantization_per_group(
    quantized_tensor, scale, group_size=group_size)

In [82]:
(dequantized_tensor-test_tensor).square().mean()

tensor(2.4995e-06)

In [83]:
def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8

    dequantized_weight = q_w.to(torch.float32) * s_w + z_w
    output = torch.nn.functional.linear(input, dequantized_weight)
    
    return output

In [84]:
input = torch.tensor([1, 2, 3], dtype=torch.float32)
weight = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])

In [85]:
q_w, s_w  = linear_q_symmetric(weight)

q_w,s_w

(tensor([[-118,  -67,   25],
         [ -89,   15,   96],
         [  14,   80,  127]], dtype=torch.int8),
 0.016929134609192376)

In [87]:
output = quantized_linear_W8A32_without_bias(input,
                                             q_w,
                                             s_w,
                                             0)

output

tensor([-2.9965,  3.8768,  9.3957])

In [88]:
fp32_output = torch.nn.functional.linear(input, weight)
fp32_output

tensor([-3.0000,  3.8500,  9.3800])