Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def create_orbax_checkpoint_manager(
enable_single_controller: bool = False,
colocated_python_checkpointing: bool = False,
enable_single_replica_ckpt_restoring: bool = False,
enable_autocheckpoint: bool = False,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand Down Expand Up @@ -248,11 +249,21 @@ def create_orbax_checkpoint_manager(
# local storage checkpoint needs parent directory created
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
if enable_continuous_checkpointing:
max_logging.log("Enabling policy for continuous checkpointing.")
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
elif enable_autocheckpoint:
max_logging.log("Enabling policy for autocheckpoint.")
save_decision_policy = save_decision_policy_lib.AnySavePolicy(
[
save_decision_policy_lib.PreemptionCheckpointingPolicy(),
save_decision_policy_lib.FixedIntervalPolicy(save_interval_steps),
]
)
else:
max_logging.log("Enabling policy for fixed interval checkpointing.")
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps)
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)

async_options = None
if enable_continuous_checkpointing:
async_options = ocp.AsyncOptions(
Expand Down Expand Up @@ -752,6 +763,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
or (step % config.checkpoint_period == 0)
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
or (config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step))
Comment thread
abhinavclemson marked this conversation as resolved.
):
blocking_until_ready_start = time.time()
max_logging.log(f"Waiting for step {step} to finish before checkpoint...")
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ source_checkpoint_layout: "orbax"

# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
colocated_python_checkpointing: False

# enables autocheckpoint, which saves a checkpoint at the preemption step.
enable_autocheckpoint: False
############################### end checkpointing ##################################


Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ class Checkpointing(BaseModel):
False,
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
)
enable_autocheckpoint: bool = Field(
False, description="If True, enables autocheckpoint or preemption induced checkpointing."
)


class OrbaxStorage(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# pylint: disable=bare-except, consider-using-generator
""" Utils that are only interesting for training in MaxText. """
"""Utils that are only interesting for training in MaxText."""

import os
import jax
Expand Down Expand Up @@ -82,6 +82,7 @@ def create_training_tools(config, model, mesh):
config.enable_single_controller,
config.colocated_python_checkpointing,
config.enable_single_replica_ckpt_restoring,
config.enable_autocheckpoint,
)

return init_rng, checkpoint_manager, learning_rate_schedule, tx
Expand Down
Loading