<a href="https://colab.research.google.com/github/Abhijit-2592/flyai/blob/main/notebooks/lec01_understanding_quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Neural Network Quantization
Neural network quantization is the process of reducing the precision of the numerical values used in a neural network, typically the weights and activations, to make the model more efficient in terms of memory usage, computational speed, and energy consumption. Instead of using 32-bit floating-point (FP) numbers, quantization reduces them to lower precision formats like 8-bit integers. This can lead to faster inference and reduced power consumption. However, quantization may result in a small loss of accuracy, though techniques like quantization-aware training help minimize this effect.

There are two types of Neural Network Quantization:
1. Symmetric Quantization
2. Affine/Asymmetric Quantization


## Symmetric Quantization
In symmetric quantization, the range of the original FP values is mapped to a symmetric range around zero in the quantized space. Thus 0 in FP is mapped to 0 in the quantized space.

![Symmetric Quantization Image](https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F730bbb8a-3a44-47f6-aefe-f652b117ae22_1124x600.png)

Image/text Courtesy: [A Visual Guide To Quantization](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-quantization) by [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst)


In [None]:
import torch
import math
from dataclasses import dataclass

In [None]:
_ = torch.manual_seed(555)

In [None]:
@dataclass
class QuantizedTensor:
    qtensor: torch.Tensor
    scale: torch.Tensor
    orig_dtype: torch.dtype
    zero_point: torch.Tensor = torch.tensor(0.0)

In [None]:
class SymmetricQuantizer:
    def __init__(self, nbits:int = 8):
        self.nbits = nbits
        self.q_min = -math.pow(2, nbits - 1)  # This is not used
        self.q_max = math.pow(2, nbits - 1) - 1

    @property
    def q_range(self):
        return f"[-{self.q_max}, {self.q_max}]"

    def _calculate_scales(self, tensor: torch.Tensor) -> torch.Tensor:
        scale = torch.amax(torch.abs(tensor))/self.q_max
        return scale.to(torch.float32)

    def quantize(self, tensor:torch.Tensor, q_dtype:torch.dtype = torch.int8) -> QuantizedTensor:
        scale = self._calculate_scales(tensor)
        qtensor = torch.round(tensor/scale)
        qtensor = torch.clamp(qtensor, min=-self.q_max, max=self.q_max)
        return QuantizedTensor(qtensor=qtensor.to(q_dtype), scale=scale, orig_dtype=tensor.dtype)

    def dequantize(self, qtensor:QuantizedTensor):
        tensor = qtensor.qtensor * qtensor.scale
        return tensor.to(qtensor.orig_dtype)

    def calculate_quantization_mae(self, tensor:torch.Tensor, q_dtype:torch.dtype = torch.int8):
        reconstructed_tensor = self.dequantize(self.quantize(tensor, q_dtype=q_dtype))
        print("Reconstructed Tensor")
        print(reconstructed_tensor)
        return torch.mean(torch.abs(tensor-reconstructed_tensor))

In [None]:
# Generate a random uniform tensor of shape (5,5) between [-5,10)
# Uniform(r1,r2) = (r1-r2)*Uniform(0,1) + r2
a = -15*torch.rand(5, 5) + 10
# Manually make an element to 0.0 so that we can see how the SymmetricQuantizer maps 0
a[2][3] = 0.0
print(a)

tensor([[-1.4829,  3.1988,  6.2363,  7.0474, -1.7724],
        [ 9.5538, -0.6993, -2.4891, -1.8361,  3.5922],
        [ 2.1141,  3.3869,  7.6704,  0.0000,  9.4361],
        [ 7.8068, -2.4722, -2.1322, -4.3650, -2.4265],
        [ 4.2931,  0.1434, -2.4976,  2.3279,  9.4024]])


In [None]:
print("Symmetric Quantization")
print("Tensor")
print(a)
symmetric_quantizer = SymmetricQuantizer(nbits=8)
qtensor = symmetric_quantizer.quantize(a)
print("Quantized Tensor")
print(qtensor.qtensor)
print(f"Scale: {qtensor.scale}")
print(f"Quantization Range: {symmetric_quantizer.q_range}")
print(f"Quantization Mean Absolute Error: {symmetric_quantizer.calculate_quantization_mae(a)}")

Tensor
tensor([[-1.4829,  3.1988,  6.2363,  7.0474, -1.7724],
        [ 9.5538, -0.6993, -2.4891, -1.8361,  3.5922],
        [ 2.1141,  3.3869,  7.6704,  0.0000,  9.4361],
        [ 7.8068, -2.4722, -2.1322, -4.3650, -2.4265],
        [ 4.2931,  0.1434, -2.4976,  2.3279,  9.4024]])
Quantized Tensor
tensor([[-20,  43,  83,  94, -24],
        [127,  -9, -33, -24,  48],
        [ 28,  45, 102,   0, 125],
        [104, -33, -28, -58, -32],
        [ 57,   2, -33,  31, 125]], dtype=torch.int8)
Scale: 0.07522675395011902
Quantization Range: [-127.0, 127.0]
Quantization Mean Absolute Error: 0.014071819372475147


As you can clearly see:
- 0 is mapped to 0.
- The highest value is mapped to +127.
- The min value (-ve of the abs(highest value) if present) will be mapped to -127. This is because we are using torch.amax(tensor.abs()) to calculate the scale
- Here we aren't using the quantization range effectively.

# Asymmetric Quantization
Asymmetric quantization, in contrast, is not symmetric around zero. Instead, it maps the minimum (β) and maximum (α) values from the float range to the minimum and maximum values of the quantized range.

![asymmetric quantization](https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F8ffa0c54-88bf-45c1-8636-bdb097bb8e6b_1172x848.png)

Image/text courtesy: [A Visual Guide To Quantization](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-quantization) by [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst)

In [None]:
class AsymmetricQuantizer:
    def __init__(self, nbits: int):
        self.nbits = nbits
        self.qmin = -math.pow(2, self.nbits - 1)
        self.qmax = math.pow(2, self.nbits - 1) - 1

    @property
    def q_range(self):
        return f"[{self.qmin}, {self.qmax}]"

    def _calculate_scale(self, tensor: torch.Tensor) -> torch.Tensor:
        rmin, rmax = torch.amin(tensor), torch.amax(tensor)
        return (rmax-rmin)/(self.qmax-self.qmin)

    def _calculate_zero_point(self, tensor: torch.Tensor, scale: torch.Tensor):
        rmin = torch.amin(tensor)
        return torch.round(self.qmin - rmin/scale)

    def quantize(self, tensor:torch.Tensor, q_dtype:torch.dtype = torch.int8) -> QuantizedTensor:
        scale = self._calculate_scale(tensor)
        zero_point = self._calculate_zero_point(tensor, scale)
        qtensor = torch.clamp(torch.round(tensor/scale) + zero_point, min=self.qmin, max=self.qmax)
        return QuantizedTensor(
            qtensor = qtensor.to(q_dtype),
            scale=scale,
            orig_dtype=tensor.dtype,
            zero_point=zero_point.to(q_dtype)
            )

    def dequantize(self, qtensor:QuantizedTensor):
        tensor = (qtensor.qtensor.to(qtensor.orig_dtype) - qtensor.zero_point.to(qtensor.orig_dtype)) * qtensor.scale
        return tensor.to(qtensor.orig_dtype)

    def calculate_quantization_mae(self, tensor:torch.Tensor, q_dtype:torch.dtype = torch.int8):
        reconstructed_tensor = self.dequantize(self.quantize(tensor, q_dtype=q_dtype))
        print("Reconstructed Tensor")
        print(reconstructed_tensor)
        return torch.mean(torch.abs(tensor-reconstructed_tensor))

In [None]:
print("Asymmetric Quantization")
print("Tensor")
print(a)
asymmetric_quantizer = AsymmetricQuantizer(nbits=8)
qtensor = asymmetric_quantizer.quantize(a)
print("Quantized Tensor")
print(qtensor.qtensor)
print(f"Scale: {qtensor.scale}")
print(f"Zero Point: {qtensor.zero_point}")
print(f"Quantization Range: {asymmetric_quantizer.q_range}")
print(f"Quantization Mean Absolute Error: {asymmetric_quantizer.calculate_quantization_mae(a)}")

Tensor
tensor([[-1.4829,  3.1988,  6.2363,  7.0474, -1.7724],
        [ 9.5538, -0.6993, -2.4891, -1.8361,  3.5922],
        [ 2.1141,  3.3869,  7.6704,  0.0000,  9.4361],
        [ 7.8068, -2.4722, -2.1322, -4.3650, -2.4265],
        [ 4.2931,  0.1434, -2.4976,  2.3279,  9.4024]])
Quantized Tensor
tensor([[ -75,   11,   66,   81,  -80],
        [ 127,  -61,  -94,  -82,   18],
        [  -9,   14,   93,  -48,  125],
        [  95,  -93,  -87, -128,  -92],
        [  31,  -45,  -94,   -5,  124]], dtype=torch.int8)
Scale: 0.05458369106054306
Zero Point: -48
Quantization Range: [-128.0, 127.0]
Quantization Mean Absolute Error: 0.012920784763991833


As you can clearly see:

- 0 is mapped to the zero_point.
- The highest value is mapped to +127.
- The lowest is mapped to -128
- The Quantization MAE is smaller than symmetric quantization. This is due to effective utilization of the entire quantization range.


## Additional Notes

What we have done in this notebook is the simplest form of quantization called **per tensor quantization** (one scale and zero point per entire tensor). As you can probably guess, we can improve the quantization accuracy by increasing the scales/zero point granularity:
- one scale/zero point per channel: **per channel quantization**. This is popular in quantizing **CNN's** weights and inputs/activations
- one scale/zero point per token: **per token quantization**. This is popular in quantizing **LLM's** inputs/activations.
- one scale/zero point per group of let's say 128 values: **group quantization**. This is popular in quantizing **LLM's** weights.

### Tradeoff
- As we increase the quantization granularity, the quantization error decreases but we also have to store a larger number of scales and zero points. Thus reducing the memory savings.
- Even though we do extra operations (dequantization etc.) we still get increased compute efficiency because, most modern GPUs support hardware accleration for Integer Matmuls. For example [Nvidia A100s](https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf) have a **19.5 TFLOPs** for **FP32 Matmuls** as compared to **624 TFLOPS** for **INT8 Matmuls**