In [1]:
from pathlib import Path
from typing import Any

import equinox as eqx
import hydra
import jax
import matplotlib.pyplot as plt
import orbax.checkpoint as ocp
from jaxtyping import PyTree
from omegaconf import DictConfig, OmegaConf


jax.config.update("jax_enable_x64", True)
plt.style.use("matplotlib_utils.styles.dash_gridded")

datadir = Path("../../data")

In [2]:
config_dict = {
    "_target_": "dynamics_discovery.models.NeuralODE",
    "dim": 3,
    "width": 32,
    "depth": 3,
    "activation": {"_target_": "hydra.utils.get_method", "path": "jax.nn.gelu"},
    "solver": {"_target_": "diffrax.Tsit5"},
    "rtol": 1e-4,
    "atol": 1e-6,
    "dt0": None,
    "key": 0,
}

config = OmegaConf.create(config_dict)
model = hydra.utils.instantiate(config)


def change_model(tree: PyTree[Any, " T"]) -> PyTree[Any, " T"]:
    tree_array, tree_rest = eqx.partition(tree, eqx.is_inexact_array)
    tree_array_new = jax.tree.map(lambda x: x * 3, tree_array)
    return eqx.combine(tree_array_new, tree_rest)


model_new = change_model(model)
eqx.tree_equal(model_new, model)

Array(False, dtype=bool)

In [None]:
def save_model(
    model: PyTree,
    config: dict | OmegaConf,
    savedir: str | Path,
    step_number: int = 0,
    options: ocp.CheckpointManagerOptions | None = None,
) -> None:
    if isinstance(config, DictConfig):
        config = OmegaConf.to_container(config, resolve=True)
    if options is None:
        options = ocp.CheckpointManagerOptions()

    with ocp.CheckpointManager(
        Path(savedir).resolve(), options=options, metadata=config
    ) as mngr:
        mngr.save(
            step_number, args=ocp.args.StandardSave(eqx.filter(model, eqx.is_array))
        )


save_model(model_new, config, "./test")

In [4]:
def load_model(
    loaddir: str | Path,
    step_number: int = 0,
    options: ocp.CheckpointManagerOptions | None = None,
) -> PyTree:
    """Loads model from a checkpoint created by the `save_model` function.

    loaddir: Directory containing the checkpoints
    step_number: Number corresponding to the specific checkpoint
    """
    options = ocp.CheckpointManagerOptions() if options is None else options

    with ocp.CheckpointManager(Path(loaddir).resolve(), options=options) as mngr_load:
        model_config = mngr_load.metadata().custom_metadata
        model_backbone = hydra.utils.instantiate(OmegaConf.create(model_config))

        weights_backbone, rest = eqx.partition(model_backbone, eqx.is_array_like)
        weights_load = mngr_load.restore(
            step_number,
            args=ocp.args.StandardRestore(weights_backbone),
        )
    return eqx.combine(weights_load, rest)


In [5]:
eqx.tree_equal(model_new, load_model("./test"))

Array(True, dtype=bool)