# POC QAT for ResNet110
This is just proof of concept implementation of Quantization Aware Training (QAT). As it trained really slow (25 min. per Epoch), I did not further investigate it.

## Imports

In [None]:
from src.evaluate import evaluate, count_total_parameters
from src.data_loader import get_cifar10_loader
from src.utils import load_model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torchvision.models.resnet import resnet18  # Replace with your custom ResNet-110
import copy
from src.evaluate import measure_inference_time

## Define Parameters

In [None]:
num_epochs = 1
learning_rate = 1e-4
backend = 'qnnpack'
device = 'cpu'
batch_size = 128

model_path = "models/resnet110_baseline_30_mps.pth"

## Load model, data and set optimizer and criterion

In [None]:
model = load_model(model_path, device=device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)
val_loader_subset = get_cifar10_loader('val', batch_size=batch_size, subset_size=1000)

## Prepare QAT

In [None]:
model.eval()
model.fuse_model()

torch.backends.quantized.engine = backend

model.to(device)
model.train()

model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)

## Training

In [None]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    running_loss = 0.0

    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Training loss: {running_loss:.4f}")

    model.eval()
    evaluate(model, val_loader_subset, device=device)

model.eval()
model_quantized = torch.quantization.convert(model, inplace=False)

## Evaluation

In [2]:
print("Evaluating quantized model...")
evaluate(model_quantized, val_loader_subset, device=device)

Epoch 1/1
Training loss: 83.9658
Validation Accuracy: 87.40%, Avg Loss: 0.4148, Time: 9.46s
Evaluating quantized model...
Validation Accuracy: 87.50%, Avg Loss: 0.4372, Time: 3.71s


(87.5, 0.437206241607666)

In [13]:
time_float = measure_inference_time(model, val_loader, device=device)
time_quant = measure_inference_time(model_quantized, val_loader, device=device)

print(f"Average inference time per batch (float model): {time_float:.4f} seconds")
print(f"Average inference time per batch (quantized model): {time_quant:.4f} seconds")


Average inference time per batch (float model): 0.6658 seconds
Average inference time per batch (quantized model): 0.3194 seconds
