In [68]:
import copy

import numpy as np
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from torch import optim

In [69]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Data loading
Downloading the MNIST dataset and dividing it into train, test and validation data loaders.

In [70]:
from torchvision import datasets
from torchvision.transforms import ToTensor

dataset = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True,
)

train_data, val_data = torch.utils.data.random_split(dataset, [50000, 10000])

test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = ToTensor()
)

In [71]:
from torch.utils.data import DataLoader

loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),

    'test'  : torch.utils.data.DataLoader(test_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),

    'valid' : torch.utils.data.DataLoader(val_data,
                                          batch_size=200,
                                          shuffle=False)
}

#### Model
Implementation of ResNet with 2 blocks - Conv2d-BN-ReLU.

In [72]:
class QuantizableBasicBlock(nn.Module):
    """
    Iniialize a residual block with two convolutions followed by batchnorm layers
    """
    def __init__(self, in_size:int, hidden_size:int, out_size:int, pad:int):
        super().__init__()
        self.add_relu = nn.quantized.FloatFunctional()

        self.conv1 = nn.Conv2d(in_size, hidden_size, kernel_size=3, stride=2, padding=pad)
        self.bn1 = nn.BatchNorm2d(hidden_size)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(hidden_size, out_size, kernel_size=3, stride=2, padding=pad)
        self.bn2 = nn.BatchNorm2d(out_size)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.add_relu.add_relu(out, identity)

        return out

In [73]:
class QuantizableResNet(nn.Module):

    def __init__(self, n_classes=10):
        super().__init__()
        self.res1 = QuantizableBasicBlock(1, 8, 16, 15)
        self.res2 = QuantizableBasicBlock(16, 32, 16, 15)
        self.conv = nn.Conv2d(16, n_classes, kernel_size=3)
        self.bn = nn.BatchNorm2d(n_classes)
        self.maxpool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.res1(x)
        x = self.res2(x)
        x = self.maxpool(self.bn(self.conv(x)))
        return x.view(x.size(0), -1)

#### Model training

In [74]:
def loss_batch(model, loss_func, xb, yb, opt=None, scheduler=None):
    loss = loss_func(model(xb), yb)
    acc = accuracy(model(xb), yb)
    if opt is not None:
        loss.backward()
        if scheduler is not None:
            scheduler.step()
        opt.step()
        opt.zero_grad()
    return acc, loss.item(), len(xb)


def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()


def get_model():
    model = QuantizableResNet()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    return model, optimizer


def fit(epochs, model, loss_func, opt, train_dl, valid_dl, scheduler=None):
    for epoch in range(epochs):
        model.train()
        # iterate over data loader object (generator)
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt, scheduler)

        model.eval()
        # no gradient computation for evaluation mode
        with torch.no_grad():
            accs, losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )

        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        val_acc = np.sum(np.multiply(accs, nums)) / np.sum(nums)

        print("Epoch:", epoch+1)
        print("Loss: ", val_loss)
        print("Accuracy: ", val_acc)
        print()


def check_accuracy(model, data_loader):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')

    model.train()

In [78]:
lr=0.01
n_epochs = 5
loss_func = F.cross_entropy

In [79]:
# get model and optimizer
model, opt = get_model()

In [80]:
# train
fit(n_epochs, model, loss_func, opt, loaders['train'], loaders['valid'])

Epoch: 1
Loss:  0.19892429441213608
Accuracy:  0.9405999970436096

Epoch: 2
Loss:  0.14958409801125527
Accuracy:  0.9561999905109405

Epoch: 3
Loss:  0.11511386513710022
Accuracy:  0.9667000019550324

Epoch: 4
Loss:  0.10888380914926529
Accuracy:  0.9694000065326691

Epoch: 5
Loss:  0.09176633350551128
Accuracy:  0.9731000077724457



In [81]:
torch.save(model, 'model/mnist.pt')

# Quantization

In [107]:
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver

def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

default_qconfig = torch.quantization.qconfig.get_default_qconfig('fbgemm')

## Post Training Quantization (PTQ)

In [153]:
'''
NOT WORKING

def ptq_eager_mode(_model, _qconfig, _dataloader) -> nn.Module:
    # copy model
    m = copy.deepcopy(_model)
    m.eval()

    # Fuse modules
    torch.quantization.fuse_modules(m.res1, [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2']], inplace=True)
    torch.quantization.fuse_modules(m.res2, [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2']], inplace=True)
    torch.quantization.fuse_modules(m, [['conv', 'bn']], inplace=True)

    # Adding qconfig
    m.qconfig = _qconfig
    torch.quantization.prepare(m, inplace=True)

    # Calibration
    calibrate(m, _dataloader)

    # Convert
    torch.quantization.convert(m, inplace=True)

    return m
'''

In [84]:
from torch.quantization import quantize_fx
from torch.ao.quantization import QConfigMapping

def ptq_fx_graph_mode(_model, _qconfig, _dataloader) -> nn.Module:
    qconfig_mapping = QConfigMapping().set_global(qconfig)

    m = copy.deepcopy(_model)
    m.eval()

    example_inputs = (next(iter(loaders['test']))[0])
    model_prepared = quantize_fx.prepare_fx(m, qconfig_mapping, example_inputs)

    calibrate(model_prepared, _dataloader)

    return quantize_fx.convert_fx(model_prepared)

### Symmetric layer-wise static PTQ

In [114]:
model = torch.load('model/mnist.pt')

qconfig = torch.quantization.QConfig(
    activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8),
    weight=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8)
)

ptq_symmetric_model = ptq_fx_graph_mode(model, qconfig, loaders['valid'])
print(ptq_symmetric_model)

GraphModule(
  (res1): Module(
    (conv1): QuantizedConvReLU2d(1, 8, kernel_size=(3, 3), stride=(2, 2), scale=0.09093476831912994, zero_point=128, padding=(15, 15))
    (conv2): QuantizedConv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.14593075215816498, zero_point=128, padding=(15, 15))
  )
  (res2): Module(
    (conv1): QuantizedConvReLU2d(16, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.22077931463718414, zero_point=128, padding=(15, 15))
    (conv2): QuantizedConv2d(32, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.3543640971183777, zero_point=128, padding=(15, 15))
  )
  (conv): QuantizedConv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), scale=0.29815182089805603, zero_point=128)
  (maxpool): AdaptiveMaxPool2d(output_size=1)
)



def forward(self, x):
    _input_scale_0 = self._input_scale_0
    _input_zero_point_0 = self._input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, _input_scale_0, _input_zero_point_0, torch.quint8);  x = _input_scale_0 = _

### Asymmetric layer-wise static PTQ

In [112]:
model = torch.load('model/mnist.pt')

qconfig = torch.quantization.QConfig(
    activation=HistogramObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
    weight=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.qint8)
)

ptq_asymmetric_model = ptq_fx_graph_mode(model, qconfig, loaders['valid'])
print(ptq_asymmetric_model)

GraphModule(
  (res1): Module(
    (conv1): QuantizedConvReLU2d(1, 8, kernel_size=(3, 3), stride=(2, 2), scale=0.04546738415956497, zero_point=0, padding=(15, 15))
    (conv2): QuantizedConv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.14294318854808807, zero_point=125, padding=(15, 15))
  )
  (res2): Module(
    (conv1): QuantizedConvReLU2d(16, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.11038965731859207, zero_point=0, padding=(15, 15))
    (conv2): QuantizedConv2d(32, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.3355857729911804, zero_point=135, padding=(15, 15))
  )
  (conv): QuantizedConv2d(16, 10, kernel_size=(3, 3), stride=(1, 1), scale=0.20874696969985962, zero_point=73)
  (maxpool): AdaptiveMaxPool2d(output_size=1)
)



def forward(self, x):
    _input_scale_0 = self._input_scale_0
    _input_zero_point_0 = self._input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, _input_scale_0, _input_zero_point_0, torch.quint8);  x = _input_scale_0 = _input

In [116]:
model = torch.load('model/mnist.pt')

print('Without quantization')
check_accuracy(model, loaders['test'])
print('\nSymmetric layer-wise static PTQ')
check_accuracy(ptq_symmetric_model, loaders['test'])
print('\nAsymmetric layer-wise static PTQ')
check_accuracy(ptq_asymmetric_model, loaders['test'])

Without quantization
Got 9774 / 10000 with accuracy 97.74

Symmetric layer-wise static PTQ
Got 9755 / 10000 with accuracy 97.55

Asymmetric layer-wise static PTQ
Got 9685 / 10000 with accuracy 96.85
