## MNIST CNN (Simple handwritten digit classifier)


In [1]:

import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from tvm import relay
import onnx

class MNISTCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)

        # compute flatten size for fc1
        dummy = torch.zeros(1, 1, 28, 28)
        x = torch.relu(self.conv1(dummy))
        x = torch.relu(self.conv2(x))
        flatten_size = x.numel() // x.shape[0]  # 64*24*24 = 36864

        self.fc1 = torch.nn.Linear(flatten_size, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate and test
mnist_model = MNISTCNN().eval()
x_mnist = torch.randn(1, 1, 28, 28)
y_mnist = mnist_model(x_mnist)
print("MNIST CNN output shape:", y_mnist.shape)

# Trace for TVM
scripted_mnist = torch.jit.trace(mnist_model, x_mnist).eval()

# Convert to TVM Relay
input_name_mnist = "input0"
shape_list_mnist = [(input_name_mnist, x_mnist.shape)]
mod_mnist, params_mnist = relay.frontend.from_pytorch(scripted_mnist, shape_list_mnist)
print("✅ MNIST CNN successfully converted to Relay")


MNIST CNN output shape: torch.Size([1, 10])
✅ MNIST CNN successfully converted to Relay


In [2]:
import tvm 
from tvm.contrib import graph_executor

# Choose your target architecture
# Example: x86 CPU
target_str = "llvm -mtriple=x86_64-linux-gnu -mcpu=haswell"
target = tvm.target.Target(target_str)
dev = tvm.cpu(0)


print(f"Compiling MNIST_CNN...")
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod_mnist, target=target, params=params_mnist)
# Save compiled library
lib_name = f"MNIST_CNN_tvm.so"
lib.export_library(lib_name)
print(f"✅ MNIST_CNN compiled and exported as {lib_name}")


One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


Compiling MNIST_CNN...
✅ MNIST_CNN compiled and exported as MNIST_CNN_tvm.so
