In [82]:
import os
import sys
import time
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
torch.manual_seed(191009)

<torch._C.Generator at 0x151fa869170>

In [83]:
from torch.nn.utils import prune
from collections import OrderedDict
from models.googlenet import GoogleNet

model = GoogleNet()
model.load_state_dict(torch.load('googlenet-196-best.pth', map_location=torch.device('cpu')))
print(model)


GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1

In [84]:
model.state_dict().keys()

odict_keys(['prelayer.0.weight', 'prelayer.1.weight', 'prelayer.1.bias', 'prelayer.1.running_mean', 'prelayer.1.running_var', 'prelayer.1.num_batches_tracked', 'prelayer.3.weight', 'prelayer.4.weight', 'prelayer.4.bias', 'prelayer.4.running_mean', 'prelayer.4.running_var', 'prelayer.4.num_batches_tracked', 'prelayer.6.weight', 'prelayer.7.weight', 'prelayer.7.bias', 'prelayer.7.running_mean', 'prelayer.7.running_var', 'prelayer.7.num_batches_tracked', 'a3.b1.0.weight', 'a3.b1.0.bias', 'a3.b1.1.weight', 'a3.b1.1.bias', 'a3.b1.1.running_mean', 'a3.b1.1.running_var', 'a3.b1.1.num_batches_tracked', 'a3.b2.0.weight', 'a3.b2.0.bias', 'a3.b2.1.weight', 'a3.b2.1.bias', 'a3.b2.1.running_mean', 'a3.b2.1.running_var', 'a3.b2.1.num_batches_tracked', 'a3.b2.3.weight', 'a3.b2.3.bias', 'a3.b2.4.weight', 'a3.b2.4.bias', 'a3.b2.4.running_mean', 'a3.b2.4.running_var', 'a3.b2.4.num_batches_tracked', 'a3.b3.0.weight', 'a3.b3.0.bias', 'a3.b3.1.weight', 'a3.b3.1.bias', 'a3.b3.1.running_mean', 'a3.b3.1.runni

In [85]:
model.named_modules

<bound method Module.named_modules of GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1

In [86]:
'''for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        print(f"Sparsity in {name} before: {100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement()):.2f}%")
        prune.l1_unstructured(module, name='weight', amount=0.5)
        prune.remove(module, 'weight')
        print(f"Sparsity in {name} after: {100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement()):.2f}%")
    # prune 40% of connections in all linear layers

print(dict(model.named_buffers()).keys())  # to verify that all masks exist'''

'for name, module in model.named_modules():\n    if isinstance(module, torch.nn.Conv2d):\n        print(f"Sparsity in {name} before: {100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement()):.2f}%")\n        prune.l1_unstructured(module, name=\'weight\', amount=0.5)\n        prune.remove(module, \'weight\')\n        print(f"Sparsity in {name} after: {100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement()):.2f}%")\n    # prune 40% of connections in all linear layers\n\nprint(dict(model.named_buffers()).keys())  # to verify that all masks exist'

In [87]:
torch.save(model.state_dict(), 'pruned_googlenet.pth')

## Structured Pruning

In [88]:
import torch
import torchvision
import torch_pruning as tp

In [89]:
imp = tp.importance.MagnitudeImportance(p=2)

ignored_layers = []
for m in model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == 1000: # ignore the classifier
    ignored_layers.append(m)

round_to = None
if isinstance( model, torchvision.models.vision_transformer.VisionTransformer):
  round_to = model.encoder.layers[0].num_heads # model-specific restriction

pruner = tp.pruner.MagnitudePruner(
    model = model,
    example_inputs = torch.randn(1,3,224,224),
    importance = imp,     # Importance Estimator
    global_pruning=False, # Please refer to Page 9 of https://www.cs.princeton.edu/courses/archive/spring21/cos598D/lectures/pruning.pdf
    ch_sparsity = 0.5,    # global sparsity for all layers
    #ch_sparsity_dict = {model.conv1: 0.2}, # manually set the sparsity of model.conv1
    iterative_steps = 1,  # number of steps to achieve the target ch_sparsity.
    ignored_layers = ignored_layers,        # ignore some layers such as the finall linear classifier
    round_to = round_to,  # round channels
    #unwrapped_parameters=[ (model.features[1][1].layer_scale, 0), (model.features[5][4].layer_scale, 0) ],
)

In [90]:
# Model size before pruning
base_macs, base_nparams = tp.utils.count_ops_and_params(model, torch.randn(1,3,32,32))
pruner.step()

# modify some inferece-related attributes if necessary
if isinstance(
    model, torchvision.models.vision_transformer.VisionTransformer
):  # ViT relies on the hidden_dim attribute for forwarding, so we have to modify this variable after pruning
    model.hidden_dim = model.conv_proj.out_channels

# Parameter & MACs Counter
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, torch.randn(1,3,32,32))
print("The pruned model:")
print(model)
print("Summary:")
print("Params: {:.2f} M => {:.2f} M".format(base_nparams/1e6, pruned_nparams/1e6))
print("MACs: {:.2f} G => {:.2f} G".format(base_macs/1e9, pruned_macs/1e9))

# Test Forward
output = model(torch.randn(1,3,32,32))
print("Output.shape: ", output.shape)

# Test Backward
loss = torch.nn.functional.cross_entropy(output, torch.randint(1, 10, (1,)))
loss.backward()

The pruned model:
GoogleNet(
  (prelayer): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(96, 48, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNo

In [92]:
torch.save(model.state_dict(), 'structured_pruned_googlenet.pth')

## Quantization using Static Quantization

In [16]:
modules_to_fuse = [
    ['prelayer.0', 'prelayer.1', 'prelayer.2'],
    ['prelayer.3', 'prelayer.4', 'prelayer.5'],
    ['prelayer.6', 'prelayer.7', 'prelayer.8'],
    ]

In [21]:
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity
import argparse

from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from conf import settings
from utils import get_network, get_test_dataloader

cifar100_test_loader = get_test_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        #settings.CIFAR100_PATH,
        num_workers=4,
        batch_size=16,
    )

Files already downloaded and verified


In [22]:
from models.quantized_googlenet import Quantized_Googlenet

quantized_model = Quantized_Googlenet()
quantized_model.load_state_dict(torch.load('googlenet-196-best.pth', map_location=torch.device('cpu')))
quantized_model.eval()
torch.quantization.fuse_modules(quantized_model, modules_to_fuse, inplace=True)

# Prepare
quantized_model.train()
quantized_model.prelayer.qconfig = torch.ao.quantization.get_default_qconfig('x86')
torch.quantization.prepare_qat(quantized_model, inplace=True)



"""Training Loop"""
n_epochs = 10
opt = torch.optim.SGD(quantized_model.parameters(), lr=0.1)
loss_fn = lambda out, tgt: torch.pow(tgt-out, 2).mean()
for epoch in range(n_epochs):
    for n_iter, (image, label) in enumerate(cifar100_test_loader):
        out = quantized_model(image)
        loss = loss_fn(out, torch.rand_like(out))
        opt.zero_grad()
        loss.backward()
        opt.step()

# Convert
quantized_model.eval()
torch.quantization.convert(quantized_model, inplace=True)




KeyboardInterrupt: 

In [18]:
quantized_model

Quantized_Googlenet(
  (prelayer): Sequential(
    (0): QuantizedConvReLU2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.009797699749469757, zero_point=0, padding=(1, 1))
    (1): Identity()
    (2): Identity()
    (3): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.01079961284995079, zero_point=0, padding=(1, 1))
    (4): Identity()
    (5): Identity()
    (6): QuantizedConvReLU2d(64, 192, kernel_size=(3, 3), stride=(1, 1), scale=0.006309532560408115, zero_point=0, padding=(1, 1))
    (7): Identity()
    (8): Identity()
  )
  (a3): Inception(
    (b1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
     

In [19]:
torch.save(quantized_model.state_dict(), 'quantized_googlenet1.pth')

In [38]:





correct_1 = 0.0
correct_5 = 0.0
total = 0
total_time = 0

with torch.no_grad():
    for n_iter, (image, label) in enumerate(cifar100_test_loader):
        print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader)))
        print(image.shape)

        # Measure time
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True) as prof:
            with record_function("model_inference"):
                output = model(image)
        
        #torch.cuda.synchronize()

        _, pred = output.topk(5, 1, largest=True, sorted=True)

        label = label.view(label.size(0), -1).expand_as(pred)
        correct = pred.eq(label).float()

        #compute top 5
        correct_5 += correct[:, :5].sum()

        #compute top1
        correct_1 += correct[:, :1].sum()
print()
print("Top 1 err: ", 1 - correct_1 / len(cifar100_test_loader.dataset))
print("Top 5 err: ", 1 - correct_5 / len(cifar100_test_loader.dataset))
print(f"Average inference time per image:miliseconds")
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

iteration: 1	total 625 iterations
torch.Size([16, 3, 32, 32])


  warn("CUDA is not available, disabling CUDA profiling")


iteration: 2	total 625 iterations
torch.Size([16, 3, 32, 32])
iteration: 3	total 625 iterations
torch.Size([16, 3, 32, 32])
iteration: 4	total 625 iterations
torch.Size([16, 3, 32, 32])
iteration: 5	total 625 iterations
torch.Size([16, 3, 32, 32])


KeyboardInterrupt: 

Test code