In [None]:
from typing import Literal

checkpoint_path = "out/phoneme-baseline-xl/best-val_bal_acc-baseline-xl-hpo-2-epoch=08-val_f1_macro=0.7019.ckpt"
# checkpoint_path = "out/phoneme-megt/best-val_bal_acc-megt-s-hpo-0-epoch=09-val_f1_macro=0.6605.ckpt"

split: Literal["train", "val"] = "train"

In [2]:
import os
from pathlib import Path
from libribrain_experiments.models.configurable_modules.classification_module import ClassificationModule
from pnpl.datasets.libribrain2025 import constants_utils

constants_utils.set_remote_constants_url(
    f"{(Path(os.getcwd()) / 'constants.json').as_uri()}")
constants_utils.refresh_constants()
model = ClassificationModule.load_from_checkpoint(checkpoint_path)

In [3]:
from lightning.pytorch.utilities.model_summary import ModelSummary

ModelSummary(model, max_depth=-1)

   | Name                         | Type                | Params | Mode 
------------------------------------------------------------------------------
0  | modules_list                 | ModuleList          | 6.1 M  | train
1  | modules_list.0               | Conv1d              | 548 K  | train
2  | modules_list.1               | ResnetBlock         | 262 K  | train
3  | modules_list.1.module_list   | ModuleList          | 262 K  | train
4  | modules_list.1.module_list.0 | ELU                 | 0      | train
5  | modules_list.1.module_list.1 | Conv1d              | 196 K  | train
6  | modules_list.1.module_list.2 | ELU                 | 0      | train
7  | modules_list.1.module_list.3 | Conv1d              | 65.8 K | train
8  | modules_list.2               | ELU                 | 0      | train
9  | modules_list.3               | Conv1d              | 196 K  | train
10 | modules_list.4               | ResnetBlock         | 262 K  | train
11 | modules_list.4.module_list   | ModuleLis

In [None]:
from pathlib import Path
from pnpl.datasets import LibriBrainPhoneme
from libribrain_experiments.grouped_dataset import MyGroupedDatasetV3

raw_visualization_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="train" if split == "train" else "validation",
)
visualization_dataset = MyGroupedDatasetV3(
    raw_visualization_dataset,
    grouped_samples=100,
    drop_remaining=False,
    average_grouped_samples=True,
    state_cache_path=Path(f"./data_preprocessed/groupedv3/{split}_grouped_100.pt"),
    balance=True,
    # repeat=20,
    shuffle=True,
    # augment=True,  # Set to True if you want to apply data augmentation
)

# raw_visualization_dataset = LibriBrainPhoneme(
#     data_path="./data/",
#     tmin=0.0,
#     tmax=0.5,
#     standardize=True,
#     partition="validation",
# )
# visualization_dataset = MyGroupedDatasetV3(
#     raw_visualization_dataset,
#     grouped_samples=100,
#     drop_remaining=False,
#     average_grouped_samples=True,
#     state_cache_path=Path("./data_preprocessed/groupedv3/val_grouped_100.pt"),
#     balance=True,
#     # repeat=20,
#     shuffle=True,
#     # augment=True,  # Set to True if you want to apply data augmentation
# )

In [8]:
import plotly.graph_objs as go
from IPython.display import display
import ipywidgets as widgets
import numpy as np
import torch


# Sample & collect activations

activations = {}

sample_id = np.random.randint(0, len(visualization_dataset))
sample = visualization_dataset[sample_id][0].unsqueeze(0)
sample_label = visualization_dataset[sample_id][1].item()
sample_label_phoneme = raw_visualization_dataset.id_to_phoneme[sample_label]


def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach().cpu()
    return hook


# Register hooks for each layer in modules_list
for idx, layer in enumerate(model.modules_list):
    layer.register_forward_hook(get_activation(f'layer_{idx}'))


model.eval()
with torch.no_grad():
    _ = model.forward(sample.to(model.device))


activations_per_row = 32

# Register hooks for each layer in modules_list
for idx, layer in enumerate(model.modules_list):
    layer.register_forward_hook(get_activation(f'layer_{idx}'))

plot_fig: go.FigureWidget = go.FigureWidget()
plot_fig_input: go.FigureWidget = go.FigureWidget()


# Plotting functions

def plot_input(sample):
    sample_np = sample.squeeze().cpu().numpy()
    plot_fig_input.data = []  # Clear previous traces
    if sample_np.ndim == 2:  # (channels, time)
        trace = go.Heatmap(z=sample_np,
                           colorscale='Viridis', name='Input',
                           colorbar=dict(title='Amplitude'))
        plot_fig_input.add_trace(trace)
        plot_fig_input.update_layout(
            title=f'Input Signal #{sample_id} - /{sample_label_phoneme}/ ({sample_label})',
            xaxis_title='Time', yaxis_title='Channel Index',
            # Set figure size to match plot
            width=5 * sample_np.shape[1], height=2.5 * sample_np.shape[0]
        )
    else:
        plot_fig_input.update_layout(title='Input (Unsupported shape)')
    plot_fig_input.show()


def plot_activation(layer_name):
    act = activations[layer_name]
    act_np = act.squeeze().cpu().numpy()
    plot_fig.data = []  # Clear previous traces
    if act_np.ndim == 2:  # (channels, time)
        plot_height = max(300, 2.5 * act_np.shape[0])
        plot_width = min(800, max(500, 5 * act_np.shape[1]))

        trace = go.Heatmap(z=act_np,
                           colorscale='Viridis', name=f'Convolutional Layer',
                           colorbar=dict(title='Activation'))

        plot_fig.add_trace(trace)
        plot_fig.update_layout(
            title=f'{layer_name} - Convolutional Layer',
            yaxis=dict(title='Channel Index', tickmode='linear',
                       dtick=min(50, act_np.shape[0] // 10)),
            xaxis=dict(title='Time', tickmode='linear',
                       dtick=min(20, act_np.shape[1] // 10)),
            width=plot_width, height=plot_height
        )
    elif act_np.ndim == 1:  # (features,)
        is_final_layer = act_np.shape[0] == 39

        # Pad act_np if not a multiple of activations_per_row
        pad_len = (-act_np.size % activations_per_row)
        if pad_len != 0:
            act_np = np.copy(act_np)  # Ensure we don't modify the original
            act_np = np.pad(act_np, (0, pad_len), constant_values=np.nan)
        act_np = act_np.reshape(-1, activations_per_row)

        plot_height = max(400, 25 * act_np.shape[0])
        plot_width = 25 * activations_per_row

        trace = go.Heatmap(z=act_np,
                           colorscale='Viridis', name='Features')
        plot_fig.add_trace(trace)

        if is_final_layer:
            # fill a square grid indexed by predicted_label
            predicted_label = np.nanargmax(act_np)
            predicted_label_x = predicted_label % activations_per_row
            predicted_label_y = predicted_label // activations_per_row
            trace = go.Scatter(
                x=[predicted_label_x - 0.5, predicted_label_x + 0.5, predicted_label_x +
                    0.5, predicted_label_x - 0.5, predicted_label_x - 0.5],
                y=[predicted_label_y - 0.5, predicted_label_y - 0.5, predicted_label_y +
                    0.5, predicted_label_y + 0.5, predicted_label_y - 0.5],
                mode='lines',
                line=dict(color='red', width=0),
                fill='toself',
                fillpattern=dict(
                    shape='x',
                    solidity=0.4,
                    size=6,
                    fgcolor='red',
                    bgcolor='rgba(0,0,0,0.1)'  # Semi-transparent background
                ),
                name='Predicted Label Index',
                zorder=10
            )
            plot_fig.add_trace(trace)

            # draw a red square outline around the grid indexed by sample_label
            sample_label_x = sample_label % activations_per_row
            sample_label_y = sample_label // activations_per_row
            trace = go.Scatter(
                x=[sample_label_x - 0.5, sample_label_x + 0.5, sample_label_x +
                    0.5, sample_label_x - 0.5, sample_label_x - 0.5],
                y=[sample_label_y - 0.5, sample_label_y - 0.5, sample_label_y +
                    0.5, sample_label_y + 0.5, sample_label_y - 0.5],
                mode='lines',
                line=dict(color='red', width=2),
                name='Sample Label Index',
                zorder=10
            )
            plot_fig.add_trace(trace)

        plot_fig.update_layout(
            title=f'{layer_name} - Dense' +
            f' (Final, {predicted_label}?{sample_label}!)' if is_final_layer else '',
            xaxis=dict(title='Feature Index', tickmode='linear', dtick=4),
            yaxis=dict(title='', tickmode="array",
                       tickvals=np.arange(0, act_np.shape[0], 4),
                       ticktext=np.arange(0, act_np.shape[0], 4) * activations_per_row),
            xaxis_tickangle=0,
            width=plot_width, height=plot_height)
    else:
        plot_fig.update_layout(title=f'{layer_name} (Unsupported shape)')
    plot_fig.show()


layer_selector = widgets.ToggleButtons(
    options=list(activations.keys()),
    value=list(activations.keys())[-1],
    description='Layer:'
)


def on_change(change):
    plot_activation(layer_selector.value)


layer_selector.observe(on_change, names='value')


display(widgets.VBox([
    widgets.HBox([layer_selector, ]),
    widgets.HBox([plot_fig_input, plot_fig])
]))
plot_input(sample)
plot_activation(layer_selector.value)

VBox(children=(HBox(children=(ToggleButtons(description='Layer:', index=14, options=('layer_0', 'layer_1', 'la…


Message serialization failed with:
Out of range float values are not JSON compliant: nan
Supporting this message is deprecated in jupyter-client 7, please make sure your message is JSON-compliant

