# Orbax Saving Tests

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

In [2]:
# Code Wassim


def save_model(state, model_path):
    """
    Save the Flax model using orbax.

    Parameters:
    - state: Flax train_state to be saved.
    - model_path: Path where the model will be saved.
    """
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target({'model': state})
    target_dir = f"{model_path}/model"
    os.makedirs(target_dir, exist_ok=True)
    # The following line is new and allows overwriting of an existing checkpoint
    orbax_checkpointer.save(f"{model_path}/model", {'model': state}, save_args=save_args, force=True)

def load_model(model_path):
    """
    Load the Flax model using orbax.

    Parameters:
    - model_path: Path from where the model will be loaded.

    Returns:
    - state: Loaded Flax train_state.
    """
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    restored_data = orbax_checkpointer.restore(model_path)
    state = restored_data['model']
    params = state['params']
    model_config = state['model_config']

    return load_train_state(params, model_config, learning_rate = 1e-3, total_steps = 10)

# Few tests

In [3]:
import orbax.checkpoint as ocp

In [4]:
path = ocp.test_utils.create_empty('/Users/mag/Documents/PHD1Y/Space_Work/Pixel_non_P2D/MICMAC/test_playground/test_Orbax/save_orbax/')
state = {
    'a': np.zeros(1),
    'b': np.zeros(1),
}
extra_params = [42, 43]

In [5]:
options = ocp.CheckpointManagerOptions(max_to_keep=2, save_interval_steps=1)

In [6]:
mngr = ocp.CheckpointManager(
    path, options=options, item_names=('state', 'extra_params')
)

for step in range(11):  # [0, 1, ..., 10]
  mngr.save(
      step,
      args=ocp.args.Composite(
          state=ocp.args.StandardSave({'a':np.zeros(step+1)+step,'b':step}),
          extra_params=ocp.args.JsonSave(extra_params),
      ),
  )
mngr.wait_until_finished()
restored = mngr.restore(10)
restored_state, restored_extra_params = restored.state, restored.extra_params

In [8]:
mngr.all_steps()

[9, 10]

In [9]:
mngr.latest_step()

10

In [10]:
mngr.should_save(10)

False

In [11]:
restored = mngr.restore(
    mngr.latest_step(),
)


In [12]:
restored

CompositeArgs({'extra_params': [42, 43], 'state': {'a': array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.]), 'b': 10}})