In [120]:
from flax import nnx
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp
import os

In [121]:
class MyModel(nnx.Module):
  def __init__(self, inSize: int, outSize: int, *, rngs: nnx.Rngs):
    intermediateSize = 64
    key = rngs.params()
    self.linear1 = nnx.Linear(inSize, intermediateSize, rngs=rngs)
    self.linear2 = nnx.Linear(intermediateSize, outSize, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x

In [122]:
rngs = nnx.Rngs(jax.random.key(0))
model = MyModel(4, 1, rngs=rngs)
tx = optax.adam(1e-3)
optimizerState = nnx.Optimizer(model, tx)

In [123]:
def lossFn(model, x, y):
  yHat = model(x)
  return jnp.mean((y - yHat)**2)

In [135]:
valueAndGrad = nnx.value_and_grad(lossFn)
input = jnp.array([1., 2., 3., 4.])
output = jnp.array([5.])
print(model(input))
value, grad = valueAndGrad(model, input, output)

optimizerState.update(grad)
print(model(input))

print(f'Optimizer:')
nnx.display(optimizerState)
print(f'Actual optimizer state:')
nnx.display(optimizerState.opt_state)

graph, state = nnx.split(model)

abstract_model = nnx.eval_shape(lambda: MyModel(4, 1, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

with ocp.StandardCheckpointer() as ckptr:
  path = os.getcwd()
  ckptr.save(path+'/standard-ckpt-1', state, force=True)
  ckptr.save(path+'/standard-ckpt-1_opt', nnx.state(optimizerState.opt_state), force=True)
  ckptr.wait_until_finished()
  result = ckptr.restore(path+'/standard-ckpt-1', abstract_state)
  optimizerState.opt_state = ckptr.restore(path+'/standard-ckpt-1_opt', nnx.state(optimizerState.opt_state))
  print(result)
  newModel = nnx.merge(graphdef, result)
  print(newModel(input))

  nnx.display(optimizerState.opt_state)

[-1.3875928]
[-1.2977124]
Optimizer:


Actual optimizer state:


[38;2;79;201;177mState[0m[38;2;255;213;3m({[0m[38;2;105;105;105m[0m
  [38;2;156;220;254m'linear1'[0m[38;2;212;212;212m: [0m[38;2;255;213;3m{[0m[38;2;105;105;105m[0m
    [38;2;156;220;254m'bias'[0m[38;2;212;212;212m: [0m[38;2;79;201;177mVariableState[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 64 (256 B)[0m
      [38;2;156;220;254mtype[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m,
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray([-0.01184409,  0.01198999,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        , -0.01184019,  0.        ,  0.        ,
             -0.01189545,  0.        ,  0.        , -0.01187441,  0.01194797,
              0.        ,  0.        ,  0.        ,  0.        ,  0.01199964,
              0.        , -0.01178404,  0.        , -0.01182538,  0.        ,
              0.        ,  0.        ,  0.        , -0.01182441,  0.        ,
             -0.01187265, -0.01189789,  0.        ,  0.01195829