# MNIST



## Offline phase

Create MNISTNet model

In [5]:
import torch

# Pytorch class that we will use to generate the graphs.
class MNISTNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MNISTNet, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Create a MNISTNet instance.
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = MNISTNet(input_size, hidden_size, output_size).to(device)

Convert the torch model to ONNX and saved it

In [10]:
import torch
import io
import onnx

onnx_file_path = "mnist_net.onnx"

# Generate a random input.
model_inputs = (torch.randn(batch_size, input_size, device=device),)

model_outputs = pt_model(*model_inputs)
if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
    
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}


torch.onnx.export(
    pt_model,
    model_inputs,
    onnx_file_path,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)


Load model and generate artifacts

In [13]:
from onnxruntime.training import artifacts

onnx_file_path = "mnist_net.onnx"
onnx_model = onnx.load(onnx_file_path)
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="data",
    additional_output_names=["output"])

2025-01-24 13:26:15.178361813 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ConstantSharing modified: 0 with status: OK
2025-01-24 13:26:15.178411794 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ShapeInputMerge modified: 0 with status: OK
2025-01-24 13:26:15.178426480 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusion modified: 0 with status: OK
2025-01-24 13:26:15.178459776 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2025-01-24 13:26:15.178472646 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer GeluFusion modified: 0 with status: OK
2025-01-24 13:26:15.178484554 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer SimplifiedLayerNormFusion modified: 0 with status: OK
2025-01-24 13:26:15.178496628 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer Fa

## Online phase

Prepare dataset

In [15]:
from torchvision import datasets, transforms

batch_size = 64
train_kwargs = {'batch_size': batch_size}
test_batch_size = 1000
test_kwargs = {'batch_size': test_batch_size}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:03<00:00, 2.89MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 294kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 2.42MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 111] Connection refused>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.83MB/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



Train the model

In [None]:
from onnxruntime.training.api import CheckpointState, Module, Optimizer
import numpy as np
import evaluate

num_epochs = 5

# Create checkpoint state.
state = CheckpointState.load_checkpoint("data/checkpoint")

# Create module.
model = Module("data/training_model.onnx", state, "data/eval_model.onnx")

# Create optimizer.
optimizer = Optimizer("data/optimizer_model.onnx", model)

# Util function to convert logits to predictions.
def get_pred(logits):
    return np.argmax(logits, axis=1)

# Training Loop :
def train(epoch):
    model.train()
    losses = []
    for _, (data, target) in tqdm(enumerate(train_loader)):
        forward_inputs = [data.reshape(len(data),784).numpy(),target.numpy().astype(np.int64)]
        train_loss, _ = model(*forward_inputs)
        optimizer.step()
        model.lazy_reset_grad()
        losses.append(train_loss)

    print(f'Epoch: {epoch+1},Train Loss: {sum(losses)/len(losses):.4f}')

# Test Loop :
def test(epoch):
    model.eval()
    losses = []
    metric = evaluate.load('accuracy')

    for _, (data, target) in tqdm(enumerate(train_loader)):
        forward_inputs = [data.reshape(len(data),784).numpy(),target.numpy().astype(np.int64)]
        test_loss, logits = model(*forward_inputs)
        metric.add_batch(references=target, predictions=get_pred(logits))
        losses.append(test_loss)

    metrics = metric.compute()
    print(f'Epoch: {epoch+1}, Test Loss: {sum(losses)/len(losses):.4f}, Accuracy : {metrics["accuracy"]:.2f}')

from tqdm import tqdm
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}")
    print("train")
    train(epoch)
    print("test")
    test(epoch)

model.export_model_for_inferencing("data/inference_model.onnx",["output"])

  0%|          | 0/5 [00:12<?, ?it/s]


KeyboardInterrupt: 

Inferencing

In [None]:
import onnxruntime 
import matplotlib.pyplot as plt

session = onnxruntime.InferenceSession('data/inference_model.onnx',providers=['CPUExecutionProvider'])

# getting one example from test list to try inference.
data = next(iter(test_loader))[0][0]

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name 
output = session.run([output_name], {input_name: data.reshape(1,784).numpy()})

# plotting the picture
plt.imshow(data[0], cmap='gray')
plt.show()

print("Predicted Label : ",get_pred(output[0]))