From fe3d2002189128f87c7aa468f820b3668a9c9390 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 10 Feb 2026 00:49:41 +0000 Subject: [PATCH] Add the option to avoid loading the checkpoint twice --- src/MaxText/rl/train_rl.py | 28 +++++++++++++++++++++++---- src/maxtext/configs/post_train/rl.yml | 2 +- src/maxtext/configs/types.py | 1 + 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index e5584b3094..70a76d9ce3 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -358,10 +358,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): ) # TODO: @mazumdera: change this to use lora - # TODO: @xfgu: instead of restoring a second time from GCS, can we just copy reference_model - # Load policy model - max_logging.log("Creating policy model with same config as reference model on trainer mesh") - actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + if trainer_config.load_checkpoint_only_once: + max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") + with reference_mesh: + actor_base_model = nnx.clone(reference_model.base) + use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model.config = None + actor_mesh = reference_mesh + else: + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + if trainer_config.debug.rl: max_logging.log("Policy Model initialized successfully") @@ -530,11 +538,23 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): # Start training + if trainer_config.load_checkpoint_only_once: + max_logging.log("Capturing reference model state before training.") + ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param)) + max_logging.warning("Starting RL training...") with reference_mesh, nn_partitioning.axis_rules(trainer_config.logical_axis_rules): rl_trainer.train(train_dataset) + if trainer_config.load_checkpoint_only_once: + max_logging.log("Checking if reference model state changed during training.") + ref_state_after = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param)) + check = jax.tree_util.tree_map(jax.numpy.array_equal, ref_state_before, ref_state_after) + if not jax.tree_util.tree_all(check): + raise ValueError("Reference model parameters changed during training!") + max_logging.log("Reference model parameters verified to be unchanged during training.") + max_logging.warning("RL Training Completed Successfully!") # Let's evaluate our model! diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 5a8f57f664..9d741e7a8c 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -15,7 +15,7 @@ # RL Configuration # This config consolidates common parameters for RL training across different model sizes -base_config: "base.yml" +base_config: "../base.yml" # ====== Hardware ===== trainer_devices_fraction: 0.5 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1043dbc3ed..91ef50bd26 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -288,6 +288,7 @@ class Checkpointing(BaseModel): lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.") load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.") enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.") + load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.") async_checkpointing: bool = Field(True, description="If True, uses an asynchronous checkpointer for performance.") checkpoint_period: int = Field(10_000, description="The frequency (in steps) at which to save checkpoints.") max_num_checkpoints_to_keep: int | None = Field(None, description="Maximum number of checkpoints to keep.")