In [17]:
from foundational_ssm.data_utils.loaders import get_nlb_train_val_loaders
from foundational_ssm.models import SSMFoundational
from omegaconf import OmegaConf
import jax
import equinox as eqx
import jax.random as jr
import matplotlib.pyplot as plt 
import optax
import jax.numpy as jnp
from jax.tree_util import tree_map
import wandb
import json
import os
from collections import defaultdict
from foundational_ssm.constants import DATASET_IDX_TO_GROUP_SHORT
from foundational_ssm.metrics import compute_r2_standard
from foundational_ssm.utils import save_model_wandb


@eqx.filter_jit
def predict_batch(model, state, inputs, key, dataset_group_idx):
    """Predict on a batch of inputs using JAX's vmap"""
    batch_keys = jr.split(key, inputs.shape[0])
    preds, _ = jax.vmap(model, axis_name="batch", in_axes=(0, None, 0, None))(inputs, state, batch_keys, dataset_group_idx)
    return preds

@eqx.filter_jit
@eqx.filter_value_and_grad(has_aux=True)
def mse_loss(model_params, model_static, state, inputs, targets, dataset_group_idx, key):
    model = eqx.combine(model_params, model_static)
    batch_keys = jr.split(key, inputs.shape[0])
    preds, state = jax.vmap(model, axis_name="batch", in_axes=(0, None, 0, None), out_axes=(0, None))(inputs, state, batch_keys, dataset_group_idx)
    mse = jnp.mean((preds - targets) ** 2)
    return (mse, state)

@eqx.filter_jit
def make_step(model, state, filter_spec, inputs, targets, dataset_group_idx, loss_fn, opt, opt_state, key):
    model_params, model_static = eqx.partition(model, filter_spec)
    (value, state), grads = loss_fn(model_params, model_static, state, inputs, targets, dataset_group_idx, key)
    updates, opt_state = opt.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state, value, grads

def load_model_and_state(wandb_pretrained_model_id, hyperparams):
    """
    either loads a model from wandb or creates a new model from hyperparams
    Args:
        wandb_pretrained_model_id: wandb artifact id of the model to load
        hyperparams: dict of hyperparams to create a new model
    Returns:
        model (SSMFoundational): Loaded model or None if not specified.
    """
    if wandb_pretrained_model_id is not None:
        api = wandb.Api()
        model_artifact = api.artifact(wandb_pretrained_model_id, type="model")
        model_artifact_dir = model_artifact.download()
        model_filename = os.path.join(model_artifact_dir, 'best_model.pt')
        with open(model_filename, "rb") as f:
            hyperparams = json.loads(f.readline().decode())
            model = SSMFoundational(**hyperparams)
            model = eqx.tree_deserialise_leaves(f, model)
            state = eqx.nn.State(model)
        return model, state
    else:
        model = SSMFoundational(**hyperparams)
        state = eqx.nn.State(model)
        return model, state

In [18]:
cfg = OmegaConf.load('../configs/finetune.yaml')
train_dataset, train_loader, val_dataset, val_loader = get_nlb_train_val_loaders()
model, state = load_model_and_state(cfg.wandb_pretrained_model_id, cfg.model)

key = jr.PRNGKey(cfg.rng_seed)
train_key, val_key = jr.split(key, 2)

In [19]:
filter_spec = tree_map(eqx.is_inexact_array, model)
lr_scheduler = lambda step: cfg.optimizer.lr
# Load JAX optimizer with scheduler
opt = optax.chain(
    optax.adamw(learning_rate=lr_scheduler, weight_decay=cfg.optimizer.weight_decay)
)
opt_state = opt.init(eqx.filter(model, filter_spec))

loss_fn = mse_loss

run_name = f'{cfg.finetune_mode}_holdout-{cfg.train_dataset.holdout_angles}'
config_dict = OmegaConf.to_container(cfg, resolve=True)
wandb.init(project=cfg.wandb.project, name=run_name, config=config_dict)  # type: ignore

# Define metrics with custom x-axis
wandb.define_metric("epoch", step_metric="epoch")
wandb.define_metric("val/*", step_metric="epoch")
wandb.define_metric("epoch_train_loss", step_metric="epoch")

[34m[1mwandb[0m: Currently logged in as: [33mmelinajingting[0m ([33mmelinajingting-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


<wandb.sdk.wandb_metric.Metric at 0x7feaf23f9590>

In [22]:
best_r2_score = 0
for epoch in range(cfg.training.epochs):
    epoch_loss = 0
    for batch in train_loader:
        inputs = batch["neural_input"]
        targets = batch["behavior_input"]
        dataset_group_idx = batch["dataset_group_idx"][0]
        key, subkey = jr.split(train_key)
        model, state, opt_state, loss_value, grads = make_step(
            model,
            state,
            filter_spec,
            inputs,
            targets,
            dataset_group_idx,
            loss_fn,
            opt,
            opt_state,
            subkey)

        # Get current learning rate from scheduler
        epoch_loss += loss_value
        
        wandb.log({
            "train/loss": loss_value,
        })
        
    if epoch % cfg.training.log_every == 0:
        # Log the epoch value so wandb can use it as x-axis for validation metrics
        wandb.log({"epoch": epoch})
        wandb.log({"train/epoch_loss": epoch_loss})
        
        total_r2_score = 0
        group_preds = defaultdict(list)
        group_targets = defaultdict(list)
        for batch in val_loader:
            inputs = batch["neural_input"]
            targets = batch["behavior_input"]
            dataset_group_idx = batch["dataset_group_idx"][0]
            dataset_group_key = DATASET_IDX_TO_GROUP_SHORT[dataset_group_idx]
            
            key, subkey = jr.split(val_key)
            batch_keys = jr.split(subkey, inputs.shape[0])
            preds, state = jax.vmap(model, axis_name="batch", in_axes=(0, None, 0, None), out_axes=(0, None))(inputs, state, batch_keys, dataset_group_idx)
            group_preds[dataset_group_key].append(preds)
            group_targets[dataset_group_key].append(targets)
            
        for group_key, preds in group_preds.items():
            preds = jnp.concatenate(preds, axis=0)
            targets = jnp.concatenate(group_targets[group_key], axis=0)
            r2_score = compute_r2_standard(preds, targets)
            wandb.log({f"val/r2_{group_key}": r2_score})
            total_r2_score += r2_score
        avg_r2_score = total_r2_score / len(group_preds)
    
        if avg_r2_score > best_r2_score:
            best_r2_score = avg_r2_score
            # save_model_wandb(model, run_name, OmegaConf.to_container(cfg.model), wandb.run)
        
        print(f"Epoch {epoch}/{cfg.training.epochs}, Loss: {epoch_loss:.4f}")

wandb.finish()
    

Epoch 0/2000, Loss: 817233.1875
Epoch 50/2000, Loss: 73276.0078
Epoch 100/2000, Loss: 31008.3047
Epoch 150/2000, Loss: 17932.3867
Epoch 200/2000, Loss: 13768.1348
Epoch 250/2000, Loss: 11209.4170
Epoch 300/2000, Loss: 9466.0146
Epoch 350/2000, Loss: 8810.9336
Epoch 400/2000, Loss: 6529.4321
Epoch 450/2000, Loss: 6865.3301
Epoch 500/2000, Loss: 5781.3315
Epoch 550/2000, Loss: 4819.2446
Epoch 600/2000, Loss: 4290.4917
Epoch 650/2000, Loss: 3444.5544
Epoch 700/2000, Loss: 3620.8601
Epoch 750/2000, Loss: 3223.9553
Epoch 800/2000, Loss: 3053.0203
Epoch 850/2000, Loss: 2951.3818
Epoch 900/2000, Loss: 2416.7649
Epoch 950/2000, Loss: 2646.6418
Epoch 1000/2000, Loss: 2861.0142
Epoch 1050/2000, Loss: 2381.5173
Epoch 1100/2000, Loss: 2225.2947
Epoch 1150/2000, Loss: 2028.1149
Epoch 1200/2000, Loss: 2255.2869
Epoch 1250/2000, Loss: 2211.0710
Epoch 1300/2000, Loss: 2136.7700
Epoch 1350/2000, Loss: 2190.7129
Epoch 1400/2000, Loss: 2576.8721
Epoch 1450/2000, Loss: 1841.8878
Epoch 1500/2000, Loss: 158

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/epoch_loss,█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/r2_pm_c_co,▁▄██████████████████████████████████████

0,1
epoch,1950.0
train/epoch_loss,1648.63562
train/loss,57.03786
val/r2_pm_c_co,0.85623
