In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# ---- Dataset Setup ----

def get_mnist(n_per_class=5000, n_test_per_class=200, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)

    transform = transforms.ToTensor()
    train_dataset_full = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset_full = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    def subsample(dataset, n):
        labels = np.array(dataset.targets)
        indices = []
        for digit in range(10):
            digit_indices = np.where(labels == digit)[0]
            chosen = np.random.choice(digit_indices, n, replace=False)
            indices.extend(chosen)
        return Subset(dataset, indices)

    train_subset = subsample(train_dataset_full, n_per_class)
    test_subset = subsample(test_dataset_full, n_test_per_class)

    train_loader = DataLoader(train_subset, batch_size=100, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=100, shuffle=False)

    return train_loader, test_loader

# ---- Model ----

class BaselineNN(nn.Module):
    def __init__(self, input_size=784, hidden_size=4, output_size=10):
        super(BaselineNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        return self.fc_out(x)

# ---- Training ----

def train_baseline(model, train_loader, epochs=10, lr=1e-3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for xb, yb in train_loader:
            logits = model(xb)
            loss = loss_fn(logits, yb)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

# ---- Evaluation ----

def evaluate_baseline(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in test_loader:
            preds = torch.argmax(model(xb), dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    acc = correct / total
    print(f"Test Accuracy: {acc * 100:.2f}%")

# ---- Run All ----

train_loader, test_loader = get_mnist()
model = BaselineNN()
train_baseline(model, train_loader)
evaluate_baseline(model, test_loader)


Epoch 1, Loss: 1.4329
Epoch 2, Loss: 0.8274
Epoch 3, Loss: 0.6697
Epoch 4, Loss: 0.5924
Epoch 5, Loss: 0.5481
Epoch 6, Loss: 0.5200
Epoch 7, Loss: 0.5016
Epoch 8, Loss: 0.4888
Epoch 9, Loss: 0.4783
Epoch 10, Loss: 0.4713
Test Accuracy: 86.95%
