In [1]:
import torch
from torch import nn

from src.utils import load_model
from src.data_loader import get_cifar10_loader
from src.train import train_model
from src.model import ResNet, BasicBlock, resnet110
from src.evaluate import evaluate
from src.utils import count_total_parameters, quantize_model
import torch
from src.utils import measure_inference_time


In [2]:
# Parameters
device = torch.device("cpu")
model_path = "pruned_34-30_resnet110_mps.pth"
#model_path = "resnet110_pretrained.pth"
backend = 'qnnpack'

batch_size = 128
learning_rate = 0.001
num_epochs = 1


In [3]:
model = load_model(model_path, device=device)

# Define optimizer and criterion for training
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
val_loader = get_cifar10_loader('val', batch_size=batch_size)
val_loader_subset = get_cifar10_loader('val', batch_size=batch_size, subset_size=1000)

In [4]:
model.to(device=device)

torch.backends.quantized.engine = backend

model_fp32 = model
model_fp32.eval()

model_fp32.fuse_model()

# Sets the backend for x86
model_fp32.qconfig = torch.quantization.get_default_qconfig(backend)

# Prepares the model for the next step i.e. calibration.
# Inserts observers in the model that will observe the activation tensors during calibration
model_fp32_prepared = torch.quantization.prepare(model_fp32, inplace = False)

evaluate(model_fp32_prepared, val_loader_subset, device)

model_quantized = torch.quantization.convert(model_fp32_prepared)

# model_quantized = model_quantized(model, val_loader_subset, device, backend=backend)

Validation Accuracy: 85.90%, Avg Loss: 0.4480, Time: 3.38s


In [5]:
evaluate(model_quantized, val_loader, device)

Validation Accuracy: 85.61%, Avg Loss: 0.4643, Time: 19.17s


(85.61, 0.4642594927787781)

In [7]:
time_float = measure_inference_time(model, val_loader, device=device)
time_quant = measure_inference_time(model_quantized, val_loader, device=device)

print(f"Average inference time per batch (float model): {time_float:.4f} seconds")
print(f"Average inference time per batch (quantized model): {time_quant:.4f} seconds")


Average inference time per batch (float model): 0.2235 seconds
Average inference time per batch (quantized model): 0.2024 seconds
