In [1]:
import os

import h5py
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, default_collate
from jax.tree_util import tree_map


import wandb
from omegaconf import OmegaConf


from foundational_ssm.models import S4DNeuroModel, S5
from foundational_ssm.utils import h5_to_dict, generate_and_save_activations_wandb
from foundational_ssm.trainer import train_decoding
from foundational_ssm.data_preprocessing import smooth_spikes


# ========== Helper Functions, will be moved to utils.py ==========
def numpy_collate(batch):
  """
  Collate function specifies how to combine a list of data samples into a batch.
  default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.
  """
  return tree_map(np.asarray, default_collate(batch))


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processed_data_folder = '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb' 
dataset_name = "mc_maze"
processed_data_path = os.path.join(processed_data_folder,dataset_name + ".h5")
trial_info_path = os.path.join(processed_data_folder,dataset_name + ".csv")

conf = {
    'task':'decoding',
    'dataset': {
        'dataset': dataset_name
    },
    'model': {
        'input_dim': 182,
        'output_dim': 2,
        'd_state': 64,
        'num_layers': 2,
        'hidden_dim': 64,
        'dropout': 0.1,
        'ssm_core':'s5'
    },
    'optimizer': {
        'lr': 0.0005,
        'weight_decay': 0.01  # Added common parameter
    },
    'training': {
        'batch_size': 64,
        'epochs': 2000
    },
    'device': 'cuda',
    'framework': 'jax'
}

args = OmegaConf.create(conf)

with h5py.File(processed_data_path, 'r') as h5file:
    dataset_dict = h5_to_dict(h5file)

trial_info = pd.read_csv(trial_info_path)
trial_info = trial_info[trial_info['split'].isin(['train','val'])]
min_idx = trial_info['trial_id'].min()
trial_info['trial_id'] = trial_info['trial_id'] - min_idx

train_ids = trial_info[trial_info['split']=='train']['trial_id'].tolist()
val_ids = trial_info[trial_info['split']=='val']['trial_id'].tolist()

# Concatenate both heldin and heldout spikes since we're using spikes to predict behavior
spikes = np.concat([
    dataset_dict['train_spikes_heldin'], 
    dataset_dict['train_spikes_heldout']],axis=2) 
smoothed_spikes = smooth_spikes(spikes, kern_sd_ms=40, bin_width=5)
behavior = dataset_dict['train_behavior']

input_dim = smoothed_spikes.shape[2]
output_dim = behavior.shape[2]

# Split train and val based on splits from nlb
# train_dataset = TensorDataset(smoothed_spikes[train_ids], behavior[train_ids])
# val_dataset = TensorDataset(smoothed_spikes[val_ids], behavior[val_ids])
# full_dataset = TensorDataset(smoothed_spikes, behavior)

run_name = f"nlb_{args.task}_{args.model.ssm_core}_l{args.model.num_layers}_d{args.model.d_state}"


In [2]:
import equinox as eqx 
from jax import random
import jax
import jax.numpy as jnp
model_key = random.PRNGKey(0)
model = S5(
    key= model_key,
    num_blocks=args.model.num_layers,
    N=args.model.input_dim,
    ssm_size=args.model.d_state,
    ssm_blocks=1,
    H=args.model.hidden_dim,
    output_dim=args.model.output_dim,
)

In [3]:
# Initialize the state
state = eqx.nn.State(model)

# Generate a batch of keys
batch_size = smoothed_spikes.shape[0]
keys = random.split(random.PRNGKey(0), batch_size)

# Run vmap with the batch of keys
output, state = jax.vmap(
    model, 
    axis_name="batch", 
    in_axes=(0, None, 0),  # Map over inputs and keys
    out_axes=(0)
)(jnp.array(smoothed_spikes), state, keys)

In [4]:
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Define loss function (MSE for regression)
def mse_loss(model, state, inputs, targets, key):
    preds, _ = model(inputs, state, key)
    return jnp.mean((preds - targets) ** 2)

# Create training step function
@partial(jax.jit, static_argnums=(0,))
def train_step(model, state, opt_state, inputs, targets, key):
    # Define loss function with current parameters
    def loss_fn(params):
        # Recreate model with updated parameters
        model = eqx.combine(params, eqx.filter(model_fn, eqx.is_array, invert=True))
        preds, _ = model(inputs, state, key)
        return jnp.mean((preds - targets) ** 2)
    
    # Get loss value and gradients
    loss_val, grads = eqx.filter_value_and_grad(loss_fn)(eqx.filter(model_fn, eqx.is_array))
    
    # Update parameters
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_model = eqx.apply_updates(model_fn, updates)
    
    return new_model, new_opt_state, loss_val

# Create evaluation function
@partial(jax.jit, static_argnums=(0,))
def evaluate(model_fn, state, inputs, targets, key):
    preds, _ = model_fn(inputs, state, key)
    loss = jnp.mean((preds - targets) ** 2)
    return loss, preds

# Prepare data: convert to jax arrays and add batch dimension if needed
def prepare_batch(inputs, targets, batch_indices):
    batch_x = jnp.array(inputs[batch_indices])
    batch_y = jnp.array(targets[batch_indices])
    return batch_x, batch_y


In [6]:
# Training configuration
epochs = 100  # Start with fewer epochs for testing
batch_size = 16  # Smaller batch size since we have sequential data
learning_rate = 1e-3
weight_decay = 1e-5
key = random.PRNGKey(42)

# Set up optimizer with learning rate schedule
schedule = optax.exponential_decay(
    init_value=learning_rate,
    transition_steps=100,
    decay_rate=0.9,
    end_value=1e-4
)
optimizer = optax.adamw(learning_rate=schedule, weight_decay=weight_decay)

# Initialize optimizer state
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

# Initialize state
state = eqx.nn.State(model)

# Prepare training and validation data
X_train = smoothed_spikes[train_ids]
y_train = behavior[train_ids]
X_val = smoothed_spikes[val_ids]
y_val = behavior[val_ids]

# Keep track of metrics
train_losses = []
val_losses = []


In [8]:
# Create batch indices
num_train = len(train_ids)
indices = np.arange(num_train)

# Training loop with fixed JIT compilation
print("Starting training...")
for epoch in tqdm(range(epochs)):
    # Shuffle train data
    np.random.shuffle(indices)
    
    # Train batches
    epoch_loss = 0.0
    num_batches = 0
    
    for i in range(0, num_train, batch_size):
        batch_indices = indices[i:i + batch_size]
        batch_x, batch_y = prepare_batch(X_train, y_train, batch_indices)
        
        # Generate new key for this batch
        key, subkey = random.split(key)
        
        # Train on batch
        model, opt_state, batch_loss = train_step(model, state, opt_state, batch_x, batch_y, subkey)
        
        epoch_loss += batch_loss
        num_batches += 1
    
    # Calculate average loss for epoch
    avg_train_loss = epoch_loss / num_batches
    train_losses.append(float(avg_train_loss))  # Convert from JAX array to float
    
    # Evaluate on validation set
    val_key = random.split(key)[0]
    val_loss, val_preds = evaluate(model, state, X_val, y_val, val_key)
    val_losses.append(float(val_loss))  # Convert from JAX array to float
    
    # Log every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

Starting training...


  0%|          | 0/100 [00:00<?, ?it/s]


ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'foundational_ssm.models.s5.S5'>, S5(
  linear_encoder=Linear(
    weight=f32[64,182],
    bias=f32[64],
    in_features=182,
    out_features=64,
    use_bias=True
  ),
  blocks=[
    S5Block(
      norm=BatchNorm(
        weight=None,
        bias=None,
        first_time_index=StateIndex(
          marker=<object object at 0x7f408899db10>, init=bool[]
        ),
        state_index=StateIndex(
          marker=<object object at 0x7f408899d840>, init=(f32[64], f32[64])
        ),
        axis_name='batch',
        inference=False,
        input_size=64,
        eps=1e-05,
        channelwise_affine=False,
        momentum=0.99
      ),
      ssm=S5Layer(
        Lambda_re=f32[32],
        Lambda_im=f32[32],
        B=f32[32,64,2],
        C=f32[64,32,2],
        D=f32[64],
        log_step=f32[32,1],
        H=64,
        P=32,
        conj_sym=True,
        discretisation='zoh'
      ),
      glu=GLU(
        w1=Linear(
          weight=f32[64,64],
          bias=f32[64],
          in_features=64,
          out_features=64,
          use_bias=True
        ),
        w2=Linear(
          weight=f32[64,64],
          bias=f32[64],
          in_features=64,
          out_features=64,
          use_bias=True
        )
      ),
      drop=Dropout(p=0.05, inference=False)
    ),
    S5Block(
      norm=BatchNorm(
        weight=None,
        bias=None,
        first_time_index=StateIndex(
          marker=<object object at 0x7f40246d2310>, init=bool[]
        ),
        state_index=StateIndex(
          marker=<object object at 0x7f40246d2810>, init=(f32[64], f32[64])
        ),
        axis_name='batch',
        inference=False,
        input_size=64,
        eps=1e-05,
        channelwise_affine=False,
        momentum=0.99
      ),
      ssm=S5Layer(
        Lambda_re=f32[32],
        Lambda_im=f32[32],
        B=f32[32,64,2],
        C=f32[64,32,2],
        D=f32[64],
        log_step=f32[32,1],
        H=64,
        P=32,
        conj_sym=True,
        discretisation='zoh'
      ),
      glu=GLU(
        w1=Linear(
          weight=f32[64,64],
          bias=f32[64],
          in_features=64,
          out_features=64,
          use_bias=True
        ),
        w2=Linear(
          weight=f32[64,64],
          bias=f32[64],
          in_features=64,
          out_features=64,
          use_bias=True
        )
      ),
      drop=Dropout(p=0.05, inference=False)
    )
  ],
  linear_layer=Linear(
    weight=f32[2,64], bias=f32[2], in_features=64, out_features=2, use_bias=True
  )
). The error was:
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/asyncio/base_events.py", line 683, in run_forever
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/asyncio/base_events.py", line 2042, in _run_once
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/asyncio/events.py", line 89, in _run
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/IPython/core/interactiveshell.py", line 3100, in run_cell
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/IPython/core/interactiveshell.py", line 3155, in _run_cell
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/IPython/core/interactiveshell.py", line 3367, in run_cell_async
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/IPython/core/interactiveshell.py", line 3612, in run_ast_nodes
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/IPython/core/interactiveshell.py", line 3672, in run_code
  File "/tmp/ipykernel_119725/3001736501.py", line 23, in <module>
  File "/cs/student/projects1/ml/2024/mlaimon/anaconda3/envs/foundational_ssm/lib/python3.13/site-packages/equinox/_module.py", line 1037, in __hash__
TypeError: unhashable type: 'jaxlib._jax.ArrayImpl'
