# Week 13 Live Demo — Mini‑MNIST with a Small MLP (scikit‑learn)
**Goal:** train a small neural network (multi-layer perception, or MLP) to classify handwritten digits using scikit‑learn, with clear explanations and visuals.

**Dataset:** scikit‑learn's built‑in 8×8 **digits** dataset — a reduced MNIST that loads fast and trains quickly.


## 1) Load and inspect the dataset

In [None]:
# We use scikit-learn's built-in 'digits' dataset.
# - images: 8x8 grayscale pictures of digits 0..9
# - data (X): flattened 64-length vectors, one per image
# - target (y): the ground-truth digit label 0..9

from sklearn.datasets import load_digits
digits = load_digits()

X_full = digits.data            # shape = (n_samples, 64)
images_full = digits.images     # shape = (n_samples, 8, 8)
y_full = digits.target          # shape = (n_samples,)

print('Total samples:', X_full.shape[0])
print('Input dimension per sample:', X_full.shape[1])
print('Image shape:', images_full.shape[1:], '| Unique labels:', sorted(set(y_full)))


## 2) Visualize the dataset (quick overview)

In [None]:
# Seeing raw data helps build intuition.
# We draw a 10x10 grid of random samples and put the label as the small title above each image.

import numpy as np, matplotlib.pyplot as plt
rng = np.random.default_rng(7)

sel = rng.choice(len(X_full), size=100, replace=False)  # pick 100 random indices
fig, axes = plt.subplots(10, 10, figsize=(8,8))
for ax, idx in zip(axes.flatten(), sel):
    ax.imshow(images_full[idx], cmap='gray')
    ax.set_title(int(y_full[idx]), fontsize=8)
    ax.axis('off')
# plt.suptitle('Random samples from the digits dataset (label shown above each image)', y=0.92)
plt.tight_layout()
plt.show()


## 3) Train/test split and safe feature scaling

In [None]:
# WHY SPLIT?  We need an *honest* test set that the model never sees during training.
# WHY STRATIFY?  Keeps class proportions similar in train and test.
# WHY SCALE?  MLPs train better when each feature has similar scale (mean 0, std 1).
#   Important: fit the scaler on *train only*, then apply to both train and test (avoid data leakage).

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

X_train, X_test, y_train, y_test, images_train, images_test = train_test_split(
    X_full, y_full, images_full, test_size=0.25, stratify=y_full, random_state=42
)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)   # Learn scaling parameters from *train only*, so each pixel is mean 0, std 1 over the training set
X_test_scaled  = scaler.transform(X_test)        # Apply the *same* transform to test

print('Train set:', X_train_scaled.shape, ' Test set:', X_test_scaled.shape)


## 4) Build a small MLP and train it over several epochs

In [None]:
# We will train a *small* MLP classifier with one hidden layer of 32 neurons.
# To collect a *testing curve* (accuracy on test vs epochs), we use incremental training via partial_fit.
# Steps per epoch:
#   1) Shuffle training data
#   2) Loop over mini-batches
#   3) Call partial_fit on each batch (first batch must pass 'classes=list_of_labels')
# After each epoch, we record:
#   - current training loss (mlp.loss_)
#   - train accuracy (on X_train_scaled)
#   - test accuracy  (on X_test_scaled)

import numpy as np
from sklearn.neural_network import MLPClassifier

classes = np.unique(y_train)  # array([0,1,2,3,4,5,6,7,8,9])

mlp = MLPClassifier(hidden_layer_sizes=(32,),    # small network, fast to train
                    activation='relu',           # try 'tanh' if you want smoother weight patterns
                    solver='sgd',                # SGD supports partial_fit for incremental learning
                    learning_rate_init=0.01,
                    alpha=1e-4,                  # L2 regularization (helps generalization)
                    max_iter=1,                  # we do 1 iteration per 'fit' call (unused here)
                    warm_start=False,            # not needed when using partial_fit
                    random_state=0)

# Hyperparameters for our simple training loop
epochs = 25
batch_size = 64

train_loss = []
train_acc  = []
test_acc   = []

n = X_train_scaled.shape[0]
indices = np.arange(n)

for epoch in range(1, epochs+1):
    # Shuffle each epoch
    np.random.shuffle(indices)
    X_shuf = X_train_scaled[indices]
    y_shuf = y_train[indices]

    # Mini-batch loop
    for start in range(0, n, batch_size):
        end = start + batch_size
        X_batch = X_shuf[start:end]
        y_batch = y_shuf[start:end]
        if start == 0 and epoch == 1:
            # First call must include 'classes'
            mlp.partial_fit(X_batch, y_batch, classes=classes)
        else:
            mlp.partial_fit(X_batch, y_batch)

    # Record metrics after this epoch
    # Note: 'loss_' is training loss on the *last batch* seen; it still shows downward trend over epochs.
    train_loss.append(mlp.loss_)
    train_acc.append(mlp.score(X_train_scaled, y_train))
    test_acc.append(mlp.score(X_test_scaled,  y_test))

    print(f"Epoch {epoch:02d}: loss={train_loss[-1]:.4f}  train_acc={train_acc[-1]:.3f}  test_acc={test_acc[-1]:.3f}")


## 5) Training and testing curves

In [None]:
# We plot training loss (should decrease) and train/test accuracy (should rise, then level off).
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1,2, figsize=(10,3.6))

ax[0].plot(train_loss, marker='o', ms=3)
ax[0].set_xlabel('epoch'); ax[0].set_ylabel('training loss'); ax[0].set_title('Loss vs epoch'); ax[0].grid(alpha=0.3)

ax[1].plot(train_acc, label='train acc')
ax[1].plot(test_acc, label='test acc')
ax[1].set_xlabel('epoch'); ax[1].set_ylabel('accuracy'); ax[1].set_title('Accuracy vs epoch'); ax[1].grid(alpha=0.3)
ax[1].legend()

plt.tight_layout(); plt.show()


## 6) Confusion matrix (which digits are confused)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

y_pred = mlp.predict(X_test_scaled)
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.arange(10))
fig, ax = plt.subplots(figsize=(5.2,4.2)); disp.plot(ax=ax, cmap='Blues', colorbar=True, values_format='d')
ax.set_title('Confusion matrix (test set)'); plt.tight_layout(); plt.show()


## 7) Look at misclassified test images (learn from mistakes)

In [None]:
# Seeing mistakes helps diagnose what patterns the model struggles with.
import numpy as np, matplotlib.pyplot as plt

wrong_idx = np.where(y_pred != y_test)[0]
print('Misclassified samples:', len(wrong_idx))

# Show up to 25 misclassified images with true/pred labels
show = wrong_idx[:25]
if len(show) > 0:
    fig, axes = plt.subplots(5, 5, figsize=(6,6))
    for ax, i in zip(axes.flatten(), show):
        ax.imshow(images_test[i], cmap='gray')
        ax.set_title(f'true={y_test[i]}  pred={y_pred[i]}', fontsize=8)
        ax.axis('off')
    plt.suptitle('Misclassified test digits')
    plt.tight_layout()
    plt.show()
else:
    print('No misclassifications found in this split; try changing random_state for a different split.')
