In [1]:
import numpy as np
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from torch import optim

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Data loading
Downloading the MNIST dataset and dividing it into train, test and validation data loaders.

In [3]:
from torchvision import datasets
from torchvision.transforms import ToTensor

dataset = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True,
)

train_data, val_data = torch.utils.data.random_split(dataset, [50000, 10000])

test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = ToTensor()
)

In [4]:
from torch.utils.data import DataLoader

loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),

    'test'  : torch.utils.data.DataLoader(test_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),

    'valid' : torch.utils.data.DataLoader(val_data,
                                          batch_size=200,
                                          shuffle=False)
}

#### Model
Implementation of ResNet with 2 blocks - Conv2d-BN-ReLU.

In [None]:
from torchvision.models.quantization.resnet import QuantizableBasicBlock

In [6]:
class BasicBlock(nn.Module):

    """
    Iniialize a residual block with two convolutions followed by batchnorm layers
    """
    def __init__(self, in_size:int, hidden_size:int, out_size:int, pad:int):
        super().__init__()
        self.add_relu = nn.quantized.FloatFunctional()

        self.conv1 = nn.Conv2d(in_size, hidden_size, kernel_size=3, stride=2, padding=pad)
        self.bn1 = nn.BatchNorm2d(hidden_size)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_size, out_size, kernel_size=3, stride=2, padding=pad)
        self.bn2 = nn.BatchNorm2d(out_size)

    def convblock(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.add_relu.add_relu(out, identity)

        return out

    def forward(self, x):
        return F.relu(x + self.convblock(x))

In [7]:
class ResNet(nn.Module):

    def __init__(self, n_classes=10):
        super().__init__()
        self.res1 = BasicBlock(1, 8, 16, 15)
        self.res2 = BasicBlock(16, 32, 16, 15)
        self.conv = nn.Conv2d(16, n_classes, kernel_size=3)
        self.batchnorm = nn.BatchNorm2d(n_classes)
        self.maxpool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.res1(x)
        x = self.res2(x)
        x = self.maxpool(self.batchnorm(self.conv(x)))
        return x.view(x.size(0), -1)

#### Model training

In [8]:
def loss_batch(model, loss_func, xb, yb, opt=None, scheduler=None):
    loss = loss_func(model(xb), yb)
    acc = accuracy(model(xb), yb)
    if opt is not None:
        loss.backward()
        if scheduler is not None:
            scheduler.step()
        opt.step()
        opt.zero_grad()
    return acc, loss.item(), len(xb)


def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()


def get_model():
    model = ResNet()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    return model, optimizer

In [9]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl, scheduler=None):
    for epoch in range(epochs):
        model.train()
        # iterate over data loader object (generator)
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt, scheduler)

        model.eval()
        # no gradient computation for evaluation mode
        with torch.no_grad():
            accs, losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )

        #NOTE: important to multiply with batch size and sum over values
        #      to account for varying batch sizes
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        val_acc = np.sum(np.multiply(accs, nums)) / np.sum(nums)

        print("Epoch:", epoch+1)
        print("Loss: ", val_loss)
        print("Accuracy: ", val_acc)
        print()

In [10]:
bs=64 #128
lr=0.01
n_epochs = 5
loss_func = F.cross_entropy

In [11]:
# get model and optimizer
model, opt = get_model()

In [12]:
# train
fit(n_epochs, model, loss_func, opt, loaders['train'], loaders['valid'])

Epoch: 1
Loss:  0.12025187499821186
Accuracy:  0.9626999950408935

Epoch: 2
Loss:  0.09123557083308696
Accuracy:  0.9728000104427338

Epoch: 3
Loss:  0.08082188189029693
Accuracy:  0.9752000105381012

Epoch: 4
Loss:  0.07578884046524763
Accuracy:  0.9766000139713288

Epoch: 5
Loss:  0.07517584595829248
Accuracy:  0.9769000101089478



In [15]:
torch.save(model.state_dict(), 'model/mnist.pt')

In [34]:
from torch.quantization.observer import MinMaxObserver

my_qconfig = torch.quantization.QConfig(
    activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8),
    weight=MinMaxObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.qint8)
)

In [35]:
model.eval()

model_with_stubs = nn.Sequential(
    torch.quantization.QuantStub(),
    model,
    torch.quantization.DeQuantStub()
)

model_with_stubs.qconfig = my_qconfig
qmodel = torch.quantization.prepare(model_with_stubs)

In [36]:
with torch.inference_mode():
    for xb, yb in loaders['valid']:
        qmodel(xb)

In [38]:
print(qmodel)

Sequential(
  (0): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=1.0)
  )
  (1): ResNet(
    (res1): ResBlock(
      (conv1): Conv2d(
        1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(15, 15)
        (activation_post_process): MinMaxObserver(min_val=-1.8065515756607056, max_val=1.6178522109985352)
      )
      (conv2): Conv2d(
        8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(15, 15)
        (activation_post_process): MinMaxObserver(min_val=-15.812700271606445, max_val=16.616336822509766)
      )
      (batchnorm1): BatchNorm2d(
        8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): MinMaxObserver(min_val=-12.264169692993164, max_val=10.694961547851562)
      )
      (batchnorm2): BatchNorm2d(
        16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (activation_post_process): MinMaxObserver(min_val=-27.85474967956543, max_val=26.247802734375)
      )
    )
 

In [37]:
torch.quantization.convert(qmodel, inplace=True)

AssertionError: Weight observer must have a dtype of qint8