In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Subset

In [2]:
#data loader adopted from training
def data_loader(data_dir,
                    batch_size,
                    random_seed=42,
                    valid_size=0.1,
                    shuffle=True,
                    test=False):

        normalize = transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )

        # define transforms
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        download = not os.path.exists(os.path.join(data_dir, "cifar-10-batches-py"))




        if test:
          dataset = datasets.CIFAR10(
            root=data_dir, train=False,
            download=download, transform=transform,
          )
          indices = list(range(0, len(dataset)))
          np.random.seed(42)
          np.random.shuffle(indices)
          data = Subset(dataset,indices)
        else:
          dataset = datasets.CIFAR10(
            root=data_dir, train=True,
            download=download, transform=transform,
          )
          indices = list(range(0, len(dataset)))
          np.random.seed(42)
          np.random.shuffle(indices)
          data = Subset(dataset,indices)

        data_loader = torch.utils.data.DataLoader(
            data, batch_size=batch_size, shuffle=shuffle
        )



        return data_loader


calib_loader = data_loader(data_dir='/content/rn',
                                  batch_size=32,
                                  test=False)
test_loader = data_loader(data_dir='/content/rn',
                                  batch_size=32,
                                  test=True)

100%|██████████| 170M/170M [00:03<00:00, 46.8MB/s]


Architecture adopted from inceptionnet.ipynb

In [4]:
import torch
import torch.nn as nn

class InceptionNet(nn.Module):
    def __init__(self, num_classes=10):
        super(InceptionNet, self).__init__()
        self.in_channels = 3


        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
        self.bn1_1 = nn.BatchNorm2d(64)
        self.relu1_1 = nn.ReLU(inplace=True)


        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.bn1_2 = nn.BatchNorm2d(64)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)


        #ib1

        #1X1
        #for 3x3
        self.conv2_1 = nn.Conv2d(64, 32, kernel_size=1, padding='same')
        self.bn2_1 = nn.BatchNorm2d(32)
        self.relu2_1 = nn.ReLU(inplace=True)

        #for 5x5
        self.conv2_2 = nn.Conv2d(64, 8, kernel_size=1, padding='same')
        self.bn2_2 = nn.BatchNorm2d(8)
        self.relu2_2 = nn.ReLU(inplace=True)



        #maxpool
        self.pool2_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        #3x3
        self.conv2_4 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.bn2_4 = nn.BatchNorm2d(64)
        self.relu2_4 = nn.ReLU(inplace=True)

        #5x5
        self.conv2_5 = nn.Conv2d(8, 16, kernel_size=5, padding='same')
        self.bn2_5 = nn.BatchNorm2d(16)
        self.relu2_5 = nn.ReLU(inplace=True)

        #for 1x1
        self.conv2_3 = nn.Conv2d(64, 32, kernel_size=1, padding='same')
        self.bn2_3 = nn.BatchNorm2d(32)
        self.relu2_3 = nn.ReLU(inplace=True)

        #after maxpool
        self.conv2_6 = nn.Conv2d(64, 16, kernel_size=1, padding='same')
        self.bn2_6 = nn.BatchNorm2d(16)
        self.relu2_6 = nn.ReLU(inplace=True)


        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        #ib2

        #1X1
        #for 3x3
        self.conv3_1 = nn.Conv2d(128, 64, kernel_size=1, padding='same')
        self.bn3_1 = nn.BatchNorm2d(64)
        self.relu3_1 = nn.ReLU(inplace=True)

        #for 5x5
        self.conv3_2 = nn.Conv2d(128, 16, kernel_size=1, padding='same')
        self.bn3_2 = nn.BatchNorm2d(16)
        self.relu3_2 = nn.ReLU(inplace=True)

        #for 1x1
        self.conv3_3 = nn.Conv2d(128, 64, kernel_size=1, padding='same')
        self.bn3_3 = nn.BatchNorm2d(64)
        self.relu3_3 = nn.ReLU(inplace=True)

        #maxpool
        self.pool3_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        #3x3
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, padding='same')
        self.bn3_4 = nn.BatchNorm2d(128)
        self.relu3_4 = nn.ReLU(inplace=True)

        #5x5
        self.conv3_5 = nn.Conv2d(16, 32, kernel_size=5, padding='same')
        self.bn3_5 = nn.BatchNorm2d(32)
        self.relu3_5 = nn.ReLU(inplace=True)

        #after maxpool
        self.conv3_6 = nn.Conv2d(128, 32, kernel_size=1, padding='same')
        self.bn3_6 = nn.BatchNorm2d(32)
        self.relu3_6 = nn.ReLU(inplace=True)


        #final averagePool
        self.avg_pool = nn.AvgPool2d(kernel_size=8, stride=2)

        self.fc1 = nn.Linear(256, 2048)
        self.relu_fc1 = nn.ReLU(inplace=True)
        # self.fc2 = nn.Linear(4096, 4096)
        # self.relu_fc2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(2048, num_classes)


        self._initialize_weights()

    def forward(self, x):
      x = self.relu1_1(self.bn1_1(self.conv1_1(x)))
      x = self.relu1_2(self.bn1_2(self.conv1_2(x)))
      x = self.pool1(x)

      x_3x3 = self.relu2_1(self.bn2_1(self.conv2_1(x)))
      x_5x5 = self.relu2_2(self.bn2_2(self.conv2_2(x)))
      x_1x1 = self.relu2_3(self.bn2_3(self.conv2_3(x)))
      x_pool = self.pool2_1(x)

      x_3x3 = self.relu2_4(self.bn2_4(self.conv2_4(x_3x3)))
      x_5x5 = self.relu2_5(self.bn2_5(self.conv2_5(x_5x5)))
      x_pool = self.relu2_6(self.bn2_6(self.conv2_6(x_pool)))

      x = torch.cat([x_1x1, x_3x3, x_5x5, x_pool], dim=1)
      x = self.maxpool(x)

      x_3x3 = self.relu3_1(self.bn3_1(self.conv3_1(x)))
      x_5x5 = self.relu3_2(self.bn3_2(self.conv3_2(x)))
      x_1x1 = self.relu3_3(self.bn3_3(self.conv3_3(x)))
      x_pool = self.pool3_1(x)

      x_3x3 = self.relu3_4(self.bn3_4(self.conv3_4(x_3x3)))
      x_5x5 = self.relu3_5(self.bn3_5(self.conv3_5(x_5x5)))
      x_pool = self.relu3_6(self.bn3_6(self.conv3_6(x_pool)))

      x = torch.cat([x_1x1, x_3x3, x_5x5, x_pool], dim=1)
      x = self.avg_pool(x)

      x = x.view(x.size(0), -1)
      x = self.relu_fc1(self.fc1(x))
      x = self.fc3(x)

      return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


Quantized version

In [14]:
import torch
import torch.nn as nn

#padding='same' doesnt work for pytorch quantization processes
# so we need to set up same convolutions manually
def get_same_padding(i, k, s):
  return (s*(i-1) + k-i)//2


class QuantInceptionNet(nn.Module):
    def __init__(self, num_classes=10):
        super(QuantInceptionNet, self).__init__()
        self.in_channels = 3
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
        self.bn1_1 = nn.BatchNorm2d(64)
        self.relu1_1 = nn.ReLU(inplace=True)


        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=get_same_padding(32, 3, 1))
        self.bn1_2 = nn.BatchNorm2d(64)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)


        #block1

        #1X1
        #for 3x3
        self.conv2_1 = nn.Conv2d(64, 32, kernel_size=1, padding=get_same_padding(16, 1, 1))
        self.bn2_1 = nn.BatchNorm2d(32)
        self.relu2_1 = nn.ReLU(inplace=True)

        #for 5x5
        self.conv2_2 = nn.Conv2d(64, 8, kernel_size=1, padding=get_same_padding(16, 1, 1))
        self.bn2_2 = nn.BatchNorm2d(8)
        self.relu2_2 = nn.ReLU(inplace=True)



        #maxpool
        self.pool2_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        #3x3
        self.conv2_4 = nn.Conv2d(32, 64, kernel_size=3, padding=get_same_padding(16, 3, 1))
        self.bn2_4 = nn.BatchNorm2d(64)
        self.relu2_4 = nn.ReLU(inplace=True)

        #5x5
        self.conv2_5 = nn.Conv2d(8, 16, kernel_size=5, padding=get_same_padding(16, 5, 1))
        self.bn2_5 = nn.BatchNorm2d(16)
        self.relu2_5 = nn.ReLU(inplace=True)

        #for 1x1
        self.conv2_3 = nn.Conv2d(64, 32, kernel_size=1, padding=get_same_padding(16, 1, 1))
        self.bn2_3 = nn.BatchNorm2d(32)
        self.relu2_3 = nn.ReLU(inplace=True)

        #after maxpool
        self.conv2_6 = nn.Conv2d(64, 16, kernel_size=1, padding=get_same_padding(16, 1, 1))
        self.bn2_6 = nn.BatchNorm2d(16)
        self.relu2_6 = nn.ReLU(inplace=True)


        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        #block2

        #1X1
        #for 3x3
        self.conv3_1 = nn.Conv2d(128, 64, kernel_size=1, padding=get_same_padding(8, 1, 1))
        self.bn3_1 = nn.BatchNorm2d(64)
        self.relu3_1 = nn.ReLU(inplace=True)

        #for 5x5
        self.conv3_2 = nn.Conv2d(128, 16, kernel_size=1, padding=get_same_padding(8, 1, 1))
        self.bn3_2 = nn.BatchNorm2d(16)
        self.relu3_2 = nn.ReLU(inplace=True)

        #for 1x1
        self.conv3_3 = nn.Conv2d(128, 64, kernel_size=1, padding=get_same_padding(8, 1, 1))
        self.bn3_3 = nn.BatchNorm2d(64)
        self.relu3_3 = nn.ReLU(inplace=True)

        #maxpool
        self.pool3_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        #3x3
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, padding=get_same_padding(8, 3, 1))
        self.bn3_4 = nn.BatchNorm2d(128)
        self.relu3_4 = nn.ReLU(inplace=True)

        #5x5
        self.conv3_5 = nn.Conv2d(16, 32, kernel_size=5, padding=get_same_padding(8, 5, 1))
        self.bn3_5 = nn.BatchNorm2d(32)
        self.relu3_5 = nn.ReLU(inplace=True)

        #after maxpool
        self.conv3_6 = nn.Conv2d(128, 32, kernel_size=1, padding=get_same_padding(8, 1, 1))
        self.bn3_6 = nn.BatchNorm2d(32)
        self.relu3_6 = nn.ReLU(inplace=True)


        #final averagePool
        self.avg_pool = nn.AvgPool2d(kernel_size=8, stride=2)

        self.fc1 = nn.Linear(256, 2048)
        self.relu_fc1 = nn.ReLU(inplace=True)
        # self.fc2 = nn.Linear(4096, 4096)
        # self.relu_fc2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(2048, num_classes)


        self._initialize_weights()

    def forward(self, x):
      x = self.quant(x)
      x = self.relu1_1(self.bn1_1(self.conv1_1(x)))
      x = self.relu1_2(self.bn1_2(self.conv1_2(x)))
      x = self.pool1(x)

      x_3x3 = self.relu2_1(self.bn2_1(self.conv2_1(x)))
      x_5x5 = self.relu2_2(self.bn2_2(self.conv2_2(x)))
      x_1x1 = self.relu2_3(self.bn2_3(self.conv2_3(x)))
      x_pool = self.pool2_1(x)

      x_3x3 = self.relu2_4(self.bn2_4(self.conv2_4(x_3x3)))
      x_5x5 = self.relu2_5(self.bn2_5(self.conv2_5(x_5x5)))
      x_pool = self.relu2_6(self.bn2_6(self.conv2_6(x_pool)))

      x = torch.cat([x_1x1, x_3x3, x_5x5, x_pool], dim=1)
      x = self.maxpool(x)

      x_3x3 = self.relu3_1(self.bn3_1(self.conv3_1(x)))
      x_5x5 = self.relu3_2(self.bn3_2(self.conv3_2(x)))
      x_1x1 = self.relu3_3(self.bn3_3(self.conv3_3(x)))
      x_pool = self.pool3_1(x)

      x_3x3 = self.relu3_4(self.bn3_4(self.conv3_4(x_3x3)))
      x_5x5 = self.relu3_5(self.bn3_5(self.conv3_5(x_5x5)))
      x_pool = self.relu3_6(self.bn3_6(self.conv3_6(x_pool)))

      x = torch.cat([x_1x1, x_3x3, x_5x5, x_pool], dim=1)
      x = self.avg_pool(x)

      x = x.view(x.size(0), -1)
      x = self.relu_fc1(self.fc1(x))
      x = self.fc3(x)
      x = self.dequant(x)

      return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)





In [16]:
model = QuantInceptionNet()
model.eval()
checkpoint = torch.load("[PATH HERE]", map_location="cpu")
model.load_state_dict(checkpoint['model_state'])


model.qconfig = torch.quantization.QConfig(
    activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8),
    weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8)
)

modules = [
    ['conv1_1', 'bn1_1', 'relu1_1'],
    ['conv1_2', 'bn1_2', 'relu1_2'],
    ['conv2_1', 'bn2_1', 'relu2_1'],
    ['conv2_2', 'bn2_2', 'relu2_2'],
    ['conv2_3', 'bn2_3', 'relu2_3'],
    ['conv2_4', 'bn2_4', 'relu2_4'],
    ['conv2_5', 'bn2_5', 'relu2_5'],
    ['conv2_6', 'bn2_6', 'relu2_6'],
    ['conv3_1', 'bn3_1', 'relu3_1'],
    ['conv3_2', 'bn3_2', 'relu3_2'],
    ['conv3_3', 'bn3_3', 'relu3_3'],
    ['conv3_4', 'bn3_4', 'relu3_4'],
    ['conv3_5', 'bn3_5', 'relu3_5'],
    ['conv3_6', 'bn3_6', 'relu3_6'],
    ['fc1', 'relu_fc1']
]

model = torch.quantization.fuse_modules(model, modules)

model.eval()
model_prepared = torch.quantization.prepare(model)
print("Number of observers:", len([m for m in model_prepared.modules() if 'Observer' in str(type(m))]))
model_prepared.eval()
with torch.no_grad():
    for batch, _ in calib_loader:
        batch = batch.to('cpu')
        model_prepared(batch)

quantized_model = torch.quantization.convert(model_prepared)


Number of observers: 17


In [17]:
print(model_prepared)

QuantInceptionNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-2.429065704345703, max_val=2.7537312507629395)
  )
  (dequant): DeQuantStub()
  (conv1_1): ConvReLU2d(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=8.551780700683594)
  )
  (bn1_1): Identity()
  (relu1_1): Identity()
  (conv1_2): ConvReLU2d(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=4.8796067237854)
  )
  (bn1_2): Identity()
  (relu1_2): Identity()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): ConvReLU2d(
    (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU(inplace=True)
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=3.6504292488098145)
  )
  (bn2_1): I

In [19]:
modelfp32 = InceptionNet()
checkpoint = torch.load("[PATH HERE]", map_location="cpu")
modelfp32.load_state_dict(checkpoint['model_state'])

<All keys matched successfully>

In [18]:
import time
total=0
correct=0
total_time = 0
quantized_model.eval()
with torch.no_grad():
    for data in test_loader:

        inputs, labels = data
        inputs, labels = inputs.to('cpu'), labels.to('cpu')
        t0 = time.time()
        outputs = quantized_model(inputs)
        t1 = time.time()
        total_time += (t1-t0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
acc = correct/total
t_av = total_time/total #average inference time per image
print("Quantized ACC: ", acc*100, "Average Inference on Quantized: ", t_av)

  x = torch.cat([x_1x1, x_3x3, x_5x5, x_pool], dim=1)
  x = torch.cat([x_1x1, x_3x3, x_5x5, x_pool], dim=1)


Quantized ACC:  80.11 Average Inference on Quantized:  0.002119420337677002


In [20]:

import time
total=0
correct=0
total_time = 0
modelfp32.eval()
with torch.no_grad():
    for data in test_loader:

        inputs, labels = data
        inputs, labels = inputs.to('cpu'), labels.to('cpu')
        t0 = time.time()
        outputs = modelfp32(inputs)
        t1 = time.time()
        total_time += (t1-t0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
acc = correct/total
t_av = total_time/total #average inference time per image
print("fp32 ACC: ", acc*100, "Average Inference on fp32: ", t_av)

fp32 ACC:  80.36999999999999 Average Inference on fp32:  0.004519141173362732


# Our Scores
Accuracy Retention = (Quantized Accuracy / Original Accuracy) × 100 = 99.68%

Inference Time=( Original Inference Time/ Quantized Inference Time) = 2.133x


