In [1]:
import numpy as np
import torch
import plotly.graph_objects as go
from pathlib import Path
import re

def plot_neuron_trajectories_3d(
    run_id: str,
    layer_idx: int,
    axes: tuple = (0, 1),
    neuron_indices: str = 'auto',
    num_neurons: int = 10,
    task_range: tuple = None,
    subsample_freq: int = 1,
    checkpoint_base_dir: str = "model_checkpoints"
):
    checkpoint_dir = Path(checkpoint_base_dir) / run_id
    if not checkpoint_dir.exists():
        raise ValueError(f"Checkpoint directory not found: {checkpoint_dir}")
    
    model_files = sorted(checkpoint_dir.glob("task*_step*_model.pt"))
    if not model_files:
        raise ValueError(f"No model checkpoints found in {checkpoint_dir}")
    
    pattern = re.compile(r"task(\d+)_step(\d+)_model\.pt")
    parsed = []
    for f in model_files:
        m = pattern.match(f.name)
        if m:
            parsed.append((int(m.group(1)), int(m.group(2)), f))
    parsed.sort(key=lambda x: (x[0], x[1]))
    
    if task_range is not None:
        t_start, t_end = task_range
        parsed = [(t, s, f) for t, s, f in parsed if t_start <= t <= t_end]
    
    if subsample_freq > 1:
        parsed = parsed[::subsample_freq]
    
    if not parsed:
        raise ValueError("No checkpoints match the specified task range")
    
    first_state = torch.load(parsed[0][2], map_location='cpu')
    layer_keys = [k for k in first_state.keys() if k.startswith('layers.') and k.endswith('.weight')]
    layer_keys.sort(key=lambda x: int(x.split('.')[1]))
    
    num_layers = len(layer_keys)
    if layer_idx < 0:
        layer_idx = num_layers + layer_idx
    if layer_idx < 0 or layer_idx >= num_layers:
        raise ValueError(f"Invalid layer_idx {layer_idx}. Model has {num_layers} hidden layers.")
    
    weight_key = f'layers.{layer_idx}.weight'
    bias_key = f'layers.{layer_idx}.bias'
    
    if layer_idx < num_layers - 1:
        next_weight_key = f'layers.{layer_idx + 1}.weight'
    else:
        next_weight_key = 'output_layer.weight'
    
    num_neurons_total = first_state[weight_key].shape[0]
    ax0, ax1 = axes
    if ax0 >= first_state[weight_key].shape[1] or ax1 >= first_state[weight_key].shape[1]:
        raise ValueError(f"Axes {axes} out of bounds for weight dim {first_state[weight_key].shape[1]}")
    
    num_snapshots = len(parsed)
    w_k_0 = np.zeros((num_snapshots, num_neurons_total))
    w_k_1 = np.zeros((num_snapshots, num_neurons_total))
    biases = np.zeros((num_snapshots, num_neurons_total))
    a_k_values = np.zeros((num_snapshots, num_neurons_total))
    task_ids = np.zeros(num_snapshots, dtype=int)
    step_ids = np.zeros(num_snapshots, dtype=int)
    
    for idx, (task, step, fpath) in enumerate(parsed):
        state = torch.load(fpath, map_location='cpu')
        weights = state[weight_key].numpy()
        bias = state[bias_key].numpy()
        next_weights = state[next_weight_key].numpy()
        
        w_k_0[idx] = weights[:, ax0]
        w_k_1[idx] = weights[:, ax1]
        biases[idx] = bias
        
        if next_weights.shape[0] == 1:
            a_k_values[idx] = next_weights[0, :]
        else:
            a_k_values[idx] = np.linalg.norm(next_weights, axis=0)
        
        task_ids[idx] = task
        step_ids[idx] = step
    
    if neuron_indices == 'auto':
        num_to_plot = min(num_neurons, num_neurons_total)
        top_neurons = np.argsort(np.abs(a_k_values[-1]))[-num_to_plot:]
        neuron_indices = top_neurons.tolist()
    elif neuron_indices == 'random':
        num_to_plot = min(num_neurons, num_neurons_total)
        rng = np.random.default_rng(42)
        neuron_indices = rng.choice(num_neurons_total, size=num_to_plot, replace=False).tolist()
    
    a_k_min = a_k_values[:, neuron_indices].min()
    a_k_max = a_k_values[:, neuron_indices].max()
    
    traces = []
    for i, neuron_idx in enumerate(neuron_indices):
        w0_traj = w_k_0[:, neuron_idx]
        w1_traj = w_k_1[:, neuron_idx]
        bias_traj = biases[:, neuron_idx]
        a_k_traj = a_k_values[:, neuron_idx]
        
        show_colorbar = (i == 0)
        
        trace = go.Scatter3d(
            x=w0_traj,
            y=w1_traj,
            z=bias_traj,
            mode='lines+markers',
            line=dict(
                color=a_k_traj,
                colorscale='Viridis',
                width=4,
                showscale=show_colorbar,
                colorbar=dict(title="a_k", thickness=15, len=0.7, x=1.02) if show_colorbar else None,
                cmin=a_k_min,
                cmax=a_k_max
            ),
            marker=dict(
                size=2,
                color=a_k_traj,
                colorscale='Viridis',
                showscale=False,
                opacity=0.6,
                cmin=a_k_min,
                cmax=a_k_max
            ),
            hovertemplate=f'<b>Neuron {neuron_idx}</b><br>w[{ax0}]: %{{x:.3f}}<br>w[{ax1}]: %{{y:.3f}}<br>b: %{{z:.3f}}<br>a_k: %{{marker.color:.3f}}<br>task: %{{customdata[0]}}, step: %{{customdata[1]}}<extra></extra>',
            customdata=np.column_stack([task_ids, step_ids]),
            showlegend=False
        )
        traces.append(trace)
    
    fig = go.Figure(data=traces)
    fig.update_layout(
        title=f'Layer {layer_idx} Neuron Trajectories (w[{ax0}], w[{ax1}], bias)',
        scene=dict(
            xaxis_title=f'w[{ax0}]',
            yaxis_title=f'w[{ax1}]',
            zaxis_title='bias',
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.3))
        ),
        width=1000,
        height=800
    )
    fig.show()

In [None]:
run_id1 = "run_SGD_lr0.05_h100_L3_tanh_tasks100_steps500_20251204_225131_297694"
run_id2 = "run_Muon_lr0.05_h100_L3_tanh_tasks100_steps500_20251204_225227_96b448"

In [3]:
plot_neuron_trajectories_3d(
    run_id=run_id1,
    layer_idx=0,
    axes=(0, 1),
    neuron_indices='auto',
    num_neurons=100,
    task_range=(0, 100),
    subsample_freq=1
)

In [5]:
plot_neuron_trajectories_3d(
    run_id=run_id2,
    layer_idx=0,
    axes=(0, 1),
    neuron_indices='auto',
    num_neurons=100,
    task_range=(0, 100),
    subsample_freq=1
)

In [6]:
def plot_neuron_scatter_grid(
    run_id: str,
    task_range: tuple = (0, 9),
    grid_rows: int = 2,
    grid_cols: int = 5,
    checkpoint_base_dir: str = "model_checkpoints"
):
    from plotly.subplots import make_subplots
    
    checkpoint_dir = Path(checkpoint_base_dir) / run_id
    model_files = sorted(checkpoint_dir.glob("task*_step*_model.pt"))
    
    pattern = re.compile(r"task(\d+)_step(\d+)_model\.pt")
    parsed = []
    for f in model_files:
        m = pattern.match(f.name)
        if m:
            parsed.append((int(m.group(1)), int(m.group(2)), f))
    parsed.sort(key=lambda x: (x[0], x[1]))
    
    t_start, t_end = task_range
    tasks_to_plot = list(range(t_start, t_end + 1))
    
    last_checkpoint_per_task = {}
    for task, step, fpath in parsed:
        if task in tasks_to_plot:
            last_checkpoint_per_task[task] = fpath
    
    fig = make_subplots(
        rows=grid_rows, cols=grid_cols,
        subplot_titles=[f"Task {t}" for t in tasks_to_plot[:grid_rows * grid_cols]]
    )
    
    all_biases = []
    for task in tasks_to_plot[:grid_rows * grid_cols]:
        if task in last_checkpoint_per_task:
            state = torch.load(last_checkpoint_per_task[task], map_location='cpu')
            all_biases.append(state['layers.0.bias'].numpy())
    if all_biases:
        all_biases = np.concatenate(all_biases)
        bias_min, bias_max = all_biases.min(), all_biases.max()
    else:
        bias_min, bias_max = 0, 1
    
    for i, task in enumerate(tasks_to_plot[:grid_rows * grid_cols]):
        row = i // grid_cols + 1
        col = i % grid_cols + 1
        
        if task not in last_checkpoint_per_task:
            continue
        
        state = torch.load(last_checkpoint_per_task[task], map_location='cpu')
        weights = state['layers.0.weight'].numpy()
        biases = state['layers.0.bias'].numpy()
        
        w0 = weights[:, 0]
        w1 = weights[:, 1]
        
        trace = go.Scatter(
            x=w0, y=w1,
            mode='markers',
            marker=dict(
                color=biases,
                colorscale='Viridis',
                cmin=bias_min, cmax=bias_max,
                showscale=(i == 0),
                colorbar=dict(title="bias", thickness=10, len=0.3, x=1.02, y=0.5) if i == 0 else None
            ),
            hovertemplate='w[0]: %{x:.3f}<br>w[1]: %{y:.3f}<br>bias: %{marker.color:.3f}<extra></extra>',
            showlegend=False
        )
        fig.add_trace(trace, row=row, col=col)
    
    fig.update_layout(
        title=f'First Layer Neurons per Task',
        width=250 * grid_cols,
        height=250 * grid_rows
    )
    fig.show()

In [8]:
plot_neuron_scatter_grid(
    run_id=run_id2,
    task_range=(0, 100),
    grid_rows=20,
    grid_cols=5
)