# Quantized ONNX Inference (AIMET Encodings)
This notebook loads a quantized ONNX model and its .encodings file, then runs inference on MNIST. If the dataset is missing, it will be downloaded automatically.

In [2]:
import os
import json
import numpy as np
import onnxruntime as ort
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [3]:
# Paths to quantized artifacts
onnx_path = "mnist_mlp_w8a8.onnx"
encodings_path = "mnist_mlp_w8a8.encodings"

if not os.path.exists(onnx_path):
    raise FileNotFoundError(f"Missing ONNX file: {onnx_path}")
if not os.path.exists(encodings_path):
    raise FileNotFoundError(f"Missing encodings file: {encodings_path}")

# Load encodings JSON for verification/logging
with open(encodings_path, "r") as f:
    encodings = json.load(f)

print("Loaded encodings keys:", list(encodings.keys()))

Loaded encodings keys: ['activation_encodings', 'param_encodings', 'quantizer_args', 'version']


In [4]:
# Dataset setup (auto-download if missing)
data_root = "data"
batch_size = 128

test = datasets.MNIST(root=data_root, train=False, download=True, transform=ToTensor())
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)

print("Test samples:", len(test))

Test samples: 10000


In [5]:
# ONNX Runtime session
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name

def eval_onnx_session(session, loader):
    correct = 0
    total = 0
    for inputs, labels in loader:
        outputs = session.run(None, {input_name: inputs.numpy()})[0]
        preds = outputs.argmax(axis=1)
        correct += (preds == labels.numpy()).sum()
        total += labels.shape[0]
    return correct / total

acc = eval_onnx_session(sess, test_loader)
print(f"Quantized ONNX test accuracy: {acc:.4f}")

Quantized ONNX test accuracy: 0.9771
