In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import common_utils as utils
import os
print(utils.device)
utils.device = "cpu" # For quantization we use cpu
print(utils.device)

cpu
cpu


In [2]:
transform = transforms.Compose(
    [
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
     ])


trainset = torchvision.datasets.FER2013(root='./', split="train",
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.FER2013(root='./', split="test",
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=8)

In [2]:
# with open("pruned_model/checkpoint_1.pth", "rb") as file:
pruned_model = torch.load("pruned_model/checkpoint_1.pth", weights_only=False, map_location=utils.device)
pruned_model.to(utils.device)

base_model = utils.BaseModel()
base_model.load_state_dict(torch.load("base_model/checkpoint_6.pth", map_location=utils.device))
base_model.to(utils.device)

BaseModel(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(256, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(2048, 4192, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=4192, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=7, bias=True)
)

In [4]:
quantized_base_model = utils.quantize_model(base_model, testloader)
quantized_pruned_model = utils.quantize_model(pruned_model, testloader)



In [5]:
# utils.device = "cuda"
utils.benchmark_model(base_model, testloader, 100)
utils.benchmark_model(quantized_base_model, testloader, 100)
utils.benchmark_model(pruned_model, testloader, 100)
utils.benchmark_model(quantized_pruned_model, testloader, 100)

Total Images Processed :  112
Total Time :  4.387948274612427 seconds
Average Time:  0.03917810959475381 seconds
Average FPS:  25.5244576714826
Total Images Processed :  112
Total Time :  1.053718090057373 seconds
Average Time:  0.009408197232655116 seconds
Average FPS:  106.29028869942036
Total Images Processed :  112
Total Time :  0.49348998069763184 seconds
Average Time:  0.0044061605419431415 seconds
Average FPS:  226.9549623716149
Total Images Processed :  112
Total Time :  0.28965330123901367 seconds
Average Time:  0.0025861901896340506 seconds
Average FPS:  386.66916455262765


{'Total Images': 112,
 'tot_time': 0.28965330123901367,
 'mean_time': 0.0025861901896340506,
 'mean_fps': 386.66916455262765}

In [12]:
print(utils.test(base_model, testloader))
print(utils.test(quantized_base_model, testloader))
print(utils.test(pruned_model, testloader))
print(utils.test(quantized_pruned_model, testloader))

(0.08092272355550509, 57.32794650320424)
(0.0814350373019545, 56.96572861521315)
(0.07692547515761258, 56.85427695736974)
(0.07708757794411543, 56.32488158261354)


In [None]:
trace_input = torch.randn(1, 1, 96, 96)
traced_base_model = torch.jit.trace(base_model, trace_input)
trace_pruned_model = torch.jit.trace(pruned_model, trace_input)
traced_quantized_base_model = torch.jit.trace(quantized_base_model, trace_input)
traced_quantized_pruned_model = torch.jit.trace(quantized_pruned_model, trace_input)
torch.jit.save(traced_base_model, os.path.join("base_model", "jit_traced.pth"))
torch.jit.save(trace_pruned_model, os.path.join("pruned_model", "jit_traced.pth"))
torch.jit.save(traced_quantized_base_model, os.path.join("base_model", "jit_traced_quantized.pth"))
torch.jit.save(traced_quantized_pruned_model, os.path.join("pruned_model", "jit_traced_quantized.pth"))
