In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from Utils import SimpleCNN
from Utils import get_device, train, test, get_predictions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define transforms with normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST standard normalization values
])

# Download MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)


In [3]:

# Set up data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Initialize the model, loss function, and optimizer
device = get_device()
print(f"Using device: {device}")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

Using device: mps


In [7]:

# Train the model
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch, criterion)
    test(model, device, test_loader, criterion)

print("Training complete.")

Train Epoch: 1 	Loss: 0.003516

Test set: Average loss: 0.0001, Accuracy: 9793/10000 (97.93%)

Train Epoch: 2 	Loss: 0.001264

Test set: Average loss: 0.0001, Accuracy: 9839/10000 (98.39%)

Train Epoch: 3 	Loss: 0.013092

Test set: Average loss: 0.0000, Accuracy: 9870/10000 (98.70%)

Train Epoch: 4 	Loss: 0.002386

Test set: Average loss: 0.0000, Accuracy: 9876/10000 (98.76%)

Train Epoch: 5 	Loss: 0.000574

Test set: Average loss: 0.0000, Accuracy: 9875/10000 (98.75%)

Train Epoch: 6 	Loss: 0.014095

Test set: Average loss: 0.0000, Accuracy: 9877/10000 (98.77%)

Train Epoch: 7 	Loss: 0.009588

Test set: Average loss: 0.0000, Accuracy: 9881/10000 (98.81%)

Train Epoch: 8 	Loss: 0.070889

Test set: Average loss: 0.0000, Accuracy: 9883/10000 (98.83%)

Train Epoch: 9 	Loss: 0.034634

Test set: Average loss: 0.0000, Accuracy: 9902/10000 (99.02%)

Train Epoch: 10 	Loss: 0.000974

Test set: Average loss: 0.0000, Accuracy: 9881/10000 (98.81%)

Training complete.


In [8]:
# Save the model
torch.save(model.state_dict(), 'mnist_cnn.pth')
print("Model saved as 'mnist_cnn.pth'")

Model saved as 'mnist_cnn.pth'


In [9]:
model.load_state_dict(torch.load('mnist_cnn.pth'))
model.eval()
predictions, true_labels = get_predictions(model, test_loader, device)


In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(true_labels, predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title("Confusion Matrix")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()