In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torch.optim as optim
import common_utils as utils
import wandb
from collections import defaultdict
import torch_pruning as tp
import gc
import matplotlib.pyplot as plt
import pickle
import torch.quantization
import copy
import os
print(utils.device)
utils.device = "cpu" # For quantization we use cpu
print(utils.device)

cpu
cpu


In [6]:
transform = transforms.Compose(
    [
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
     ])


trainset = torchvision.datasets.FER2013(root='./', split="train",
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.FER2013(root='./', split="test",
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=8)

In [7]:
with open("pruned_model/checkpoint_1.pth", "rb") as file:
    pruned_model = torch.load(file, weights_only=False, map_location=utils.device)
pruned_model.to(utils.device)

base_model = utils.BaseModel()
base_model.load_state_dict(torch.load("base_model/checkpoint_6.pth"))
base_model.to(utils.device)

  base_model.load_state_dict(torch.load("base_model/checkpoint_6.pth"))


BaseModel(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(2048, 4192, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=4192, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=7, bias=True)
)

In [8]:
backend = "fbgemm"
quantized_base_model= utils.QuantizedModelWrapper(copy.deepcopy(base_model), None)
quantized_base_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(quantized_base_model, inplace=True)
utils.test(quantized_base_model, testloader, 500)
torch.quantization.convert(quantized_base_model, inplace=True)



QuantizedModelWrapper(
  (model): BaseModel(
    (conv1): QuantizedConv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.020121701061725616, zero_point=67, padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): QuantizedConv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.011837604455649853, zero_point=76, padding=(1, 1))
    (conv3): QuantizedConv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.008719135075807571, zero_point=89, padding=(1, 1))
    (conv4): QuantizedConv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1), scale=0.004681930877268314, zero_point=93, padding=(1, 1))
    (conv5): QuantizedConv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), scale=0.004713323898613453, zero_point=90, padding=(1, 1))
    (conv6): QuantizedConv2d(2048, 4192, kernel_size=(3, 3), stride=(1, 1), scale=0.011218554340302944, zero_point=69)
    (fc1): QuantizedLinear(in_features=4192, out_features=2048, scale=0.029412638396024

In [9]:
backend = "fbgemm"
quantized_pruned_model= utils.QuantizedModelWrapper(copy.deepcopy(pruned_model), None)
quantized_pruned_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(quantized_pruned_model, inplace=True)
utils.test(quantized_pruned_model, testloader, 500)
torch.quantization.convert(quantized_pruned_model, inplace=True)



QuantizedModelWrapper(
  (model): BaseModel(
    (conv1): QuantizedConv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.020843572914600372, zero_point=67, padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): QuantizedConv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.013500155881047249, zero_point=73, padding=(1, 1))
    (conv3): QuantizedConv2d(32, 102, kernel_size=(3, 3), stride=(1, 1), scale=0.00780163798481226, zero_point=83, padding=(1, 1))
    (conv4): QuantizedConv2d(102, 409, kernel_size=(3, 3), stride=(1, 1), scale=0.004067199304699898, zero_point=88, padding=(1, 1))
    (conv5): QuantizedConv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), scale=0.003621872281655669, zero_point=81, padding=(1, 1))
    (conv6): QuantizedConv2d(409, 838, kernel_size=(3, 3), stride=(1, 1), scale=0.00809921883046627, zero_point=48)
    (fc1): QuantizedLinear(in_features=838, out_features=614, scale=0.024973023682832718, zero

In [14]:
# utils.device = "cuda"
utils.benchmark_model(base_model, testloader, 100)
utils.benchmark_model(quantized_base_model, testloader, 100)
utils.benchmark_model(pruned_model, testloader, 100)
utils.benchmark_model(quantized_pruned_model, testloader, 100)

Total Images Processed :  112
Total Time :  1.8826067447662354 seconds
Average Time:  0.016808988792555674 seconds
Average FPS:  59.49197850871777
Total Images Processed :  112
Total Time :  0.7001886367797852 seconds
Average Time:  0.006251684256962368 seconds
Average FPS:  159.9568946378444
Total Images Processed :  112
Total Time :  0.30153679847717285 seconds
Average Time:  0.0026922928435461862 seconds
Average FPS:  371.43061996288554
Total Images Processed :  112
Total Time :  0.21832847595214844 seconds
Average Time:  0.0019493613924298967 seconds
Average FPS:  512.9885119728831


{'Total Images': 112,
 'tot_time': 0.21832847595214844,
 'mean_time': 0.0019493613924298967,
 'mean_fps': 512.9885119728831}

In [12]:
print(utils.test(base_model, testloader))
print(utils.test(quantized_base_model, testloader))
print(utils.test(pruned_model, testloader))
print(utils.test(quantized_pruned_model, testloader))

(0.08092272355550509, 57.32794650320424)
(0.0814350373019545, 56.96572861521315)
(0.07692547515761258, 56.85427695736974)
(0.07708757794411543, 56.32488158261354)


In [15]:
torch.save(quantized_base_model, os.path.join("base_model", "quantized.pth"))
torch.save(quantized_pruned_model, os.path.join("pruned_model", "quantized.pth"))
