In [32]:
from torchvision import datasets
from torchvision.transforms import ToTensor

# This will download MNIST into ./data if it's not already there
train = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
test  = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

print(len(train), len(test))

60000 10000


In [33]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

batch_size = 128
val_size = 5000
train_size = len(train) - val_size
generator = torch.Generator().manual_seed(42)
train_subset, val_subset = random_split(train, [train_size, val_size], generator=generator)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)

class MLPClassifier(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dims=(256, 128), num_classes=10):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x)

model = MLPClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print(model)

device: cuda
MLPClassifier(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [3]:
import time

def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

epochs = 5
for epoch in range(1, epochs + 1):
    start_time = time.perf_counter()
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    epoch_time = time.perf_counter() - start_time
    print(
        f"Epoch {epoch}: train loss {train_loss:.4f}, train acc {train_acc:.4f} | "
        f"val loss {val_loss:.4f}, val acc {val_acc:.4f} | time {epoch_time:.2f}s"
    )

Epoch 1: train loss 0.4031, train acc 0.8844 | val loss 0.1926, val acc 0.9398 | time 6.48s
Epoch 2: train loss 0.1607, train acc 0.9526 | val loss 0.1251, val acc 0.9614 | time 4.90s
Epoch 3: train loss 0.1110, train acc 0.9655 | val loss 0.0972, val acc 0.9718 | time 4.83s
Epoch 4: train loss 0.0890, train acc 0.9720 | val loss 0.0930, val acc 0.9714 | time 5.06s
Epoch 5: train loss 0.0728, train acc 0.9775 | val loss 0.0865, val acc 0.9732 | time 4.96s


In [5]:
# Export trained model to ONNX
onnx_path = "mnist_mlp.onnx"
model.eval()
dummy_input = torch.randn(1, 1, 28, 28, device=device)
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=12,
    do_constant_folding=True,
)
print(f"Saved ONNX model to {onnx_path}")

verbose: False, log level: Level.ERROR

Saved ONNX model to mnist_mlp.onnx


In [4]:
onnx_path = "mnist_mlp.onnx"

In [35]:
import os
import ctypes
import importlib

lib_path = "/usr/lib/x86_64-linux-gnu"
os.environ["LD_LIBRARY_PATH"] = f"{lib_path}:" + os.environ.get("LD_LIBRARY_PATH", "")

try:
    ctypes.CDLL("libpython3.10.so.1.0")
    print("Loaded libpython3.10.so.1.0")
except OSError as e:
    print("Failed to load libpython3.10.so.1.0:", e)

py_libpymo = importlib.import_module("aimet_onnx.common.py_libpymo")
libpymo = importlib.import_module("aimet_onnx.common.libpymo")
defs = importlib.import_module("aimet_onnx.common.defs")
encoding = importlib.import_module("aimet_onnx._encoding")
qc_quantize_op = importlib.import_module("aimet_onnx.qc_quantize_op")
quantsim = importlib.import_module("aimet_onnx.quantsim")
aimet_onnx = importlib.import_module("aimet_onnx")

importlib.reload(py_libpymo)
importlib.reload(libpymo)
importlib.reload(defs)
importlib.reload(encoding)
importlib.reload(qc_quantize_op)
importlib.reload(quantsim)
importlib.reload(aimet_onnx)

Loaded libpython3.10.so.1.0


<module 'aimet_onnx' from '/home/jovyan/video-proj-storage/mushfiq_files/python3.10/lib/python3.10/site-packages/aimet_onnx/__init__.py'>

In [12]:
import traceback
try:
    from aimet_onnx.common import libpymo
    print("libpymo loaded:", libpymo)
except Exception as e:
    print("libpymo import error:", repr(e))
    traceback.print_exc()

libpymo loaded: <module 'aimet_onnx.common.libpymo' from '/home/jovyan/video-proj-storage/mushfiq_files/python3.10/lib/python3.10/site-packages/aimet_onnx/common/libpymo.py'>


In [13]:
import importlib
import traceback
try:
    _libpymo = importlib.import_module("aimet_onnx.common._libpymo")
    print("_libpymo loaded:", _libpymo)
except Exception as e:
    print("_libpymo import error:", repr(e))
    traceback.print_exc()

_libpymo loaded: <module 'aimet_onnx.common._libpymo' from '/home/jovyan/video-proj-storage/mushfiq_files/python3.10/lib/python3.10/site-packages/aimet_onnx/common/_libpymo.cpython-310-x86_64-linux-gnu.so'>


In [16]:
from aimet_onnx.common import _libpymo
try:
    _ = _libpymo.BlockTensorQuantizer([], 8, _libpymo.QuantizationMode.QUANTIZATION_TF)
    print("BlockTensorQuantizer OK with empty list")
except Exception as e:
    print("BlockTensorQuantizer error with empty list:", repr(e))

try:
    _ = _libpymo.BlockTensorQuantizer([1], 8, _libpymo.QuantizationMode.QUANTIZATION_TF)
    print("BlockTensorQuantizer OK with [1]")
except Exception as e:
    print("BlockTensorQuantizer error with [1]:", repr(e))

BlockTensorQuantizer OK with empty list
BlockTensorQuantizer OK with [1]


In [17]:
from aimet_onnx.common import libpymo as libpymo_py
from aimet_onnx.common import _libpymo as libpymo_bin
print("Same BlockTensorQuantizer:", libpymo_py.BlockTensorQuantizer is libpymo_bin.BlockTensorQuantizer)
print("libpymo_py:", libpymo_py)
print("libpymo_bin:", libpymo_bin)

Same BlockTensorQuantizer: True
libpymo_py: <module 'aimet_onnx.common.libpymo' from '/home/jovyan/video-proj-storage/mushfiq_files/python3.10/lib/python3.10/site-packages/aimet_onnx/common/libpymo.py'>
libpymo_bin: <module 'aimet_onnx.common._libpymo' from '/home/jovyan/video-proj-storage/mushfiq_files/python3.10/lib/python3.10/site-packages/aimet_onnx/common/_libpymo.cpython-310-x86_64-linux-gnu.so'>


In [28]:
from aimet_onnx.common import defs as defs_mod
from aimet_onnx.common import libquant_info
from aimet_onnx.qc_quantize_op import QcQuantizeOp

QuantScheme = defs_mod.QuantScheme
try:
    _qi = libquant_info.QcQuantizeInfo()
    _op = QcQuantizeOp(_qi, QuantScheme.post_training_tf)
    print("QcQuantizeOp init OK")
except Exception as e:
    print("QcQuantizeOp init error:", repr(e))

QcQuantizeOp init OK


In [26]:
from aimet_onnx.common import libquant_info
print("TensorQuantizerOpMode in libquant_info:", hasattr(libquant_info, "TensorQuantizerOpMode"))
if hasattr(libquant_info, "TensorQuantizerOpMode"):
    print("libquant_info.TensorQuantizerOpMode:", libquant_info.TensorQuantizerOpMode)

TensorQuantizerOpMode in libquant_info: False


In [8]:
import aimet_onnx

In [None]:
# AIMET ONNX PTQ (calibration-based)
import numpy as np
import onnx
import aimet_onnx
from aimet_onnx import QuantizationSimModel
from aimet_onnx.common.defs import QuantScheme
from enum import Enum

# Load exported ONNX model
model_onnx = onnx.load(onnx_path)

# Optional: simplify ONNX graph
try:
    import onnxsim
    model_onnx, _ = onnxsim.simplify(model_onnx)
    print("ONNX model simplified")
except Exception as e:
    print("onnxsim not available or simplification failed:", e)

# Create QuantSim model (W8A8)
providers = ["CPUExecutionProvider"]
sim = QuantizationSimModel(
    model_onnx,
    param_type=aimet_onnx.int8,
    activation_type=aimet_onnx.int8,
    quant_scheme=QuantScheme.post_training_tf,
    config_file="default",
    providers=providers,
)

input_name = sim.session.get_inputs()[0].name
NUM_CALIBRATION_SAMPLES = 1024
num_batches = max(1, NUM_CALIBRATION_SAMPLES // batch_size)

def onnx_data_generator(num_batches):
    for i, (data, _) in enumerate(train_loader):
        if i >= num_batches:
            break
        yield {input_name: data.numpy()}

# Compute encodings using calibration data (train subset only)
sim.compute_encodings(onnx_data_generator(num_batches))

def eval_onnx_session(session, loader):
    correct = 0
    total = 0
    for inputs, labels in loader:
        outputs = session.run(None, {input_name: inputs.numpy()})[0]
        preds = outputs.argmax(axis=1)
        correct += (preds == labels.numpy()).sum()
        total += labels.shape[0]
    return correct / total

val_acc_quant = eval_onnx_session(sim.session, val_loader)
print(f"Quantized (W8A8) validation accuracy: {val_acc_quant:.4f}")

def _jsonify(obj):
    if isinstance(obj, Enum):
        return obj.name
    if isinstance(obj, dict):
        return {k: _jsonify(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_jsonify(v) for v in obj]
    return obj

sim.quant_args = _jsonify(sim.quant_args)

# Export quantized model and encodings
export_path = "."
export_prefix = "mnist_mlp_w8a8"
sim.export(export_path, export_prefix, export_model=True)
print(f"Exported quantized model to {export_prefix}.onnx and encodings file")

ONNX model simplified
2026-02-09 04:18:40,150 - Quant - INFO - Selecting DefaultOpInstanceConfigGenerator to compute the specialized config. hw_version:None


[1;31m2026-02-09 04:18:40.155690261 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 14041, index: 0, mask: {64, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2026-02-09 04:18:40.157894411 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 14048, index: 7, mask: {4, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2026-02-09 04:18:40.157997601 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 14049, index: 8, mask: {68, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2026-02-09 04:18:40.158146340 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 14050, index: 9, mask: {5, }, error code: 22 error msg: 

In [None]:
# Final test on quantized ONNX model (CPU timing)
import onnxruntime as ort
import time

quant_onnx_path = "mnist_mlp_w8a16.onnx"
sess = ort.InferenceSession(quant_onnx_path, providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name

start_time = time.perf_counter()
test_acc_quant = eval_onnx_session(sess, test_loader)
test_time_quant = time.perf_counter() - start_time
print(f"Quantized ONNX test acc: {test_acc_quant:.4f} | time {test_time_quant:.2f}s")

In [None]:
start_time = time.perf_counter()
test_loss, test_acc = evaluate(model, test_loader, criterion)
test_time = time.perf_counter() - start_time
print(f"Final test: loss {test_loss:.4f}, acc {test_acc:.4f} | time {test_time:.2f}s")

Final test: loss 0.0720, acc 0.9763 | time 0.47s
