In [1]:
import torch
from torch import nn
from src.evaluate import measure_inference_time
from src.utils import load_model, save_quantized_model, load_quantized_model
from src.data_loader import get_cifar10_loader
from src.train import train_model
from src.model import ResNet, BasicBlock, resnet110
from src.evaluate import evaluate, count_total_parameters
from src.utils import quantize_model, save_model
import torch
import torch.quantization


In [2]:
# Parameters
device = torch.device("cpu")
model_path = "models/pruned_45-30_kd_10_resnet110_mps.pth"
backend = 'qnnpack'

batch_size = 128


In [3]:
model = load_model(model_path, device=device)

# Load data
val_loader = get_cifar10_loader('val', batch_size=batch_size)
val_loader_subset = get_cifar10_loader('val', batch_size=batch_size, subset_size=1000)

# Static Quantization

In [4]:
model.to(device=device)

torch.backends.quantized.engine = backend

model_fp32 = model
model_fp32.eval()

model_fp32.fuse_model()

# Uses HistogramObserver
model_fp32.qconfig = torch.quantization.get_default_qconfig(backend)

# Prepares the model for the next step i.e. calibration.
# Inserts observers in the model that will observe the activation tensors during calibration
model_fp32_prepared = torch.quantization.prepare(model_fp32, inplace = False)

evaluate(model_fp32_prepared, val_loader_subset, device)

model_quantized = torch.quantization.convert(model_fp32_prepared, inplace=False)

# model_quantized = quantize_model(model, val_loader_subset, device, backend=backend)

Validation Accuracy: 87.30%, Avg Loss: 0.6305, Time: 3.24s


In [5]:
evaluate(model_quantized, val_loader, device)

Validation Accuracy: 88.38%, Avg Loss: 0.6145, Time: 16.17s


(88.38, 0.6145401290893555, 16.17400074005127)

In [6]:
time_float = measure_inference_time(model, val_loader, device=device)
time_quant = measure_inference_time(model_quantized, val_loader, device=device)

print(f"Average inference time per batch (float model): {time_float:.4f} seconds")
print(f"Average inference time per batch (quantized model): {time_quant:.4f} seconds")


Average inference time per batch (float model): 0.1912 seconds
Average inference time per batch (quantized model): 0.1717 seconds


In [7]:
save_quantized_model(model_quantized, "quantized_pruned_45-30_kd_10_resnet110_cpu.pt")

In [6]:
from src.utils import load_model
test = load_model("resnet110_baseline_120_mps.pth")


ResNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(16, 1

In [12]:
model.to(device=device)

torch.backends.quantized.engine = backend
quantized_model = load_model("resnet110_baseline_120_mps.pth", device=device)
# Save the model (full model)
scripted = torch.jit.script(quantized_model)
scripted.save('model_scripted.pt')
quantized_model = torch.jit.load('model_scripted.pt')

Trace Memory allaction during loading the model with tracemalloc

In [13]:
import tracemalloc
import time

torch.backends.quantized.engine = backend

# Start measuring memory usage before loading
tracemalloc.start()
start_time = time.time()

#quantized_model = load_quantized_model("quantized_resnet110_baseline_120_cpu.pt")
#quantized_model = load_model("resnet110_baseline_120_mps.pth")
quantized_model = torch.jit.load('model_scripted.pt')
load_time = time.time() - start_time
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"Model load time: {load_time:.4f} seconds")
print(f"Model load memory usage (peak): {peak / 1024 / 1024:.2f} MB")


Model load time: 0.1480 seconds
Model load memory usage (peak): 1.37 MB


Trace total memory change when loading the model into memory with psutil (sometimes not correct / fluctuates but gives an estimate, depending on the state of the machine (idle, just started IDE))

In [53]:
import psutil
import os

def get_memory_mb():
    process = psutil.Process(os.getpid())
    mem_bytes = process.memory_info().rss  # in bytes
    return mem_bytes / (1024 * 1024)

# Measure baseline memory
print("Measuring baseline memory...")
time.sleep(1)
baseline_mem = get_memory_mb()
print(f"Baseline: {baseline_mem:.2f} MB")

# Load the model
print("Loading model...")
#quantized_model = load_quantized_model("quantized_resnet110_baseline_120_cpu.pt")
quantized_model = load_model("resnet110_baseline_120_mps.pth", device=device)
#quantized_model = torch.jit.load('model_scripted.pt')
quantized_model.eval()

# Wait a bit to let memory settle
time.sleep(1)
post_load_mem = get_memory_mb()
print(f"After model load: {post_load_mem:.2f} MB")

# Calculate delta
model_static_mem = post_load_mem - baseline_mem
print(f"Static memory used by model (just sitting in RAM): {model_static_mem:.2f} MB")


Measuring baseline memory...
Baseline: 79.22 MB
Loading model...
After model load: 122.72 MB
Static memory used by model (just sitting in RAM): 43.50 MB


CPU and Memory allocation during one forward pass with warm up with torch.profiler

In [55]:
from torch.profiler import profile, ProfilerActivity
import torch

quantized_model.eval()
example_inputs = torch.randn(1, 3, 32, 32)

# Warm-up
for _ in range(5):
    with torch.no_grad():
        _ = quantized_model(example_inputs)

# Profiling
with profile(
    activities=[ProfilerActivity.CPU],
    profile_memory=True,
    record_shapes=True,
    with_stack=False,
    with_flops=True
) as prof:
    with torch.no_grad():
        _ = quantized_model(example_inputs)

print(prof.key_averages().table(sort_by="self_cpu_memory_usage"))

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  Total KFLOPs  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::empty         4.56%     378.177us         4.56%     378.177us       0.426us      32.99 Mb      32.99 Mb           888            --  
                   aten::resize_         1.28%     106.073us         1.28%     106.073us       0.956us       2.72 Mb       2.72 Mb           111            --  
                       aten::add         1.90%     157.294us         1.90%     157.294us       2.913us       1.97 Mb       1.97 Mb            54       516.096  
                aten::empty_like  

# Dynamic Quantization
As dynamic quantization from PyTorch only supports quantizing a few layers, this method is not very effective for ResNet.

In [10]:
model.to(device=device)
model.eval()

torch.backends.quantized.engine = backend

# Apply dynamic quantization (Only supports nn.Linear from ResNet Model)
model_quantized = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},
    dtype=torch.qint8
)
