<a href="https://colab.research.google.com/github/Abhishek315-a/machine-larning-models/blob/main/Image_Classification_with_Quantization_and_ONNX_Optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image Classification with Quantization and ONNX Optimization
# ------------------------------------------------------------
# Requirements:
# pip install torch torchvision onnx onnxruntime numpy pandas matplotlib

In [None]:
!pip install torch torchvision onnx onnxruntime numpy pandas matplotlib

Collecting onnx
  Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.23.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m125.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.23.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m107.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (4

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import onnx
import onnxruntime as ort
import numpy as np
import time
import os
import matplotlib.pyplot as plt

# ------------------------------------------------------------
# 1. Data Preparation
# ------------------------------------------------------------

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

classes = trainset.classes

100%|██████████| 170M/170M [00:13<00:00, 12.5MB/s]


# ------------------------------------------------------------
# 2. Define Model (ResNet18)
# ------------------------------------------------------------

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = torchvision.models.resnet18(pretrained=False, num_classes=10)
model = model.to(device)

Using device: cuda




# ------------------------------------------------------------
# 3. Training
# ------------------------------------------------------------

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(epochs=2):  # keep small epochs for demo
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(trainloader):.4f}")

train_model(epochs=2)  # change to higher epochs for better accuracy

Epoch [1/2], Loss: 1.3656
Epoch [2/2], Loss: 0.9594


# ------------------------------------------------------------
# 4. Evaluate PyTorch Model
# ------------------------------------------------------------

In [None]:
def evaluate_pytorch():
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print(f"PyTorch FP32 Accuracy: {acc:.2f}%")
    return acc

pytorch_acc = evaluate_pytorch()

PyTorch FP32 Accuracy: 64.15%


# ------------------------------------------------------------
# 5. Export Model to ONNX
# ------------------------------------------------------------

In [None]:
onnx_model_path = "resnet18_cifar10.onnx"
dummy_input = torch.randn(1, 3, 32, 32, device=device)
torch.onnx.export(model, dummy_input, onnx_model_path,
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

print("ONNX model exported:", onnx_model_path)

  torch.onnx.export(model, dummy_input, onnx_model_path,


ONNX model exported: resnet18_cifar10.onnx


# ------------------------------------------------------------
# 6. Inference with ONNX Runtime (FP32)
# ------------------------------------------------------------

In [None]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_session = ort.InferenceSession(onnx_model_path)

def evaluate_onnx(session, quantized=False):
    correct, total = 0, 0
    start = time.time()
    for images, labels in testloader:
        ort_inputs = {"input": to_numpy(images)}
        ort_outs = session.run(None, ort_inputs)
        preds = np.argmax(ort_outs[0], axis=1)
        total += labels.size(0)
        correct += (preds == labels.numpy()).sum().item()
    end = time.time()

    acc = 100 * correct / total
    runtime = end - start
    mode = "INT8" if quantized else "FP32"
    print(f"ONNX {mode} Accuracy: {acc:.2f}%, Inference Time: {runtime:.2f}s")
    return acc, runtime

onnx_acc_fp32, onnx_time_fp32 = evaluate_onnx(ort_session)

ONNX FP32 Accuracy: 64.15%, Inference Time: 12.06s


# ------------------------------------------------------------
# 7. Quantization (INT8)
# ------------------------------------------------------------

In [None]:
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
import onnxruntime as ort

# Custom DataReader for calibration
class CIFAR10DataReader(CalibrationDataReader):
    def __init__(self, dataloader, num_batches=10):
        self.enum_data = None
        self.dataloader = dataloader
        self.num_batches = num_batches

    def get_next(self):
        if self.enum_data is None:
            inputs = []
            count = 0
            for images, _ in self.dataloader:
                if count >= self.num_batches:
                    break
                inputs.append({"input": images.numpy()})
                count += 1
            self.enum_data = iter(inputs)
        return next(self.enum_data, None)

# Calibration
calibration_reader = CIFAR10DataReader(testloader, num_batches=10)

quantized_model_path = "resnet18_cifar10_int8.onnx"
quantize_static(
    model_input=onnx_model_path,
    model_output=quantized_model_path,
    calibration_data_reader=calibration_reader,
    weight_type=QuantType.QInt8
)

# Load quantized model
ort_session_int8 = ort.InferenceSession(quantized_model_path)
onnx_acc_int8, onnx_time_int8 = evaluate_onnx(ort_session_int8, quantized=True)




ONNX INT8 Accuracy: 63.91%, Inference Time: 22.03s


# ------------------------------------------------------------
# 8. Model Size Comparison
# ------------------------------------------------------------

In [None]:
fp32_size = os.path.getsize(onnx_model_path) / (1024 * 1024)
int8_size = os.path.getsize(quantized_model_path) / (1024 * 1024)

print(f"Model Size FP32: {fp32_size:.2f} MB")
print(f"Model Size INT8: {int8_size:.2f} MB")
print(f"Size Reduction: {(1 - int8_size/fp32_size)*100:.1f}%")


Model Size FP32: 42.65 MB
Model Size INT8: 10.72 MB
Size Reduction: 74.9%


# ------------------------------------------------------------
# 9. Benchmark Summary
# ------------------------------------------------------------

In [None]:
summary = {
    "FP32 Accuracy": pytorch_acc,
    "ONNX FP32 Accuracy": onnx_acc_fp32,
    "ONNX INT8 Accuracy": onnx_acc_int8,
    "FP32 Inference Time (s)": onnx_time_fp32,
    "INT8 Inference Time (s)": onnx_time_int8,
    "FP32 Model Size (MB)": fp32_size,
    "INT8 Model Size (MB)": int8_size
}

In [None]:
print("\nBenchmark Summary:")
for k, v in summary.items():
    print(f"{k}: {v}")


Benchmark Summary:
FP32 Accuracy: 64.15
ONNX FP32 Accuracy: 64.15
ONNX INT8 Accuracy: 63.91
FP32 Inference Time (s): 12.057161331176758
INT8 Inference Time (s): 22.029340028762817
FP32 Model Size (MB): 42.64553356170654
INT8 Model Size (MB): 10.715697288513184
