In [135]:
import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

<torch._C.Generator at 0x18b4838d0f0>

In [136]:
from torch.nn.utils import prune
from collections import OrderedDict
from models.googlenet import GoogleNet

model = GoogleNet()
model.load_state_dict(torch.load('googlenet-196-best.pth', map_location=torch.device('cpu')))
print(model)


GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1

In [137]:
model.state_dict().keys()

odict_keys(['prelayer.0.weight', 'prelayer.1.weight', 'prelayer.1.bias', 'prelayer.1.running_mean', 'prelayer.1.running_var', 'prelayer.1.num_batches_tracked', 'prelayer.3.weight', 'prelayer.4.weight', 'prelayer.4.bias', 'prelayer.4.running_mean', 'prelayer.4.running_var', 'prelayer.4.num_batches_tracked', 'prelayer.6.weight', 'prelayer.7.weight', 'prelayer.7.bias', 'prelayer.7.running_mean', 'prelayer.7.running_var', 'prelayer.7.num_batches_tracked', 'a3.b1.0.weight', 'a3.b1.0.bias', 'a3.b1.1.weight', 'a3.b1.1.bias', 'a3.b1.1.running_mean', 'a3.b1.1.running_var', 'a3.b1.1.num_batches_tracked', 'a3.b2.0.weight', 'a3.b2.0.bias', 'a3.b2.1.weight', 'a3.b2.1.bias', 'a3.b2.1.running_mean', 'a3.b2.1.running_var', 'a3.b2.1.num_batches_tracked', 'a3.b2.3.weight', 'a3.b2.3.bias', 'a3.b2.4.weight', 'a3.b2.4.bias', 'a3.b2.4.running_mean', 'a3.b2.4.running_var', 'a3.b2.4.num_batches_tracked', 'a3.b3.0.weight', 'a3.b3.0.bias', 'a3.b3.1.weight', 'a3.b3.1.bias', 'a3.b3.1.running_mean', 'a3.b3.1.runni

In [138]:
model.named_modules

<bound method Module.named_modules of GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1

In [139]:
'''for name, module in model.named_modules():
    for mini_name, mini_module in module.named_modules():
        # prune 20% of connections in all 2D-conv layers
        if isinstance(mini_module, torch.nn.Conv2d):
            print(f"Sparsity in {mini_name} before: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")
            prune.l1_unstructured(mini_module, name='weight', amount=0.4)
            prune.remove(mini_module, 'weight')
            print(f"Sparsity in {mini_name} after: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")
        # prune 40% of connections in all linear layers

print(dict(model.named_buffers()).keys())  # to verify that all masks exist'''

'for name, module in model.named_modules():\n    for mini_name, mini_module in module.named_modules():\n        # prune 20% of connections in all 2D-conv layers\n        if isinstance(mini_module, torch.nn.Conv2d):\n            print(f"Sparsity in {mini_name} before: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")\n            prune.l1_unstructured(mini_module, name=\'weight\', amount=0.4)\n            prune.remove(mini_module, \'weight\')\n            print(f"Sparsity in {mini_name} after: {100. * float(torch.sum(mini_module.weight == 0))/ float(mini_module.weight.nelement()):.2f}%")\n        # prune 40% of connections in all linear layers\n\nprint(dict(model.named_buffers()).keys())  # to verify that all masks exist'

## Quantization using Static Quantization

In [147]:
from models.quantized_googlenet import Quantized_Googlenet

quantized_model = Quantized_Googlenet(model)

quantized_model.eval()

quantized_model.qconfig = torch.ao.quantization.get_default_qconfig('x86')

quantized_model = torch.ao.quantization.prepare(quantized_model)

model_int8 = torch.ao.quantization.convert(quantized_model)




In [148]:
model_int8

Quantized_Googlenet(
  (model): GoogleNet(
    (prelayer): Sequential(
      (0): QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (4): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): QuantizedConv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (7): QuantizedBatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
    )
    (a3): Inception(
      (b1): Sequential(
        (0): QuantizedConv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
        (1): Quantiz

In [None]:
torch.save(model_int8.state_dict(), 'quantized_googlenet.pth')