In [12]:
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np

ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints/")


class ThreeLayerMLP(nnx.Module):
    def __init__(self, dim, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
        self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
        self.linear3 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
        self.dropout1 = nnx.Dropout(rate=0.4, rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=0.4, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.dropout2(x)
        return self.linear3(x)


# Instantiate the model and show we can run it.
model = ThreeLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))

model.eval()
assert model(x).shape == (3, 4)

_, state = nnx.split(model)
nnx.display(state)

checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", state)

State({
  'dropout1': {
    'rngs': {
      'default': {
        'count': VariableState(
          type=RngCount,
          value=Array(3, dtype=uint32),
          tag='default'
        ),
        'key': VariableState(
          type=RngKey,
          value=Array((), dtype=key<fry>) overlaying:
          [0 0],
          tag='default'
        )
      }
    }
  },
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=None
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[-0.80345297, -0.34071913, -0.9408296 ,  0.01005968],
             [ 0.26146442,  1.1247735 ,  0.54563737, -0.374164  ],
             [ 1.0281805 , -0.6798804 , -0.1488401 ,  0.05694951],
             [-0.44308168, -0.60587114,  0.434087  , -0.40541083]],      dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=None
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[-0.7430909 , -0.8467984 ,  0.3140029 , -0.2883

ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: /tmp/my-checkpoints/state
Traceback (most recent call last):
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 132, in _thread_func
    future.result()
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py", line 78, in result
    f.result(timeout=time_remaining)
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py", line 250, in result
    return self._t.join(timeout=timeout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py", line 62, in join
    raise self._exception
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py",