2020from absl import flags
2121import datetime
2222from etils import epath
23+ from flax import nnx
2324from flax .training import train_state
2425import jax
2526from maxtext .utils .globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
@@ -536,7 +537,7 @@ def load_state_if_possible(
536537 load_parameters_from_path : str ,
537538 load_full_state_from_path : str ,
538539 checkpoint_storage_concurrent_gb : int ,
539- abstract_unboxed_pre_state : train_state .TrainState ,
540+ abstract_unboxed_pre_state : train_state .TrainState | nnx . State ,
540541 enable_single_replica_ckpt_restoring : bool | None = False ,
541542 dataset_type : str | None = "tfds" ,
542543 step : int = - 1 , # -1 means latest
@@ -604,9 +605,14 @@ def map_to_pspec(data):
604605 )
605606 ocp .type_handlers .register_type_handler (jax .Array , array_handler , override = True )
606607
607- restore_args = jax .tree_util .tree_map (map_to_pspec , abstract_unboxed_pre_state )
608+ # Convert nnx.State to pure dict to match how checkpoints are saved for NNX
609+ restore_target = abstract_unboxed_pre_state
610+ if isinstance (abstract_unboxed_pre_state , nnx .State ):
611+ restore_target = abstract_unboxed_pre_state .to_pure_dict ()
612+
613+ restore_args = jax .tree_util .tree_map (map_to_pspec , restore_target )
608614 checkpoint_args = ocp .args .PyTreeRestore (
609- item = abstract_unboxed_pre_state ,
615+ item = restore_target ,
610616 restore_args = restore_args ,
611617 partial_restore = True ,
612618 )
@@ -620,9 +626,7 @@ def map_to_pspec(data):
620626 (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
621627 ):
622628 return (
623- checkpoint_manager .restore (
624- step , args = Composite (state = checkpoint_args )
625- ).state ,
629+ checkpoint_manager .restore (step , args = Composite (state = checkpoint_args )).state ,
626630 None ,
627631 )
628632 # Case 2: Matches if dataset type is "grain" and the data iterator is not a
@@ -647,9 +651,14 @@ def map_to_pspec(data):
647651 return (checkpoint_manager .restore (step , args = Composite (items = checkpoint_args )), None )
648652
649653 if load_parameters_from_path != "" :
654+ if isinstance (abstract_unboxed_pre_state , nnx .State ):
655+ _ , params , _ = nnx .split (abstract_unboxed_pre_state .model , nnx .Param , ...)
656+ else :
657+ params = abstract_unboxed_pre_state .params
658+
650659 restored_params = load_params_from_path (
651660 load_parameters_from_path ,
652- abstract_unboxed_pre_state . params ,
661+ params ,
653662 checkpoint_storage_concurrent_gb ,
654663 use_ocdbt = use_ocdbt ,
655664 use_zarr3 = use_zarr3 ,
@@ -741,7 +750,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
741750 # Determine the effective step for saving a checkpoint.
742751 # If 'step' is not provided, this call is for a potential final checkpoint
743752 # and use the last completed step from the state.
744- actual_step = (int (state .step ) - 1 ) if step is None else int (step )
753+ if step is not None :
754+ actual_step = int (step )
755+ else :
756+ if config .pure_nnx :
757+ actual_step = int (state .optimizer .step ) - 1
758+ else :
759+ # Linen TrainState has .step attribute
760+ actual_step = int (state .step ) - 1
761+
762+ if config .pure_nnx :
763+ # Convert nnx.State to dict.
764+ state = state .to_pure_dict ()
745765
746766 # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
747767 # This occurs if this function was called:
0 commit comments