# Quantization for Neural Networks

After the small asymetric quantization example, In this notebook, we will see how to quantize a Neural Network (NN).

## Post Training Quantization (PTQ)

The PQT will involve training a regular model and then quantizing it.

To do so, we will use observer to determine alpha, beta, scale and zero factors, whilst simply running inference. Just like we did in the f32 to int8 vector quantization example.

This will be done using pytorch only.

## Quantization Aware Training (QAT)

For this, you will have to wait until the the **next lecture**, where we will use Brevitas, a superset of pytorch, to do QAT

### Side note on Pytorch vs Brevitas for Quantization

Note that other framework than brevitas exists for QAT but FINN (a very important tool for later in the course) was built for working with Brevitas.

*Meaning* : even though we use PyTorch here as it is the easiest for PQT, we will quickly transition to brevitas for QAT. See this notebook serve as learning material to demonstrate that you can also you quantization for simpler AI use cases to save inference costs.

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
_ = torch.manual_seed(0)

In [2]:
# IMPORT THE DATA
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

In [3]:
# DEFINE THE MODEL
# This example will be more elaborated in the second lecture, along side a full QAT example in Brevitas

class SimpleClassifier(nn.Module):
    def __init__ (self):
        super(SimpleClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)

In [4]:
# DECLARE THE MODEL AND OPTIMIZATION PARAMETERS
import torch.optim as optim

model = SimpleClassifier()
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# TRAIN THE MODEL
for epoch in range(5):
    for i, (images, labels) in enumerate(train_loader):
        # Flatten the image
        images = images.reshape(-1, 28*28)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
    print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')

print("Training finished!")

In [None]:
# Testing loop
import torch
model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        # Flatten the image
        data = data.reshape(-1, 28*28)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test Accuracy: {accuracy:.2f}%')

old_accuracy = accuracy #save for later

## NOW LET'S ANALYSE !

In [None]:
import os

# GET MODEL SIZE
def get_size(model):
    torch.save(model.state_dict(), "model_before_PTQ.p")
    size = os.path.getsize("model_before_PTQ.p")/1e3
    os.remove("model_before_PTQ.p")
    return(size)

old_size = get_size(model)
print("size of the model before PTQ : ", old_size, "KB")

# POST TRAINING QUANT

When we did the quantization example, we saw that we use min and max valus to compute adequate quantization.

Here inputs changes all the time ! we have to run inference to gather data in order to determine the best parameters.

To do this, we will simply use "Obervers"

## Crerate a model with observers

first we add quant and dequant [stubs](https://pytorch.org/docs/stable/generated/torch.ao.quantization.QuantStub.html) wich are observers for broad I/O quantization

In [8]:
# DEFINE THE MODEL

class SimpleQuantClassifier(nn.Module):
    def __init__ (self):
        super(SimpleQuantClassifier, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.model = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        out = self.quant(x)
        out = self.model(out)
        out = self.dequant(out)
        return out

and add observers to intermediate layers too

In [None]:
import torch.ao.quantization


quant_model = SimpleQuantClassifier()
quant_model.load_state_dict(model.state_dict()) # load pre-trained weights into the quant model
quant_model.eval()

quant_model.qconfig = torch.ao.quantization.default_qconfig
quant_model = torch.ao.quantization.prepare(quant_model) # insert observers
quant_model

## Run inference on the new model

This will allow observer to gather data

In [None]:
import torch
quant_model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        # Flatten the image
        data = data.reshape(-1, 28*28)
        output = quant_model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test Accuracy: {accuracy:.2f}%')

We now re check our model, we now see that observers carry data with them, good !

In [None]:
quant_model

## Quantize the model

We can now simply use othe pytorch API to use these data for quantization !

we than visualize our weights, they are now INT8 !

In [12]:
import torch.ao.quantization

quant_model = torch.ao.quantization.convert(quant_model)

In [None]:
print(torch.int_repr(quant_model.model[0].weight()))

We can also compare quantized and dequantized weights, we can see that a small error has been introduced.

In [None]:
print(model.model[0].weight)            # original weights
print(quant_model.model[0].weight())    # dequant weights

# Lets compare again !

we will now analyse accuracy and size of the model. (/4 theorically)

In [None]:
import torch
quant_model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        # Flatten the image
        data = data.reshape(-1, 28*28)
        output = quant_model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test Accuracy (ORIGINAL): {old_accuracy:.2f}%')
print(f'Test Accuracy (Quantized): {accuracy:.2f}%')

In [None]:
print("size of the model before PTQ : ", old_size, "KB")
print("size of the model after PTQ : ", get_size(quant_model), "KB")