In [1]:
import torch

In [2]:
def linear_q_with_scale_and_zero_point(
    tensor,
    scale,
    zero_point,
    dtype=torch.int8
):
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rounded_tensor = torch.round(scaled_and_shifted_tensor)
    
    # make sure the value is between min and max quantize value
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    
    q_tensor = rounded_tensor.clamp(q_min, q_max).to(dtype)
    return q_tensor

In [3]:
test_tensor = torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5, -184],
     [0,     684.6, 245.5]]
)

In [4]:
# hardcoded scale and zero point
scale = 3.5
zero_point = -70

In [5]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor,
    scale, zero_point)

In [6]:
quantized_tensor

tensor([[ -15,  -74,  127],
        [ -44,   14, -123],
        [ -70,  126,    0]], dtype=torch.int8)

In [7]:
dequantized_tensor = scale * (quantized_tensor.float() - zero_point)

In [8]:
dequantized_tensor

tensor([[ 192.5000,  -14.0000,  689.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000,  686.0000,  245.0000]])

In [9]:
# problems if not casting to float
scale * (quantized_tensor - zero_point)

tensor([[ 192.5000,  -14.0000, -206.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000, -210.0000,  245.0000]])

In [10]:
def linear_dequantization(q, scale, zero_point):
    return scale * (q.float() - zero_point)

In [11]:
# quantization error
(dequantized_tensor - test_tensor).square().mean()

tensor(170.8753)