# Load trained models and generated trajectories from Wandb
You will need to: 
1. Enroll in Wandb project `foundational_ssm_nlb`
2. Check runs in the project for all trained models. 
3. Download the original dataset from huggingface `https://huggingface.co/datasets/MelinaLaimon/nlb_processed/tree/main`, 

In [2]:
import wandb 
from foundational_ssm.models import SSMFoundational
import matplotlib.pyplot as plt
import torch
import os
import h5py
import pandas as pd
import numpy as np
from foundational_ssm.utils import load_model_wandb

# Load Model

In [16]:
api = wandb.Api()
wandb_account = "melinajingting-ucl"
project = "foundational_ssm_pretrain_decoding"
run_name = "perich_miller_population_2018_l3_d64" # Change this to the desired run name
version = "v3"

model_artifact_full_name = f"{wandb_account}/{project}/{run_name}_best_model:{version}"
model_artifact = api.artifact(model_artifact_full_name, type="model")
model_artifact_dir = model_artifact.download()

model_filename = os.path.join(model_artifact_dir, 'best_model.pt')

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [1]:
import json
from foundational_ssm.models import SSMFoundational
import equinox as eqx
import jax.random as jr

with open(model_filename, "rb") as f:
    hyperparams = json.loads(f.readline().decode())
    model = SSMFoundational(**hyperparams, key=jr.PRNGKey(0))
    model = eqx.tree_deserialise_leaves(f, model)

ModuleNotFoundError: No module named 'foundational_ssm.models.s4d'

In [5]:
import optax
from foundational_ssm.utils.wandb_utils_jax import load_model_and_state_wandb
from foundational_ssm.utils.training import get_filter_spec
import equinox as eqx

model, state = load_model_and_state_wandb(wandb_pretrained_model_id="melinajingting-ucl/foundational_ssm_pretrain_decoding/train_batch-1024_sub-cmtj_l1_d128_best_model:v3")
opt = optax.adamw(learning_rate=0.001, weight_decay=0.001)
filter_spec = get_filter_spec(
        model,
        freeze_ssm=False,
        freeze_mlp=False
    )
opt_state = opt.init(eqx.filter(model, filter_spec))

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [None]:
import json
import os

def save_checkpoint_wandb(model, state, opt_state, epoch, step, run_name):
    """Save model, optimizer state, epoch, and step to a checkpoint file."""
    with open('checkpoint.ckpt', 'wb') as f:
        # Write metadata as JSON in the first line
        meta = json.dumps({'epoch': epoch, 'step': step})
        f.write((meta + '\n').encode())
        eqx.tree_serialise_leaves(f, model)
        eqx.tree_serialise_leaves(f, state)
        eqx.tree_serialise_leaves(f, opt_state)
    artifact = wandb.Artifact(
        name=f'{run_name}_checkpoint',  # Name for the artifact
        type="checkpoint",                # Artifact type (can be "model", "checkpoint", etc.)
        description=f"Checkpoint at epoch {epoch}",
    )
    wandb.log_artifact(artifact)
    

def load_checkpoint_wandb(path, model_template, state_template, opt_state_template, wandb_run_name, wandb_project, wandb_entity):
    """Load model, optimizer state, epoch, and step from a checkpoint file."""
    api = wandb.Api()
    artifact_full_name = f"{wandb_entity}/{wandb_project}/{wandb_run_name}_checkpoint:latest"
    artifact_save_path = f"{wandb_run_name}"
    artifact = api.artifact(artifact_full_name, type="checkpoint")
    dir = artifact.download(artifact_save_path)
    path = os.path.join(dir, 'checkpoint.ckpt')
    with open(path, 'rb') as f:
        meta = json.loads(f.readline().decode())
        model = eqx.tree_deserialise_leaves(f, model_template)
        state = eqx.tree_deserialise_leaves(f, state_template)
        opt_state = eqx.tree_deserialise_leaves(f, opt_state_template)
    return model, state, opt_state, meta['epoch'], meta['step'], meta



In [11]:
import wandb

# Set your entity, project, and run ID
entity = "melinajingting-ucl"
project = "foundational_ssm_pretrain_decoding"
run_id = "cr6zuzfw"  # The run you want to attach the artifact to
run_name = 'train_batch-1024_sub-cmtj_l4_d64'
epoch = 60
step = 19124

# Path to your checkpoint file
ckpt_path = "checkpoint.ckpt"  # or "best_model" if that's your file

# Initialize existing run to upload checkpoint
# api = wandb.Api()
# run = api.run(f"{entity}/{project}/{run_id}")
# wandb.init(entity=entity, project=project, id=run_id, resume="allow")
# save_checkpoint_wandb(model, state, opt_state, epoch, step, run_name)
# wandb.finish()

model, state, opt_state, epoch, step, meta = load_checkpoint_wandb(path=ckpt_path, model_template=model, state_template=state, opt_state_template=opt_state, wandb_run_name=run_name, wandb_project='foundational_ssm_pretrain_decoding', wandb_entity='melinajingting-ucl')

[34m[1mwandb[0m:   1 of 1 files downloaded.  


RuntimeError: Deserialised leaf at path (GetAttrKey(name='encoders'), SequenceKey(idx=0), GetAttrKey(name='weight')) has changed shape from (120, 353) in `like` to (56, 353) on disk.