In [77]:
import multiprocessing as mp 

from omegaconf import OmegaConf
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from jax.tree_util import tree_flatten_with_path, GetAttrKey, DictKey, SequenceKey

import matplotlib.pyplot as plt
import pandas as pd
from foundational_ssm.models import SSMFoundationalDecoder, SSMDownstreamDecoder, discretize_zoh
from foundational_ssm.utils import load_model_and_state_from_checkpoint_wandb

from foundational_ssm.samplers import RandomVariableWindowSampler 
from foundational_ssm.constants import DATA_ROOT 
from foundational_ssm.loaders import get_brainset_data_loader
from foundational_ssm.utils.training_utils import create_optimizer_and_state

mp.set_start_method("spawn", force=True) # otherwise causes deadlock on jax.

%load_ext autoreload
%autoreload 2


def mse_single_sample_foundation(model, state, input, target, mask, dataset_group_idxs, key, dataset_group_weights=None, skip_timesteps=0):
    """MSE loss for foundational model (takes dataset_group_idx and mask)"""
    pred, state = model(input, state, dataset_group_idxs, key)

    # Only evaluate loss on timesteps > skip_timesteps
    pred = pred[skip_timesteps:, :]  # Shape: (batch, seq_len - skip_timesteps, output_dim)
    target = target[skip_timesteps:, :]  # Shape: (batch, seq_len - skip_timesteps, output_dim)
    mask = mask[skip_timesteps:]  # Shape: (batch, seq_len - skip_timesteps)

    # Only compute loss on unmasked elements
    squared_error = (pred - target) ** 2
    mask = mask[..., None]
    masked_squared_error = jnp.where(mask, squared_error, 0.0)
    
    # dataset_group_weights = dataset_group_weights[..., None, None]  # shape (batch, 1, 1) to broadcast
    weighted_squared_error = squared_error #* dataset_group_weights
        
    masked_squared_error = jnp.where(mask, weighted_squared_error, 0.0)
    mse = masked_squared_error.sum() / mask.sum()
    return mse


def _path_to_str(path):
    parts = []
    for entry in path:
        if isinstance(entry, GetAttrKey):
            parts.append(str(entry.name))
        elif isinstance(entry, DictKey):
            parts.append(str(entry.key))
        elif isinstance(entry, SequenceKey):
            parts.append(f"[{entry.idx}]")
        else:
            parts.append(str(entry))
    return ".".join(parts)

def _stack_weighted_grads_to_matrix(grads_batch_tree, lr_leaves):
    """Return a dict of per-leaf weighted, flattened gradients keyed by leaf path.

    grads_batch_tree: PyTree of per-sample gradients; each array leaf has shape (N, ...).
    lr_leaves: list of scalars aligned with the flattened leaf order.
    Returns: dict[str, jnp.ndarray] mapping leaf path -> (N, leaf_size) matrix, weighted by sqrt(lr).
    """
    # Filter non-array leaves, keeping structure for alignment
    filtered = eqx.filter(grads_batch_tree, eqx.is_array)
    path_leaves, _ = tree_flatten_with_path(filtered)
    out = {}
    i = 0
    for path, leaf in path_leaves:
        if leaf is None:
            continue
        g2d = jnp.reshape(leaf, (leaf.shape[0], -1))
        lr_scalar = lr_leaves[i]
        out[_path_to_str(path)] = g2d * jnp.sqrt(lr_scalar)
        i += 1
    return out

def _dicts_to_matrix(dict_list):
    """Concatenate per-leaf dicts into a single matrix.

    - Uses a stable sorted key order from the first dict.
    - Horizontally concatenates leaves per sample, then vertically stacks across dicts.
    """
    if not dict_list:
        return jnp.zeros((0, 0))
    keys = sorted(dict_list[0].keys())
    row_blocks = []
    for d in dict_list:
        if not keys:
            return jnp.zeros((d[next(iter(d))].shape[0], 0)) if d else jnp.zeros((0, 0))
        row_blocks.append(jnp.concatenate([d[k] for k in keys], axis=1))
    return jnp.concatenate(row_blocks, axis=0)

def stack_dicts(dict_list):
    """Stack a list of dicts (per-batch) into a single dict of stacked arrays keyed by leaf path.

    Args:
        dict_list (list): list of dicts each containing array leaves with leading batch dim.

    Returns:
        dict mapping canonical leaf path (str) -> array of shape (sum_batches, leaf_size...)
    """
    if not dict_list:
        return {}
    keys = sorted(dict_list[0].keys())
    stacked = {k: jnp.concatenate([d[k] for d in dict_list], axis=0) for k in keys}
    return stacked

def stack_trees_to_dict(tree_list, flatten=True):
    """Stack a list of PyTrees (per-batch) into a dict of stacked arrays keyed by leaf path.

    Args:
        tree_list (list): list of PyTrees each containing array leaves with leading batch dim.
        flatten (bool): whether to reshape leaves to (batch, -1) before stacking.

    Returns:
        dict mapping canonical leaf path (str) -> array of shape (sum_batches, leaf_size...)

    Notes: Requires that all trees have the same leaf paths; raises ValueError otherwise.
    """
    if not tree_list:
        return {}
    dicts = []
    for tree in tree_list:
        filtered = eqx.filter(tree, eqx.is_array)
        path_leaves, _ = tree_flatten_with_path(filtered)
        d = {}
        for path, leaf in path_leaves:
            if leaf is None:
                continue
            if flatten:
                d[_path_to_str(path)] = jnp.reshape(leaf, (leaf.shape[0], -1))
            else:
                d[_path_to_str(path)] = leaf
        dicts.append(d)
    return stack_dicts(dicts)


def stack_trees_to_matrix(tree_list, flatten=True):
    """Convert a list of PyTrees into a single matrix by stacking leaves horizontally then samples vertically.

    This is useful for building G matrices for NTK or activation similarity computations.
    """
    stacked = stack_trees_to_dict(tree_list, flatten=flatten)
    if not stacked:
        return jnp.zeros((0, 0))
    keys = sorted(stacked.keys())
    return jnp.concatenate([stacked[k] for k in keys], axis=1)

def compute_kernel_similarity(K_1, K_2):
    """Compute similarity between two kernel matrices using Frobenius inner product."""
    frob_inner_product = np.sum(K_1 * K_2)
    frob_norm_1 = np.sqrt(np.sum(K_1 ** 2))
    frob_norm_2 = np.sqrt(np.sum(K_2 ** 2))
    similarity = 1 - frob_inner_product / (frob_norm_1 * frob_norm_2)
    return similarity

def compute_per_leaf_kernel_similarity(dict_a, dict_b):
    """Given two stacked dicts (key -> (N, D) array), compute per-leaf kernel similarity.

    Returns a dict key->scalar where scalar is the same metric used by compute_kernel_similarity
    applied to K_a = A @ A.T and K_b = B @ B.T.
    """
    keys = sorted(set(dict_a.keys()) & set(dict_b.keys()))
    out = {}
    for k in keys:
        A = dict_a[k]
        B = dict_b[k]
        # Ensure same shapes
        if A.shape[0] != B.shape[0]:
            raise ValueError(f"Batch sizes differ for leaf {k}: {A.shape[0]} vs {B.shape[0]}")
        K_a = A @ A.T
        K_b = B @ B.T
        out[k] = compute_kernel_similarity(K_a, K_b)
    return out

def compute_per_leaf_cosine_similarity(dict_a, dict_b):
    """Compute per-leaf cosine similarity by vectorizing each stacked array and computing cosine(v_a, v_b).

    Returns dict key -> cosine scalar in [-1,1]. If vectors are zero, returns 0.
    """
    keys = sorted(set(dict_a.keys()) & set(dict_b.keys()))
    out = {}
    for k in keys:
        A = dict_a[k]
        B = dict_b[k]
        if A.shape != B.shape:
            raise ValueError(f"Shapes differ for leaf {k}: {A.shape} vs {B.shape}")
        va = A.ravel()
        vb = B.ravel()
        na = np.linalg.norm(va)
        nb = np.linalg.norm(vb)
        if na == 0 or nb == 0:
            out[k] = 0.0
        else:
            inner = np.vdot(va, vb)
            cos_val = np.real(inner / (na * nb))  # use real part; or use np.abs(...) for magnitude
            out[k] = 1 - cos_val
    return out

def get_per_leaf_metadata_from_trees(tree_list):
    """Return dict key -> { 'shape': leaf_shape_without_batch, 'dtype': dtype } from the first tree in the list."""
    if not tree_list:
        return {}
    filtered = eqx.filter(tree_list[0], eqx.is_array)
    path_leaves, _ = tree_flatten_with_path(filtered)
    out = {}
    for path, leaf in path_leaves:
        if leaf is None:
            continue
        out[_path_to_str(path)] = {
            'shape': tuple(leaf.shape[1:]),
            'dtype': str(leaf.dtype),
        }
    return out

def mse_single_sample_downstream(model, state, input, target, mask, key, dataset_group_weights=None, skip_timesteps=0):
    """MSE loss for foundational model (takes dataset_group_idx and mask)"""
    pred, state = model(input, state, key)

    # Only evaluate loss on timesteps > skip_timesteps
    pred = pred[skip_timesteps:, :]  # Shape: (batch, seq_len - skip_timesteps, output_dim)
    target = target[skip_timesteps:, :]  # Shape: (batch, seq_len - skip_timesteps, output_dim)
    mask = mask[skip_timesteps:]  # Shape: (batch, seq_len - skip_timesteps)

    # Only compute loss on unmasked elements
    squared_error = (pred - target) ** 2
    mask = mask[..., None]
    masked_squared_error = jnp.where(mask, squared_error, 0.0)
    
    # dataset_group_weights = dataset_group_weights[..., None, None]  # shape (batch, 1, 1) to broadcast
    weighted_squared_error = squared_error #* dataset_group_weights
        
    masked_squared_error = jnp.where(mask, weighted_squared_error, 0.0)
    mse = masked_squared_error.sum() / mask.sum()
    print("Downstream MSE:", mse)
    return mse

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
dataset, loader, max_neural_input = get_brainset_data_loader(    
    dataset_args = {
        'keep_files_open': False,
        'lazy': True,
        'config': '/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/reaching_analysis.yaml'
    },
    dataloader_args={
        'batch_size': 16,
        'num_workers': 2,
        'persistent_workers': True,
    },
    sampler = 'TrialSampler',
    sampler_args = {
        'max_window_length': 5.0
    },
    data_root = '../' + DATA_ROOT,
    prepend_history = 0.3,
    sampling_rate = 200,
    split = 'train_trial_subsample'
 )


In [38]:
# Collect gradients and activations eqx.nn.inference_mode 
from tqdm import tqdm

layers = [2]
skip_timesteps = 0.3 * 200
model_cls = SSMFoundationalDecoder
models = {}
results = []
G_list = [] 
A_list = []
loss_fn = mse_single_sample_foundation 
epochs = list(range(0,501,100))
layer_keys = ["post_encoder","ssm_x", "ssm_y", "ssm_post_gelu", "ssm_post_glu", "pre_decoder"] 

for l in layers:
    checkpoint_name = f'melinajingting-ucl/foundational_ssm_pretrain/l{l}_reaching_normalized_checkpoint'
    model_cfg = OmegaConf.load(f'/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/model/l{l}.yaml')
    for epoch in epochs:
        artifact_full_name = f'{checkpoint_name}:epoch_{epoch}'
        model, state, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=model_cls)
        models.update({
            f'epoch {epoch}':model
        })  
        
        inf_model = eqx.nn.inference_mode(model)
        grads_list = []
        activations_list = []
        for batch in loader:
            batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
            inputs = batch["neural_input"]
            targets = batch["behavior_input"]
            mask = batch["mask"]
            dataset_group_idxs = batch["dataset_group_idx"]

            grads = jax.vmap(eqx.filter_grad(loss_fn), in_axes=(None, None, 0, 0, 0, 0, None, None), axis_name='batch')(model, state, inputs, targets, mask, dataset_group_idxs, jr.PRNGKey(0), skip_timesteps)
            grads_list.append(grads)

            _, _, activations = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, 0, None))(inputs, state, dataset_group_idxs, layer_keys)
            activations_list.append(activations)

        opt, opt_state, lr_scheduler, lr_tree = create_optimizer_and_state(model, optimizer_cfg=model_cfg.optimizer, model_cfg=model_cfg.model, return_lr_tree=True)
        lr_leaves = [leaf for leaf in jax.tree_util.tree_leaves(eqx.filter(lr_tree, eqx.is_array)) if leaf is not None]
        
        G_dict = stack_trees_to_dict(grads_list, flatten=True)
        G_dict = jax.device_get(G_dict)           # moves whole PyTree to host
        G_dict = {k: np.asarray(v) for k, v in G_dict.items()}

        A_dict = stack_dicts(activations_list)
        A_dict = jax.device_get(A_dict)
        A_dict = {k: np.asarray(v) for k, v in A_dict.items()}

        if epoch == 0:
            G_dict_0 = G_dict
            A_dict_0 = A_dict
        G_list.append(G_dict)
        A_list.append(A_dict)
        
        kernel_similarity = compute_per_leaf_kernel_similarity(G_dict, G_dict_0)
        activation_similarity = compute_per_leaf_cosine_similarity(A_dict, A_dict_0)
        _results = {
            'l': l,
            'epoch': epoch,
            'checkpoint_name': checkpoint_name.split('/')[-1]
        }
        _results.update(kernel_similarity)
        _results.update(activation_similarity)
        results.append(_results)

{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixed

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  
  batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
  batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  similarity = 1 - frob_inner_product / (frob_norm_1 * frob_norm_2)
  similarity = 1 - frob_inner_product / (frob_norm_1 * frob_norm_2)


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixed

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixed

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixed

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixed

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixed

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


# Downstream task, from scratch

In [92]:
# Collect gradients and activations eqx.nn.inference_mode 
from tqdm import tqdm
from foundational_ssm.utils.downstream_utils import get_rtt_datasets
cfg = OmegaConf.create({
                        'model':{
                            'input_dim': 72,
                            'context_dim': 4,
                            'ssm_io_dim': 256,
                            'ssm_dim': 128,
                            'ssm_init_diag_blocks': 4,
                            'ssm_num_layers': 2,
                            'output_dim': 2,
                            'rng_seed': 42,
                            'dt_min': 0.001,
                            'dt_max': 0.01,
                            'dropout_p': 0.03,
                            'ssm_dropout_p': 0.01
                        },
                        'optimizer':{
                            'lr': 0.002,
                            'weight_decay': 0.01,
                            'mode': 'all'
                        }
                    })

dataset_cfg = OmegaConf.create({
    'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5',
    'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5',
    'batch_size': 64,
    'phase': 'test',
    'skip_timesteps': 60
})
train_data, val_data, data = get_rtt_datasets(dataset_cfg, jr.PRNGKey(0))
loss_fn = mse_single_sample_downstream 

model_cls = SSMDownstreamDecoder
checkpoint_name = 'melinajingting-ucl/foundational_ssm_rtt/l2_scratch_all_checkpoint'
models = {}
epochs = list(range(0,501,100))
train_trial_subset = {k: v[1:65] for k, v in train_data.items()}

for epoch in epochs:
    artifact_full_name = f'{checkpoint_name}:epoch_{epoch}'
    model, state, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=model_cls)
    models.update({
        f'epoch {epoch}':model
    })
    
    inf_model = eqx.nn.inference_mode(model)
    batch = train_trial_subset
    inputs = batch["neural_input"]
    targets = batch["behavior_input"]
    mask = batch["mask"]
    dataset_group_idxs = batch["dataset_group_idx"]
    grads = jax.vmap(eqx.filter_grad(loss_fn), in_axes=(None, None, 0, 0, 0, None, None), axis_name='batch')(model, state, inputs, targets, mask, jr.PRNGKey(0), skip_timesteps)
    opt, opt_state, lr_scheduler, lr_tree = create_optimizer_and_state(model, optimizer_cfg=cfg.optimizer, model_cfg=cfg.model, return_lr_tree=True)
    init_out, _, activations = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(inputs, state, layer_keys)
    
    G_dict = stack_trees_to_dict([grads], flatten=True)
    G_dict = jax.device_get(G_dict)           # moves whole PyTree to host
    G_dict = {k: np.asarray(v) for k, v in G_dict.items()}

    A_dict = jax.device_get(activations)
    A_dict = {k: np.asarray(v) for k, v in A_dict.items()}
    
    if epoch == 0:
        G_dict_0 = G_dict
        A_dict_0 = A_dict
    kernel_similarity = compute_per_leaf_kernel_similarity(G_dict, G_dict_0)
    activation_similarity = compute_per_leaf_cosine_similarity(A_dict, A_dict_0)
    _results = {
        'l': l,
        'epoch': epoch,
        'checkpoint_name': checkpoint_name.split('/')[-1]
    }
    _results.update(kernel_similarity)
    _results.update(activation_similarity)
    results.append(_results)


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': True, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.mode': 'all', 'training.from_scratch': True}

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': True, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.mode': 'all', 'model.checkpoint': 'melinajin

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': True, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.m

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': True, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.m

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': True, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.m

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>


  frob_norm_1 = np.sqrt(np.sum(K_1 ** 2))


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': True, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.mode': 'all', 'model.checkpoint': 'melinajin

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>


# Downstream Task, From Finetuning

In [93]:
checkpoint_name = 'melinajingting-ucl/foundational_ssm_rtt/l2_reaching_normalized_all_checkpoint'
train_trial_subset = {k: v[1:65] for k, v in train_data.items()}
for epoch in epochs:
    artifact_full_name = f'{checkpoint_name}:epoch_{epoch}'
    model, state, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=model_cls)
    batch = train_trial_subset
    inputs = batch["neural_input"]
    targets = batch["behavior_input"]
    mask = batch["mask"]
    dataset_group_idxs = batch["dataset_group_idx"]
    grads = jax.vmap(eqx.filter_grad(loss_fn), in_axes=(None, None, 0, 0, 0, None, None), axis_name='batch')(model, state, inputs, targets, mask, jr.PRNGKey(0), skip_timesteps)
    opt, opt_state, lr_scheduler, lr_tree = create_optimizer_and_state(model, optimizer_cfg=cfg.optimizer, model_cfg=cfg.model, return_lr_tree=True)
    inf_model = eqx.nn.inference_mode(model)
    init_out, _, activations = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(inputs, state, layer_keys)
    
    G_dict = stack_trees_to_dict([grads], flatten=True)
    G_dict = jax.device_get(G_dict)           # moves whole PyTree to host
    G_dict = {k: np.asarray(v) for k, v in G_dict.items()}

    A_dict = jax.device_get(activations)
    A_dict = {k: np.asarray(v) for k, v in A_dict.items()}
    
    if epoch == 0:
        G_dict_0 = G_dict
        A_dict_0 = A_dict
    G_list.append(G_dict)
    A_list.append(A_dict)
    kernel_similarity = compute_per_leaf_kernel_similarity(G_dict, G_dict_0)
    activation_similarity = compute_per_leaf_cosine_similarity(A_dict, A_dict_0)
    _results = {
        'l': l,
        'epoch': epoch,
        'checkpoint_name': checkpoint_name.split('/')[-1]
    }
    _results.update(kernel_similarity)
    _results.update(activation_similarity)
    results.append(_results)

{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': False, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}}


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': False, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>


  frob_norm_1 = np.sqrt(np.sum(K_1 ** 2))


{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': False, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.mode': 'all', 'model.checkpoint': 'melinaji

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': False, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': False, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>
{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1001, 'from_scratch': False, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'optimizer.

[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  


Downstream MSE: LinearizeTracer<float32[]>


In [None]:
results_df = pd.DataFrame(results) 
# results_df = results_df[results_df['checkpoint_name'] == 'l2_reaching_normalized_checkpoint']

In [None]:
# Heatmaps: NTK (kernel) and Activations (cosine) without subplots, with publication-friendly labels and encoder index averaging
import re
import numpy as np
import matplotlib.pyplot as plt

df = results_df.copy()

# Identify columns
meta_cols = {'l','epoch'}
all_cols = [c for c in df.columns if c not in meta_cols]
kernel_cols = [c for c in all_cols if ('.' in c) and ('bias' not in c) and ('encoders.[9]' not in c)]
activation_cols = [c for c in all_cols if '.' not in c]

epochs_sorted = sorted(df['epoch'].unique())
layers_sorted = sorted(df['l'].unique())

def _split_path_segments(col: str):
    # Split on '.' but keep bracket indices as separate tokens
    # e.g., 'encoders.[3].weight' -> ['encoders', '[3]', 'weight']
    parts = []
    for token in re.split(r'(\[\d+\])|\.', col):
        if token and token != '.':
            parts.append(token)
    return parts

def _collapse_indices(parts):
    # Collapse bracket tokens into previous token: ['encoders', '[3]', 'weight'] -> ['encoders[3]', 'weight']
    out = []
    for p in parts:
        if re.fullmatch(r'\[\d+\]', p):
            if out:
                out[-1] = f"{out[-1]}{p}"
            else:
                out.append(p)
        else:
            out.append(p)
    return out

def pretty_kernel_label(col: str, max_depth: int = 3):
    parts = _collapse_indices(_split_path_segments(col))
    # Keep only the last `max_depth` segments for brevity
    if 'glu' in parts and len(parts) > max_depth:
        parts = parts[-max_depth+1:]
    elif len(parts) > max_depth:
        parts = parts[-max_depth:]
    # Optional light abbreviations
    parts = [p.replace('weight', 'W').replace('kernel', 'K') for p in parts]
    return ' · '.join(parts)

def _title_case_preserve_acronyms(s: str):
    # Title-case tokens but keep common acronyms uppercased
    acronyms = {'ssm','io','gelu','glu','rtt','nlb'}
    toks = re.split(r'[_\s]+', s)
    out = []
    for t in toks:
        if t.lower() in acronyms:
            out.append(t.upper())
        else:
            # keep single-letter variables (x,y) lower
            out.append(t if len(t) == 1 else t.capitalize())
    return ' '.join(out)

def pretty_activation_label(col: str):
    m = re.match(r'^(.*)_(\d+)$', col)
    if m:
        base, idx = m.group(1), m.group(2)
        base_pretty = _title_case_preserve_acronyms(base)
        return f"{base_pretty} [idx={idx}]"
    # Otherwise just prettify underscores/case
    return _title_case_preserve_acronyms(col)

def format_row_label(col: str, lval: int, kind: str):
    if kind == 'kernel':
        comp = pretty_kernel_label(col)
    else:
        comp = pretty_activation_label(col)
    return f"{comp}"

def _encoder_base_key(col: str) -> str:
    # If this is an encoder param, strip numeric indices: encoders.[3].weight -> encoders.[].weight
    if re.match(r'^encoders\.', col):
        return re.sub(r'\[\d+\]', '[]', col)
    return col

def build_heatmap_data(df, columns, kind='kernel'):
    rows = []  # list of (row_label, lval, group_cols)
    cols_set = set(df.columns)
    if kind == 'kernel':
        # Group encoder columns by base (index removed) and keep others 1:1
        groups = {}  # base_key -> list of original columns
        for c in columns:
            base = _encoder_base_key(c)
            groups.setdefault(base, []).append(c)
        for base in sorted(groups.keys()):
            present_cols = [c for c in groups[base] if c in cols_set]
            if not present_cols:
                continue
            for lval in layers_sorted:
                rows.append((format_row_label(base, lval, kind), lval, present_cols))
    else:  # activation
        for c in columns:
            if c not in cols_set:
                continue
            for lval in layers_sorted:
                rows.append((format_row_label(c, lval, kind), lval, [c]))
    # Build matrix
    row_labels = [r[0] for r in rows]
    M = np.full((len(rows), len(epochs_sorted)), np.nan, dtype=float)
    for i, (label, lval, group_cols) in enumerate(rows):
        # Subset for the layer once (safe: all group_cols are present)
        sub = df[df['l'] == lval][['epoch'] + group_cols]
        if sub.empty:
            continue
        for j, ep in enumerate(epochs_sorted):
            se = sub[sub['epoch'] == ep]
            if se.empty:
                continue
            vals = se[group_cols].to_numpy().astype(float).ravel()
            if np.isnan(vals).all():
                continue
            M[i, j] = np.nanmean(vals)
    return row_labels, M

def plot_heatmap(row_labels, M, title, cmap='viridis', vmin=None, vmax=None, save_path=None):
    # Auto-size figure based on number of rows and columns
    nrows = len(row_labels)
    ncols = len(epochs_sorted)
    height = max(4.0, min(1.0 + 0.35 * nrows, 24.0))
    width = max(6.0, min(1.0 + 0.6 * ncols, 24.0))
    fig, ax = plt.subplots(figsize=(width, height), dpi=200)
    im = ax.imshow(M, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_title(title)
    ax.set_xlabel('epoch')
    ax.set_ylabel('component')
    # Ticks
    ax.set_xticks(np.arange(len(epochs_sorted)))
    ax.set_xticklabels([str(e) for e in epochs_sorted], rotation=45, ha='right')
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_yticklabels(row_labels, fontsize=7)
    ax.grid(False)
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label('value')
    # Adjust left margin based on max label length for readability
    max_len = max((len(s) for s in row_labels), default=0)
    left_margin = min(0.55, 0.12 + 0.0065 * max_len)
    fig.subplots_adjust(left=left_margin)
    fig.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()


In [104]:
import os
df = results_df[results_df['checkpoint_name'] == 'l2_reaching_normalized_checkpoint'] \
    .drop(columns=['checkpoint_name']) \
    .dropna(axis=1, how='any')
figures_directory = '/cs/student/projects1/ml/2024/mlaimon/UCL-ML-Thesis/Writeup/figures'
# Build and plot NTK heatmap
kernel_row_labels, kernel_M = build_heatmap_data(df, kernel_cols, kind='kernel')
plot_heatmap(kernel_row_labels, kernel_M, title='Evolution of NTK Matrix similarity during pretraining', save_path=os.path.join(figures_directory, 'ntk_heatmap_pretrain.pdf'))

# Build and plot activation heatmap
act_row_labels, act_M = build_heatmap_data(df, activation_cols, kind='activation')
plot_heatmap(act_row_labels, act_M, title="Evolution of (1 - cosine similarity of activations) during pretraining", save_path=os.path.join(figures_directory, 'activation_heatmap_pretrain.pdf'))

KeyError: "None of [Index(['[0].context_embedding.weight'], dtype='object')] are in the [columns]"

In [102]:
df.keys()

Index(['l', 'epoch', 'context_embedding.weight', 'decoder.bias',
       'decoder.weight', 'encoders.[0].bias', 'encoders.[0].weight',
       'encoders.[1].bias', 'encoders.[1].weight', 'encoders.[2].bias',
       'encoders.[2].weight', 'encoders.[3].bias', 'encoders.[3].weight',
       'encoders.[4].bias', 'encoders.[4].weight', 'encoders.[5].bias',
       'encoders.[5].weight', 'encoders.[6].bias', 'encoders.[6].weight',
       'encoders.[7].bias', 'encoders.[7].weight', 'encoders.[8].bias',
       'encoders.[8].weight', 'ssm_blocks.[0].glu.w1.bias',
       'ssm_blocks.[0].glu.w1.weight', 'ssm_blocks.[0].glu.w2.bias',
       'ssm_blocks.[0].glu.w2.weight', 'ssm_blocks.[0].ssm.B',
       'ssm_blocks.[0].ssm.C', 'ssm_blocks.[0].ssm.D',
       'ssm_blocks.[0].ssm.Lambda_im', 'ssm_blocks.[0].ssm.Lambda_re',
       'ssm_blocks.[0].ssm.log_step', 'ssm_blocks.[1].glu.w1.bias',
       'ssm_blocks.[1].glu.w1.weight', 'ssm_blocks.[1].glu.w2.bias',
       'ssm_blocks.[1].glu.w2.weight', 'ssm_blo

In [None]:
df = results_df[results_df['checkpoint_name'] == 'l2_scra']
# Build and plot NTK heatmap
kernel_row_labels, kernel_M = build_heatmap_data(df, kernel_cols, kind='kernel')
plot_heatmap(kernel_row_labels, kernel_M, title='Evolution of NTK Matrix similarity during pretraining')

# Build and plot activation heatmap
act_row_labels, act_M = build_heatmap_data(df, activation_cols, kind='activation')
plot_heatmap(act_row_labels, act_M, title="Evolution of (1 - cosine similarity of activations) during pretraining")