
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RichardJPovinelli/Neural_Networks_Course/blob/main/Dynamic_MLP_MNIST.ipynb)

In [235]:
# Imports
import time
import random
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import pandas as pd

try:
    import ipywidgets as widgets
    from IPython.display import display

    WIDGETS_AVAILABLE = True
except ImportError:
    WIDGETS_AVAILABLE = False

# Global run control flag (used by Abort button)
STOP_REQUESTED = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [236]:
# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed()

In [237]:
# Data Loaders (MNIST)
batch_size_default = 128
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
# Split train into train/val
val_size = 5000
train_size = len(train_ds) - val_size
train_ds, val_ds = torch.utils.data.random_split(
    train_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)


def make_loaders(batch_size):
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds, batch_size=batch_size, shuffle=False),
        DataLoader(test_ds, batch_size=batch_size, shuffle=False),
    )


train_loader, val_loader, test_loader = make_loaders(batch_size_default)
print(f"Train: {train_size}, Val: {val_size}, Test: {len(test_ds)}")

Train: 55000, Val: 5000, Test: 10000


In [238]:
# Dynamic MLP Model with optional BatchNorm, Dropout, Residual Blocks
class DynamicMLP(nn.Module):
    def __init__(
        self,
        input_dim=28 * 28,
        num_classes=10,
        hidden_width=256,
        depth=4,
        activation="relu",
        dropout=0.0,
        batchnorm=False,
        residual=False,
        widths_list=None,
    ):
        super().__init__()
        act_map = {
            "relu": nn.ReLU(),
            "gelu": nn.GELU(),
            "tanh": nn.Tanh(),
            "sigmoid": nn.Sigmoid(),
            "leakyrelu": nn.LeakyReLU(0.2),
            "elu": nn.ELU(),
        }
        self.activation = act_map[activation.lower()]
        self.dropout_p = dropout
        self.batchnorm = batchnorm
        self.residual = residual
        if widths_list is not None and len(widths_list) > 0:
            widths = widths_list
            depth = len(widths)
        else:
            widths = [hidden_width] * depth
        self.input_dim = input_dim
        self.num_classes = num_classes
        layers = []
        prev = input_dim
        self.skip_dims = []
        for w in widths:
            block = []
            linear = nn.Linear(prev, w)
            block.append(linear)
            if batchnorm:
                block.append(nn.BatchNorm1d(w))
            block.append(self.activation)
            if self.dropout_p > 0:
                block.append(nn.Dropout(self.dropout_p))
            layers.append(nn.Sequential(*block))
            self.skip_dims.append((prev, w))
            prev = w
        self.layers = nn.ModuleList(layers)
        self.out = nn.Linear(prev, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = x
        for i, layer in enumerate(self.layers):
            inp = out
            out = layer(out)
            if self.residual:
                in_dim, out_dim = self.skip_dims[i]
                if in_dim == out_dim:
                    out = out + inp
        return self.out(out)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Quick sanity test
model_test = DynamicMLP(depth=3, hidden_width=128, activation="relu", residual=True)
print("Param count test model:", model_test.count_params())

Param count test model: 134794


In [239]:
def plot_history(history: dict, title: str = "Training History"):
    epochs = history.get("epoch") or []
    if not epochs:
        print("No history data available to plot.")
        return

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(epochs, history.get("train_loss", []), label="Train Loss")
    axes[0].plot(epochs, history.get("val_loss", []), label="Val Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Loss")
    axes[0].legend()

    axes[1].plot(epochs, history.get("train_error", []), label="Train Error")
    axes[1].plot(epochs, history.get("val_error", []), label="Val Error")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Error Rate")
    axes[1].set_title("Error")
    axes[1].legend()

    fig.suptitle(title)
    fig.tight_layout()
    plt.show()

def plot_run_log(metric: str = "test_error"):
    """Plot a bar chart of the requested metric across runs and show a short description for each run beneath the chart.

    The function uses the top axis for the metric chart and the bottom axis to list the run descriptions, wrapped and left-justified.
    """
    if not EXPERIMENT_LOG:
        print("No experiment log entries to plot yet.")
        return

    df = pd.DataFrame(EXPERIMENT_LOG)
    if metric not in df.columns:
        print(f"Metric '{metric}' not found in experiment log columns: {sorted(df.columns.tolist())}")
        return

    # Prepare data
    runs = df.index + 1
    values = df[metric]

    # Short labels for xticks: simply run numbers to avoid overlap
    xtick_labels = [str(int(r)) for r in runs]

    # Full descriptions for display under the plot
    def _full_desc(row, idx):
        widths = row.get("widths_list")
        width_desc = f"widths={widths}" if widths and isinstance(widths, (list, tuple)) else f"hidden={int(row.get('hidden_width', 0))}"
        return (
            f"Run {idx}: depth={row.get('depth')}, {width_desc}, act={row.get('activation')}, opt={row.get('optimizer')}, "
            f"epochs={row.get('epochs')}, batch={row.get('batch_size')}, {metric}={row.get(metric)}"
        )

    descriptions = [_full_desc(row, int(idx) + 1) for idx, row in df.iterrows()]

    # Wrap long description lines for readability
    import textwrap
    wrap_width = 100
    wrapped_lines = []
    for d in descriptions:
        wrapped_lines.extend(textwrap.wrap(d, wrap_width) or [d])

    # Limit number of lines shown to keep figure readable
    max_lines = 12
    if len(wrapped_lines) > max_lines:
        wrapped_lines = wrapped_lines[:max_lines]
        wrapped_lines.append(f"... and {len(descriptions) - max_lines} more runs (see EXPERIMENT_LOG)")

    # Dynamic height based on number of description lines
    top_height = 3
    bottom_height = 0.25 * max(len(wrapped_lines), 1)
    fig_height = top_height + bottom_height

    # Create figure and axes with sufficient spacing for xticks and descriptions
    import matplotlib.gridspec as gridspec
    fig = plt.figure(figsize=(max(8, len(runs) * 0.8), fig_height))
    gs = gridspec.GridSpec(2, 1, height_ratios=[top_height, bottom_height], hspace=0.4)
    ax = fig.add_subplot(gs[0])
    ax_desc = fig.add_subplot(gs[1])

    # Plot bar chart
    bars = ax.bar(runs, values, color="#1f77b4")
    ax.set_xticks(runs)
    ax.set_xticklabels(xtick_labels, rotation=0, ha='center', fontsize=9)
    ax.set_xlabel("Run #")
    ax.set_ylabel(metric)
    ax.set_title(f"Experiment Log — {metric}")
    ax.grid(True, axis="y", alpha=0.3)

    # Render descriptions in the bottom axes
    ax_desc.axis('off')
    text_block = "\n".join(wrapped_lines)
    ax_desc.text(0.01, 0.98, text_block, va='top', ha='left', family='monospace', fontsize=9, transform=ax_desc.transAxes)

    # Avoid using tight_layout if there are non-standard Axes in the figure; instead adjust subplots
    fig.subplots_adjust(left=0.06, right=0.98, top=0.95, bottom=0.05)
    plt.show()

In [240]:
def plot_example_predictions(model, loader, device=torch.device('cpu'), n_correct=4, n_incorrect=4):
    """Plot n_correct correct predictions and n_incorrect incorrect predictions found by the model on data from `loader`.
    """
    import numpy as np
    # Guard short-circuit if either the model or loader is not available
    if model is None:
        print("No model provided; skipping example predictions.")
        return
    if loader is None:
        print("No test_loader available; skipping example predictions.")
        return

    model.eval()
    correct_ex = []
    incorrect_ex = []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            preds = logits.argmax(dim=1)
            for i in range(xb.size(0)):
                img = xb[i].cpu().squeeze(0) if xb.shape[1] == 1 else xb[i].cpu()
                true = int(yb[i].cpu())
                pred = int(preds[i].cpu())
                if pred == true and len(correct_ex) < n_correct:
                    correct_ex.append((img.numpy(), true, pred))
                elif pred != true and len(incorrect_ex) < n_incorrect:
                    incorrect_ex.append((img.numpy(), true, pred))
                if len(correct_ex) >= n_correct and len(incorrect_ex) >= n_incorrect:
                    break
            if len(correct_ex) >= n_correct and len(incorrect_ex) >= n_incorrect:
                break
    # Prepare plot
    import matplotlib.pyplot as plt
    ncols = max(n_correct, n_incorrect)
    fig, axes = plt.subplots(2, ncols, figsize=(ncols * 1.6, 4))
    # Top row: correct
    for c in range(ncols):
        ax = axes[0, c] if ncols > 1 else axes[0]
        if c < len(correct_ex):
            img, t, p = correct_ex[c]
            ax.imshow(img, cmap='gray')
            ax.set_title(f't:{t} p:{p}')
        else:
            ax.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])
    # Bottom row: incorrect
    for c in range(ncols):
        ax = axes[1, c] if ncols > 1 else axes[1]
        if c < len(incorrect_ex):
            img, t, p = incorrect_ex[c]
            ax.imshow(img, cmap='gray')
            ax.set_title(f't:{t} p:{p}')
        else:
            ax.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])
    plt.suptitle('Sample Predictions — Top: correct | Bottom: incorrect')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


### Set up GUI

In [241]:
# Interactive Widgets (if available)
from ipywidgets import Layout, GridBox
from contextlib import redirect_stdout

# Defensive cleanup: close existing widget objects to avoid stale front-end/back-end state.
for _n in [
    "depth_w",
    "width_w",
    "widths_text",
    "act_w",
    "optimizer_w",
    "lr_w",
    "dropout_w",
    "batchnorm_w",
    "residual_w",
    "scheduler_w",
    "step_size_w",
    "epochs_w",
    "batch_w",
    "metric_dropdown",
    "run_button",
    "abort_button",
    "log_area",
    "figure_out",
    "grid",
    "button_row",
    "container",
]:
    try:
        if _n in globals() and hasattr(globals()[_n], "close"):
            globals()[_n].close()
    except Exception:
        pass


def on_abort(b):
    global STOP_REQUESTED
    STOP_REQUESTED = True
    abort_button.disabled = True
    # Append message to the log textarea
    log_area.value = log_area.value + "Abort requested — the current run will stop shortly.\n"


def on_run(b):
    run_button.disabled = True
    abort_button.disabled = False
    # Clear previous log and append new header
    log_area.value = ""
    figure_out.clear_output()
    if widths_text.value.strip():
        try:
            wl = [int(x.strip()) for x in widths_text.value.split(",") if x.strip()]
        except ValueError:
            log_area.value += "Invalid widths list; using uniform width.\n"
            wl = None
    else:
        wl = None
    try:
        lr_val = float(lr_w.value)
    except (TypeError, ValueError):
        lr_val = 1e-3
        log_area.value += "Learning rate unparsable; resetting to 1e-3.\n"
    if lr_val < lr_min:
        lr_val = lr_min
        log_area.value += f"Learning rate too small; clamped to {lr_min:.1e}.\n"
    elif lr_val > lr_max:
        lr_val = lr_max
        log_area.value += f"Learning rate too large; clamped to {lr_max:.1e}.\n"
    lr_w.value = lr_val
    cfg = RunConfig(
        depth=depth_w.value,
        hidden_width=width_w.value,
        activation=act_w.value,
        dropout=dropout_w.value,
        batchnorm=batchnorm_w.value,
        residual=residual_w.value,
        optimizer=optimizer_w.value,
        lr=lr_val,
        epochs=epochs_w.value,
        batch_size=batch_w.value,
        scheduler=scheduler_w.value,
        step_size=step_size_w.value,
        widths_list=wl,
    )
    if cfg.optimizer.lower() == "lm" and cfg.scheduler:
        log_area.value += "Scheduler request ignored: LM optimizer does not support StepLR.\n"
    log_area.value += f"Running with config: {cfg}\n"

    class TextAreaWriter:
        """Stream stdout directly into the Textarea log."""

        def __init__(self, textarea):
            self.textarea = textarea

        def write(self, text):
            if text:
                try:
                    self.textarea.value += text
                except Exception:
                    pass

        def flush(self):
            pass

    writer = TextAreaWriter(log_area)
    with redirect_stdout(writer):
        model, history = run_experiment(cfg)

    plt.close("all")
    opt_label = history.get("optimizer_used", cfg.optimizer)
    with figure_out:
        from IPython.display import clear_output

        clear_output(wait=True)
        plot_history(history, title=f"Error Curves (opt={opt_label}, depth={cfg.depth})")
        plot_run_log(metric_dropdown.value)
        # Show sample predictions: top row correct, bottom row incorrect (4 each).
        try:
            plot_example_predictions(model, test_loader, device=device, n_correct=4, n_incorrect=4)
        except Exception as ex:
            print("Error plotting example predictions:", ex)

    run_button.disabled = False
    abort_button.disabled = True


# Widget styles & layout
# Use flexible control layout (responsive)
# Use a larger min_width + flex to prevent overlap of inputs on small displays
default_ctrl_layout = Layout(width="auto", min_width="220px", flex="1 1 280px")
hbox_layout = Layout(
    align_items="flex-start",
    display="flex",
    flex_flow="row wrap",
    justify_content="flex-start",
    width="100%",
)
vbox_layout = Layout(width="100%")

# Shared bounds for learning-rate validation
lr_min, lr_max = 1e-7, 1e-1

# Core widgets
depth_w = widgets.BoundedIntText(
    value=4,
    min=1,
    max=12,
    step=1,
    description="Number of Layers",
    layout=default_ctrl_layout,
)
width_w = widgets.BoundedIntText(
    value=256,
    min=16,
    max=1024,
    step=16,
    description="Neurons per layer",
    layout=default_ctrl_layout,
)
widths_text = widgets.Text(
    value="",
    description="Widths List",
    placeholder="Comma-separated custom widths (optional)",
    layout=default_ctrl_layout,
)
act_w = widgets.Dropdown(
    options=["relu", "gelu", "tanh", "sigmoid", "leakyrelu", "elu"],
    value="relu",
    description="Activation",
    layout=default_ctrl_layout,
)
optimizer_w = widgets.Dropdown(
    options=["adam", "adamw", "sgd", "lm"],
    value="adam",
    description="Optimizer",
    layout=default_ctrl_layout,
)
lr_w = widgets.BoundedFloatText(
    value=1e-3,
    min=lr_min,
    max=lr_max,
    step=1e-4,
    description="Learning Rate",
    layout=default_ctrl_layout,
)
dropout_w = widgets.BoundedFloatText(
    value=0.1,
    min=0.0,
    max=0.7,
    step=0.05,
    description="Dropout",
    layout=default_ctrl_layout,
)
batchnorm_w = widgets.ToggleButton(value=True, description="BatchNorm", layout=Layout(width="130px"))
residual_w = widgets.ToggleButton(value=True, description="Residual", layout=Layout(width="130px"))
scheduler_w = widgets.ToggleButton(value=False, description="Scheduler", layout=Layout(width="130px"))
step_size_w = widgets.BoundedIntText(
    value=3,
    min=1,
    max=10,
    step=1,
    description="Scheduler Step Size",
    layout=default_ctrl_layout,
)
epochs_w = widgets.BoundedIntText(
    value=5,
    min=1,
    max=50,
    step=1,
    description="Epochs",
    layout=default_ctrl_layout,
)
batch_w = widgets.BoundedIntText(
    value=128,
    min=16,
    max=512,
    step=16,
    description="Batch Size",
    layout=default_ctrl_layout,
)
metric_dropdown = widgets.Dropdown(
    options=["test_error", "best_val_error", "test_loss"],
    value="test_error",
    description="Metric",
    layout=default_ctrl_layout,
)

run_button = widgets.Button(description="Run Experiment", button_style="success")
abort_button = widgets.Button(description="Abort", button_style="warning", disabled=True)
log_area = widgets.Textarea(
    value="",
    placeholder="Training output appears here...",
    layout=Layout(width="100%", height="260px"),
)
figure_out = widgets.Output(layout=Layout(border="1px solid #ccc", width="100%"))

# Layout composition
grid = GridBox(
    children=(
        depth_w,
        width_w,
        widths_text,
        epochs_w,
        batch_w,
        dropout_w,
        step_size_w,
        lr_w,
        optimizer_w,
        act_w,
        metric_dropdown,
    ),
    layout=Layout(
        align_items="flex-start",
        grid_gap="8px",
        grid_template_columns="repeat(auto-fit, minmax(260px, 1fr))",
        width="100%",
    ),
)
button_row = widgets.HBox(
    children=[run_button, abort_button, residual_w, batchnorm_w, scheduler_w],
    layout=hbox_layout,
)
container = widgets.VBox(children=[grid, button_row, log_area, figure_out], layout=vbox_layout)


### Run Configuration Parameters
- `Number of Layers (depth)`: number of hidden layers (ignored when `widths_list` is set; default 4).
- `Neurons per layer (hidden_width)`: neurons per hidden layer when using uniform widths (default 256).
- `Specify architecture (widths_list)`: optional explicit sequence of layer widths that overrides `hidden_width`/`depth`.
- `Activation (activation)`: activation function name (`relu`, `gelu`, `tanh`, `sigmoid`, `leakyrelu`, `elu`).
- `Dropout (dropout)`: dropout probability applied after activations (0 disables).
- `Optimizer (optimizer)`: optimizer choice (`adam`, `adamw`, `sgd`, `lm`; `lm` uses an LBFGS-based Levenberg–Marquardt-like step and ignores schedulers).
- `Learning Rate (lr)`: learning rate supplied to the optimizer (default 1e-3).
- `Batch Normalization (batchnorm)`: enable per-layer batch normalization (`True`/`False`).
- `Residual Layers (residual)`: add skip connections when layer dimensions align (`True`/`False`).
- `Scheduler (scheduler)`: toggle StepLR scheduler usage (`True`/`False`).
- `Scheduler Step Size (step_size)`: StepLR step interval in epochs when scheduler is active (default 3).
- `Epochs (epochs)`: total training epochs (default 5).
- `Batch Size (batch_size)`: mini-batch size used by data loaders (default 128).
- `Error Metric ()`: Error metric used for plots

In [242]:
if WIDGETS_AVAILABLE:
    try:
        # Attach handler only if not attached already (idempotent)
        if 'run_button' in globals():
            if not getattr(run_button, '_on_run_attached', False):
                run_button.on_click(on_run)
                run_button._on_run_attached = True
        display(container)
    except Exception as e:
        print('Error while attaching run handler / displaying GUI:', e)
else:
    print("ipywidgets not available. Use manual config in next cell.")


VBox(children=(GridBox(children=(BoundedIntText(value=4, description='Number of Layers', layout=Layout(flex='1…

In [243]:
# Re-display GUI and attach handlers if needed (safe to run multiple times)
if WIDGETS_AVAILABLE:
    try:
        # Re-attach handler only once to avoid duplicated events
        if 'run_button' in globals():
            if not getattr(run_button, '_on_run_attached', False):
                run_button.on_click(on_run)
                run_button._on_run_attached = True
        # Some users may want to re-display the main container if the frontend cleared it
        if 'container' in globals():
            display(container)
        else:
            print('Container not available; re-run GUI setup cell to (re)create widgets.')
    except Exception as e:
        print('Error while (re)attaching GUI:', e)
else:
    print('ipywidgets unavailable; cannot display GUI.')

VBox(children=(GridBox(children=(BoundedIntText(value=4, description='Number of Layers', layout=Layout(flex='1…

### If no GUI, configure by modifying code

In [244]:
# Example training run on MNIST
example_cfg = RunConfig(
    depth=4,
    hidden_width=256,
    activation="relu",
    dropout=0.1,
    batchnorm=True,
    residual=True,
    optimizer="adam",
    lr=1e-3,
    epochs=5,
    batch_size=128,
    scheduler=False,
)
# Uncomment below to run the example configuration
# print('Launching example configuration:', example_cfg)
# example_model, example_history = run_experiment(example_cfg)
# plot_history(
#     example_history,
#     title=f"Example Run (activation={example_cfg.activation}, depth={example_cfg.depth}, width={example_cfg.hidden_width})"
# )
# plot_run_log('test_error')