# Orbax Saving Tests

In [1]:
import jax
import jax.lax as jlax
import jax.numpy as jnp
import numpy as np
from jax import config

In [2]:
config.update('jax_enable_x64', True)

# Few tests

In [3]:
import orbax.checkpoint as ocp

In [4]:
path = ocp.test_utils.create_empty('/Users/mag/Documents/PHD1Y/Space_Work/Pixel_non_P2D/MICMAC/test_playground/test_Orbax/save_orbax/')
state = {
    'r': 0,
    'B_f': np.zeros((4,2)),
}

In [5]:
num_steps_save = 5
num_steps = 15

In [6]:
save_interval_steps = 5
options = ocp.CheckpointManagerOptions(max_to_keep=None, save_interval_steps=save_interval_steps)

In [8]:
mngr = ocp.CheckpointManager(
    path, options=options, item_names=('state')
)

mngr = ocp.type_handlers.ArrayHandler('array_save')

In [9]:
save_array = jnp.zeros((num_steps_save, 9))

def func_update_r_B_f(carry, iteration):
    # r, B_f, save_array = carry
    pytree = carry
    new_r = pytree['r']+0.01
    new_B_f = pytree['B_f']+0.01

    new_save_array = jnp.copy(pytree['save_array'])

    new_save_array = new_save_array.at[iteration%num_steps_save,:-1].set(new_B_f.ravel())
    new_save_array = new_save_array.at[iteration%num_steps_save,-1].set(new_r)
    
    # mngr.save(
    #     iteration,
    #     args=ocp.args.Composite(
    #         state=ocp.args.StandardSave({'r':new_save_array[:,-1],'B_f':new_save_array[:,:-1]}),
    #         custom_metadata=args.,
    #     ),
    # )
    mngr.save(
        iteration,
        state=ocp.args.StandardSave({'save_array':new_save_array}),
    )
    orbax.checkpoint.array_checkpoint_handler.ArraySaveArgs
    # new_pytree = jax.tree_util.tree_flatten()
    return {'r':new_r,'B_f':new_B_f,'save_array':new_save_array}, {'r':new_r,'B_f':new_B_f}


dict_init_params = state.copy()
dict_init_params['save_array'] = save_array
pytree_input, treedef = jax.tree_util.tree_flatten(dict_init_params)

carry, result = jlax.scan(func_update_r_B_f, dict_init_params, jnp.arange(0, num_steps))

mngr.wait_until_finished()
restored = mngr.restore(15)
restored_state, restored_extra_params = restored.state, restored.extra_params

AttributeError: 'ArrayHandler' object has no attribute 'save'

In [None]:
carry

{'B_f': Array([[0.15, 0.15],
        [0.15, 0.15],
        [0.15, 0.15],
        [0.15, 0.15]], dtype=float64),
 'r': Array(0.15, dtype=float64),
 'save_array': Array([[0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11],
        [0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12],
        [0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13],
        [0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14],
        [0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15]],      dtype=float64)}

In [None]:
result

{'B_f': Array([[[0.01, 0.01],
         [0.01, 0.01],
         [0.01, 0.01],
         [0.01, 0.01]],
 
        [[0.02, 0.02],
         [0.02, 0.02],
         [0.02, 0.02],
         [0.02, 0.02]],
 
        [[0.03, 0.03],
         [0.03, 0.03],
         [0.03, 0.03],
         [0.03, 0.03]],
 
        [[0.04, 0.04],
         [0.04, 0.04],
         [0.04, 0.04],
         [0.04, 0.04]],
 
        [[0.05, 0.05],
         [0.05, 0.05],
         [0.05, 0.05],
         [0.05, 0.05]],
 
        [[0.06, 0.06],
         [0.06, 0.06],
         [0.06, 0.06],
         [0.06, 0.06]],
 
        [[0.07, 0.07],
         [0.07, 0.07],
         [0.07, 0.07],
         [0.07, 0.07]],
 
        [[0.08, 0.08],
         [0.08, 0.08],
         [0.08, 0.08],
         [0.08, 0.08]],
 
        [[0.09, 0.09],
         [0.09, 0.09],
         [0.09, 0.09],
         [0.09, 0.09]],
 
        [[0.1 , 0.1 ],
         [0.1 , 0.1 ],
         [0.1 , 0.1 ],
         [0.1 , 0.1 ]],
 
        [[0.11, 0.11],
         [0.11, 0.11

In [None]:
mngr.all_steps()

[9, 10]

In [None]:
mngr.latest_step()

10

In [None]:
mngr.should_save(10)

False

In [None]:
restored = mngr.restore(
    mngr.latest_step(),
)


In [None]:
restored

CompositeArgs({'extra_params': [42, 43], 'state': {'a': array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.]), 'b': 10}})