In [2]:
import functools as ft
import pathlib
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any, BinaryIO

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import PyTree
import equinox as eqx



class TreePathError(RuntimeError):
    path: tuple


def _ordered_tree_map(
    f: Callable[..., Any],
    tree: Any,
    *rest: Any,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Any:
    """Like jax.tree_util.tree_map, but guaranteed to iterate over the tree
    in fixed order. (Namely depth-first left-to-right.)
    """
    # Discussion: https://github.com/patrick-kidger/equinox/issues/136
    paths_and_leaves, treedef = jtu.tree_flatten_with_path(tree, is_leaf)
    all_leaves = list(zip(*paths_and_leaves)) + [treedef.flatten_up_to(r) for r in rest]

    @ft.wraps(f)
    def _f(path, *xs):
        try:
            return f(*xs)
        except TreePathError as e:
            combo_path = path + e.path
            exc = TreePathError(f"Error at leaf with path {combo_path}")
            exc.path = combo_path
            raise exc from e
        except Exception as e:
            exc = TreePathError(f"Error at leaf with path {path}")
            exc.path = path
            raise exc from e

    return treedef.unflatten(_f(*xs) for xs in zip(*all_leaves))


def _with_suffix(path):
    path = pathlib.Path(path)
    if path.suffix == "":
        return path.with_suffix(".eqx")
    else:
        return path
    
def default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any:
    """Default filter specification for deserialising saved data.

    **Arguments**

    -   `f`: file-like object
    -   `x`: The leaf for which the data needs to be loaded.

    **Returns**

    The new value for datatype `x`.

    !!! info

        This function can be extended to customise the deserialisation behaviour for
        leaves.

    !!! example

        Skipping loading of jax.Array.

        ```python
        import jax.numpy as jnp
        import equinox as eqx

        tree = (jnp.array([4,5,6]), [1,2,3])
        new_filter_spec = lambda f,x: (
            x if isinstance(x, jax.Array) else eqx.default_deserialise_filter_spec(f, x)
        )
        new_tree = eqx.tree_deserialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)
        ```
    """  # noqa: E501
    # try:
    if isinstance(x, (jax.Array, jax.ShapeDtypeStruct)):
        return jnp.load(f)
    elif isinstance(x, np.ndarray):
        # Important to use `np` here to avoid promoting NumPy arrays to JAX.
        return np.load(f)
    elif eqx.is_array_like(x):
        # np.generic gets deserialised directly as an array, so convert back to a scalar
        # type here.
        # See also https://github.com/google/jax/issues/17858
        out = np.load(f)
        if isinstance(x, jax.dtypes.bfloat16):
            out = out.view(jax.dtypes.bfloat16)
        if np.size(out) == 1:
            return type(x)(out.item())
    else:
        return x
    # except:
    #     print("Failed to load data for leaf with shape/ value:", x.shape if hasattr(x, 'shape') else x)
    #     return x 


@contextmanager
def _maybe_open(path_or_file: str | pathlib.Path | BinaryIO, mode: str):
    """A function that unifies handling of file objects and path-like objects
    by opening the latter."""
    if isinstance(path_or_file, (str, pathlib.Path)):
        file = open(_with_suffix(path_or_file), mode)
        try:
            yield file
        finally:
            file.close()
    else:  # file-like object
        yield path_or_file


def _assert_same(array_impl_type):
    def _assert_same_impl(path, new, old):
        typenew = type(new)
        typeold = type(old)
        if typeold is jax.ShapeDtypeStruct:
            typeold = array_impl_type
        if typenew is not typeold:
            raise RuntimeError(
                f"Deserialised leaf at path '{jtu.keystr(path)}' has changed type from "
                f"{type(old)} in `like` to {type(new)} on disk."
            )
        if isinstance(new, (np.ndarray, jax.Array)):
            if new.shape != old.shape:
                raise RuntimeError(
                    f"Deserialised leaf at path {path} has changed shape from "
                    f"{old.shape} in `like` to {new.shape} on disk."
                )
            if new.dtype != old.dtype:
                raise RuntimeError(
                    f"Deserialised leaf at path {path} has changed dtype from "
                    f"{old.dtype} in `like` to {new.dtype} on disk."
                )

    return _assert_same_impl

def tree_deserialise_leaves(
    path_or_file: str | pathlib.Path | BinaryIO,
    like: PyTree,
    filter_spec=default_deserialise_filter_spec,
    is_leaf: Callable[[Any], bool] | None = None,
) -> PyTree:
    """Load the leaves of a PyTree from a file.

    **Arguments:**

    - `path_or_file`: The file location to load values from or a binary file-like
        object.
    - `like`: A PyTree of same structure, and with leaves of the same type, as the
        PyTree being loaded. Those leaves which are loaded will replace the
        corresponding leaves of `like`.
    - `filter_spec`: Specifies how to load each kind of leaf. By default all JAX
        arrays, NumPy arrays, Python bool/int/float/complexes are loaded, and
        all other leaf types are not loaded, and will retain their
        value from `like`. (See [`equinox.default_deserialise_filter_spec`][].)
    - `is_leaf`: Called on every node of `like`; if `True` then this node will be
        treated as a leaf.

    **Returns:**

    The loaded PyTree, formed by iterating over `like` and replacing some of its leaves
    with the leaves saved in `path`.

    !!! example

        This can be used to load a model from file.

        ```python
        import equinox as eqx
        import jax.random as jr

        model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
        eqx.tree_serialise_leaves("some_filename.eqx", model_original)
        model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)

        # To partially load weights, do model surgery. In this case load everything
        # except the final layer.
        model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)
        ```

    !!! example

        A common pattern is the following:

        ```python
        def run(..., load_path=None):
            if load_path is None:
                model = Model(...hyperparameters...)
            else:
                model = eqx.filter_eval_shape(Model, ...hyperparameters...)
                model = eqx.tree_deserialise_leaves(load_path, model)
        ```
        in which either a model is created directly (e.g. at the start of training), or
        a suitable `like` is constructed (e.g. when resuming training), where
        [`equinox.filter_eval_shape`][] is used to avoid creating spurious short-lived
        arrays taking up memory.

    !!! info

        `filter_spec` should typically be a function `(File, Any) -> Any`, which takes
        a file handle and a leaf from `like`, and either returns the corresponding
        loaded leaf, or returns the leaf from `like` unchanged.

        It can also be a PyTree of such functions, in which case the PyTree structure
        should be a prefix of `pytree`, and each function will be mapped over the
        corresponding sub-PyTree of `pytree`.
    """  # noqa: E501
    with _maybe_open(path_or_file, "rb") as f:

        def _deserialise(spec, x):
            def __deserialise(y):
                return spec(f, y)

            return _ordered_tree_map(__deserialise, x, is_leaf=is_leaf)

        out = _ordered_tree_map(_deserialise, filter_spec, like)
    with jax.ensure_compile_time_eval():
        # ArrayImpl isn't a public type, so this is how we get access to it instead.
        # `ensure_compile_time_eval` just in case someone is doing deserialisation
        # inside JIT. Which would be weird, but still.
        array_impl_type = type(jnp.array(0))
    jtu.tree_map_with_path(_assert_same(array_impl_type), out, like, is_leaf=is_leaf)
    return out


def file_to_tree(
    path_or_file: str | pathlib.Path | BinaryIO,
    like: PyTree,
    filter_spec=default_deserialise_filter_spec,
    is_leaf: Callable[[Any], bool] | None = None,
) -> PyTree:
    """
    """  # noqa: E501
    with _maybe_open(path_or_file, "rb") as f:

        def _deserialise(spec, x):
            def __deserialise(y):
                return spec(f, y)

            return _ordered_tree_map(__deserialise, x, is_leaf=is_leaf)

        out = _ordered_tree_map(_deserialise, filter_spec, like)
    return out

In [3]:
import wandb
import equinox as eqx
import os 

# Foundational SSM imports
from omegaconf import OmegaConf
import tempfile 
from foundational_ssm.models import SSMDownstreamDecoder, SSMFoundationalDecoder
from foundational_ssm.utils import h5_to_dict
from foundational_ssm.transform import smooth_spikes
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any, BinaryIO


def load_model_and_state_from_checkpoint_wandb(artifact_full_name, model_cls=SSMFoundationalDecoder, model_cfg=None):
    """Load model, optimizer state, epoch, and step from a checkpoint file."""
    api = wandb.Api()
    try:
        artifact = api.artifact(artifact_full_name, type="checkpoint")
    except Exception as e:
        raise FileNotFoundError(f"Could not find checkpoint artifact: {artifact_full_name}")
    
    if model_cfg is None:
        run = artifact.logged_by()
        run_cfg = OmegaConf.create(run.config)
        print(run_cfg)
        model_cfg = OmegaConf.create(run_cfg.model)
    
    model_template, state_template = eqx.nn.make_with_state(model_cls)(
        **model_cfg
    )
    model_template = eqx.nn.inference_mode(model_template, False)
    
    with tempfile.TemporaryDirectory() as temp_dir:
        artifact.download(temp_dir)
        model = eqx.tree_deserialise_leaves(os.path.join(temp_dir, "model.ckpt"), model_template, default_deserialise_filter_spec)
        state = eqx.tree_deserialise_leaves(os.path.join(temp_dir, "state.ckpt"), state_template, default_deserialise_filter_spec)

    meta = artifact.metadata
    return model, state, meta

In [4]:
artifact_full_name = f"melinajingting-ucl/foundational_ssm_pretrain/l4_reaching_normalized_checkpoint:best"

api = wandb.Api()
try:
    artifact = api.artifact(artifact_full_name, type="checkpoint")
except Exception as e:
    raise FileNotFoundError(f"Could not find checkpoint artifact: {artifact_full_name}")

run = artifact.logged_by()
run_cfg = OmegaConf.create(run.config)
print(run_cfg)
model_cfg = OmegaConf.create(run_cfg.model)

model_template, state_template = eqx.nn.make_with_state(SSMFoundationalDecoder)(
    **model_cfg
)
model_template = eqx.nn.inference_mode(model_template, False)

with tempfile.TemporaryDirectory() as temp_dir:
    artifact.download(temp_dir)
    out = file_to_tree(os.path.join(temp_dir, "model.ckpt"), model_template)

{'model': {'dt_max': 0.01, 'dt_min': 0.001, 'ssm_dim': 128, 'rng_seed': 42, 'dropout_p': 0.01, 'output_dim': 2, 'ssm_io_dim': 512, 'ssm_num_layers': 4, 'ssm_init_diag_blocks': 4}, 'wandb': {'tags': ['neural', 'behavior', 'masking'], 'entity': 'melinajingting-ucl', 'project': 'foundational_ssm_pretrain', 'resume_run_id': None, 'run_name_postfix': '_normalized'}, 'rng_seed': 42, 'training': {'epochs': 1001, 'log_val_every': 50, 'checkpoint_every': 1}, 'model_cfg': 'configs/model/l4.yaml', 'optimizer': {'lr': 0.001, 'mode': 'all', 'weight_decay': 0.01}, 'val_loader': {'sampler': 'SequentialFixedWindowSampler', 'dataset_args': {'lazy': True, 'split': 'val', 'keep_files_open': False}, 'sampler_args': {'drop_short': False, 'window_length': 3.28, 'min_window_length': 0.88}, 'sampling_rate': 200, 'dataloader_args': {'batch_size': 1024, 'num_workers': 0, 'persistent_workers': False}}, 'dataset_cfg': 'configs/dataset/reaching.yaml', 'train_loader': {'sampler': 'RandomFixedWindowSampler', 'datase

[34m[1mwandb[0m: Downloading large artifact l4_reaching_normalized_checkpoint:best, 66.53MB. 3 files... 
[34m[1mwandb[0m:   3 of 3 files downloaded.  
Done. 0:0:0.4 (181.4MB/s)


In [7]:
out

SSMFoundationalDecoder(
  context_embedding=Embedding(
    num_embeddings=12, embedding_size=4, weight=f32[10,4]
  ),
  encoders=[
    Linear(
      weight=f32[508,625],
      bias=f32[508],
      in_features=625,
      out_features=508,
      use_bias=True
    ),
    Linear(
      weight=f32[508,625],
      bias=f32[508],
      in_features=625,
      out_features=508,
      use_bias=True
    ),
    Linear(
      weight=f32[508,625],
      bias=f32[508],
      in_features=625,
      out_features=508,
      use_bias=True
    ),
    Linear(
      weight=f32[508,625],
      bias=f32[508],
      in_features=625,
      out_features=508,
      use_bias=True
    ),
    Linear(
      weight=f32[508,625],
      bias=f32[508],
      in_features=625,
      out_features=508,
      use_bias=True
    ),
    Linear(
      weight=f32[508,625],
      bias=f32[508],
      in_features=625,
      out_features=508,
      use_bias=True
    ),
    Linear(
      weight=f32[508,625],
      bias=f32[508],
     