# PyTorch Demo Walkthrough

This notebook demonstrates loading a saved model (or building one), running inference on a few MNIST images, and visualizing predictions.

In [ ]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from models.simple_cnn import SimpleCNN

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
test = datasets.MNIST(root='data', train=False, download=True, transform=transform)

model = SimpleCNN(in_channels=1, num_classes=10)
model.eval()

fig, axes = plt.subplots(1,5, figsize=(12,3))
for i, ax in enumerate(axes):
    img, label = test[i]
    input_tensor = img.unsqueeze(0)
    with torch.no_grad():
        out = model(input_tensor)
        pred = out.argmax(dim=1).item()
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'gt:{label} pred:{pred}')
    ax.axis('off')
plt.show()
