# Day 5
## Pytorch
* Build an MLP in Pytorch
* Train MNIST/Fashion-MNIST (CPU)
* Add weight decay, run with SGD vs Adam, add LR scheduler. 

### Check: 
* MNIST: test accuracy > 97% 
* MLP < 10 epochs (Adam helps)

### Interview drill 
Differences between weight decay and L2 in Adam (decoupled vs classical)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import random, os

In [3]:
# TODO: compute mean / std of the dataset and use it for normalization

In [12]:
# ---- Configuration ----
DATASET = datasets.FashionMNIST
BATCH_SIZE = 256
EPOCHS = 10
LR = 2e-3
WD = 1e-4
H = 42
COSINE_LR = True # Use cosine learning rate schedule
SEED = 42
NUM_WORKERS = 2
DEVICE = "mps" if torch.mps.is_available() else "cpu"

print(f"[DEVICE] Using device: {DEVICE}")

# ---- Reproducibility ----
def set_seed(seed=42):
    random.seed(seed); 
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    torch.backends.mps.deterministic = True
    torch.backends.mps.benchmark = False
set_seed(SEED)


[DEVICE] Using device: mps


## Fashion MNIST data preprocessing

**Fashion MNIST** stores pixel values as unsigned 8-bit integers in the range 0...255. Dividing by 255 converts this to floats in $[0,1]$. Necessary before mean/std normalization. 

In [13]:
import pandas as pd

raw_train = datasets.FashionMNIST(root="./data", train=True,  download=True)
data = raw_train.data.float().div_(255)
mean = data.mean().item()
std = data.std().item()

print(f"[Data] Mean: {mean}, std: {std}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,))
])

[Data] Mean: 0.28604060411453247, std: 0.3530242443084717


In [14]:
# --- Reload datasets with transforms (standard) ---
train_ds = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)

print("[Data] Train size:", len(train_ds))
print("[Data] Test size:", len(test_ds))

[Data] Train size: 60000
[Data] Test size: 10000


In [15]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

## MLP Implementation

*view* - reshapes tensors without copying memory when possible. Pass the new shape, -1: infer dimension automatically. 

In [16]:
class MLP(nn.Module):
    def __init__(self,
                 input_dimension=28*28,
                 hidden_dimensions=(256,128),
                 output_dimension=10,
                 droupout_probability=0.1,
                 use_batchnorm=False):
        """Initialize a simple MLP with one hidden layer. Initializes weights and biases for the two layers at random.

        Args:
            input_dimension (int): Dimensionality of input data (D).
            hidden_dimensions (tuple of int): Dimensionality of hidden layers.
            output_dimension (int): Number of classes (C).
            droupout_probability (float): Dropout probability, between 0 and 1.
            use_batchnorm (bool): Whether to use batch normalization after each hidden layer.

        """
        super().__init__()
        layers = []

        last_dimension = input_dimension
        for hidden_dimension in hidden_dimensions:
            layers.append(nn.Linear(last_dimension, hidden_dimension))
            if use_batchnorm:
                layers.append(nn.BatchNorm1d(hidden_dimension))
            layers.append(nn.ReLU())
            if droupout_probability > 0:
                layers.append(nn.Dropout(droupout_probability))
            last_dimension = hidden_dimension

        layers.append(nn.Linear(last_dimension, output_dimension))
        self.net = nn.Sequential(*layers)


    def forward(self, X):
        N = X.shape[0]
        X = X.view(N, -1) # reshape (N, 28, 28) -> (N, 28*28)
        return self.net(X)

model = MLP(droupout_probability=0.1,
            use_batchnorm=False).to(DEVICE)

print(model)

MLP(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [17]:
import torch.optim as optim

# ---- Optimizer/Loss ---- 
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) if COSINE_LR else None

# ---- Train/eval ----
def train_epoch(model:nn.Module,
                loader:DataLoader,
                optimizer:optim.Optimizer,
                loss:nn.Module) -> tuple:
    """Trains one epoch.

    Args:
        model (nn.Module): Model to train.
        loader (DataLoader): DataLoader for training data.
        optimizers (torch.optim.Optimizer): Optimizer to use.
        loss (callable): Loss function.

    Returns:
        tuple: (average loss, accuracy)

    """
    model.train()
    total_loss, total, correct = 0.0, 0, 0
    for (x, y) in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        batch_loss = loss(logits, y)
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item() * y.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total


In [18]:
@torch.no_grad()
def evaluate(model:nn.Module,
             loader:DataLoader,
             loss:nn.Module) -> tuple:
    model.eval()
    total_loss, total, correct = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss_value = loss(logits, y) # Note: loss takes logits/targets, not probabilities
        total_loss += loss_value.item() * y.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)

    return total_loss / total, correct / total

In [19]:
# --- Training loop ---
best_accuracy = 0.0
epochs = 10

for epoch in range(1, epochs+1):
    train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, loss)
    test_loss, test_accuracy = evaluate(model, test_loader, loss)

    if scheduler: scheduler.step()

    print(f"[Training] Epoch {epoch:2d}/{epochs} | "
          f"Train loss: {train_loss:.4f}, accuracy: {train_accuracy:.4f} | "
          f"Test loss: {test_loss:.4f}, accuracy: {test_accuracy:.4f} | "
          f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")

    if test_accuracy > best_accuracy:
        best_accuracy = test_accuracy
        torch.save(model.state_dict(), "fashion_mnist_mlp_best.pth")
        print(f"[Model] New best model saved with accuracy: {best_accuracy:.4f}")

print(f"[Model] Best accuracy: {best_accuracy:.4f}")



[Training] Epoch  1/10 | Train loss: 0.5146, accuracy: 0.8151 | Test loss: 0.4252, accuracy: 0.8441 | Learning rate: 0.001951
[Model] New best model saved with accuracy: 0.8441
[Training] Epoch  2/10 | Train loss: 0.3681, accuracy: 0.8651 | Test loss: 0.3719, accuracy: 0.8652 | Learning rate: 0.001809
[Model] New best model saved with accuracy: 0.8652
[Training] Epoch  3/10 | Train loss: 0.3269, accuracy: 0.8801 | Test loss: 0.3754, accuracy: 0.8661 | Learning rate: 0.001588
[Model] New best model saved with accuracy: 0.8661
[Training] Epoch  4/10 | Train loss: 0.3010, accuracy: 0.8879 | Test loss: 0.3516, accuracy: 0.8726 | Learning rate: 0.001309
[Model] New best model saved with accuracy: 0.8726
[Training] Epoch  5/10 | Train loss: 0.2803, accuracy: 0.8954 | Test loss: 0.3373, accuracy: 0.8773 | Learning rate: 0.001000
[Model] New best model saved with accuracy: 0.8773
[Training] Epoch  6/10 | Train loss: 0.2569, accuracy: 0.9047 | Test loss: 0.3233, accuracy: 0.8841 | Learning rate