In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import onnx
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType

torch.manual_seed(42)
np.random.seed(42)


In [2]:
def get_data_loaders(dataset_name, batch_size=128):
    print(f"Loading {dataset_name}...")
    ds = {
        'mnist': datasets.MNIST,
        'cifar': datasets.CIFAR10
    }

    transform = transforms.Compose([
            transforms.ToTensor(),
        ])
    train_dataset = ds[dataset_name]('./data', train=True, download=True, transform=transform)
    test_dataset = ds[dataset_name]('./data', train=False, transform=transform)

    input_shape = (1, 28, 28) if dataset_name == 'mnist' else (3, 32, 32)
    num_classes = 10

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, input_shape, num_classes


In [4]:
class MnistNet(nn.Module):
    def __init__(self):
        super(MnistNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.dropout1(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

class CifarNet(nn.Module):
    def __init__(self):
        super(CifarNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3)
        self.dropout2 = nn.Dropout(0.25)
        
        self.fc1 = nn.Linear(64 * 6 * 6, 512)
        self.dropout3 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)
        x = self.dropout2(x)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout3(x)
        x = self.fc2(x)
        return x



In [5]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ' 
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def build_and_train(dataset_name, train_loader, epochs=5):
    device = torch.device("cpu") # Force CPU for consistency
    
    if dataset_name == 'mnist':
        model = MnistNet().to(device)
    else:
        model = CifarNet().to(device)
        
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model_path = f"model_torch_{dataset_name}.pth"
    if os.path.exists(model_path):
        print(f"Loading existing model: {model_path}")
        model.load_state_dict(torch.load(model_path))
    else:
        print("Starting training...")
        for epoch in range(1, epochs + 1):
            train(model, device, train_loader, optimizer, epoch)
        torch.save(model.state_dict(), model_path)
        print("Model saved.")
        
    return model


In [6]:
def convert_to_onnx(model, dataset_name, input_shape):
    model.eval()
    dummy_input = torch.randn(1, *input_shape)
    output_path = f"model_torch_{dataset_name}.onnx"
    
    print(f"Exporting to ONNX: {output_path}")
    torch.onnx.export(model, 
                      dummy_input, 
                      output_path, 
                      verbose=False,
                      input_names=['input'], 
                      output_names=['logits'], # Renamed to 'logits' to be clear
                      dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}})
    return output_path


In [7]:
def benchmark_models(model, onnx_path, test_loader, input_shape, batch_sizes=[1, 8, 32, 128]):
    print(f"\nBenchmarking PyTorch vs ONNX ({onnx_path})")
    
    # ONNX Session
    sess_options = ort.SessionOptions()
    session = ort.InferenceSession(onnx_path, sess_options, providers=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    
    # Grab chunk of data
    all_data = []
    for data, _ in test_loader:
        all_data.append(data)
        if len(all_data) * test_loader.batch_size > 200:
            break
    x_test = torch.cat(all_data, dim=0) 
    
    results = []
    print(f"{ 'Batch Size':<12} | { 'PyTorch (ms)':<15} | { 'ONNX (ms)':<15}")
    print("-" * 50)
    
    device = torch.device("cpu")
    model.to(device)
    model.eval()
    
    for bs in batch_sizes:
        batch_data = x_test[:bs]
        if batch_data.shape[0] < bs:
            padding = torch.zeros(bs - batch_data.shape[0], *input_shape)
            batch_data = torch.cat([batch_data, padding], dim=0)
            
        # 1. PyTorch Benchmark
        with torch.no_grad():
            # Warmup
            _ = model(batch_data)
            start = time.time()
            for _ in range(10):
                _ = model(batch_data)
            end = time.time()
            torch_time = (end - start) / 10 * 1000
            
        # 2. ONNX Benchmark
        numpy_data = batch_data.numpy()
        # Warmup
        session.run(None, {input_name: numpy_data})
        start = time.time()
        for _ in range(10):
            session.run(None, {input_name: numpy_data})
        end = time.time()
        onnx_time = (end - start) / 10 * 1000
        
        print(f"{bs:<12} | {torch_time:<15.4f} | {onnx_time:<15.4f}")
        results.append((bs, torch_time, onnx_time))
        
    return results


In [8]:
def compare_accuracy_and_soft(model, onnx_path, test_loader, num_samples=5):
    print(f"\nComparing Accuracy and Predictions for {onnx_path}")
    
    session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    
    model.eval()
    correct_torch = 0
    correct_onnx = 0
    total = 0
    
    max_diff_logits = 0.0
    
    with torch.no_grad():
        for data, target in test_loader:
            # PyTorch (Output is Logits)
            logits_torch = model(data)
            _, predicted_torch = torch.max(logits_torch.data, 1)
            correct_torch += (predicted_torch == target).sum().item()
            
            # ONNX (Output is Logits)
            data_np = data.numpy()
            logits_onnx = session.run(None, {input_name: data_np})[0]
            predicted_onnx = np.argmax(logits_onnx, axis=1)
            correct_onnx += (predicted_onnx == target.numpy()).sum().item()
            
            # Diff on logits
            diff = np.abs(logits_torch.numpy() - logits_onnx)
            max_diff_logits = max(max_diff_logits, np.max(diff))
            
            total += target.size(0)
            
            # Compare Soft predictions (Probabilities) for the first batch
            if total == target.size(0): 
                 print(f"Comparing soft predictions (Probabilities) for first {num_samples} samples:")
                 
                 # Apply Softmax to convert Logits -> Probs for display
                 probs_torch = F.softmax(logits_torch, dim=1).numpy()
                 # For numpy/ONNX, we implement softmax manually or use torch's
                 probs_onnx = F.softmax(torch.from_numpy(logits_onnx), dim=1).numpy()
                 
                 for i in range(num_samples):
                     p_t = probs_torch[i]
                     p_o = probs_onnx[i]
                     
                     top3_t = np.argsort(p_t)[-3:][::-1]
                     top3_o = np.argsort(p_o)[-3:][::-1]
                     
                     print(f"Sample {i}:")
                     print(f"  PyTorch Top-3: {top3_t} Probs: {p_t[top3_t]}")
                     print(f"  ONNX    Top-3: {top3_o} Probs: {p_o[top3_o]}")
                     print(f"  Max Diff (Probs): {np.max(np.abs(p_t - p_o)):.8f}")

    acc_torch = 100 * correct_torch / total
    acc_onnx = 100 * correct_onnx / total
    
    print(f"\nPyTorch Accuracy: {acc_torch:.2f}%")
    print(f"ONNX Accuracy:    {acc_onnx:.2f}%")
    print(f"Max Absolute Difference (Logits): {max_diff_logits:.8f}")


In [9]:
def quantize_and_analyze(original_onnx_path, dataset_name, test_loader, input_shape, batch_sizes):
    print(f"\nQuantizing model: {original_onnx_path}")
    quantized_model_path = f"model_torch_{dataset_name}.quant.onnx"
    
    quantize_dynamic(
        model_input=original_onnx_path,
        model_output=quantized_model_path,
        weight_type=QuantType.QUInt8
    )
    
    orig_size = os.path.getsize(original_onnx_path)
    quant_size = os.path.getsize(quantized_model_path)
    print(f"Original Model Size: {orig_size/1024/1024:.2f} MB")
    print(f"Quantized Model Size: {quant_size/1024/1024:.2f} MB")
    print(f"Compression Ratio: {orig_size/quant_size:.2f}x")
    
    print("\nBenchmarking Quantized Model:")
    sess_options = ort.SessionOptions()
    session = ort.InferenceSession(quantized_model_path, sess_options, providers=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    
    all_data = []
    for data, _ in test_loader:
        all_data.append(data)
        if len(all_data) * test_loader.batch_size > 200:
            break
    x_test = torch.cat(all_data, dim=0).numpy()
    
    print(f"{ 'Batch Size':<12} | { 'Quantized (ms)':<15}")
    print("-" * 30)
    
    for bs in batch_sizes:
        batch_data = x_test[:bs]
        if batch_data.shape[0] < bs:
             padding = np.zeros((bs - batch_data.shape[0], *input_shape), dtype=np.float32)
             batch_data = np.concatenate([batch_data, padding], axis=0)
             
        session.run(None, {input_name: batch_data})
        start = time.time()
        for _ in range(10):
            session.run(None, {input_name: batch_data})
        end = time.time()
        quant_time = (end - start) / 10 * 1000
        print(f"{bs:<12} | {quant_time:<15.4f}")
        
    return quantized_model_path


In [10]:
def run_pipeline(dataset_name):
    print(f"\n{'='*20} Running Pipeline for {dataset_name} {'='*20}")
    
    train_loader, test_loader, input_shape, num_classes = get_data_loaders(dataset_name)
    model = build_and_train(dataset_name, train_loader, epochs=5)
    
    onnx_path = convert_to_onnx(model, dataset_name, input_shape)
    
    benchmark_models(model, onnx_path, test_loader, input_shape)
    compare_accuracy_and_soft(model, onnx_path, test_loader)
    
    quant_path = quantize_and_analyze(onnx_path, dataset_name, test_loader, input_shape, [1, 8, 32, 128])
    
    print("\nVerifying Quantized Model Accuracy:")
    session = ort.InferenceSession(quant_path, providers=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    correct = 0
    total = 0
    for data, target in test_loader:
        data_np = data.numpy()
        logits = session.run(None, {input_name: data_np})[0]
        predicted = np.argmax(logits, axis=1)
        correct += (predicted == target.numpy()).sum().item()
        total += target.size(0)
    print(f"Quantized ONNX Accuracy: {100 * correct / total:.2f}%")


In [11]:
run_pipeline('mnist')


Loading mnist...


100.0%
100.0%
100.0%
100.0%


Starting training...


KeyboardInterrupt: 