In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Subset

Adopted DataLoader from Training file. This will help us to set up calibration data for quantization.

In [2]:
#data loader adopted from training
def data_loader(data_dir,
                    batch_size,
                    random_seed=42,
                    valid_size=0.1,
                    shuffle=True,
                    test=False):

        normalize = transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )

        # define transforms
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        download = not os.path.exists(os.path.join(data_dir, "cifar-10-batches-py"))




        if test:
          dataset = datasets.CIFAR10(
            root=data_dir, train=False,
            download=download, transform=transform,
          )
          indices = list(range(0, len(dataset)))
          np.random.seed(42)
          np.random.shuffle(indices)
          data = Subset(dataset,indices)
        else:
          dataset = datasets.CIFAR10(
            root=data_dir, train=True,
            download=download, transform=transform,
          )
          indices = list(range(0, len(dataset)))
          np.random.seed(42)
          np.random.shuffle(indices)
          data = Subset(dataset,indices)

        data_loader = torch.utils.data.DataLoader(
            data, batch_size=batch_size, shuffle=shuffle
        )



        return data_loader


calib_loader = data_loader(data_dir='/content/rn',
                                  batch_size=64,
                                  test=False)
test_loader = data_loader(data_dir='/content/rn',
                                  batch_size=64,
                                  test=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/rn/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 20.5MB/s]


Extracting /content/rn/cifar-10-python.tar.gz to /content/rn


Quantized version of the same ResNet20 Architecture. I added nn.quantized.FloatFunctional instead of add(). I added QuantStub and DeQuantStub for conversion to Int8. We should also fuse conv, bn for better results(fuse_modules).

In [3]:
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.quantization
import torch.nn.init as init
class QuantResidualBlock(nn.Module):
        def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
            super(QuantResidualBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1)
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.relu=nn.ReLU()
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
            self.bn2 = nn.BatchNorm2d(out_channels)
            self.downsample = downsample
            self.out_channels = out_channels
            self.skip_add = nn.quantized.FloatFunctional()
        def forward(self, x):
            residual = x
            y = self.conv1(x)
            y = self.bn1(y)
            y = self.relu(y)
            y = self.conv2(y)
            y = self.bn2(y)
            if self.downsample:
                residual = self.downsample(x)
            out = self.skip_add.add(y, residual)
            out = self.relu(out)
            return out

class QuantResNet(nn.Module):
        def __init__(self, block, layers, num_classes = 10):
            super(QuantResNet, self).__init__()
            self.inplanes = 16
            self.quant = torch.quantization.QuantStub()
            self.conv1 = nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1)
            self.bn1 = nn.BatchNorm2d(16)
            self.relu=nn.ReLU()
            #self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
            self.layer0 = self._make_layer(block, 16, layers[0], stride = 1)
            self.layer1 = self._make_layer(block, 32, layers[1], stride = 2)
            self.layer2 = self._make_layer(block, 64, layers[2], stride = 2)
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            self.fc = nn.Linear(64, num_classes)
            self._initialize_weights()
            self.dequant = torch.quantization.DeQuantStub()
        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        init.zeros_(m.bias)

        def _make_layer(self, block, planes, blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes:

                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                    nn.BatchNorm2d(planes),
                )
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes
            for i in range(1, blocks):
                layers.append(block(self.inplanes, planes))

            return nn.Sequential(*layers)

        def forward(self, x):
            x = self.quant(x)
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            #x = self.maxpool(x)
            x = self.layer0(x)
            x = self.layer1(x)
            x = self.layer2(x)

            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            x = self.dequant(x)
            return x

In [None]:
model = QuantResNet(QuantResidualBlock, [3, 3, 3])



We use the quantized architectu but use the same weights from orginal ResNet. We don't need to train again. This is called PTQ(Post Training Quantization)

In [None]:

model.eval()
checkpoint = torch.load("/content/resnet20_fast.pth", map_location="cpu")
model.load_state_dict(checkpoint)

#model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
#model.qconfig = torch.quantization.default_qconfig
model.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8),
    weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8)
)


modules = [
    ['conv1', 'bn1', 'relu'],
    ['layer0.0.conv1', 'layer0.0.bn1'],
    ['layer0.0.conv2', 'layer0.0.bn2'],
    ['layer0.1.conv1', 'layer0.1.bn1'],
    ['layer0.1.conv2', 'layer0.1.bn2'],
    ['layer0.2.conv1', 'layer0.2.bn1'],
    ['layer0.2.conv2', 'layer0.2.bn2'],
    ['layer1.0.conv1', 'layer1.0.bn1'],
    ['layer1.0.conv2', 'layer1.0.bn2'],
    ['layer1.0.downsample.0', 'layer1.0.downsample.1'],
    ['layer1.1.conv1', 'layer1.1.bn1'],
    ['layer1.1.conv2', 'layer1.1.bn2'],
    ['layer1.2.conv1', 'layer1.2.bn1'],
    ['layer1.2.conv2', 'layer1.2.bn2'],
    ['layer2.0.conv1', 'layer2.0.bn1'],
    ['layer2.0.conv2', 'layer2.0.bn2'],
    ['layer2.0.downsample.0', 'layer2.0.downsample.1'],
    ['layer2.1.conv1', 'layer2.1.bn1'],
    ['layer2.1.conv2', 'layer2.1.bn2'],
    ['layer2.2.conv1', 'layer2.2.bn1'],
    ['layer2.2.conv2', 'layer2.2.bn2'],
]

model = torch.quantization.fuse_modules(model, modules)
model.eval()
model_prepared = torch.quantization.prepare(model)
print("Number of observers:", len([m for m in model_prepared.modules() if 'Observer' in str(type(m))]))
model_prepared.eval()
with torch.no_grad():
    for batch, _ in calib_loader:
        batch = batch.to('cpu')
        model_prepared(batch)

quantized_model = torch.quantization.convert(model_prepared)


In [6]:
print(quantized_model)

QuantResNet(
  (quant): Quantize(scale=tensor([0.0203]), zero_point=tensor([120]), dtype=torch.quint8)
  (conv1): QuantizedConvReLU2d(3, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.03401070460677147, zero_point=0, padding=(1, 1))
  (bn1): Identity()
  (relu): Identity()
  (layer0): Sequential(
    (0): QuantResidualBlock(
      (conv1): QuantizedConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.054778438061475754, zero_point=146, padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU()
      (conv2): QuantizedConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.07266166806221008, zero_point=132, padding=(1, 1))
      (bn2): Identity()
      (skip_add): QFunctional(
        scale=0.0802651196718216, zero_point=113
        (activation_post_process): Identity()
      )
    )
    (1): QuantResidualBlock(
      (conv1): QuantizedConv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.049193125218153, zero_point=131, padding=(1, 1))
      (bn1): Identity()
      (relu):

This model is now tested on 10000 test images. We will check inference speed and accuracy.

In [7]:
import time
total=0
correct=0
total_time = 0
quantized_model.eval()
with torch.no_grad():
    for data in test_loader:

        inputs, labels = data
        inputs, labels = inputs.to('cpu'), labels.to('cpu')
        t0 = time.time()
        outputs = quantized_model(inputs)
        t1 = time.time()
        total_time += (t1-t0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
acc = correct/total
t_av = total_time/total #average inference time per image
print("Quantized ACC: ", acc*100, "Average Inference on Quantized: ", t_av)

Quantized ACC:  81.08999999999999 Average Inference on Quantized:  0.0019118833780288696


In [8]:
import torch.nn.init as init
class ResidualBlock(nn.Module):
        def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
            super(ResidualBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1)
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.relu=nn.ReLU()
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
            self.bn2 = nn.BatchNorm2d(out_channels)
            self.downsample = downsample
            self.out_channels = out_channels
            #self.skip_add = nn.quantized.FloatFunctional()
        def forward(self, x):
            residual = x
            y = self.conv1(x)
            y = self.bn1(y)
            y = self.relu(y)
            y = self.conv2(y)
            y = self.bn2(y)
            if self.downsample:
                residual = self.downsample(x)
            y += residual
            y = self.relu(y)
            return y
class ResNet(nn.Module):
        def __init__(self, block, layers, num_classes = 10):
            super(ResNet, self).__init__()
            self.inplanes = 16
            #self.quant = torch.quantization.QuantStub()
            self.conv1 = nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1)
            self.bn1 = nn.BatchNorm2d(16)
            self.relu=nn.ReLU()
            #self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
            self.layer0 = self._make_layer(block, 16, layers[0], stride = 1)
            self.layer1 = self._make_layer(block, 32, layers[1], stride = 2)
            self.layer2 = self._make_layer(block, 64, layers[2], stride = 2)
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            self.fc = nn.Linear(64, num_classes)
            self._initialize_weights()  # Apply He initialization
            #self.dequant = torch.quantization.DeQuantStub()
        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        init.zeros_(m.bias)

        def _make_layer(self, block, planes, blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes:

                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                    nn.BatchNorm2d(planes),
                )
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes
            for i in range(1, blocks):
                layers.append(block(self.inplanes, planes))

            return nn.Sequential(*layers)

        def forward(self, x):
            #x = self.quant(x)
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            #x = self.maxpool(x)
            x = self.layer0(x)
            x = self.layer1(x)
            x = self.layer2(x)

            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            #x = self.dequant(x)
            return x

In [None]:
modelq = ResNet(ResidualBlock, [3,3,3])
checkpoint = torch.load("/content/resnet20_fast.pth", map_location="cpu")
modelq.load_state_dict(checkpoint)


In [10]:
import time
total=0
correct=0
total_time = 0
modelq.eval()
with torch.no_grad():
    for data in test_loader:

        inputs, labels = data
        inputs, labels = inputs.to('cpu'), labels.to('cpu')
        t0 = time.time()
        outputs = modelq(inputs)
        t1 = time.time()
        total_time += (t1-t0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
acc = correct/total
t_av = total_time/total #average inference time per image
print("FP32 orginal model ACC: ", acc*100, "Average Inference on FP32: ", t_av)

FP32 orginal model ACC:  81.3 Average Inference on FP32:  0.0026718903303146364


# Our Scores

Accuracy Retention = (Quantized Accuracy / Original Accuracy) × 100
                   = *99.74%*

Inference Time=( Original Inference Time/ Quantized Inference Time)
              = *1.397x*

Overall = Accuracy Retention/100 + Inference Speedup/5 = 1.2768

Lets look at the others!!