In [1]:
import os

import torch
import torchvision
from torch.quantization import prepare, convert, default_qconfig, quantize

In [2]:
cifar_dataset = torchvision.datasets.CIFAR10('.', download=True)
cifar_dataset

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 170498071/170498071 [02:09<00:00, 1316889.83it/s]


Extracting ./cifar-10-python.tar.gz to .


Dataset CIFAR10
    Number of datapoints: 50000
    Root location: .
    Split: Train

In [3]:
cifar_dataloader = torch.utils.data.DataLoader(
    cifar_dataset,
    batch_size=4
)
cifar_dataloader

<torch.utils.data.dataloader.DataLoader at 0x12c1a5790>

In [4]:
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [6]:
print_size_of_model(model)

Size (MB): 46.827865


# 1. Specify quantization configuration

In [7]:
model.qconfig = torch.ao.quantization.default_qconfig
model.qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

In [8]:
model = prepare(model, inplace=False)

# 2. Convert to quantized model

In [9]:
torch.backends.quantized.engine = 'qnnpack'

In [10]:
model = torch.ao.quantization.convert(model, inplace=False)



In [11]:
print_size_of_model(model)

Size (MB): 11.838747
