In [19]:
import numpy as np
import orbax.checkpoint as ocp
import jax

In [20]:
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')

In [21]:
my_tree = {
    'a': np.arange(8),
    'b': {
        'c': 42,
        'd': np.arange(16),
    },
}
abstract_my_tree = jax.tree_util.tree_map(
    ocp.utils.to_shape_dtype_struct, my_tree)

In [22]:
abstract_my_tree

{'a': ShapeDtypeStruct(shape=(8,), dtype=int64),
 'b': {'c': 42, 'd': ShapeDtypeStruct(shape=(16,), dtype=int64)}}

In [23]:
checkpointer = ocp.StandardCheckpointer()
# 'checkpoint_name' must not already exist.
checkpointer.save(path / 'checkpoint_name', my_tree)


In [24]:
checkpointer.restore(
    path / 'checkpoint_name/',
    abstract_my_tree
)


{'a': array([0, 1, 2, 3, 4, 5, 6, 7]),
 'b': {'c': 42,
  'd': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])}}

In [25]:
checkpointer.metadata(path / 'checkpoint_name')

TreeMetadata(
  custom_metadata=None
  tree={'a': ArrayMetadata :  name=a,  directory=/tmp/my-checkpoints/checkpoint_name,  shape=(8,),  sharding=None,  dtype=int64,  storage=StorageMetadata(chunk_shape=(8,), write_shape=None),, 'b': {'c': ScalarMetadata(name='b.c', directory=PosixGPath('/tmp/my-checkpoints/checkpoint_name'), shape=(), sharding=None, dtype=dtype('int64'), storage=None), 'd': ArrayMetadata :  name=b.d,  directory=/tmp/my-checkpoints/checkpoint_name,  shape=(16,),  sharding=None,  dtype=int64,  storage=StorageMetadata(chunk_shape=(16,), write_shape=None),}}
)

In [28]:
metadata = {
    "version": "1.0",
    "lang":"en"
}
checkpointer = ocp.Checkpointer(
    ocp.CompositeCheckpointHandler()
)


In [29]:
checkpointer.save(
    path / 'composite_checkpoint',
    args=ocp.args.Composite(
        state=ocp.args.StandardSave(my_tree),
        metadata=ocp.args.JsonSave(metadata),
    ),
)

In [30]:
restored = checkpointer.restore(path / "composite_checkpoint")




In [31]:
restored.state

{'a': array([0, 1, 2, 3, 4, 5, 6, 7]),
 'b': {'c': 42,
  'd': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])}}

In [32]:
list((path / 'composite_checkpoint').iterdir())

[PosixGPath('/tmp/my-checkpoints/composite_checkpoint/state'),
 PosixGPath('/tmp/my-checkpoints/composite_checkpoint/_CHECKPOINT_METADATA'),
 PosixGPath('/tmp/my-checkpoints/composite_checkpoint/metadata')]

In [36]:
path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager')
state = {
    'a': np.arange(8),
    'b': np.arange(16),
}
extra_params = [42, 43]

In [37]:
options = ocp.CheckpointManagerOptions(max_to_keep=10, save_interval_steps=2)
mngr = ocp.CheckpointManager(
    path, options=options, item_names=('state', 'extra_params')
)

In [38]:
for step in range(10):
    mngr.save(step, args=ocp.args.Composite(
        state=ocp.args.StandardSave(state),
        extra_params=ocp.args.JsonSave(extra_params),
    ))

mngr.wait_until_finished()

In [39]:
restored = mngr.restore(8)



In [18]:
import jax
import orbax.checkpoint as ocp
import jax.numpy as jnp
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager_sharded')

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec('model'),
)

In [19]:
create_sharded_array = lambda x: jax.device_put(x,sharding)

In [20]:
train_state = {
    "a": jnp.arange(16),
    "b": jnp.ones(16),
}
train_state = jax.tree.map(create_sharded_array, train_state)

In [21]:
jax.tree.map(lambda x: x.sharding,train_state)

{'a': NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host),
 'b': NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host)}

In [22]:
num_steps = 10
options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)
mngr = ocp.CheckpointManager(path, options=options)

@jax.jit
def train_fn(state):
    return jax.tree.map(lambda x: x + 1, state)


for step in range(num_steps):
    train_state = train_fn(train_state)
    mngr.save(step, args=ocp.args.StandardSave(train_state))

mngr.wait_until_finished()


In [24]:
mngr.restore(mngr.latest_step())



{'a': Array([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],      dtype=int32),
 'b': Array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.,
        10., 10., 10.], dtype=float32)}

In [27]:
import jax
import jax.numpy as jnp
train_state = {
    "a": jnp.arange(16),
    "b": jnp.ones(16),
}

train_state = jax.tree.map(jnp.zeros_like, train_state)

In [28]:
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec('model'),
)

In [29]:
create_sharded_array = lambda x: jax.device_put(x, sharding)

In [30]:
train_state = jax.tree.map(create_sharded_array,train_state)

In [31]:
abstract_train_state = jax.tree.map(ocp.utils.to_shape_dtype_struct,train_state)

In [None]:
restored = mngr.restore(
    mngr.latest_step(),
)


In [37]:
jax.tree.map(lambda x:x.sharding,restored)

{'a': NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host),
 'b': NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec('model',), memory_kind=unpinned_host)}