In [14]:

import torch
import copy
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.ao.quantization.quantize_fx as quantize_fx
from src.vgg import *
from src.util import *
import os

def quantize_model(model):
    quantize_model = copy.deepcopy(model)
    quantize_model.qconfig = torch.quantization.get_default_qconfig('x86')
    quantize_model = torch.quantization.QuantWrapper(quantize_model)
    torch.quantization.prepare(quantize_model, inplace=True)
    torch.quantization.convert(quantize_model, inplace=True)
    return quantize_model

def get_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    return os.path.getsize("temp.p")/1e6
    

model_path = "./models"

vgg11 = vgg('vgg11'); vgg11.load_state_dict(torch.load(f'{model_path}/vgg11.pth')); vgg11.eval()
vgg13 = vgg('vgg13'); vgg13.load_state_dict(torch.load(f'{model_path}/vgg13.pth')); vgg13.eval()
vgg16 = vgg('vgg16'); vgg16.load_state_dict(torch.load(f'{model_path}/vgg16.pth')); vgg16.eval()
vgg19 = vgg('vgg19'); vgg19.load_state_dict(torch.load(f'{model_path}/vgg19.pth')); vgg19.eval()

qat_vgg11 = quantize_model(vgg11); qat_vgg11.load_state_dict(torch.load(f'{model_path}/qat_vgg11.pth'))
qat_vgg13 = quantize_model(vgg13); qat_vgg13.load_state_dict(torch.load(f'{model_path}/qat_vgg13.pth'))
qat_vgg16 = quantize_model(vgg16); qat_vgg16.load_state_dict(torch.load(f'{model_path}/qat_vgg16.pth'))
qat_vgg19 = quantize_model(vgg19); qat_vgg19.load_state_dict(torch.load(f'{model_path}/qat_vgg19.pth'))

ptq_vgg11 = quantize_model(vgg11); qat_vgg11.load_state_dict(torch.load(f'{model_path}/ptq_static_vgg11.pth'))
# x_ptq_vgg11 = quantize_model(vgg11); qat_vgg11.load_state_dict(torch.load(f'{model_path}/x_ptq_static_vgg11.pth'))
# ptq_vgg13 = quantize_model(vgg13); qat_vgg11.load_state_dict(torch.load(f'{model_path}/ptq_static_vgg13.pth'))
# ptq_vgg13 = quantize_model(vgg13); qat_vgg11.load_state_dict(torch.load(f'{model_path}/ptq_static_vgg13.pth'))
# ptq_vgg16 = quantize_model(vgg16); qat_vgg11.load_state_dict(torch.load(f'{model_path}/ptq_static_vgg16.pth'))
# ptq_vgg19 = quantize_model(vgg19); qat_vgg11.load_state_dict(torch.load(f'{model_path}/ptq_static_vgg19.pth'))

models = {
    # 'vgg11': vgg11,
    # 'vgg13': vgg13,
    # 'vgg16': vgg16,
    # 'vgg19': vgg19,
    
    'ptq_vgg11': ptq_vgg11,
    # 'x_ptq_vgg11': x_ptq_vgg11,
    # 'ptq_vgg13': ptq_vgg13,
    # 'ptq_vgg16': ptq_vgg16,
    # 'ptq_vgg19': ptq_vgg19,
    
    # 'qat_vgg11': qat_vgg11,
    # 'qat_vgg13': qat_vgg13,
    # 'qat_vgg16': qat_vgg16,
    # 'qat_vgg19': qat_vgg19,
    
}

In [15]:
# measure model size
scale = 1024**2
for model_name, model in models.items():
    model_size = get_size_of_model(model)
    print(f'{model_name} model size:\t {model_size} MB')

ptq_vgg11 model size:	 9.351926 MB


In [16]:
# measure inference latency
for model_name, model in models.items():
    elapsed_time = measure_inference_latency(model, 'cpu')
    print(f'{model_name} elapsed time:\t {elapsed_time:4f}')

ptq_vgg11 elapsed time:	 0.000932


In [17]:
# measure accuracy

data_path = "/workspace/shared/data"
test_dataset = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

for model_name, model in models.items():
    accuracy = measure_accuracy(model, test_loader, 'cpu')
    print(f'{model_name} accuracy:\t {accuracy:.4f}')

Files already downloaded and verified
ptq_vgg11 accuracy:	 0.1062
