In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import common_utils as utils
import os
print(utils.device)
utils.device = "cpu" # For quantization we use cpu
print(utils.device)

In [None]:
transform = transforms.Compose(
    [
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
     ])


trainset = torchvision.datasets.FER2013(root='./', split="train",
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.FER2013(root='./', split="test",
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=8)

In [None]:
pruned_model = torch.load("pruned_model/checkpoint_1.pth", weights_only=False, map_location=utils.device)
pruned_model.to(utils.device)

base_model = utils.BaseModel()
base_model.load_state_dict(torch.load("base_model/checkpoint_6.pth", map_location=utils.device))
base_model.to(utils.device)

In [None]:
quantized_base_model = utils.quantize_model(base_model, testloader)
quantized_pruned_model = utils.quantize_model(pruned_model, testloader)

In [None]:
utils.benchmark_model(base_model, testloader, 100)
utils.benchmark_model(quantized_base_model, testloader, 100)
utils.benchmark_model(pruned_model, testloader, 100)
utils.benchmark_model(quantized_pruned_model, testloader, 100)

In [None]:
print(utils.test(base_model, testloader))
print(utils.test(quantized_base_model, testloader))
print(utils.test(pruned_model, testloader))
print(utils.test(quantized_pruned_model, testloader))

In [None]:
trace_input = torch.randn(1, 1, 96, 96)
traced_base_model = torch.jit.trace(base_model, trace_input)
trace_pruned_model = torch.jit.trace(pruned_model, trace_input)
traced_quantized_base_model = torch.jit.trace(quantized_base_model, trace_input)
traced_quantized_pruned_model = torch.jit.trace(quantized_pruned_model, trace_input)
torch.jit.save(traced_base_model, os.path.join("base_model", "jit_traced.pth"))
torch.jit.save(trace_pruned_model, os.path.join("pruned_model", "jit_traced.pth"))
torch.jit.save(traced_quantized_base_model, os.path.join("base_model", "jit_traced_quantized.pth"))
torch.jit.save(traced_quantized_pruned_model, os.path.join("pruned_model", "jit_traced_quantized.pth"))
