# MNIST

In [None]:
%pip install torch numpy matplotlib tqdm torchvision ipywidgets --upgrade

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms, datasets

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"{device}")

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(28),
    transforms.RandomRotation(15),
    transforms.ToTensor()
])

train_data = datasets.MNIST(root='.', train=True, download=True, transform=transform_train)
test_data = datasets.MNIST(root='.', train=False, download=True, transform=transforms.ToTensor())

train_dataset = torch.utils.data.Subset(train_data, range(700))
validation_dataset = torch.utils.data.Subset(train_data, range(700, 1000))

batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False)

# MLP Design

In [None]:
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.affine1 = nn.Linear(28 * 28, 100)
        self.batchnorm1 = nn.BatchNorm1d(100)
        self.affine2 = nn.Linear(100, 100)
        self.batchnorm2 = nn.BatchNorm1d(100)
        self.affine3 = nn.Linear(100, 100)
        self.batchnorm3 = nn.BatchNorm1d(100)
        self.affine4 = nn.Linear(100, 100)
        self.batchnorm4 = nn.BatchNorm1d(100)
        self.affine5 = nn.Linear(100, 10)
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = x.view(-1, 28 * 28)
        x = self.batchnorm1(self.act(self.affine1(x)))
        x = self.batchnorm2(self.act(self.affine2(x)))
        x = self.batchnorm3(self.act(self.affine3(x)))
        x = self.batchnorm4(self.act(self.affine4(x)))
        x = self.affine5(x)

        return x

model = MLP()
model = model.to(device)

print(model)
print(sum(p.numel() for p in model.parameters()))

lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-6)
criterion = nn.CrossEntropyLoss()

In [None]:
from tqdm.notebook import tqdm

num_epochs = 50
iter = 0
best_validation_acc = 0

for epoch in tqdm(range(num_epochs), desc="Epoch", position=1):
    model.train()
    for image, label in tqdm(train_loader, desc="Batch", position=0, leave=True):
        image, label = image.to(device), label.to(device)  # Send data to the target device

        optimizer.zero_grad()  # Clear previous gradients
        output = model(image)  # Forward pass
        loss = criterion(output, label)  # Calculate the loss
        
        loss.backward()  # Backward pass
        optimizer.step()  # Update model weights
        
        iter += 1  # Increment iteration count
    
    num_correct = 0
    model.eval()
    
    with torch.no_grad():
        for image, label in tqdm(validation_loader, desc="Val", position=0, leave=True):
            image, label = image.to(device), label.to(device)
            output = model(image)
            pred = output.argmax(dim=1)
            num_correct += (pred == label).sum()
    validation_accuracy = num_correct / len(validation_dataset) * 100

    if validation_accuracy >= best_validation_acc:
        best_validation_acc = validation_accuracy
    elif validation_accuracy < best_validation_acc - 10:
        print(f"Early stopped at epoch {epoch + 1}")
        break

print(f"Total iterations: {iter}")

In [None]:
num_correct = 0

with torch.no_grad():
    for image, label in tqdm(test_loader, "Validation"):
        image, label = image.to(device), label.to(device)
        output = model(image)
        pred = output.argmax(dim=1)
        num_correct += (pred == label).sum()

print(f"Accuracy : { num_correct / len(test_data) * 100:.2f}%")