# **Lab 3: Quantization**

# Setup

In [None]:
# install the newest version of torch, torchvision, and timm
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata timm
!pip3 install torch torchaudio torchvision torchtext torchdata timm

In [1]:
import copy
import math
import random
from collections import OrderedDict, defaultdict

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from tqdm.auto import tqdm

import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fb877f94130>

Test Functions **(DO NOT MODIFY!!)**

# Part2: Quantize MobileNetV2 and Export

Below shows the steps of how to quantize & convert the model.  For more details, refer to [Quantization-Aware Training](https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html)and[Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html). You may have to run it using your own machine.

***The code blocks below doesen't have to be executed when you are submitting this file.***

 $$
        Score = (10 \times Step  function(Accuracy-0.88)+ 20 \times \dfrac{Accuracy - 0.88}{0.96 - 0.88})
$$

1. Load **mobilenet_v2** with 96.3% accuracy on CIFAR10. (Link of the model is written in the spec of lab3)

In [3]:
from torch._export import capture_pre_autograd_graph

from torch.ao.quantization.quantize_pt2e import (
  prepare_pt2e,
  convert_pt2e,
  prepare_qat_pt2e
)

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)

In [4]:
import torch
from torchvision.models import mobilenet_v2
# from torchvision.models.quantization import mobilenet_v2

model = torch.load('./mobilenetv2_0.963.pth')

2. Quantize the model using XNNPACKQuantizer, you can choose either Post Training Quantization or Quantization-Aware Training.

In [5]:
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
def prepare_data(batch_size):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize images to match MobileNet input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_set = CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False,drop_last=True)
    return train_loader, test_loader

In [6]:
import os
def evaluate_model(model, data_loader,device):

    model.to(device)
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy}%')
    return accuracy

def train_one_epoch(model, criterion, optimizer, data_loader, device):

    cnt = 0

    for image, target in data_loader:
        cnt += 1
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [7]:
batch_size = 128
train_loader, test_loader = prepare_data(batch_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

model.eval()
print_size_of_model(model)
evaluate_model(model, test_loader,device)

Files already downloaded and verified
Files already downloaded and verified
device: cuda:0
Size (MB): 9.169412
Accuracy of the model on the test images: 96.30408653846153%


96.30408653846153

In [31]:
def quantize_ptq_model(model: nn.Module, device="cuda:0") -> None:
    ############### YOUR CODE STARTS HERE ###############

    # Step 1. program capture
    model.to(device)
    example_input = (torch.randn(1, 3, 224, 224).to(device), )
    ptq_model = capture_pre_autograd_graph(model, example_input)

    # Step 2. set quantizatizer
    # prepare_pt2e folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model.
    quantizer = XNNPACKQuantizer()
    quantizer.set_global(get_symmetric_quantization_config())

    # Step 3. prepare pt2e
    ptq_model = prepare_pt2e(ptq_model, quantizer)
    
    # calibration
    def calibrate(model, data_loader):
        with torch.no_grad():
            for images, _ in data_loader:
                images = images.to(device)
                model(images)
        return
    calibrate(ptq_model, train_loader)

    # Step 4. convert model
    ptq_model = convert_pt2e(ptq_model)

    ############### YOUR CODE ENDS HERE #################
    return ptq_model


ptq_model = quantize_ptq_model(model, device=device)
torch.ao.quantization.move_exported_model_to_eval(ptq_model)
print_size_of_model(ptq_model)
evaluate_model(ptq_model, test_loader,device)


Accuracy of the model on the test images: 95.22235576923077%


95.22235576923077

In [33]:
# Save PTQ model
file_path = './mobilenetv2_ptq.pth'
example_inputs = (next(iter(test_loader))[0].to(device),)
quantized_ep = torch.export.export(ptq_model, example_inputs)
torch.export.save(quantized_ep, file_path)

In [9]:
# Load PTQ model
loaded_quantized_ep = torch.export.load('./mobilenetv2_ptq.pth')
loaded_quantized_model = loaded_quantized_ep.module()

evaluate_model(loaded_quantized_model, test_loader,device)

Accuracy of the model on the test images: 95.22235576923077%


95.22235576923077