# Customizing Checkpointing Behavior

Orbax allows users to specify their own logic for dealing with custom objects.
Customization can occur at two levels - the level of a "checkpointable", and the
level of a "PyTree leaf".

## Custom Checkpointables

First, ensure that you are familiar with the documentation on "checkpointables".
To recap, a "checkpointable" is a distinct unit of an entire checkpoint. For
example, the model state is a checkpointable distinct from the dataset iterator.
Embeddings, if used, may also be represented as a separate checkpointable.

Let us consider a toy example. Let's say that in addition to our PyTree state
(represented as a dictionary of arrays, containing the parameters and optimizer
state) and our dataset iteration (represented using PyGrain), we also have an
object called `Point`, which has integer properties `x` and `y`. (Obviously,
since this object is a dataclass, it would be easy to just convert this to a
PyTree, and save it in the same way as the primary model state. So this example
is a bit contrived, but demonstrates the point well enough.)

Our `Point` class is defined as follows.

In [None]:
import dataclasses
import json
from typing import Any, Awaitable
import aiofiles
import jax
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp


@dataclasses.dataclass
class Point:
  x: int
  y: int


model_state = {
    'params': np.arange(16),
    'opt_state': np.ones(16),
}

If we just try to save the `Point` (along with our other checkpointables), it
will fail because the object type is not recognized.

In [None]:
try:
  ocp.save_checkpointables(
      '/tmp/ckpt1',
      dict(model_state=model_state, point=Point(1, 2)),
  )
except BaseException as e:
  print(e)

There are two possible approaches for implementing support for `Point` in Orbax.
We will start with the simpler of the two.

### Implementing `Point` as a `StatefulCheckpointable`

The `Point` object must implement the methods of the `StatefulCheckpointable`
Protocol. We need to implement `save` and `load` methods so that Orbax will know
how to deal with the `Point` object.

In [None]:
from __future__ import annotations

del Point


@dataclasses.dataclass
class Point(ocp.StatefulCheckpointable):

  x: int
  y: int

  async def save(
      self, directory: ocp.path.PathAwaitingCreation
  ) -> Awaitable[None]:
    return self._background_save(
        directory,
        # If the object could be modified by the main thread while being
        # written, it is important to make a copy to prevent race conditions.
        dataclasses.asdict(self),
    )

  async def load(self, directory: ocp.path.Path) -> Awaitable[None]:
    return self._background_load(directory)

  async def _background_save(
      self,
      directory: ocp.path.PathAwaitingCreation,
      value: dict[str, int],
  ):
    # In a multiprocess setting, prevent multiple processes from writing the
    # same thing.
    if jax.process_index() == 0:
      directory = await directory.await_creation()
      async with aiofiles.open(directory / 'point.txt', 'w') as f:
        contents = json.dumps(value)
        await f.write(contents)

  async def _background_load(
      self,
      directory: ocp.path.Path,
  ):
    async with aiofiles.open(directory / 'point.txt', 'r') as f:
      contents = json.loads(await f.read())
      self.x = contents['x']
      self.y = contents['y']

Let's break this down.

Both `save` and `load` methods consist of two phases: blocking and non-blocking.
Blocking operations must execute *now*, before returning control to the caller.
Non-blocking operations may occur in a background thread, and are represented by
an `Awaitable` function returned back to the caller without being executed
(yet).

When saving, in the case of `Point`, we make a copy of the properties to prevent
them from being concurrently modified by the main thread while we are writing
them in the background thread. For a `jax.Array`, we would similarly need to
perform a transfer from device memory to host memory. When the blocking
operations complete, we can construct an awaitable function that writes the
values to a file. Note also that we must wait for the parent directory to be
created, since upper layers of Orbax have already scheduled this execution
asynchronously.

Loading is similar. Typically there are fewer operations that need to happen
synchronously, as the caller should know they cannot do anything with the object
until it is fully loaded. Again, the awaitable function that is run in the
background should return nothing, and instead set relevant properties in `self`
after loading from disk.

Now we can successfully save the `Point`.

In [None]:
ocp.save_checkpointables(
    '/tmp/ckpt1',
    dict(model_state=model_state, point=Point(1, 2)),
)

It is important to note that because `Point` is a stateful checkpointable, we
have to provide a `Point` object in order to restore it. In typical usage, we
should construct a `Point` object with "uninitialized" values. Calling
`load_checkpointables` then updates the provided object as a side effect (it
also returns it).

In [None]:
uninitialized_point = Point(0, 0)
ocp.load_checkpointables(
    '/tmp/ckpt1',
    dict(point=uninitialized_point),
)
uninitialized_point

### Supporting `Point` with `CheckpointableHandler`

While `StatefulCheckpointable` has a simple and powerful interface, it may not
be the right fit in every case. `StatefulCheckpointable` may be insufficient in
cases such as:

*   `Point` may be defined in some third-party library that we cannot easily
    control, and thus could not directly add `save` and `load` methods to the
    class itself.
*   When loading, users might need to customize loading behavior in a more
    dynamic way. For a `jax.Array`, resharding, casting, and reshaping are
    common operations. For a `Point`, users might want to cast `x` and `y`
    between `int` and `float` more dynamically.
*   We may have multiple different ways to save and load `Point` that users want
    to enable in different contexts. In such cases, placing all that different
    logic within the single `Point` class may add too much complexity.

For such cases (and others), Orbax provides an interface called
`CheckpointableHandler`.

First, let's redefine our `Point` class and also introduce an `AbstractPoint`
class. This allows us to specify the type of `x` or `y` that should be used for
loading.

In [None]:
del Point
import asyncio
from typing import Type

Scalar = int | float


@dataclasses.dataclass
class Point:
  x: Scalar
  y: Scalar


@dataclasses.dataclass
class AbstractPoint:
  x: Type[Scalar]
  y: Type[Scalar]

In [None]:
async def _write_point(
    directory: ocp.path.Path, checkpointable: dict[str, Scalar]
):
  async with aiofiles.open(directory / 'point.txt', 'w') as f:
    contents = json.dumps(checkpointable)
    await f.write(contents)


async def _write_point_metadata(
    directory: ocp.path.Path, checkpointable: dict[str, Scalar]
):
  async with aiofiles.open(directory / 'point_metadata.txt', 'w') as f:
    contents = json.dumps(
        {k: type(v).__name__ for k, v in checkpointable.items()}
    )
    await f.write(contents)


class PointHandler(ocp.CheckpointableHandler[Point, AbstractPoint]):

  async def _background_save(
      self,
      directory: ocp.path.PathAwaitingCreation,
      checkpointable: dict[str, Scalar],
  ):
    if jax.process_index() == 0:
      directory = await directory.await_creation()
      await asyncio.gather(
          _write_point(directory, checkpointable),
          _write_point_metadata(directory, checkpointable),
      )

  async def _background_load(
      self,
      directory: ocp.path.Path,
      abstract_checkpointable: AbstractPoint | None = None,
  ) -> Point:
    async with aiofiles.open(directory / 'point.txt', 'r') as f:
      contents = json.loads(await f.read())
      if abstract_checkpointable is None:
        return Point(**contents)
      else:
        return Point(
            abstract_checkpointable.x(contents['x']),
            abstract_checkpointable.y(contents['y']),
        )

  async def save(
      self,
      directory: ocp.path.PathAwaitingCreation,
      checkpointable: Point,
  ) -> Awaitable[None]:
    return self._background_save(directory, dataclasses.asdict(checkpointable))

  async def load(
      self,
      directory: ocp.path.Path,
      abstract_checkpointable: AbstractPoint | None = None,
  ) -> Awaitable[Point]:
    return self._background_load(directory, abstract_checkpointable)

  async def metadata(self, directory: ocp.path.Path) -> AbstractPoint:
    async with aiofiles.open(directory / 'point_metadata.txt', 'r') as f:
      contents = json.loads(await f.read())
      return AbstractPoint(
          **{k: getattr(__builtins__, v) for k, v in contents.items()}
      )

  def is_handleable(self, checkpointable: Any) -> bool:
    return isinstance(checkpointable, Point)

  def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool:
    return isinstance(abstract_checkpointable, AbstractPoint)

This class associates itself with two types, the `Checkpointable` and the
`AbstractCheckpointable` (`Point` and `AbstractPoint` in this case). `Point` is
the input for saving, and `AbstractPoint` (or `None`) is the input for loading
(as well as the parent directory in both cases).

Saving logic in this class is essentially the same as in our
`StatefulCheckpointable` definition above.

Loading is different because loading is no longer stateful - it instead accepts
an optional `AbstractPoint` and returns a newly constructed `Point`. Providing
`None` as the input indicates that the object should simply be restored exactly
as it was saved. (Note that for some objects, this may not be possible, and it
may be necessary to raise an error if some input from the user is required to
know how to load.) Otherwise, the provided `AbstractCheckpointable` serves as
the guide describing how the concrete loaded object (`Point` in this case)
should be constructed.

We also have the capability of defining a `metadata` method in this class. In
the case of `Point`, the object is obviously quite lightweight already. For real
use cases, the checkpoint itself may be expensive to load fully, and some
metadata describing important properties that can be loaded cheaply is
essential. The `metadata` method should return an instance of
`AbstractCheckpointable`.

Finally, two additional methods, `is_handleable` and `is_abstract_handleable`
should be defined. These methods accept any object, and decide whether the given
object is an acceptable input for saving or loading, respectively. In most
cases, a simple `isinstance` check will suffice, but for more generic
constructs, like `PyTree`s, more involved logic is necessary.

We can now register `PointHandler` in order to deal with `Point` objects.

In [None]:
ocp.handlers.register_handler(PointHandler)

In [None]:
ocp.save_checkpointables(
    '/tmp/ckpt2',
    dict(model_state=model_state, point=Point(1, 2.4)),
)

Since the `AbstractPoint` is optional, we do not need to specify any arguments
to load everything successfully.

In [None]:
ocp.load_checkpointables('/tmp/ckpt2')

However, if desired, we can specify an abstract checkpointable to customize the
dtypes of the restored values.

In [None]:
ocp.load_checkpointables(
    '/tmp/ckpt2', dict(point=AbstractPoint(x=float, y=int))
)

We can use `checkpointables_metadata` to load the metadata, in the form of an
`AbstractPoint`.

In [None]:
ocp.checkpointables_metadata('/tmp/ckpt2').metadata['point']

## Custom Leaf Handler

This is an advanced topic.  Make sure you are familar with [the guide on checkpointing PyTrees](checkpointing_pytrees.ipynb) before reading this notebook.

PyTrees are a common tree structure used to represent training states. LeafHandlers are responsible for serializing and deserializing each leaf node. Different leaf object types require specific LeafHandlers. Orbax includes standard LeafHandlers for common types including jax.Array, np.ndarray, int, float, and str. Before creating a custom LeafHandler, always check the options available in ocp.options.PytreeOptions and ocp.ption.ArrayOptions to ensure no existing options can meet your needs.

One of common reasons to have a custom LeafHandler is to support a custom type that is not supported by Orbax.  I will use the `Point` class from above as the example.  Let's say you need to checkpoint many Point objects in a nested tree structure.  It might make sense to store it within a Pytree along with your train state.  Then you would need to write a PointLeafHandler and register it with the LeafHandlerRegistry.

In [None]:
import dataclasses
import json
from typing import Awaitable, Type
from etils import epath
import numpy as np
from orbax.checkpoint import multihost
import orbax.checkpoint.experimental.v1 as ocp
from orbax.checkpoint.experimental.v1 import serialization


@dataclasses.dataclass
class Point:
  x: int | float
  y: int | float

For LeafHandler, we need to define a AbtractPoint class as well.  This is required for two reasons:
1. The AbstractPoint class is used during restoration to indicate what type of a leaf object will be restored as.
2. In addition, metadata of a leaf node will be returned as AbstractPoint, avoid the need to restore the actual leaf object.

In following example of AbstractPoint, we just define it as the type of data members without actual values.

In [None]:
@dataclasses.dataclass
class AbstractPoint:
  x: Type[int|float]
  y: Type[int|float]

  @classmethod
  def from_point(cls, point):
    return cls(x=type(point.x), y=type(point.y))


Next we will define the actual PointLeafHandler.  See the comments below which explain what functions are required.

In [None]:
from typing import Sequence
import asyncio
import aiofiles

In [None]:
class PointLeafHandler(serialization.LeafHandler[Point, AbstractPoint]):
  """A custom leaf handler for testing."""

  def __init__(self, context: ocp.Context | None = None):
    """Required Initializer.

    This initializer is initialized lazily during checkpoint operations.  If the
    signature is not matched, an exception will be raised during initialization.

    Args:
      context: The context for the leaf handler.  The leaf handler can
        initialize and operate according to the context.  In this example, we do
        not utilize it though.  For more examples, see ArrayLeafHandler.
    """
    del context

  async def serialize(
      self,
      params: Sequence[serialization.SerializationParam[Point]],
      serialization_context: serialization.SerializationContext,
  ) -> Awaitable[None]:
    """Required Serialize function.

    This function writes the specified leaves of a checkpointable to a storage
    location.  A couple of notes here:
    1. This function is called on all hosts, but in this example, only the
    primary host will write.
    2. we use `await await_creation()` to ensure the parent directory is created
    before writing.
    """

    async def _background_serialize(params, serialization_context):
      # make sure the parent directory is created
      await serialization_context.parent_dir.await_creation()

      # only the primary host writes
      if multihost.is_primary_host(0):
        for param in params:
          # save the value
          async with aiofiles.open(
              serialization_context.parent_dir.path / f'{param.name}.txt',
              'w',
          ) as f:
            await f.write(json.dumps(dataclasses.asdict(param.value)))

          # save the metadata
          async with aiofiles.open(
              serialization_context.parent_dir.path
              / f'{param.name}.metadata.txt',
              'w',
          ) as abstract_f:
            contents = json.dumps({
                k: type(v).__name__
                for k, v in dataclasses.asdict(param.value).items()
            })
            await abstract_f.write(contents)

    return _background_serialize(params, serialization_context)

  async def deserialize(
      self,
      params: Sequence[serialization.DeserializationParam[AbstractPoint]],
      deserialization_context: serialization.DeserializationContext,
  ) -> Awaitable[Sequence[Point]]:
    """Required Deserialize function.

    Returns sequence of leaves from a stored checkpointable location. Note that
    we use asyncio.to_thread to ensure the deserialization is performed in a
    background thread immediately before returning this call.
    """

    async def _deserialize_impl():
      ret = []
      for param in params:
        async with aiofiles.open(
            deserialization_context.parent_dir / f'{param.name}.txt',
            'r',
        ) as f:
          ret.append(Point(**json.loads(await f.read())))

      return ret

    return _deserialize_impl()

  async def metadata(
      self,
      params: Sequence[serialization.DeserializationParam[None]],
      deserialization_context: serialization.DeserializationContext,
  ) -> Sequence[AbstractPoint]:
    """Required Metadata function.

    Returns a sequence of metadata that helps to describe the available leaves
    in this checkpoint location.
    """

    ret = []
    for param in params:
      async with aiofiles.open(
          deserialization_context.parent_dir / f'{param.name}.metadata.txt', 'r'
      ) as f:
        contents = json.loads(await f.read())
        ret.append(
            AbstractPoint(
                **{k: getattr(__builtins__, v) for k, v in contents.items()}
            )
        )
    return ret

Next, we will define a train_state for demostration purpose.  In this train_state, it has some common types as well as some Points that are nested inside the PyTree.

In [None]:
# define a PyTree Train State

train_state = {
    'a': np.arange(16),
    'b': np.ones(16),
    'scalar': 123.0,
    'mixed': {
        'a': np.arange(16),
        'b': np.ones(16),
        'scalar': 123.0,
        'Point': Point(0, 0.5),
    },
    'Points': {
        'level1': {
            'point_int': Point(1, 2),
            'point_float': Point(3.0, 4.0),
            'level2': {
                'point_mixed1': Point(5, 6.0),
                'point_mixed2': Point(7.0, 8),
                'point_int': Point(9, 10),
                'point_float': Point(11.0, 12.0),
            },
        }
    },
}

Next, we will prepare a LeafHandlerRegistry.  In this registry, the type and its abstract type will map with a LeafHandler.  In the following example, we create a `StandardLeafHandler` first.  This is the same as the registry used by default.  Then PointLeafHandler is added along its type Point and abstract type AbstractPoint.  Note that only the `PointLeafHandler` type is registered, not the handler instance.  The instance will be created lazily depending on checkpoint operations.

In [None]:
# Create LeafHandlerRegistry
registry = serialization.StandardLeafHandlerRegistry() # with standard handlers
registry.add(Point, AbstractPoint, PointLeafHandler) # add custom handler

In [None]:
# prepare the checkpoint directory
path = epath.Path('/tmp/with_points')
path.rmtree(missing_ok=True)

Now, we are ready to save the `train_state`.  To customize context and pass the custom registry, you can use the `ocp.Context` as below.

In [None]:
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
            leaf_handler_registry=registry
    )
):
  ocp.save_pytree(path, train_state)

After saving, let's load the checkpoint back to see if we can get back the expected Point objects.  We will again create a ocp.Context with our custom registry.

In [None]:
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
            leaf_handler_registry=registry
    )
):
  restored_train_state = ocp.load_pytree(path)

In [None]:
import pprint
pprint.pprint(restored_train_state)

We can see the restored_train_state looks exactly the same as the original train_state.

Finally, we also want to see if we can read the expected metadata.  Similarly, we will use ocp.Context to use our registry with the custom PointLeafHandler.

In [None]:
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
            leaf_handler_registry=registry
    )
):
  restored_metadata = ocp.pytree_metadata(path)

We can see the AbstractPoints are returned for Point leaves.

In [None]:
pprint.pprint(restored_metadata.metadata)