# ✅ Checkpointing

In this example, saving and restoring a simple model is demonstrated using [orbax](https://orbax.readthedocs.io/en/latest/index.html) library.

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet
!pip install orbax-checkpoint --quiet
!pip install optax --quiet

#### Basic usage

In [1]:
import serket as sk
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import orbax.checkpoint as ocp
import optax

net = sk.Sequential(
    sk.nn.Linear(1, 128, key=jr.PRNGKey(0)),
    jax.nn.relu,
    sk.nn.Linear(128, 1, key=jr.PRNGKey(1)),
)

# exclude non-parameters
net = sk.tree_mask(net)

# 1) get flat parameters and the tree structure
flat_net, treedef = jtu.tree_flatten(net)

# 2) define a checkpointer and save the parameters
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save("ckpt1", flat_net)

# 3) load the flat parameters
flat_net = checkpointer.restore("ckpt1")

# 4) reconstruct the tree using the loaded flat parameters and the tree structure
loaded_net = jtu.tree_unflatten(treedef, flat_net)



#### Managing checkpoints

For checkpointed saving, `orbax` offers the ability to define set of options to configure the process.

For full guide check [here](https://orbax.readthedocs.io/en/latest/index.html)

In [2]:
manager = ocp.CheckpointManager(
    directory="ckpt2",
    # lets assume we want to save neural network parameters and optimizer state
    # then we need to define a checkpointers dict with the keys "net" and "state"
    checkpointers=dict(net=ocp.PyTreeCheckpointer(), state=ocp.PyTreeCheckpointer()),
    # save checkpoints every 2 steps and keep the last 3 checkpoints
    options=ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2),
)

#### Define a train loop

In [3]:
def loss_func(net, x, y):
    net = sk.tree_unmask(net)
    return jnp.mean((jax.vmap(net)(x) - y) ** 2)


optim = optax.adam(1e-3)
optim_state = optim.init(net)
optim_state_treedef = jtu.tree_structure(optim_state)


@jax.jit
def train_step(net, optim_state: optax.OptState, x: jax.Array, y: jax.Array):
    loss, grads = jax.value_and_grad(loss_func)(net, x, y)
    updates, optim_state = optim.update(grads, optim_state)
    net = optax.apply_updates(net, updates)
    return net, optim_state, loss


x = jax.random.uniform(jax.random.PRNGKey(0), (100, 1))
y = jnp.sin(x) + jax.random.normal(jax.random.PRNGKey(0), (100, 1)) * 0.1

# should save step [0, 2, 4, 6, 8], and keep the last 3 checkpoints
# namely step [4, 6, 8]

for step in range(10):
    net, optim_state, loss = train_step(net, optim_state, x, y)
    flat_net = jtu.tree_leaves(net)
    flat_optim_state = jtu.tree_leaves(optim_state)
    # note that we need to save the *flat* parameters and the *flat* optimizer state
    manager.save(step, dict(net=flat_net, state=flat_optim_state))

In [4]:
# check all the checkpoints
manager.all_steps()

[4, 6, 8]

In [None]:
# load checkpoint at step 6
checkpointers = manager.restore(6)

loaded_flat_net = checkpointers["net"]
loaded_optim_flat_state = checkpointers["state"]

# reconstruct the tree with the loaded parameters
loaded_net = jtu.tree_unflatten(treedef, loaded_flat_net)
loaded_optim_state = jtu.tree_unflatten(optim_state_treedef, loaded_optim_flat_state)