# Linear Quantization

Quantization of Linear Layers. For more details on Quantization refer [previous notebook](../notebooks/9-dealing-with-few-to-no-labels.ipynb)

## Quantization

$f = \left(\frac{{f_{\text{max}} - f_{\text{min}}}}{{q_{\text{max}} - q_{\text{min}}}}\right) (q - Z) = S(q - Z)$

$S - scale factor, Z - Zero Point, q - quantized value$

$q = \frac{f}{S} + z$

$q = int(round(q))$

Let's implement this equation next.

In [5]:
import torch

def linear_q_with_scale_and_zerp_point(
    tensor, scale, zero_point, dtype=torch.int8
):
    scaled_and_shifted_tensor = tensor / scale + zero_point
    rouded_tensor = torch.round(scaled_and_shifted_tensor)
    # clamp to the range of the dtype
    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max
    clamped_tensor = torch.clamp(rouded_tensor, q_min, q_max)
    # Cast to the dtype
    return clamped_tensor.to(dtype)

In [6]:
# Define a test tensor of shape (3,3) of FP32, and calculate scale, zeropoint to compare quantizatoin function above
test_tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
scale = 0.1
zero_point = 0

quantized_tensor = linear_q_with_scale_and_zerp_point(test_tensor, scale, zero_point)

In [7]:
quantized_tensor

tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]], dtype=torch.int8)

In [8]:
1.0 / 0.1

10.0

In [11]:
def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

In [12]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
dequantized_tensor

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

In [None]:
# Quantization error
(dequantized_tensor - quantized_tensor).square().mean()

tensor(2565.)