In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os
from torchvision.datasets import MNIST
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F


import torch.ao.quantization as quant

In [2]:
print(torch.__version__)
torch.backends.quantized.engine = 'fbgemm'
print(torch.backends.quantized.supported_engines)

2.6.0+cu126
['none', 'onednn', 'x86', 'fbgemm']


In [3]:
torch.manual_seed(42)
batch_size = 32
learning_rate = 1e-3

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_trainset = MNIST(root = 'data/', train = True, transform = transform)
mnist_testset = MNIST(root = 'data/', train = False, transform = transform)

train_loader = DataLoader(mnist_trainset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(mnist_testset, batch_size = batch_size, shuffle = True)

In [4]:
device = 'cpu'

In [5]:
import torch.ao.quantization


class TrainingMNISTModel(nn.Module):
    def __init__(self, neuron_1 = 64, neuron_2 = 64):
        super(TrainingMNISTModel, self).__init__()
        self.quant = quant.QuantStub()
        self.linear1 = nn.Linear(28 * 28, neuron_1)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(neuron_1, neuron_2)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(neuron_2, 10)
        self.dequant = quant.DeQuantStub()

    def forward(self, img):
        x = img.reshape(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu1(self.linear1(x))
        x = self.relu2(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x
    
    def fuse_model(self):
        # Fuse layers for quantization
        torch.ao.quantization.fuse_modules(self, [['linear1', 'relu1'], ['linear2', 'relu2']], inplace = True)
    


In [6]:
qat_model = TrainingMNISTModel().to(device)

optimizer = torch.optim.Adam(qat_model.parameters(), lr = learning_rate, weight_decay= 1e-4) # L2-regularization

In [7]:
def train(train_loader, model, epochs = 5, total_iterations_limit = None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc = f'Epoch {epoch + 1}') # desc = description
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit

        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)

            output = model(x)
            loss = loss_fn(output, y)
            loss_sum += loss
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss = avg_loss) # post_fix for tqdm

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), 'temp_delme.p')
    print('Size (KB)', os.path.getsize('temp_delme.p') / 1e3)
    os.remove('temp_delme.p')

In [8]:
MODEL_FILENAME = 'mnistmodel_qat.pt'

if Path(MODEL_FILENAME).exists():
    qat_model.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, qat_model, epochs = 5)
    # Save the model to disk
    torch.save(qat_model.state_dict(), MODEL_FILENAME)

Loaded model from disk


In [9]:
def test(model, total_iterations = None):
    model.eval()
    model.to('cpu')

    correct = 0
    total = 0
    iterations = 0

    with torch.inference_mode():
        for data in tqdm(test_loader, desc = 'Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)

            output = model(x)

            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1

            if total_iterations is not None and iterations >= total_iterations:
                break

    print(f'Accuracy: {round(correct / total, 3)}')

In [10]:
# Print the weights matrix of the model before quantization
print('Weights before QAT')
print(qat_model.linear1.weight)
print(qat_model.linear1.weight.dtype)

print('Size of the model before QAT')
print_size_of_model(qat_model)

print(f'Accuracy of the model before QAT: ')
test(qat_model)

Weights before QAT
Parameter containing:
tensor([[ 0.0394,  0.0418,  0.0037,  ..., -0.0020,  0.0214,  0.0256],
        [-0.0038, -0.0204,  0.0337,  ...,  0.0044,  0.0149,  0.0265],
        [ 0.0006,  0.0030,  0.0059,  ..., -0.0113, -0.0175,  0.0072],
        ...,
        [-0.0320, -0.0422, -0.0065,  ...,  0.0089, -0.0336,  0.0155],
        [ 0.0100, -0.0117,  0.0249,  ...,  0.0118,  0.0005, -0.0186],
        [ 0.0407, -0.0163,  0.0370,  ...,  0.0396,  0.0056,  0.0238]],
       requires_grad=True)
torch.float32
Size of the model before QAT
Size (KB) 222.886
Accuracy of the model before QAT: 


Testing:   0%|          | 0/313 [00:00<?, ?it/s]

Testing: 100%|██████████| 313/313 [00:02<00:00, 138.86it/s]

Accuracy: 0.971





In [11]:
print(type(qat_model.linear1))

<class 'torch.nn.modules.linear.Linear'>


### Finished Pretrain, Start Quantization
Start to do the quantization

In [12]:
qat_model.train()
qat_model.fuse_model()
qat_model.qconfig = quant.get_default_qat_qconfig('fbgemm')
print(qat_model.qconfig)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})


In [13]:
# prepare QAT
torch.quantization.prepare_qat(qat_model, inplace = True)

# Fine-tune the model with QAT
print("Fine-tuning the model with QAT...")
train(train_loader, qat_model, epochs = 5)



Fine-tuning the model with QAT...


Epoch 1:   0%|          | 0/1875 [00:00<?, ?it/s, loss=tensor(0.0041, grad_fn=<DivBackward0>)]

Epoch 1: 100%|██████████| 1875/1875 [00:25<00:00, 72.65it/s, loss=tensor(0.0564, grad_fn=<DivBackward0>)]
Epoch 2: 100%|██████████| 1875/1875 [00:25<00:00, 72.25it/s, loss=tensor(0.0476, grad_fn=<DivBackward0>)]
Epoch 3: 100%|██████████| 1875/1875 [00:26<00:00, 71.09it/s, loss=tensor(0.0423, grad_fn=<DivBackward0>)]
Epoch 4: 100%|██████████| 1875/1875 [00:30<00:00, 61.85it/s, loss=tensor(0.0374, grad_fn=<DivBackward0>)]
Epoch 5: 100%|██████████| 1875/1875 [00:35<00:00, 52.83it/s, loss=tensor(0.0352, grad_fn=<DivBackward0>)]


In [14]:
# Convert the model to a quantized version
qat_model.eval()  # Switch to evaluation mode before conversion
quant.convert(qat_model, inplace = True)
print("Model converted to quantized version.")

print('\n Inverted Residual Block: After fusion and QAT, note fused modules: \n\n',qat_model.linear1)

Model converted to quantized version.

 Inverted Residual Block: After fusion and QAT, note fused modules: 

 QuantizedLinearReLU(in_features=784, out_features=64, scale=0.25556251406669617, zero_point=0, qscheme=torch.per_channel_affine)


In [15]:
print("Size of model after QAT")
print_size_of_model(qat_model)

test(qat_model)

Size of model after QAT
Size (KB) 64.354


Testing: 100%|██████████| 313/313 [00:02<00:00, 136.32it/s]

Accuracy: 0.974





In [16]:
print(torch.dequantize(qat_model.linear1.weight()))

tensor([[ 0.0527,  0.0527,  0.0158,  ...,  0.0105,  0.0316,  0.0369],
        [-0.0036, -0.0202,  0.0337,  ...,  0.0047,  0.0150,  0.0265],
        [ 0.0174,  0.0174,  0.0209,  ...,  0.0035, -0.0035,  0.0244],
        ...,
        [-0.0411, -0.0513, -0.0154,  ...,  0.0000, -0.0462,  0.0051],
        [ 0.0175, -0.0035,  0.0350,  ...,  0.0210,  0.0105, -0.0105],
        [ 0.0434, -0.0109,  0.0434,  ...,  0.0434,  0.0109,  0.0271]])


In [17]:
print("----------------------------------------------------------------------------")
print(qat_model.linear1.scale)
print(qat_model.linear1.zero_point)

----------------------------------------------------------------------------
0.25556251406669617
0


In [18]:
weight_q = qat_model.linear1.weight()
weight_int8 = weight_q.int_repr()
print(weight_int8)

tensor([[ 10,  10,   3,  ...,   2,   6,   7],
        [ -7, -39,  65,  ...,   9,  29,  51],
        [  5,   5,   6,  ...,   1,  -1,   7],
        ...,
        [ -8, -10,  -3,  ...,   0,  -9,   1],
        [  5,  -1,  10,  ...,   6,   3,  -3],
        [  8,  -2,   8,  ...,   8,   2,   5]], dtype=torch.int8)
