# 3️⃣ Post-Training Quantization (ONNX Runtime)
Quantize the ONNX model to INT8 using ONNX Runtime. Saves as `model_quant.onnx`.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
root = '/content/drive/MyDrive/hardware_aware_optimization'

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128*8*8, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = SimpleCNN().to(device)
model.load_state_dict(torch.load(f"{root}/models/base_model.pt", map_location=device))


<All keys matched successfully>

In [None]:
for module in model.features:
    if isinstance(module, nn.Conv2d):
        prune.ln_structured(module, name="weight", amount=0.2, n=2, dim=0)
for module in model.features:
    if isinstance(module, nn.Conv2d):
        prune.remove(module, 'weight')


In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
trainset = torchvision.datasets.CIFAR10(root=f'{root}/data/cifar10', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=f'{root}/data/cifar10', train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

def train(model, loader, criterion, optimizer):
    model.train()
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

def test(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

for epoch in range(1, 6):
    train(model, trainloader, criterion, optimizer)
    acc = test(model, testloader)
    print(f"Epoch {epoch} | Test Acc: {acc:.2f}%")

torch.save(model.state_dict(), f"{root}/models/model_pruned.pt")
print("Pruned model saved.")

Epoch 1 | Test Acc: 80.60%
Epoch 2 | Test Acc: 80.89%
Epoch 3 | Test Acc: 81.05%
Epoch 4 | Test Acc: 81.37%
Epoch 5 | Test Acc: 81.63%
Pruned model saved.


In [None]:
!pip3 install onnx
import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 32, 32, device=device)
onnx_pruned_path = f"{root}/models/model_pruned.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_pruned_path,
    input_names=['input'],
    output_names=['output'],
    opset_version=13,
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print(f"Exported pruned model to {onnx_pruned_path}")

Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m118.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx
Successfully installed onnx-1.18.0
Exported pruned model to /content/drive/MyDrive/hardware_aware_optimization/models/model_pruned.onnx


In [None]:
size_pruned_onnx = os.path.getsize(onnx_pruned_path) / 1024 / 1024
size_fp32_onnx = os.path.getsize(f"{root}/models/model.onnx") / 1024 / 1024
print(f"ONNX FP32: {size_fp32_onnx:.2f} MB | Pruned ONNX: {size_pruned_onnx:.2f} MB")

ONNX FP32: 8.37 MB | Pruned ONNX: 8.37 MB
