In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [2]:
from typing import Optional, Any
import shutil

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint

import optax

2024-09-29 18:10:08.268015: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-29 18:10:08.268092: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-29 18:10:08.302040: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

In [4]:
# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
ckpt



{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = None
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7cddf35a6ca0>, update=<function chain.<locals>.update_fn at 0x7cddf35667a0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

In [6]:
from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/home/research/github/one_policy_to_run_them_all/notebooks/tmp', ckpt, save_args=save_args)

In [8]:
raw_restored = orbax_checkpointer.restore('/home/research/github/one_policy_to_run_them_all/notebooks/tmp')
raw_restored



{'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': Array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}