In [None]:
import warnings
warnings.filterwarnings("ignore")

import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


In [None]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=16, pin_memory=True)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)

        self.maxpool = nn.MaxPool2d(2, 2)  # Initialized here

        self.fc1 = nn.Linear(7*7*64, 512)
        self.relu3 = nn.ReLU(inplace=True)

        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool(x)  

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool(x) 

        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu3(x)

        x = self.fc2(x)
        return x


In [None]:
class QuantizedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model_fp32 = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x


In [None]:
from train_helpers import ClassifierTrainer,save_plots
# define model

unqant_model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(unqant_model.parameters(), lr=0.01)


trainer = ClassifierTrainer(
    model= unqant_model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=trainloader,
    val_loader=testloader,
    num_epochs=4,
    cuda=False
)
trainer.train()


save_plots(
    train_acc=trainer.train_accs,
    train_loss=trainer.train_losses,
    valid_acc=trainer.val_accs,
    valid_loss=trainer.val_losses,
)


In [None]:
# load the torch state 
state = torch.load("outputs/best_model.pth")
unqant_model = Net()


# loading the state dict
unqant_model.load_state_dict(state['model_state_dict'])

In [None]:
unqant_model.eval()

In [None]:
import copy
unqant_model_copy = copy.deepcopy(unqant_model)

In [None]:
unqant_model_copy.eval()

In [None]:


fused_layers = [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2', 'relu2']]
fused_model = torch.quantization.fuse_modules(unqant_model_copy, fused_layers, inplace=True)

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
quantized_model = QuantizedModel(model=fused_model)



In [None]:
# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
# quantized_model = QuantizedModel(model=fused_model)



In [None]:
# Select quantization schemes from 
# https://pytorch.org/docs/stable/quantization-support.html
quantization_config = torch.quantization.get_default_qconfig("fbgemm")

quantized_model.qconfig = quantization_config

# Print quantization configurations
print(quantized_model.qconfig)

In [None]:
torch.quantization.prepare(quantized_model, inplace=True)

In [None]:
def calibrate_model(model, loader, device=torch.device("cpu")):

    model.to(device)
    model.eval()

    for inputs, labels in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        _ = model(inputs)



    # Use training data for calibration.

calibrate_model(model=quantized_model, loader=trainloader, device="cpu")

In [None]:
quantized_model = torch.quantization.convert(quantized_model, inplace=True)

In [None]:
quantized_model.eval()

In [None]:
from utils import ModelCompare
model_compare = ModelCompare(
    model1=quantized_model,
    model1_info="Quantized Model",
    model2=unqant_model,
    model2_info="Uquantize model",
    cuda=False
)

In [None]:
print("="*50)
model_compare.compare_size()
print("="*50)
model_compare.compare_accuracy(dataloder=testloader)
print("="*50)
model_compare.compare_inference_time(N=2 , dataloder=testloader)

In [None]:
torch.jit.save(torch.jit.script(quantized_model), "JIT_MODEL.jit")

module = torch.jit.load('JIT_MODEL.jit')



In [None]:
from utils import ModelCompare
model_compare = ModelCompare(
    model1=module,
    model1_info="Unquantized Model",
    model2=quantized_model,
    model2_info="Quantized Model",
    cuda=False
)

print("="*50)
model_compare.compare_size()
print("="*50)
model_compare.compare_accuracy(dataloder=testloader)
print("="*50)
model_compare.compare_inference_time(N=2 , dataloder=testloader)