# Post-Training Quantization

Let's start with importing some libraries that we will need for this tutorial.

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torchao.quantization import quantize_, Int8WeightOnlyConfig
from torchao.utils import get_model_size_in_bytes
import argparse
import os
import copy

ModuleNotFoundError: No module named 'torch.hub'

If you're running into issues importing these libraries, check which python env the jupyter kernel is using!

In [None]:
import sys
print(sys.executable)

### Formatting the Training Dataset

We will start by normalizing our MNIST dataset with the pre-calculated mean values

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(
    train=True, transform=transform
)
test_dataset = datasets.MNIST(
    train=False, transform=transform
)

Then we'll load it into these data loaders for easy access to loading the dataset when the time comes.

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Neural Network Class

Let's define our neural network class now! We'll keep it super simple with 2 layers. Since each image is 28x28 pixels, we will have our input be of dimension 28x28. We want to output which digit the image is and since there are 10 different possible digits, we will use an output of dimension 10.

In [None]:
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.fc1 = torch.nn.Linear(28*28, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

### Training the Model

Now that we have defined our model, let's train it using the dataset and see what the type and format of the weights are. (Expecting 32-bit floating point numbers)

In [None]:
batch_size = 64
learning_rate = 1e-4
epochs = 200

In [None]:
model_fp32 = Network()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model_fp32.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    model_fp32.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model_fp32(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    if epoch % 10 == 0 or epoch == epochs - 1:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

In [None]:
print(model_fp32)

Now, lets quantize the model to INT8 and see what happens to the weights. 

In [None]:
model_int8 = copy.deepcopy(model_fp32)
quantize_(model_int8, Int8WeightOnlyConfig())

In [None]:
print(model_int8)

### Evaluating the Model

So far, we have trained our model using 32-bit floating point weights/biases and created a quantized version using INT8. Let's see what the accuracy change is between the two and how much the model size has changed as well!

In [None]:
model_fp32.eval()

fp32_correct = 0
fp32_total = 0

with torch.no_grad():
    for data, target in test_loader:
        output = model_fp32(data)
        _, predicted = torch.max(output.data, 1)
        fp32_total += target.size(0)
        fp32_correct += (predicted == target).sum().item()

In [None]:
print(f"Test Accuracy FP32: {100 * fp32_correct / fp32_total:.2f}%")
print(f"FP32 SIZE: {get_model_size_in_bytes(model_fp32) / 1e6:.2f} MB")

In [None]:
model_int8.eval()

int8_correct = 0
int8_total = 0

with torch.no_grad():
    for data, target in test_loader:
        output = model_int8(data)
        _, predicted = torch.max(output.data, 1)
        int8_total += target.size(0)
        int8_correct += (predicted == target).sum().item()

In [None]:
print(f"Test Accuracy INT8: {100 * int8_correct / int8_total:.2f}%")
print(f"INT8 SIZE: {get_model_size_in_bytes(model_int8) / 1e6:.2f} MB")

### Results

From the comparison, we see that the accuracy reduced slightly but the model size has decreased heavily. This is due to the fact that we have reduced the precision from FP32 to INT8 which is around a reduction of 4x in bytes. Our model sizes reflect this accurately. 