In [1]:
import wandb
import equinox as eqx
import os 

# Foundational SSM imports
from omegaconf import OmegaConf
import tempfile 
from foundational_ssm.models import SSMDownstreamDecoder, SSMFoundationalDecoder
from foundational_ssm.utils import h5_to_dict
from foundational_ssm.transform import smooth_spikes
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any, BinaryIO


%load_ext autoreload
%autoreload 2

def default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any:
    """Default filter specification for deserialising saved data.

    **Arguments**

    -   `f`: file-like object
    -   `x`: The leaf for which the data needs to be loaded.

    **Returns**

    The new value for datatype `x`.

    !!! info

        This function can be extended to customise the deserialisation behaviour for
        leaves.

    !!! example

        Skipping loading of jax.Array.

        ```python
        import jax.numpy as jnp
        import equinox as eqx

        tree = (jnp.array([4,5,6]), [1,2,3])
        new_filter_spec = lambda f,x: (
            x if isinstance(x, jax.Array) else eqx.default_deserialise_filter_spec(f, x)
        )
        new_tree = eqx.tree_deserialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)
        ```
    """  # noqa: E501
    try:
        if isinstance(x, (jax.Array, jax.ShapeDtypeStruct)):
            return jnp.load(f)
        elif isinstance(x, np.ndarray):
            # Important to use `np` here to avoid promoting NumPy arrays to JAX.
            return np.load(f)
        elif eqx.is_array_like(x):
            # np.generic gets deserialised directly as an array, so convert back to a scalar
            # type here.
            # See also https://github.com/google/jax/issues/17858
            out = np.load(f)
            if isinstance(x, jax.dtypes.bfloat16):
                out = out.view(jax.dtypes.bfloat16)
            if np.size(out) == 1:
                return type(x)(out.item())
        else:
            return x
    except:
        print("Failed to load data for leaf with shape/ value:", x.shape if hasattr(x, 'shape') else x)
        return x 

def load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=SSMDownstreamDecoder, model_cfg=None):
    """Load model, optimizer state, epoch, and step from a checkpoint file."""
    api = wandb.Api()
    try:
        artifact = api.artifact(artifact_full_name, type="checkpoint")
    except Exception as e:
        raise FileNotFoundError(f"Could not find checkpoint artifact: {artifact_full_name}")
    
    if model_cfg is None:
        run = artifact.logged_by()
        run_cfg = OmegaConf.create(run.config)
        print(run_cfg)
        model_cfg = OmegaConf.create(run_cfg.model)
    
    model_template, state_template = eqx.nn.make_with_state(model_cls)(
        **model_cfg
    )
    
    with tempfile.TemporaryDirectory() as temp_dir:
        artifact.download(temp_dir)
        model = eqx.tree_deserialise_leaves(os.path.join(temp_dir, "model.ckpt"), model_template, default_deserialise_filter_spec)
        state = eqx.tree_deserialise_leaves(os.path.join(temp_dir, "state.ckpt"), state_template, default_deserialise_filter_spec)

    meta = artifact.metadata
    return model, state, meta

# Downstream Model

In [2]:
layer = "2"
pretrain_mode = "scratch"
train_mode = "all"
alias = "best" # can be latest/best/ epoch_{any value in range(0,1000,100)}
# epoch 0 now stores a fresh model.
artifact_full_name = f"melinajingting-ucl/foundational_ssm_rtt/l{layer}_{pretrain_mode}_{train_mode}_checkpoint:{alias}"
model, state, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name)

{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 32, 'rng_seed': 42, 'dropout_p': 0.03, 'input_dim': 130, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.03, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'downstream', 'decoding', 'rtt'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_rtt'}, 'device': 'cuda', 'dataset': {'test': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_val.h5', 'phase': 'test', 'train': '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized_train.h5', 'batch_size': 64, 'skip_timesteps': 56}, 'rng_seed': 42, 'training': {'epochs': 1500, 'log_val_every': 100, 'checkpoint_every': 100, 'save_activations': False, 'save_checkpoints': True, 'log_pred_and_activations_every': 999}, 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}}


[34m[1mwandb[0m:   3 of 3 files downloaded.  


## Calling with activations (Downstream)

In [3]:
# Download mc_rtt_trialized from https://huggingface.co/datasets/MelinaLaimon/nlb_processed/tree/main
# Edit dataset_dir to your directory
dataset_dir = "../../data/foundational_ssm/processed/nlb" 
dataset_path = os.path.join(dataset_dir, "mc_rtt_trialized.h5")
data = h5_to_dict(dataset_path)
data["neural_input"] = smooth_spikes(data["neural_input"], kern_sd_ms=20, bin_size_ms=5, time_axis=1)
input = data["neural_input"]
target_vel = data["behavior_input"]

# Specify the layers you want to generate the activations of. 
# ["post_encoder", "ssm_pre_activation", "ssm_post_activation"]
layer_keys = ["ssm_pre_activation"] 
inf_model = eqx.nn.inference_mode(model) # Switches off dropout
pred_vel, _, activations = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(input, state, layer_keys)

## Example: Plotting Output

In [None]:
import pandas as pd 
from foundational_ssm.plotting import aggregate_bin_label_results, plot_pred_vs_targets_by_angle_bin


# Download mc_rtt_trialized from https://huggingface.co/datasets/MelinaLaimon/nlb_processed/tree/main
# Edit dataset_dir to your directory
dataset_dir = "../../data/foundational_ssm/processed/nlb" 
trial_info = pd.read_csv(os.path.join(dataset_dir, "mc_rtt_trialized.csv"))
dataset_path = os.path.join(dataset_dir, "mc_rtt_trialized.h5")
data = h5_to_dict(dataset_path)
data["neural_input"] = smooth_spikes(data["neural_input"], kern_sd_ms=20, bin_size_ms=5, time_axis=1)
input = data["neural_input"]
target_vel = data["behavior_input"]

# Specify the layers you want to generate the activations of. 
# ["post_encoder", "ssm_pre_activation", "ssm_post_activation"]
layer_keys = ["ssm_pre_activation"] 
inf_model = eqx.nn.inference_mode(model) # Switches off dropout
pred_vel, _, activations = jax.vmap(inf_model.call_with_activations, axis_name="batch", in_axes=(0, None, None))(input, state, layer_keys)

results_df = aggregate_bin_label_results(trial_info, target_vel, pred_vel)
fig = plot_pred_vs_targets_by_angle_bin(results_df)
fig.show()

# Foundational Model

In [2]:

model = "l2"
dataset = "reaching"
alias = "best"

artifact_full_name = f"melinajingting-ucl/foundational_ssm_pretrain/{model}_{dataset}_checkpoint:{alias}"
model, state, meta = load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=SSMFoundationalDecoder)

{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.03, 'output_dim': 2, 'ssm_io_dim': 256, 'ssm_dropout_p': 0.01, 'ssm_num_layers': 2, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l2.yaml', 'optimizer': {'lr': 0.002, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': True, 'window_length': 3.279}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'train

[34m[1mwandb[0m:   3 of 3 files downloaded.  


## Loading the dataset

In [None]:
import multiprocessing as mp

# Foundational SSM core imports
from foundational_ssm.loaders import get_brainset_data_loader
from foundational_ssm.constants import DATA_ROOT
from foundational_ssm.samplers import TrialSampler
import os 

mp.set_start_method("spawn", force=True) # otherwise causes deadlock on jax.

data_root = '../' + DATA_ROOT # change to the folder holding the brainsets
config_dir = '../configs/dataset' # change
dataset_args = {
    'keep_files_open': False,
    'lazy': True,
    'split': 'val' # or 'train' 
}
dataloader_args = {
    'batch_size': 128, # Adjust per your system capacity
    'num_workers': 4,
    'persistent_workers': False
}
sampler = 'SequentialFixedWindowSampler'
sampler_args = { 
                'window_length': 3.279,
                'drop_short': True 
                }

dataset, data_loader = get_brainset_data_loader(
    dataset_args=dataset_args,
    sampler = sampler,
    sampler_args = sampler_args,
    dataloader_args = dataloader_args,
    sampling_rate = 200,
    dataset_cfg = os.path.join(config_dir, 'reaching_analysis.yaml'),
    data_root = data_root
)

sessions = dataset.get_session_ids() # list of sessions in your dataset
sampling_intervals = dataset.get_sampling_intervals() # list of sampling intervals for each session

## Validation

In [None]:
from foundational_ssm.utils.pretrain_utils import validate_one_epoch 

metrics = validate_one_epoch(
    data_loader, model, state, skip_timesteps=56 # only when computing R2, we would keep this for analysis
)
metrics



{'val/r2_pm_c_co': 0.9161773920059204,
 'val/r2_pm_c_rt': 0.8242394924163818,
 'val/r2_pm_m_rt': 0.8003662824630737,
 'val/r2_pm_m_co': 0.8416915535926819,
 'val/r2_os_i_rt': 0.7526258230209351,
 'val/r2_os_l_rt': 0.5253342390060425,
 'val/r2_cs_j_co': 0.90641850233078,
 'val/r2_cs_n_co': 0.9437413215637207,
 'val/r2_avg': 0.813824325799942,
 'val/r2_all': 0.9150385856628418,
 'val/time': 26.957586765289307}