In [None]:
import warnings
warnings.filterwarnings("ignore")

import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
        

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)

        self.maxpool = nn.MaxPool2d(2, 2)  # Initialized here

        self.fc1 = nn.Linear(7*7*64, 512)
        self.relu3 = nn.ReLU(inplace=True)

        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool(x)  # Called here

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool(x)  # Called here

        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu3(x)

        x = self.fc2(x)
        return x


In [None]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=16, pin_memory=True)



In [None]:
# load the torch state 
state = torch.load("outputs/best_model.pth")
model = Net()


# loading the state dict
model.load_state_dict(state['model_state_dict'])


Quantization aware training in PyTorch.

In [None]:


quant_network = Net()
# print_size_of_model(quant_network , "Before Qat Traning size of mode: ")
# calculate_accuracy(quant_network,testloader, print_c=True)



In [None]:
# set model to evaluation 
quant_network.eval()

# check the layers that can be fused.
fused_layers = [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2', 'relu2']]

# Fuse the layers
torch.quantization.fuse_modules(quant_network, fused_layers, inplace=True)



In [None]:
class QuantizedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model_fp32 = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

In [None]:
# Apply torch.quantization.QuantStub() and torch.quantization.QuantStub() to the inputs and outputs, respectively.
quant_network = QuantizedModel(quant_network)

In [None]:
# Select quantization schemes from 
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
 
quant_network.qconfig = quantization_config

# Print quantization configurations
print(quant_network.qconfig)

In [None]:
# prepare for QAT
torch.quantization.prepare_qat(quant_network, inplace=True)

In [None]:
from train_helpers import ClassifierTrainer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(quant_network.parameters(), lr=0.01)

trainer = ClassifierTrainer(
    model= quant_network,
    optimizer=optimizer, 
    criterion=criterion,
    train_loader=trainloader,
    val_loader=testloader,
    cuda=False,
    num_epochs=4
)

trainer.train(save_model=False)




In [None]:
quant_network.to("cpu")
quantized_model = torch.quantization.convert(quant_network, inplace=True)

In [None]:
from utils import ModelCompare
model_compare = ModelCompare(
    model1=quantized_model,
    model1_info="Quantized Model",
    model2=model,
    model2_info="UnQuantized Model",
    cuda=False
)

In [None]:
print("="*50)
model_compare.compare_size()
print("="*50)
model_compare.compare_accuracy(dataloder=testloader)
print("="*50)
model_compare.compare_inference_time(N=10 , dataloder=testloader)