Problem:

- We observe that the R^2 of datasets other than Churchland Shenoy (CS) are low in comparison to what's achievable in POSSM. 
- We observe that the overall R^2 is still high despite this.
- Thus leading to the hypothesis that the model's training is dominated by the CS dataset. We wish to investigate where the model is failing by conducting the following segmentation analysis:

Slice by categories:
1. Within trial vs outside of trial
2. Movement phases 
3. Task 

Metrics:
1. Total duration
2. Variance 
3. MSE / R^2 

What happens if we scale the loss by the variance? 
1. samples of lower variance will have higher penalties, vice versa. 

In [1]:
from omegaconf import OmegaConf
from foundational_ssm.loaders import get_brainset_train_val_loaders
from foundational_ssm.utils import load_model_and_state_wandb
from foundational_ssm.constants import DATA_ROOT, parse_session_id
from tqdm import tqdm
import pandas as pd
import multiprocessing as mp

%load_ext autoreload
%autoreload 2

In [2]:
config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/test_pretrain.yaml"
cfg = OmegaConf.load(config_path) 
train_dataset, _, val_dataset, _ = get_brainset_train_val_loaders(
    cfg.train_loader,
    cfg.val_loader,
    data_root='../'+DATA_ROOT
)
model_artifact_id = 'melinajingting-ucl/foundational_ssm_pretrain_decoding/possm_dataset_l1_d128_best_model:latest'
model, state = load_model_and_state_wandb(model_artifact_id)

train_sampling_intervals = train_dataset.get_sampling_intervals()
val_sampling_intervals = val_dataset.get_sampling_intervals()

mp.set_start_method("spawn", force=True)
window_length = 1
sampling_rate = 200

recording_metrics_list = []

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Problem:

- We observe that the R^2 of datasets other than Churchland Shenoy (CS) are low in comparison to what's achievable in POSSM. 
- We observe that the overall R^2 is still high despite this.
- Thus leading to the hypothesis that the model's training is dominated by the CS dataset. We wish to investigate where the model is failing by conducting the following segmentation analysis:

Slice by categories:
1. Within trial vs outside of trial
2. Movement phases 
3. Task 

Metrics:
1. Total duration
2. Variance 
3. MSE / R^2 

What happens if we scale the loss by the variance? 
1. samples of lower variance will have higher penalties, vice versa. 

In [3]:
import jax
from jax import random as jr
from jax import numpy as jnp
import numpy as np
from foundational_ssm.samplers import SequentialFixedWindowSampler
from foundational_ssm.collate import pad_collate
from functools import partial 
from torch.utils.data import DataLoader

def compute_variance_mse_r2(preds, targets):
    preds = preds.reshape(-1, preds.shape[-1])
    targets = targets.reshape(-1, targets.shape[-1]) 
    
    mse = jnp.mean((targets - preds) ** 2)         # Per-dimension MSE
    var = jnp.var(targets) 
    
    ss_res = jnp.sum((targets - preds) ** 2, axis=0) 
    ss_tot = jnp.sum((targets - jnp.mean(targets, axis=0)) ** 2, axis=0)
    r2_per_dim = 1 - ss_res / (ss_tot + 1e-8) # Add epsilon for stability
    return var, mse, jnp.mean(r2_per_dim)

def compute_interval_metrics(recording_data, sampling_intervals, model, state, window_length, sampling_rate, prefix='', rng_key = jr.PRNGKey(0)):
    #deterministic random key for the dropout. TODO: expose inference variable in S5Block all the way up to SSMFoundationalDecoder
    metrics = {}
    interval = list(sampling_intervals.values())[0]
    metrics[f'{prefix}total_duration'] = np.sum(interval.end - interval.start)
    metrics[f'{prefix}mean_duration']  = np.mean(interval.end - interval.start)
    try:
        sampler = SequentialFixedWindowSampler(sampling_intervals=sampling_intervals, window_length=window_length, drop_short=True)
        loader = DataLoader(
            dataset=recording_data,
            sampler=sampler,
            batch_size=256,
            collate_fn=partial(pad_collate, fixed_seq_len=int(window_length*sampling_rate)),
            num_workers=0,
            pin_memory=True,
            persistent_workers=False
        )
        all_preds = []
        all_targets = []
        for batch in loader:
            batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
            inputs, targets, mask, dataset_group_idxs = batch.values()
            mask = mask[..., None]
            preds, state = jax.vmap(model, axis_name="batch", in_axes=(0, None, None, 0), out_axes=(0, None))(inputs, state, rng_key, dataset_group_idxs)
            
            all_preds.append(jnp.where(mask, preds, 0))
            all_targets.append(jnp.where(mask, targets, 0))
            
        all_preds = jnp.concatenate(all_preds, axis=0)
        all_targets = jnp.concatenate(all_targets, axis=0)
        var, mse, r2 = compute_variance_mse_r2(preds, targets)
    except:
        var, mse, r2 = np.nan, np.nan, np.nan 
    metrics.update({
        f'{prefix}var': var,
        f'{prefix}mse': mse,
        f'{prefix}r2': r2,
    }) 
    return metrics
    

In [4]:
for i, (recording_id, train_interval) in tqdm(enumerate(train_sampling_intervals.items())):
    recording_data = train_dataset.get_recording_data(recording_id) 
    intervals_dict = {'overall': train_interval}
    if hasattr(recording_data, 'trials'):
        
        intervals_dict['in_trial'] = train_interval & recording_data.trials
        intervals_dict['ex_trial'] = train_interval.difference(intervals_dict['in_trial'])
        
        if hasattr(recording_data.trials, 'is_valid'):
            intervals_dict['valid_trial'] = train_interval & recording_data.trials.select_by_mask(recording_data.trials.is_valid)
            intervals_dict['ex_valid_trial'] = intervals_dict['in_trial'].difference(intervals_dict['valid_trial'])
    
    if hasattr(recording_data, 'movement_phases'):
        for phase in recording_data.movement_phases.keys():
            intervals_dict[phase] = train_interval & getattr(recording_data.movement_phases, phase)
    
    dataset, subject, task = parse_session_id(recording_id)
    recording_metrics_dict = {
        'dataset':dataset,
        'session_id': recording_id.split('/')[1],
        'subject':subject,
        'task':task
    } 
    
    for interval_name, interval in intervals_dict.items():
        interval_metrics = compute_interval_metrics(train_dataset, {recording_id: interval}, model, state, window_length, sampling_rate, prefix=interval_name+'_')
        recording_metrics_dict.update(interval_metrics)
    recording_metrics_list.append(recording_metrics_dict)
    
    recording_metrics_df = pd.DataFrame(recording_metrics_list)
    recording_metrics_df.to_csv('recording_metrics.csv')

0it [00:00, ?it/s]

  batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
2025-07-26 13:23:56.833119: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-07-26 13:31:35.306148: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-07-26 13:31:49.749123: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a go

In [18]:
recording_metrics_df = pd.DataFrame(recording_metrics_list[3:])

In [19]:
recording_metrics_df.to_csv('recording_metrics.csv')

In [None]:

from functools import partial


# Create a pandas dataframe of sessions holding the following columns:
# Info: 1. dataset 
# Stats: per ('within-trial', 'inter-trial'); movement_phases (hold_period', 'invalid', 'random_period', 'reach_period', 'return_period')
#   - duration,, 
#   - variance
#   - mse/r2. this can be done by getting predictions over the entire duration using sequentialfixedwindowsampler.
train_dataset.transform = transform_brainsets_regular_time_series_smoothed
train_window_length = 1 
val_window_length = 6
sampling_rate = 200



for recording_id, sampling_interval in train_sampling_intervals.items():
    recording_data = train_dataset.get_recording_data(recording_id) 
    variance, mse, r2 = compute_train_variance_mse_r2(train_dataset,
                                                    sampling_intervals = {recording_id: sampling_interval},
                                                    model=model,
                                                    state=state
                                                    )
    if hasattr(recording_data, "trials"):
        if hasattr(recording_data.trials, "is_valid"):
            valid_trial_intervals = recording_data.trials.select_by_mask(recording_data.trials.is_valid)
        else:
            valid_trial_intervals = recording_data.trials
        ex_valid_trial_intervals = sampling_interval.difference(valid_trial_intervals)
        
        trial_duration = (valid_trial_intervals.end - valid_trial_intervals.start).sum()
        ex_trial_duration = (ex_valid_trial_intervals.end - ex_valid_trial_intervals.start).sum()
        
        
        
    
    # split by within trial and outside of trial 
    # recording_data.select_by_interval ()
    # get total durations 
    # get variance
    # get prediction errors. 
    
    
    
    # split by movement phases 
    
    # get total durations for each, variances, and prediction errors. 
    
    break

  batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}


dict_keys(['neural_input', 'behavior_input', 'mask', 'dataset_group_idx'])
dict_keys(['neural_input', 'behavior_input', 'mask', 'dataset_group_idx'])


In [27]:
mse

Array([662.4436, 450.7957], dtype=float32)

In [None]:
print(content.decode())

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x93 in position 0: invalid start byte

In [None]:
for movement_phase in recording_data.movement_phases.keys(): 

['hold_period', 'invalid', 'random_period', 'reach_period', 'return_period']