# MNIST Model Testing and Visualization

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

## Model Definition

In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

## Setup and Data Loading

In [None]:
# Set up device
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# Load the model
model = SimpleNN().to(device)
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

# Define the transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the test dataset
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

## Testing and Visualization Functions

In [None]:
def test_and_visualize(num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    correct = 0
    
    for i in range(num_images):
        data, target = next(iter(test_loader))
        data, target = data.to(device), target.to(device)
        
        with torch.no_grad():
            output = model(data)
            predicted = output.argmax(dim=1, keepdim=True)
        
        image = data.cpu().squeeze().numpy()
        axes[i].imshow(image, cmap='gray')
        axes[i].set_title(f'Pred: {predicted.item()}\nTrue: {target.item()}')
        axes[i].axis('off')
        
        correct += (predicted == target).sum().item()
    
    plt.tight_layout()
    plt.show()
    
    print(f"Accuracy on {num_images} images: {100 * correct / num_images:.2f}%")

def test_on_entire_set():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    print(f'Accuracy on the entire test set: {100 * correct / total:.2f}%')

## Run Tests

In [None]:
# Test and visualize 5 random images
test_and_visualize(5)

In [None]:
# Test on the entire set
test_on_entire_set()