Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
)

# TODO: @mazumdera: change this to use lora
# TODO: @xfgu: instead of restoring a second time from GCS, can we just copy reference_model
# Load policy model
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
if trainer_config.load_checkpoint_only_once:
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
with reference_mesh:
actor_base_model = nnx.clone(reference_model.base)
use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config
actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings)
actor_model.config = None
actor_mesh = reference_mesh
else:
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)


if trainer_config.debug.rl:
max_logging.log("Policy Model initialized successfully")
Expand Down Expand Up @@ -530,11 +538,23 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):

# Start training

if trainer_config.load_checkpoint_only_once:
max_logging.log("Capturing reference model state before training.")
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))

max_logging.warning("Starting RL training...")

with reference_mesh, nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
rl_trainer.train(train_dataset)

if trainer_config.load_checkpoint_only_once:
max_logging.log("Checking if reference model state changed during training.")
Comment thread
xuefgu marked this conversation as resolved.
ref_state_after = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
check = jax.tree_util.tree_map(jax.numpy.array_equal, ref_state_before, ref_state_after)
if not jax.tree_util.tree_all(check):
raise ValueError("Reference model parameters changed during training!")
max_logging.log("Reference model parameters verified to be unchanged during training.")

max_logging.warning("RL Training Completed Successfully!")

# Let's evaluate our model!
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# RL Configuration
# This config consolidates common parameters for RL training across different model sizes

base_config: "base.yml"
base_config: "../base.yml"

# ====== Hardware =====
trainer_devices_fraction: 0.5
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ class Checkpointing(BaseModel):
lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.")
load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.")
enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.")
load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.")
Comment thread
xuefgu marked this conversation as resolved.
async_checkpointing: bool = Field(True, description="If True, uses an asynchronous checkpointer for performance.")
checkpoint_period: int = Field(10_000, description="The frequency (in steps) at which to save checkpoints.")
max_num_checkpoints_to_keep: int | None = Field(None, description="Maximum number of checkpoints to keep.")
Expand Down
Loading