In [174]:
import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

<torch._C.Generator at 0x18b4838d0f0>

In [175]:
from torch.nn.utils import prune
from collections import OrderedDict
from models.googlenet import GoogleNet
from models.quantized_googlenet import quantized_googlenet

model = GoogleNet()
model.load_state_dict(torch.load('googlenet-196-best.pth', map_location=torch.device('cpu')))
print(model)


GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1

In [176]:
model.state_dict().keys()

odict_keys(['prelayer.0.weight', 'prelayer.1.weight', 'prelayer.1.bias', 'prelayer.1.running_mean', 'prelayer.1.running_var', 'prelayer.1.num_batches_tracked', 'prelayer.3.weight', 'prelayer.4.weight', 'prelayer.4.bias', 'prelayer.4.running_mean', 'prelayer.4.running_var', 'prelayer.4.num_batches_tracked', 'prelayer.6.weight', 'prelayer.7.weight', 'prelayer.7.bias', 'prelayer.7.running_mean', 'prelayer.7.running_var', 'prelayer.7.num_batches_tracked', 'a3.b1.0.weight', 'a3.b1.0.bias', 'a3.b1.1.weight', 'a3.b1.1.bias', 'a3.b1.1.running_mean', 'a3.b1.1.running_var', 'a3.b1.1.num_batches_tracked', 'a3.b2.0.weight', 'a3.b2.0.bias', 'a3.b2.1.weight', 'a3.b2.1.bias', 'a3.b2.1.running_mean', 'a3.b2.1.running_var', 'a3.b2.1.num_batches_tracked', 'a3.b2.3.weight', 'a3.b2.3.bias', 'a3.b2.4.weight', 'a3.b2.4.bias', 'a3.b2.4.running_mean', 'a3.b2.4.running_var', 'a3.b2.4.num_batches_tracked', 'a3.b3.0.weight', 'a3.b3.0.bias', 'a3.b3.1.weight', 'a3.b3.1.bias', 'a3.b3.1.running_mean', 'a3.b3.1.runni

In [177]:
model.named_modules

<bound method Module.named_modules of GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1

In [178]:
'''for name, module in model.named_modules():
    for mini_name, mini_module in module.named_modules():
        # prune 20% of connections in all 2D-conv layers
        if isinstance(mini_module, torch.nn.Conv2d):
            print(f"Sparsity in {mini_name} before: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")
            prune.l1_unstructured(mini_module, name='weight', amount=0.4)
            prune.remove(mini_module, 'weight')
            print(f"Sparsity in {mini_name} after: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")
        # prune 40% of connections in all linear layers

print(dict(model.named_buffers()).keys())  # to verify that all masks exist'''

'for name, module in model.named_modules():\n    for mini_name, mini_module in module.named_modules():\n        # prune 20% of connections in all 2D-conv layers\n        if isinstance(mini_module, torch.nn.Conv2d):\n            print(f"Sparsity in {mini_name} before: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")\n            prune.l1_unstructured(mini_module, name=\'weight\', amount=0.4)\n            prune.remove(mini_module, \'weight\')\n            print(f"Sparsity in {mini_name} after: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")\n        # prune 40% of connections in all linear layers\n\nprint(dict(model.named_buffers()).keys())  # to verify that all masks exist'

In [179]:

modules_to_fuse = [
    ['model.prelayer.0', 'model.prelayer.1', 'model.prelayer.2'],
    ['model.prelayer.3', 'model.prelayer.4', 'model.prelayer.5'],
    ['model.prelayer.6', 'model.prelayer.7', 'model.prelayer.8'],

    ['model.a3.b1.0', 'model.a3.b1.1', 'model.a3.b1.2'],
    ['model.a3.b2.0', 'model.a3.b2.1', 'model.a3.b2.2'],
    ['model.a3.b2.3', 'model.a3.b2.4', 'model.a3.b2.5'],
    ['model.a3.b3.0', 'model.a3.b3.1', 'model.a3.b3.2'],
    ['model.a3.b3.3', 'model.a3.b3.4', 'model.a3.b3.5'],
    ['model.a3.b3.6', 'model.a3.b3.7', 'model.a3.b3.8'],
    ['model.a3.b4.1', 'model.a3.b4.2', 'model.a3.b4.3'],

    ['model.b3.b1.0', 'model.b3.b1.1', 'model.b3.b1.2'],
    ['model.b3.b2.0', 'model.b3.b2.1', 'model.b3.b2.2'],
    ['model.b3.b2.3', 'model.b3.b2.4', 'model.b3.b2.5'],
    ['model.b3.b3.0', 'model.b3.b3.1', 'model.b3.b3.2'],
    ['model.b3.b3.3', 'model.b3.b3.4', 'model.b3.b3.5'],
    ['model.b3.b3.6', 'model.b3.b3.7', 'model.b3.b3.8'],
    ['model.b3.b4.1', 'model.b3.b4.2', 'model.b3.b4.3'],

    ['model.a4.b1.0', 'model.a4.b1.1', 'model.a4.b1.2'],
    ['model.a4.b2.0', 'model.a4.b2.1', 'model.a4.b2.2'],
    ['model.a4.b2.3', 'model.a4.b2.4', 'model.a4.b2.5'],
    ['model.a4.b3.0', 'model.a4.b3.1', 'model.a4.b3.2'],
    ['model.a4.b3.3', 'model.a4.b3.4', 'model.a4.b3.5'],
    ['model.a4.b3.6', 'model.a4.b3.7', 'model.a4.b3.8'],
    ['model.a4.b4.1', 'model.a4.b4.2', 'model.a4.b4.3'],

    ['model.b4.b1.0', 'model.b4.b1.1', 'model.b4.b1.2'],
    ['model.b4.b2.0', 'model.b4.b2.1', 'model.b4.b2.2'],
    ['model.b4.b2.3', 'model.b4.b2.4', 'model.b4.b2.5'],
    ['model.b4.b3.0', 'model.b4.b3.1', 'model.b4.b3.2'],
    ['model.b4.b3.3', 'model.b4.b3.4', 'model.b4.b3.5'],
    ['model.b4.b3.6', 'model.b4.b3.7', 'model.b4.b3.8'],
    ['model.b4.b4.1', 'model.b4.b4.2', 'model.b4.b4.3'],

    ['model.c4.b1.0', 'model.c4.b1.1', 'model.c4.b1.2'],
    ['model.c4.b2.0', 'model.c4.b2.1', 'model.c4.b2.2'],
    ['model.c4.b2.3', 'model.c4.b2.4', 'model.c4.b2.5'],
    ['model.c4.b3.0', 'model.c4.b3.1', 'model.c4.b3.2'],
    ['model.c4.b3.3', 'model.c4.b3.4', 'model.c4.b3.5'],
    ['model.c4.b3.6', 'model.c4.b3.7', 'model.c4.b3.8'],
    ['model.c4.b4.1', 'model.c4.b4.2', 'model.c4.b4.3'],

    ['model.d4.b1.0', 'model.d4.b1.1', 'model.d4.b1.2'],
    ['model.d4.b2.0', 'model.d4.b2.1', 'model.d4.b2.2'],
    ['model.d4.b2.3', 'model.d4.b2.4', 'model.d4.b2.5'],
    ['model.d4.b3.0', 'model.d4.b3.1', 'model.d4.b3.2'],
    ['model.d4.b3.3', 'model.d4.b3.4', 'model.d4.b3.5'],
    ['model.d4.b3.6', 'model.d4.b3.7', 'model.d4.b3.8'],
    ['model.d4.b4.1', 'model.d4.b4.2', 'model.d4.b4.3'],

    ['model.e4.b1.0', 'model.e4.b1.1', 'model.e4.b1.2'],
    ['model.e4.b2.0', 'model.e4.b2.1', 'model.e4.b2.2'],
    ['model.e4.b2.3', 'model.e4.b2.4', 'model.e4.b2.5'],
    ['model.e4.b3.0', 'model.e4.b3.1', 'model.e4.b3.2'],
    ['model.e4.b3.3', 'model.e4.b3.4', 'model.e4.b3.5'],
    ['model.e4.b3.6', 'model.e4.b3.7', 'model.e4.b3.8'],
    ['model.e4.b4.1', 'model.e4.b4.2', 'model.e4.b4.3'],

    ['model.a5.b1.0', 'model.a5.b1.1', 'model.a5.b1.2'],
    ['model.a5.b2.0', 'model.a5.b2.1', 'model.a5.b2.2'],
    ['model.a5.b2.3', 'model.a5.b2.4', 'model.a5.b2.5'],
    ['model.a5.b3.0', 'model.a5.b3.1', 'model.a5.b3.2'],
    ['model.a5.b3.3', 'model.a5.b3.4', 'model.a5.b3.5'],
    ['model.a5.b3.6', 'model.a5.b3.7', 'model.a5.b3.8'],
    ['model.a5.b4.1', 'model.a5.b4.2', 'model.a5.b4.3'],

    ['model.b5.b1.0', 'model.b5.b1.1', 'model.b5.b1.2'],
    ['model.b5.b2.0', 'model.b5.b2.1', 'model.b5.b2.2'],
    ['model.b5.b2.3', 'model.b5.b2.4', 'model.b5.b2.5'],
    ['model.b5.b3.0', 'model.b5.b3.1', 'model.b5.b3.2'],
    ['model.b5.b3.3', 'model.b5.b3.4', 'model.b5.b3.5'],
    ['model.b5.b3.6', 'model.b5.b3.7', 'model.b5.b3.8'],
    ['model.b5.b4.1', 'model.b5.b4.2', 'model.b5.b4.3']
]


## Quantization using Static Quantization

In [180]:
from models.quantized_googlenet import Quantized_Googlenet

quantized_model = Quantized_Googlenet(model)

quantized_model.eval()

quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('x86')

quantized_model_fused = torch.ao.quantization.fuse_modules(quantized_model, modules_to_fuse)

quantized_model = torch.ao.quantization.prepare(quantized_model_fused)

model_int8 = torch.ao.quantization.convert(quantized_model)


AttributeError: 'QConfigMapping' object has no attribute '_fields'

In [None]:
model_int8

Quantized_Googlenet(
  (model): GoogleNet(
    (prelayer): Sequential(
      (0): QuantizedConvReLU2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (1): Identity()
      (2): Identity()
      (3): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (4): Identity()
      (5): Identity()
      (6): QuantizedConvReLU2d(64, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (7): Identity()
      (8): Identity()
    )
    (a3): Inception(
      (b1): Sequential(
        (0): QuantizedConvReLU2d(192, 64, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
        (1): Identity()
        (2): Identity()
      )
      (b2): Sequential(
        (0): QuantizedConvReLU2d(192, 96, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
        (1): Identity()
        (2): Identity()
        (3): QuantizedConvReLU2d(96, 128, kernel_size=(3, 3), stride=

In [None]:
torch.save(model_int8.state_dict(), 'quantized_googlenet.pth')