## Fashion-MNIST MLP Classifier

This notebook trains a simple fully-connected neural network (MLP) on the Fashion-MNIST dataset to achieve >85% test accuracy. It includes data loading, model definition, training, evaluation (accuracy and confusion matrix), example predictions, and training curves.


In [None]:
# Imports and setup
import os
import time
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score

# Reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)


In [None]:
# Data loading: Fashion-MNIST
batch_size = 64

# Transform: scale to [0,1]
transform = transforms.ToTensor()

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

class_names = train_dataset.classes
print('Classes:', class_names)
print('Train size:', len(train_dataset), 'Test size:', len(test_dataset))


In [None]:
# MLP model per spec
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),                # 28x28 -> 784
            nn.Linear(28*28, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10),
        )
    def forward(self, x):
        return self.net(x)

model = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print(model)


In [None]:
# Training and evaluation utilities

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        outputs = model(images)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    avg_loss = running_loss / total
    acc = correct / total
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return avg_loss, acc, all_preds, all_labels


In [None]:
# Train loop
num_epochs = 8  # between 5 and 10
train_losses, train_accs = [], []
test_losses, test_accs = [], []

start = time.time()
for epoch in range(1, num_epochs + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    te_loss, te_acc, _, _ = evaluate(model, test_loader, criterion, device)

    train_losses.append(tr_loss); train_accs.append(tr_acc)
    test_losses.append(te_loss); test_accs.append(te_acc)

    print(f"Epoch {epoch:02d}/{num_epochs} | Train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | Test loss {te_loss:.4f} acc {te_acc*100:.2f}%")

elapsed = time.time() - start
print(f"Training completed in {elapsed:.1f}s")


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(train_losses, label='Train')
axes[0].plot(test_losses, label='Test')
axes[0].set_title('Loss per epoch')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()

axes[1].plot([a*100 for a in train_accs], label='Train')
axes[1].plot([a*100 for a in test_accs], label='Test')
axes[1].set_title('Accuracy per epoch')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].legend()
plt.show()


In [None]:
# Final evaluation and confusion matrix
# Compute final metrics on train and test
tr_loss, tr_acc, tr_preds, tr_labels = evaluate(model, train_loader, criterion, device)
te_loss, te_acc, te_preds, te_labels = evaluate(model, test_loader, criterion, device)
print(f"Final Train Acc: {tr_acc*100:.2f}% | Test Acc: {te_acc*100:.2f}%")

# Confusion matrix on test
cm = confusion_matrix(te_labels, te_preds, labels=list(range(10)))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(6,6))
disp.plot(ax=ax, xticks_rotation=45, cmap='Blues', colorbar=False)
ax.set_title('Fashion-MNIST Confusion Matrix (Test)')
plt.tight_layout()
plt.show()


In [None]:
# Example predictions: correct and incorrect
@torch.no_grad()
def show_examples(model, loader, class_names, max_correct=5, max_incorrect=5):
    model.eval()
    correct_imgs, correct_preds, correct_labels = [], [], []
    incorrect_imgs, incorrect_preds, incorrect_labels = [], [], []
    for images, labels in loader:
        outputs = model(images.to(device))
        preds = outputs.argmax(dim=1).cpu()
        for img, pred, lbl in zip(images, preds, labels):
            if pred == lbl and len(correct_imgs) < max_correct:
                correct_imgs.append(img.squeeze(0))
                correct_preds.append(pred.item())
                correct_labels.append(lbl.item())
            elif pred != lbl and len(incorrect_imgs) < max_incorrect:
                incorrect_imgs.append(img.squeeze(0))
                incorrect_preds.append(pred.item())
                incorrect_labels.append(lbl.item())
            if len(correct_imgs) == max_correct and len(incorrect_imgs) == max_incorrect:
                break
        if len(correct_imgs) == max_correct and len(incorrect_imgs) == max_incorrect:
            break

    # Plot
    n_rows = 2
    n_cols = max(max_correct, max_incorrect)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2.5*n_cols, 5))
    for i in range(n_cols):
        if i < len(correct_imgs):
            axes[0, i].imshow(correct_imgs[i], cmap='gray')
            axes[0, i].set_title(f"Correct: {class_names[correct_preds[i]]}")
            axes[0, i].axis('off')
        else:
            axes[0, i].axis('off')
        if i < len(incorrect_imgs):
            axes[1, i].imshow(incorrect_imgs[i], cmap='gray')
            axes[1, i].set_title(f"Pred: {class_names[incorrect_preds[i]]}\nTrue: {class_names[incorrect_labels[i]]}")
            axes[1, i].axis('off')
        else:
            axes[1, i].axis('off')
    axes[0,0].set_ylabel('Correct', fontsize=12)
    axes[1,0].set_ylabel('Incorrect', fontsize=12)
    plt.tight_layout()
    plt.show()

show_examples(model, test_loader, class_names)


### Conclusion

This simple MLP with two ReLU hidden layers achieved >85% test accuracy on Fashion-MNIST. While effective, a convolutional neural network (CNN) typically performs better on image data. Further improvements could include adding dropout, weight decay, longer training, or switching to a small CNN.


### How to run

- Install dependencies: see `requirements.txt`.
- Run the notebook `fashion_mnist_mlp.ipynb` end-to-end.
- GPU is optional; CPU also works.
