# Pipeline for Multiple Model Compression Techniques
This pipeline will demonstrate pruning, quantization, knowledge distillation and their combination. At the end there will be an evaluation of different models.

## Imports

In [1]:
import torch
from torch import nn
import torch_pruning as tp
import pandas as pd
from src.model import ResNet, BasicBlock, resnet110
from src.utils import load_model, save_model, iterative_pruner, load_quantized_model
from src.data_loader import get_cifar10_loader
from src.train import train_model, train_model_kd, loss_fn_kd, KDParams
from src.evaluate import evaluate, measure_inference_time, count_total_parameters, evaluate_model_all_metrics, estimate_model_memory_footprint_from_bits


## Base Parameters

In [2]:
device = torch.device("mps")

model_path = "models/resnet110_baseline_30_mps.pth"

batch_size = 128


## Pruning

In [3]:
# load base model
model = load_model(model_path, device)
pruned_model = model

count_total_parameters(model)

# Channel sparsity / Pruning ratio
# ch_sparsity = 0.15 -> 1228878 ca. 30%
# ch_sparsity = 0.29 -> 848388  ca. 50%
# ch_sparsity = 0.45 -> 509972  ca. 70%
# ch_sparsity = 0.95 -> 6765    ca. 0.39%

ch_sparsity = 0.45
iterative_pruning_steps = 5

Total number of parameters in the model: 1730714


In [4]:
# For pruning the model has to be on the cpu
pruned_model.to("cpu")
example_inputs = torch.randn(1, 3, 32, 32)

# Importance Criterion
imp = tp.importance.TaylorImportance()

# Initialization of pruner
pruner = tp.pruner.MagnitudePruner(
    pruned_model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_pruning_steps,
    ch_sparsity=ch_sparsity,
)

# Actual pruning
iterative_pruner(pruner, iterative_pruning_steps)




In [5]:
count_total_parameters(pruned_model)

Total number of parameters in the model: 509972


509972

## Quantization

In [6]:
# Quantization requires to be made on the cpu
device = torch.device("cpu")

# Set the quantization backend to 'qnnpack', which is optimized for ARM CPUs (e.g., mobile devices)
# This enables efficient int8 operations during inference using PyTorch Mobile
backend = 'qnnpack'


In [7]:
# Load base model again
model = load_model(model_path, device=device)

# Load data
val_loader = get_cifar10_loader('val', batch_size=batch_size)
val_loader_subset = get_cifar10_loader('val', batch_size=batch_size, subset_size=1000)

In [8]:
model.to(device=device)

# Set the backend
torch.backends.quantized.engine = backend

model_fp32 = model
model_fp32.eval()

# Fuse modules (e.g., Conv+BN+ReLU) for better quantization accuracy
model_fp32.fuse_model()

# Set the quantization config for the model
model_fp32.qconfig = torch.quantization.get_default_qconfig(backend)

# Insert observers for calibration
model_fp32_prepared = torch.quantization.prepare(model_fp32, inplace = False)

# Run the model to collect activation stats for quantization
evaluate(model_fp32_prepared, val_loader_subset, device)

# Convert the calibrated model to quantized version
model_quantized = torch.quantization.convert(model_fp32_prepared)

# Custom function to quantize model
# model_quantized = quantize_model(model, val_loader_subset, device, backend=backend)

Validation Accuracy: 87.00%, Avg Loss: 0.4502, Time: 5.57s


In [9]:
evaluate(model_quantized, val_loader, device)

Validation Accuracy: 86.08%, Avg Loss: 0.4575, Time: 27.88s


(86.08, 0.45750905919075013, 27.88299798965454)

In [10]:
# Measure and compare inference time for float and quantized models
time_float = measure_inference_time(model, val_loader, device=device)
time_quant = measure_inference_time(model_quantized, val_loader, device=device)

print(f"Average inference time per batch (float model): {time_float:.4f} seconds")
print(f"Average inference time per batch (quantized model): {time_quant:.4f} seconds")


Average inference time per batch (float model): 0.4558 seconds
Average inference time per batch (quantized model): 0.3012 seconds


## Knowledge Distillation

In [21]:
# Set device to GPU for faster training
device = torch.device("mps")
teacher_model_path = "models/resnet110_baseline_30_mps.pth"
student_model_path = "models/pruned_95-30_resnet110_mps.pth"

learning_rate = 0.001
num_epochs = 1
kd_alpha = 0.7
kd_temperature = 4.0


In [22]:
# Load pretrained models
teacher_model = load_model(teacher_model_path, device=device)
student_model = load_model(student_model_path, device=device)

# Define optimizer and criterion for training
optimizer = torch.optim.Adam(student_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)

# Set alpha and temperature parameters for Knowledge Distillation
kd_params = KDParams(alpha=kd_alpha, temperature=kd_temperature)


In [23]:
# Student accuracy before
evaluate(student_model, val_loader, device)

Validation Accuracy: 42.63%, Avg Loss: 1.5419, Time: 1.79s


(42.63, 1.5419314538955688, 1.7944419384002686)

In [19]:
# Actual Knowledge Distillation
train_model_kd(
    student_model=student_model,
    teacher_model=teacher_model,
    train_loader=train_loader,
    optimizer=optimizer,
    device=device,
    kd_params=kd_params,
    num_epochs=num_epochs,
    loss_fn_kd=loss_fn_kd
)


                                                                                     

In [20]:
# Student accuracy after
evaluate(student_model, val_loader, device)

Validation Accuracy: 40.70%, Avg Loss: 1.6044, Time: 2.64s


(40.7, 1.604372562789917, 2.637606143951416)

# Comparison of Model Compression Techniques


In [8]:
# Use GPU for unquantized models
device = torch.device("mps")

base_model_path = "models/resnet110_baseline_30_mps.pth"
pruned_model_50_path = "models/pruned_29-30_resnet110_mps.pth"
pruned_model_70_path = "models/pruned_45-30_resnet110_mps.pth"
base_quantized_model_path = "models/quantized_resnet110_baseline_30_cpu.pt"
pruned_kd_model_70_path = "models/pruned_45-30_kd_10_resnet110_mps.pth"
pruned_kd_quantized_model_70_path = "models/quantized_pruned_45-30_kd_10_resnet110_cpu.pt"

base_model = load_model(teacher_model_path, device=device)
pruned_model_50 = load_model(pruned_model_50_path, device=device)
pruned_model_70 = load_model(pruned_model_70_path, device=device)
pruned_kd_model_30 = load_model(pruned_kd_model_70_path, device=device)

results = {}

results["Base Model"] = evaluate_model_all_metrics(base_model, val_loader, device,base_model_path)
results["Pruned 50%"] = evaluate_model_all_metrics(pruned_model_50, val_loader, device, pruned_model_50_path)
results["Pruned 70%"] = evaluate_model_all_metrics(pruned_model_70, val_loader, device, pruned_model_70_path)
results["Pruned + KD 70%"] = evaluate_model_all_metrics(pruned_kd_model_30, val_loader, device, pruned_kd_model_70_path)

# Quantized model only works on cpu
device = torch.device("cpu")
base_quantized_model_30 = load_quantized_model(base_quantized_model_path)
pruned_kd_quantized_model_30 = load_quantized_model(pruned_kd_quantized_model_70_path)

results["Base Quantized"] = evaluate_model_all_metrics(base_quantized_model_30, val_loader, device, base_quantized_model_path)
# As the quantized model is loaded as "scripted" (no PyTorch nn.Module),
# the model parameters cannot be correctly retrieved. For this purpose, we set it manually
results["Base Quantized"]["parameters"] = 1730714
results["Base Quantized"]["memory_footprint"] = estimate_model_memory_footprint_from_bits(1730714, bits=8)

results["Pruned + KD + Quantized 70%"] = evaluate_model_all_metrics(pruned_kd_quantized_model_30, val_loader, device, pruned_kd_quantized_model_70_path)
# As the quantized model is loaded as "scripted" (no PyTorch nn.Module),
# the model parameters cannot be correctly retrieved. For this purpose, we set the following manually
results["Pruned + KD + Quantized 70%"]["parameters"] = 509972
results["Pruned + KD + Quantized 70%"]["memory_footprint"] = estimate_model_memory_footprint_from_bits(509972, bits=8)

Validation Accuracy: 86.57%, Avg Loss: 0.4323, Time: 5.15s
Total number of parameters in the model: 1730714
Validation Accuracy: 87.88%, Avg Loss: 0.4254, Time: 4.34s
Total number of parameters in the model: 848388
Validation Accuracy: 84.10%, Avg Loss: 0.5485, Time: 3.08s
Total number of parameters in the model: 509972
Validation Accuracy: 86.62%, Avg Loss: 0.4247, Time: 2.88s
Total number of parameters in the model: 509972
Validation Accuracy: 86.08%, Avg Loss: 0.4575, Time: 35.52s
Total number of parameters in the model: 192
Validation Accuracy: 86.65%, Avg Loss: 0.4277, Time: 20.25s
Total number of parameters in the model: 104


In [17]:
# Aggregate results in Data Frame
results_table = pd.DataFrame([
    {
        "Model": model_name,
        "Accuracy (%)": f"{metrics['accuracy']:.2f}",
        "Parameters": f"{metrics['parameters']:,.0f}",
        "Inference Time (s)": f"{metrics['inference_time']:.4f}",
        "Memory Footprint (MB)": f"{metrics['memory_footprint']:.5f}",
        "File Size (MB)": f"{metrics['file_size']:.2f}"
    }
    for model_name, metrics in results.items()
])

# Display the table
results_table = results_table.set_index("Model")
display(results_table)


Unnamed: 0_level_0,Accuracy (%),Parameters,Inference Time (s),Memory Footprint (MB),File Size (MB)
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Base Model,86.57,1730714,0.5238,6.60215,6.92
Pruned 50%,87.88,848388,0.3654,3.23634,3.54
Pruned 70%,84.1,509972,0.2565,1.94539,2.24
Pruned + KD 70%,86.62,509972,0.2523,1.94539,2.25
Base Quantized,86.08,1730714,0.3384,1.65054,1.77
Pruned + KD + Quantized 70%,86.65,509972,0.1889,0.48635,0.61


In [25]:
results_table.to_csv("../notebooks/data/model_metrics.csv")


In [26]:
results_table = pd.read_csv("../notebooks/data/model_metrics.csv", index_col="Model")
display(results_table)

Unnamed: 0_level_0,Accuracy (%),Parameters,Inference Time (s),Memory Footprint (MB),File Size (MB)
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Base Model,86.57,1730714,0.5238,6.60215,6.92
Pruned 50%,87.88,848388,0.3654,3.23634,3.54
Pruned 70%,84.1,509972,0.2565,1.94539,2.24
Pruned + KD 70%,86.62,509972,0.2523,1.94539,2.25
Base Quantized,86.08,1730714,0.3384,1.65054,1.77
Pruned + KD + Quantized 70%,86.65,509972,0.1889,0.48635,0.61
