# Introduction to Checkpointing with Orbax

The Orbax library provides multiple loosely related packages geared towards JAX
model persistence; **checkpointing** is a core Orbax component.
You can install the checkpointing package with:

```
pip install orbax-checkpoint
```

Be sure to check out our [PyPI page](https://pypi.org/project/orbax-checkpoint/)
and [GitHub page](https://github.com/google/orbax) for more information.

This tutorial (and others in the Orbax documentation) generally assume a basic level of familiarity with the [JAX](https://docs.jax.dev/en/latest/index.html) library.

Now, let's get started with some usage examples. First, we need to set up a
simple PyTree containing JAX arrays. This represents our JAX model.

In [None]:
### Setup ###
import itertools
from etils import epath
import jax
import numpy as np

directory = epath.Path('/tmp/my-checkpoints')
pytree = {
    'a': np.arange(64).reshape((8, 8)),
    'b': np.arange(16),
    'c': np.asarray(4.5),
}
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
shardings = {
    'a': jax.sharding.NamedSharding(
        mesh, jax.sharding.PartitionSpec('x', None)
    ),
    'b': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
    'c': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
}
pytree = jax.tree.map(
    lambda arr, sharding: jax.make_array_from_callback(
        arr.shape,
        sharding,
        lambda idx: arr[idx],
    ),
    pytree,
    shardings,
)

_checkpoint_name = itertools.count()


def next_checkpoint_name() -> str:
  return f'ckpt{next(_checkpoint_name)}'

## Reading and Writing

First, import the checkpointing package. For v1, it's crucial to use the exact import statement as below; an incorrect import can lead to errors or unexpected behavior.

In [None]:
from orbax.checkpoint import v1 as ocp

Using the tree of `jax.Array` created above, let's save a checkpoint.

In [None]:
checkpoint_name = next_checkpoint_name()
ocp.save_pytree(directory / checkpoint_name, pytree)

Loading yields the original PyTree of arrays.

In [None]:
ocp.load_pytree(directory / checkpoint_name)

We can inspect the tree structure and array properties using `pytree_metadata`.

In [None]:
ocp.pytree_metadata(directory / checkpoint_name).metadata

Note that we are accessing the property: `pytree_metadata(...).metadata`. This is the metadata specific to the PyTree itself. Other properties are general to the entire checkpoint, such as timestamps.

Be sure to check out additional documentation on **Working with PyTrees** TODO(b/409381706): Add link.

## Checkpointing in a Training Loop

When training an ML model, checkpoints are commonly used to record progress for later recovery in case of failure, to perform evaluations, or to distribute the model to downstream consumers after the experiment completes. Typically, a checkpoint is saved every `n` steps.

In [None]:
@jax.jit
def train_step(state):
  """Fake train step. This applies a function to `state` in some way."""
  return jax.tree.map(lambda x: x + 1, state)


def initialize_state():
  """Initializes the state, typically given some random number generator."""
  return {'step': 0, **pytree}


def init_or_restore(
    source_checkpoint_path: str | None,
):
  # If provided, restore initial checkpoint (e.g. for fine-tuning).
  # This can be referred to as a "source" checkpoint. Note the distinction drawn
  # between this "source checkpoint" and the "latest checkpoint". The source
  # checkpoint comes from a different experiment entirely, and is just used
  # to initialize the current experiment. The latest checkpoint comes from this
  # experiment, and allows us to resume after interruption.
  if source_checkpoint_path:
    return ocp.load_pytree(source_checkpoint_path)
  # Otherwise, init from scratch
  else:
    return initialize_state()

In [None]:
def train():
  total_steps = 10
  with ocp.training.Checkpointer(directory / 'experiment') as ckptr:
    # If checkpoints exist in the root directory, we are recovering after a
    # restart, and should resume from the latest checkpoint.
    # Otherwise, init from scratch or load the source checkpoint.
    if ckptr.latest is None:
      train_state = init_or_restore(directory / checkpoint_name)
      start_step = 0
    else:
      train_state = ckptr.load_pytree()
      start_step = ckptr.latest.step

    for step in range(start_step, total_steps):
      train_state = train_step(train_state)
      ckptr.save_pytree(step, train_state)

In [None]:
train()

In [None]:
!ls {directory / 'experiment'}

To summarize, a typical training workflow (from a checkpoint-focused perspective), consists of the following steps:

*   Identify the latest checkpoint, if any.
*   If no latest checkpoint is found:
  * Restore from the source checkpoint if provided, or,
  * Initialize the model from scratch.
*   If a latest checkpoint is found, restore it, and resume training from the latest step.



## What's Next? TODO(b/409381706): Provide links to other pages.

So far, we have seen some simple and common patterns of Orbax usage. This represents just the tip of the checkpointing iceberg. We encourage the reader to explore additional topics.

PyTrees of arrays are a fundamental representation of ML models in JAX. **Working with PyTrees** examines PyTree checkpointing in greater detail, showing how to reshard, cast, and manipulate other array properties. It also demonstrates multiple mechanisms for partially restoring a PyTree. Further advanced options for saving and restoring PyTrees and arrays are also shown.

Compute efficiency is crucial for training ML models. **Async checkpointing** shows how to save and load in a background thread, minimizing the performance impact of checkpointing on the training job.

PyTrees of arrays are not the only type of object that needs to be checkpointed. Orbax introduces the concept of a **Checkpointable** to represent other objects, like dataset iterators or special metadata, that must be saved alongside the main model. Further mechanisms for advanced support for user-customized objects are also shown.

The step-based training loop is a common concept across many ML workflows. We expand on the **training module** provided by Orbax, which offers `Checkpointer` as the primary entry point.

Interacting directly with the **file format** of the checkpoint on disk is useful in a variety of circumstances. We provide details on the file format, contributing to a deeper grasp of Orbax concepts, debugging strategies, and advanced options.