## Clean Neural Network

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
    def forward(self, x):
        return self.net(x)

model = MNISTClassifier()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

Epoch 1/5, Loss: 0.3421
Epoch 2/5, Loss: 0.1494
Epoch 3/5, Loss: 0.1004
Epoch 4/5, Loss: 0.0755
Epoch 5/5, Loss: 0.0584


## Model Evaluation

In [2]:
import torch.nn.functional as F

model.eval()
correct = 0
total = 0

samples_shown = 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        probs = F.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, dim=1)

        for i in range(10):
            if samples_shown < 10:
                label = labels[i].item()
                pred = preds[i].item()
                prob = probs[i][pred].item()
                print(f"Sample {samples_shown+1} | True Label: {label} | Predicted: {pred} | Probability: {prob:.4f}")
                samples_shown += 1

        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"\nFinal Test Accuracy: {accuracy * 100:.2f}%")

Sample 1 | True Label: 7 | Predicted: 7 | Probability: 0.9994
Sample 2 | True Label: 2 | Predicted: 2 | Probability: 0.9999
Sample 3 | True Label: 1 | Predicted: 1 | Probability: 0.9871
Sample 4 | True Label: 0 | Predicted: 0 | Probability: 1.0000
Sample 5 | True Label: 4 | Predicted: 4 | Probability: 0.9983
Sample 6 | True Label: 1 | Predicted: 1 | Probability: 0.9983
Sample 7 | True Label: 4 | Predicted: 4 | Probability: 0.9993
Sample 8 | True Label: 9 | Predicted: 9 | Probability: 0.9663
Sample 9 | True Label: 5 | Predicted: 5 | Probability: 0.9568
Sample 10 | True Label: 9 | Predicted: 9 | Probability: 0.9998

Final Test Accuracy: 97.48%
