In [2]:
from torchvision import datasets
from torchvision.transforms import ToTensor

# This will download MNIST into ./data if it's not already there
train = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
test  = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

print(len(train), len(test))

60000 10000


In [6]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split

device = torch.device("cpu")
print("device:", device)

batch_size = 128
val_size = 5000
train_size = len(train) - val_size
generator = torch.Generator().manual_seed(42)
train_subset, val_subset = random_split(train, [train_size, val_size], generator=generator)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)

class MLPClassifier(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dims=(256, 128), num_classes=10):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x)

model = MLPClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print(model)

device: cpu
MLPClassifier(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [7]:
import time

def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

epochs = 5
for epoch in range(1, epochs + 1):
    start_time = time.perf_counter()
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    epoch_time = time.perf_counter() - start_time
    print(
        f"Epoch {epoch}: train loss {train_loss:.4f}, train acc {train_acc:.4f} | "
        f"val loss {val_loss:.4f}, val acc {val_acc:.4f} | time {epoch_time:.2f}s"
    )

Epoch 1: train loss 0.4065, train acc 0.8846 | val loss 0.1872, val acc 0.9438 | time 6.23s
Epoch 2: train loss 0.1615, train acc 0.9513 | val loss 0.1275, val acc 0.9610 | time 6.17s
Epoch 3: train loss 0.1142, train acc 0.9652 | val loss 0.1070, val acc 0.9676 | time 6.16s
Epoch 4: train loss 0.0921, train acc 0.9715 | val loss 0.0916, val acc 0.9724 | time 7.07s
Epoch 5: train loss 0.0743, train acc 0.9773 | val loss 0.0821, val acc 0.9746 | time 7.16s


In [10]:
# Export trained model to ONNX
onnx_path = "mnist_mlp.onnx"
model.eval()
dummy_input = torch.randn(1, 1, 28, 28, device=device)
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=12,
    do_constant_folding=True,
)
print(f"Saved ONNX model to {onnx_path}")

Saved ONNX model to mnist_mlp.onnx


In [11]:
# AIMET ONNX PTQ (calibration-based)
import numpy as np
import onnx
import aimet_onnx
from aimet_onnx import QuantizationSimModel
from aimet_onnx.common.defs import QuantScheme

# Load exported ONNX model
model_onnx = onnx.load(onnx_path)

# Optional: simplify ONNX graph
try:
    import onnxsim
    model_onnx, _ = onnxsim.simplify(model_onnx)
    print("ONNX model simplified")
except Exception as e:
    print("onnxsim not available or simplification failed:", e)

# Create QuantSim model (W8A16)
providers = ["CPUExecutionProvider"]
sim = QuantizationSimModel(
    model_onnx,
    param_type=aimet_onnx.int8,
    activation_type=aimet_onnx.int16,
    quant_scheme=QuantScheme.min_max,
    config_file="default",
    providers=providers,
)

input_name = sim.session.get_inputs()[0].name
NUM_CALIBRATION_SAMPLES = 1024
num_batches = max(1, NUM_CALIBRATION_SAMPLES // batch_size)

def onnx_data_generator(num_batches):
    for i, (data, _) in enumerate(train_loader):
        if i >= num_batches:
            break
        yield {input_name: data.numpy()}

# Compute encodings using calibration data (train subset only)
sim.compute_encodings(onnx_data_generator(num_batches))

def eval_onnx_session(session, loader):
    correct = 0
    total = 0
    for inputs, labels in loader:
        outputs = session.run(None, {input_name: inputs.numpy()})[0]
        preds = outputs.argmax(axis=1)
        correct += (preds == labels.numpy()).sum()
        total += labels.shape[0]
    return correct / total

val_acc_quant = eval_onnx_session(sim.session, val_loader)
print(f"Quantized (W8A16) validation accuracy: {val_acc_quant:.4f}")

# Export quantized model and encodings
export_path = "."
export_prefix = "mnist_mlp_w8a16"
sim.export(export_path, export_prefix, export_model=True)
print(f"Exported quantized model to {export_prefix}.onnx and encodings file")

ModuleNotFoundError: No module named 'aimet_onnx'

In [None]:
# Final test on quantized ONNX model (CPU timing)
import onnxruntime as ort
import time

quant_onnx_path = "mnist_mlp_w8a16.onnx"
sess = ort.InferenceSession(quant_onnx_path, providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name

start_time = time.perf_counter()
test_acc_quant = eval_onnx_session(sess, test_loader)
test_time_quant = time.perf_counter() - start_time
print(f"Quantized ONNX test acc: {test_acc_quant:.4f} | time {test_time_quant:.2f}s")

In [8]:
start_time = time.perf_counter()
test_loss, test_acc = evaluate(model, test_loader, criterion)
test_time = time.perf_counter() - start_time
print(f"Final test: loss {test_loss:.4f}, acc {test_acc:.4f} | time {test_time:.2f}s")

Final test: loss 0.0720, acc 0.9763 | time 0.47s
