In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import time
from PIL import Image

In [2]:
#Device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False)

In [4]:
# CNN model definition
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64*7*7)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

#Instantiate the model
model = Net().to(device)

In [5]:
#Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
# Training/Testing Loop
for epoch in tqdm(range(3), desc="Epochs"):
    # Training loop
    model.train()
    running_loss = 0.0
    correct_train, total_train = 0, 0
    for imgs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1} Training"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    epoch_train_loss = running_loss / total_train
    epoch_train_accuracy = 100 * correct_train / total_train

    # Evaluation loop
    model.eval()
    correct_test, total_test = 0, 0
    running_test_loss = 0.0
    with torch.no_grad():
        for imgs, labels in tqdm(testloader, desc=f"Epoch {epoch+1} Testing"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)

            running_test_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(1)
            correct_test += (preds == labels).sum().item()
            total_test += labels.size(0)

    epoch_test_loss = running_test_loss / total_test
    epoch_test_accuracy = 100 * correct_test / total_test

    print(f"Epoch {epoch+1}:")
    print(f"  Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {epoch_train_accuracy:.2f}%")
    print(f"  Test Loss: {epoch_test_loss:.4f}, Test Accuracy: {epoch_test_accuracy:.2f}%")

Epochs:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1 Training:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 1 Testing:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1:
  Training Loss: 0.2163, Training Accuracy: 93.39%
  Test Loss: 0.0630, Test Accuracy: 98.01%


Epoch 2 Training:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 2 Testing:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 2:
  Training Loss: 0.0531, Training Accuracy: 98.36%
  Test Loss: 0.0489, Test Accuracy: 98.30%


Epoch 3 Training:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 3 Testing:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 3:
  Training Loss: 0.0363, Training Accuracy: 98.85%
  Test Loss: 0.0397, Test Accuracy: 98.64%


In [7]:
# Save
torch.save(model.state_dict(), "mnist_cnn.pth")
print("Saved PyTorch weights → mnist_cnn.pth")

Saved PyTorch weights → mnist_cnn.pth


In [8]:
# Export torchscript
example = torch.randn(1,1,28,28).to(device)
traced = torch.jit.trace(model.cpu(), torch.randn(1,1,28,28))
traced.save("../cpp/mnist_cnn.pt")
print("Saved TorchScript → ../cpp/mnist_cnn.pt")

Saved TorchScript → ../cpp/mnist_cnn.pt


In [9]:
# Export ONNX
dummy = torch.randn(1,1,28,28)
torch.onnx.export(model.cpu(), dummy, "../web/mnist_cnn.onnx",
                  input_names=["input"], output_names=["output"],
                  opset_version=11)
print("Saved ONNX → ../web/mnist_cnn.onnx")

Saved ONNX → ../web/mnist_cnn.onnx


In [10]:
# Inference Calculation on a random tensor
x = torch.randn(1,1,28,28)
#warmup
for _ in range(10): _ = model(x)

iters = 100
t0 = time.perf_counter()
for _ in range(iters):
    _ = model(x)
t1 = time.perf_counter()
print(f"Python inference avg: {(t1-t0)*1000/iters:.2f} ms")


Python inference avg: 0.61 ms


In [11]:
#Extracting a digit for benchmarking
# Load MNIST test set
testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
img, label = testset[0]  # first test digit
img = transforms.ToPILImage()(img)
img.save("digit.png")
print("Saved digit.png (label =", label, ")")

Saved digit.png (label = 7 )
