In [1]:
import torch
from torch import nn
from src.evaluate import measure_inference_time
from src.utils import load_model, save_quantized_model, load_quantized_model
from src.data_loader import get_cifar10_loader
from src.train import train_model
from src.model import ResNet, BasicBlock, resnet110
from src.evaluate import evaluate, count_total_parameters
from src.utils import quantize_model, save_model
import torch
import torch.quantization


In [2]:
# Parameters
device = torch.device("cpu")
model_path = "models/resnet110_baseline_30_mps.pth"
backend = 'qnnpack'

batch_size = 128


In [3]:
model = load_model(model_path, device=device)

# Load data
val_loader = get_cifar10_loader('val', batch_size=batch_size)
val_loader_subset = get_cifar10_loader('val', batch_size=batch_size, subset_size=1000)

In [4]:
model.to(device=device)

torch.backends.quantized.engine = backend

model_fp32 = model
model_fp32.eval()

model_fp32.fuse_model()

model_fp32.qconfig = torch.quantization.get_default_qconfig(backend)

# Prepares the model for the next step i.e. calibration.
# Inserts observers in the model that will observe the activation tensors during calibration
model_fp32_prepared = torch.quantization.prepare(model_fp32, inplace = False)

evaluate(model_fp32_prepared, val_loader_subset, device)

model_quantized = torch.quantization.convert(model_fp32_prepared, inplace=False)

# model_quantized = quantize_model(model, val_loader_subset, device, backend=backend)

Validation Accuracy: 87.00%, Avg Loss: 0.4502, Time: 6.55s


In [5]:
evaluate(model_quantized, val_loader, device)

Validation Accuracy: 86.65%, Avg Loss: 0.4277, Time: 16.25s


(86.65, 0.42767993969917295, 16.24708914756775)

In [6]:
time_float = measure_inference_time(model, val_loader, device=device)
time_quant = measure_inference_time(model_quantized, val_loader, device=device)

print(f"Average inference time per batch (float model): {time_float:.4f} seconds")
print(f"Average inference time per batch (quantized model): {time_quant:.4f} seconds")


Average inference time per batch (float model): 0.1903 seconds
Average inference time per batch (quantized model): 0.1688 seconds


In [5]:
save_quantized_model(model_quantized, "models/quantized_resnet110_baseline_30_cpu.pt")

In [10]:
model.to(device=device)
model.eval()

torch.backends.quantized.engine = backend

# Apply dynamic quantization (Only supports nn.Linear from ResNet Model)
model_quantized = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},
    dtype=torch.qint8
)


In [9]:
test23412= load_quantized_model("notebooks/qat_resnet110_baseline_30_cpu.pt")

In [10]:
print(test23412)

RecursiveScriptModule(
  original_name=ResNet
  (conv1): RecursiveScriptModule(original_name=Conv2d)
  (bn1): RecursiveScriptModule(original_name=Identity)
  (layer1): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=BasicBlock
      (ff): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
      (conv1): RecursiveScriptModule(original_name=Conv2d)
      (bn1): RecursiveScriptModule(original_name=Identity)
      (conv2): RecursiveScriptModule(original_name=Conv2d)
      (bn2): RecursiveScriptModule(original_name=Identity)
      (shortcut): RecursiveScriptModule(original_name=Sequential)
    )
    (1): RecursiveScriptModule(
      original_name=BasicBlock
      (ff): RecursiveScriptModule(
        original_name=QFunctional
        (activation_post_process): RecursiveScriptModule(original_name=Identity)
      )
      (conv1): RecursiveScript