# From scratch

Run it from start and some necessary files will be saved

In [8]:
!pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-8.1.8-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.15-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.8-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading jupyterlab_widgets-3.0.16-py3-none-any.whl (914 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m914.9/914.9 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading widgetsnbextension-4.0.15-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: widgets

In [1]:
!jupyter nbextension enable --py widgetsnbextension

usage: jupyter [-h] [--version] [--config-dir] [--data-dir] [--runtime-dir]
               [--paths] [--json] [--debug]
               [subcommand]

Jupyter: Interactive Computing

positional arguments:
  subcommand     the subcommand to launch

optional arguments:
  -h, --help     show this help message and exit
  --version      show the versions of core jupyter packages and exit
  --config-dir   show Jupyter config dir
  --data-dir     show Jupyter data dir
  --runtime-dir  show Jupyter runtime dir
  --paths        show all Jupyter paths. Add --json for machine-readable
                 format.
  --json         output paths as machine-readable json
  --debug        output debug information about paths

Available subcommands: kernel kernelspec migrate run troubleshoot

Jupyter command `jupyter-nbextension` not found.


In [1]:
import os
import sys
sys.path.append(os.path.abspath(".."))

from NNModel import MultiLayerNN
import numpy as np
import pickle
import torch
import torch.nn as nn
from prep_data import prep_data
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [2]:
train_loader, test_loader = prep_data()
device = "cuda" if torch.cuda.is_available() else "cpu"
state_dict = torch.load("data/model.pt", weights_only=True)

new_dict = {}
#When model compiles, all weights are modified from LAYERNAME to _orig_model.LAYERNAME. We remove the _orig_model. for compatibility
for k in state_dict.keys():
    new_dict[k[10:]] = state_dict[k]

model = MultiLayerNN(latent_size=128, num_layers=3)
model.load_state_dict(new_dict)
model.to(device)

images = []
labels = []
for loader in [train_loader, test_loader]:
    for image, label, idxs in loader:
        images.append(image.squeeze().reshape(-1, 28*28))
        labels.append(label)
images = torch.cat(images).to(device)
labels = torch.cat(labels).to(device)

In [3]:
# Dictionary to store the activations
activations = {}
# Dictionary to store layer names for plotting titles
layer_names = {}

def get_activation(name):
    """Hook function to save the output of a layer"""
    def hook(model, input, output):
        # We detach the output tensor to prevent saving the whole computation graph
        activations[name] = output.detach()
    return hook

# Register a forward hook for each linear layer
# We capture the output of the Linear layer *before* it goes into ReLU
hook_handles = []
layer_id = 0

# Hook for the first linear layer
layer_name = f"Layer {layer_id}: Dim Reduction"
handle = model.dim_reduction.register_forward_hook(get_activation(layer_name))
hook_handles.append(handle)
layer_names[layer_id] = layer_name
layer_id += 1

# Hooks for the hidden linear layers
for i, layer in enumerate(model.hidden_layers):
    if isinstance(layer, nn.Linear):
        layer_name = f"Layer {layer_id}: Hidden Linear"
        handle = layer.register_forward_hook(get_activation(layer_name))
        hook_handles.append(handle)
        layer_names[layer_id] = layer_name
        layer_id += 1
layer_name = f"Layer {layer_id}: Output"
handle = model.output.register_forward_hook(get_activation(layer_name))
hook_handles.append(handle)
layer_names[layer_id] = layer_name

# Run a forward pass to trigger the hooks and populate the 'activations' dict
with torch.no_grad():
    output = model(images)

# Don't forget to remove the hooks when you're done to avoid memory leaks
for handle in hook_handles:
    handle.remove()

In [4]:
def predict_from_layer(start_layer_index, activation_tensor):
    """
    Takes an activation tensor from a specific layer and passes it
    through the rest of the model to get a final prediction.
    """
    x = activation_tensor

    # Manually apply the forward pass for subsequent layers
    # Apply ReLU for the starting layer (since we hooked pre-relu)
    x = nn.functional.relu(x)

    # Find the starting point in the hidden_layers list
    linear_layer_count = 1 # Start after dim_reduction
    start_idx_in_hidden = -1
    for i, layer in enumerate(model.hidden_layers):
        if isinstance(layer, nn.Linear):
            if linear_layer_count == start_layer_index:
                start_idx_in_hidden = i + 1 # Start from the ReLU after this linear layer
                break
            linear_layer_count += 1

    # Pass through remaining hidden layers
    if start_idx_in_hidden != -1:
        for i in range(start_idx_in_hidden, len(model.hidden_layers)):
            x = model.hidden_layers[i](x)

    # Final output layer
    if x.shape[1] == 10:
        return x
    return model.output(x)


# # Now, loop through the captured activations and plot
# for layer_index, (name, data) in enumerate(activations.items()):
#     print(f"Generating boundary for: {name}")

#     # 1. Fit PCA on this layer's activations
#     pca = PCA(n_components=2)
#     features_2d = pca.fit_transform(data.cpu().numpy())

#     # 2. Create mesh grid
#     x_min, x_max = features_2d[:, 0].min()*1.1, features_2d[:, 0].max() *1.1
#     y_min, y_max = features_2d[:, 1].min() *1.1, features_2d[:, 1].max() *1.1
#     xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.2), np.arange(y_min, y_max, 0.2))

#     # 3. Inverse transform grid points and predict
#     grid_points = np.c_[xx.ravel(), yy.ravel()]
#     grid_points_high_dim = pca.inverse_transform(grid_points)
#     grid_tensor = torch.tensor(grid_points_high_dim, dtype=torch.float32).to(device)

#     with torch.no_grad():
#         # Use the helper function to predict from this intermediate layer
#         outputs = predict_from_layer(layer_index, grid_tensor)
#         _, Z = torch.max(outputs, 1)
#         Z = Z.cpu().numpy().reshape(xx.shape)

#     # 4. Plotting
#     plt.figure(figsize=(10, 8))
#     plt.contourf(xx, yy, Z, alpha=0.4, cmap=plt.cm.tab10)
#     scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels.cpu(), s=1, cmap=plt.cm.tab10)

#     plt.title(f"Decision Boundary at {name}")
#     plt.xlabel("Principal Component 1")
#     plt.ylabel("Principal Component 2")
#     plt.legend(handles=scatter.legend_elements()[0], labels=list(range(10)))
#     plt.show()

In [5]:
# load incorrect indices & build index mapping
from collections import OrderedDict

# [(idx, yhat, y), ...]
with open("data/incorrect_preds.pkl", "rb") as f:
    incorrect_triplets = pickle.load(f)

incorrect_indices = [t[0] for t in incorrect_triplets]

all_indices = []
images_list, labels_list = [], []
for loader in [train_loader, test_loader]:
    for image, label, idxs in loader:
        images_list.append(image.squeeze().reshape(-1, 28*28))
        labels_list.append(label)
        all_indices.append(idxs)

images = torch.cat(images_list).to(device)
labels = torch.cat(labels_list).to(device)
all_indices = torch.cat(all_indices).cpu().numpy()

# create mapping from dataset index to row in activations
index_to_row = {}
for row, idx in enumerate(all_indices):
    index_to_row.setdefault(int(idx), row)

# Save index mapping for later use
with open("data/index_to_row.pkl", "wb") as f:
    pickle.dump(index_to_row, f)

In [6]:
# cache PCA projections and decision boundaries for each layer
layer_cache = OrderedDict()
labels_per_row = labels.detach().cpu().numpy()

for layer_index, (name, data) in enumerate(activations.items()):
    print(f"[cache] {name}")

    # PCA
    pca = PCA(n_components=2, random_state=42)
    feats2d = pca.fit_transform(data.cpu().numpy())

    # mesh grid
    x_min, x_max = feats2d[:, 0].min()*1.1, feats2d[:, 0].max()*1.1
    y_min, y_max = feats2d[:, 1].min()*1.1, feats2d[:, 1].max()*1.1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.2), np.arange(y_min, y_max, 0.2))

    grid_points = np.c_[xx.ravel(), yy.ravel()]
    grid_points_high_dim = pca.inverse_transform(grid_points)
    grid_tensor = torch.tensor(grid_points_high_dim, dtype=torch.float32).to(device)

    with torch.no_grad():
        outputs = predict_from_layer(layer_index, grid_tensor)
        _, Z = torch.max(outputs, 1)
        Z = Z.cpu().numpy().reshape(xx.shape)

    layer_cache[name] = {
        "index": layer_index,
        "pca": pca,
        "feats2d": feats2d,
        "labels": labels_per_row,
        "xx": xx, "yy": yy, "Z": Z,
    }

# Save cache to disk
with open("data/layer_cache.pkl", "wb") as f:
    pickle.dump(layer_cache, f)


[cache] Layer 0: Dim Reduction
[cache] Layer 1: Hidden Linear
[cache] Layer 2: Hidden Linear
[cache] Layer 3: Hidden Linear
[cache] Layer 4: Output


In [7]:
# plotting function with highlight

incorrect_dict = {idx: (yhat, ytrue) for idx, yhat, ytrue in incorrect_triplets}

def plot_all_layers_with_highlight(dataset_idx: int, point_size=10):
    if dataset_idx not in index_to_row:
        print(f"[WARN] Index {dataset_idx} is not in the dataset.")
        return

    row = index_to_row[dataset_idx]
    # true_label = int(labels[row].cpu().item())
    true_label = incorrect_dict.get(dataset_idx, (None, None))[1]
    pred_label = incorrect_dict.get(dataset_idx, (None, None))[0]

    # Plot for each layer
    for name, blob in layer_cache.items():
        feats2d = blob["feats2d"]
        xx, yy, Z = blob["xx"], blob["yy"], blob["Z"]

        plt.figure(figsize=(9, 7))
        plt.contourf(xx, yy, Z, alpha=0.35, cmap=plt.cm.tab10)
        sc = plt.scatter(feats2d[:, 0], feats2d[:, 1],
                         c=labels.cpu().numpy(), s=point_size, cmap=plt.cm.tab10)

        # Highlight the specific point
        hx, hy = feats2d[row, 0], feats2d[row, 1]
        plt.scatter([hx], [hy], s=160, facecolors='none', edgecolors='k', linewidths=2.2, marker='o')
        plt.scatter([hx], [hy], s=40, c=np.array([true_label]),
                    cmap=plt.cm.tab10, vmin=0, vmax=9)  # fill with true label color

        plt.title(
            f"Decision Boundary at {name}\n"
            f"highlight idx={dataset_idx} | y={true_label}, ŷ={pred_label}"
        )
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.legend(handles=sc.legend_elements()[0], labels=list(range(10)), title="True label", loc="best")
        plt.show()


In [8]:
# ---- Interactive selection (fixed) ----
import ipywidgets as widgets
from ipywidgets import Layout
from IPython.display import display

incorrect_options = sorted(set(incorrect_indices))
if not incorrect_options:
    print("No misclassified samples found.")
else:
    dd = widgets.Dropdown(options=incorrect_options, description='Misclassified index:',
                          layout={'width': '300px'}, style={'description_width': 'initial'})
    btn = widgets.Button(description='Plot', button_style='primary')
    out = widgets.Output()

    def on_click(_):
        with out:
            out.clear_output(wait=True)
            plot_all_layers_with_highlight(int(dd.value))

    btn.on_click(on_click)

    display(widgets.HBox([dd, btn]), out)

    # Initial plot
    with out:
        plot_all_layers_with_highlight(int(dd.value))


HBox(children=(Dropdown(description='Misclassified index:', layout=Layout(width='300px'), options=(24, 72, 80,…

Output()

In [37]:
incorrect_dict.get(233)

(3, 8)

# Use cached layers

In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# [(idx, yhat, y), ...]
with open("data/incorrect_preds.pkl", "rb") as f:
    incorrect_triplets = pickle.load(f)

with open("data/index_to_row.pkl", "rb") as f:
    index_to_row = pickle.load(f)

with open("data/layer_cache.pkl", "rb") as f:
    layer_cache = pickle.load(f)


In [3]:
# plotting function with highlight

incorrect_dict = {idx: (yhat, ytrue) for idx, yhat, ytrue in incorrect_triplets}

def plot_all_layers_with_highlight(dataset_idx: int, point_size=10):
    if dataset_idx not in index_to_row:
        print(f"[WARN] Index {dataset_idx} is not in the dataset.")
        return

    row = index_to_row[dataset_idx]
    # true_label = int(labels[row].cpu().item())
    true_label = incorrect_dict.get(dataset_idx, (None, None))[1]
    pred_label = incorrect_dict.get(dataset_idx, (None, None))[0]

    # Plot for each layer
    for name, blob in layer_cache.items():
        feats2d = blob["feats2d"]
        xx, yy, Z = blob["xx"], blob["yy"], blob["Z"]

        plt.figure(figsize=(9, 7))
        plt.contourf(xx, yy, Z, alpha=0.35, cmap=plt.cm.tab10)
        sc = plt.scatter(feats2d[:, 0], feats2d[:, 1],
                         c=blob["labels"], s=point_size, cmap=plt.cm.tab10)

        # Highlight the specific point
        hx, hy = feats2d[row, 0], feats2d[row, 1]
        plt.scatter([hx], [hy], s=160, facecolors='none', edgecolors='k', linewidths=2.2, marker='o')
        plt.scatter([hx], [hy], s=40, c=np.array([true_label]),
                    cmap=plt.cm.tab10, vmin=0, vmax=9)  # fill with true label color

        plt.title(
            f"Decision Boundary at {name}\n"
            f"highlight idx={dataset_idx} | y={true_label}, ŷ={pred_label}"
        )
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.legend(handles=sc.legend_elements()[0], labels=list(range(10)), title="True label", loc="best")
        plt.show()


In [4]:
# ---- Interactive selection (fixed) ----
import ipywidgets as widgets
from ipywidgets import Layout
from IPython.display import display


incorrect_indices = [t[0] for t in incorrect_triplets]
incorrect_options = sorted(set(incorrect_indices))
if not incorrect_options:
    print("No misclassified samples found.")
else:
    dd = widgets.Dropdown(options=incorrect_options, description='Misclassified index:',
                          layout={'width': '300px'}, style={'description_width': 'initial'})
    btn = widgets.Button(description='Plot', button_style='primary')
    out = widgets.Output()

    def on_click(_):
        with out:
            out.clear_output(wait=True)
            plot_all_layers_with_highlight(int(dd.value))

    btn.on_click(on_click)

    display(widgets.HBox([dd, btn]), out)

    # Initial plot
    with out:
        plot_all_layers_with_highlight(int(dd.value))


HBox(children=(Dropdown(description='Misclassified index:', layout=Layout(width='300px'), options=(24, 72, 80,…

Output()