Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ephemeral checkpoints #397

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ class FSDPConfig(BaseConfig):
class CheckpointType(StrEnum):
sharded = "sharded"
unsharded = "unsharded"
sharded_ephemeral = "sharded_ephemeral"


class ShardedCheckpointerType(StrEnum):
Expand Down Expand Up @@ -757,19 +758,31 @@ class TrainConfig(BaseConfig):

save_interval: int = 1000
"""
How often (in terms of batches) to save training state checkpoints that can be used for restarts.
How often (in terms of steps) to save sharded training state checkpoints.
"""

save_interval_unsharded: Optional[int] = None
"""
How often (if at all) to save the unsharded state to a single file.
How often (if at all) to save unsharded training state checkpoint.
For large models it can be costly to save these, so it usually makes sense to save
these less often than regular (sharded) training checkpoints.
"""

save_interval_ephemeral: Optional[int] = None
"""
How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
as those saved every `save_interval` except that at most only the most recent one of these is kept.
This is useful when you want to checkpoint often for restarts in case of failures, but don't
want to keep the majority of these checkpoints.

For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
a temporary checkpoint every 100 steps in case your job fails. In that case you would
set `save_interval=1000` and `save_interval_ephemeral=100`.
"""

save_num_checkpoints_to_keep: int = -1
"""
How many checkpoints to keep.
How many sharded checkpoints to keep.
"""

save_num_unsharded_checkpoints_to_keep: int = -1
Expand Down
40 changes: 37 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class Trainer:
"""Tracks the global total number of tokens trained on."""
checkpoints: List[Path] = field(default_factory=list)
unsharded_checkpoints: List[Path] = field(default_factory=list)
ephemeral_checkpoints: List[Path] = field(default_factory=list)
min_train_loss: float = float("inf")
cur_train_loss: float = float("inf")
indices_file: Optional[TextIO] = None
Expand Down Expand Up @@ -142,6 +143,7 @@ def trainer_state_dict(self) -> Dict[str, Any]:
"world_size": get_world_size(),
"checkpoints": self.checkpoints,
"unsharded_checkpoints": self.unsharded_checkpoints,
"ephemeral_checkpoints": self.ephemeral_checkpoints,
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
Expand All @@ -162,6 +164,11 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
for path in state_dict["unsharded_checkpoints"]
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]
self.ephemeral_checkpoints = [
path
for path in state_dict["ephemeral_checkpoints"]
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]

# Dataset / dataloader position.
checkpoint_epoch = state_dict.get("epoch", 0)
Expand Down Expand Up @@ -245,6 +252,11 @@ def _save_checkpoint(
current_checkpoints = self.unsharded_checkpoints
link_latest = get_global_rank() == 0
num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
elif checkpoint_type == CheckpointType.sharded_ephemeral:
suffix = ""
current_checkpoints = self.ephemeral_checkpoints
link_latest = get_fs_local_rank() == 0
num_checkpoints_to_keep = 1
else:
raise NotImplementedError(checkpoint_type)

Expand Down Expand Up @@ -305,8 +317,8 @@ def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
checkpointer = build_sharded_checkpointer(self.cfg)
return self._save_checkpoint(checkpointer, CheckpointType.sharded)

def remove_sharded_checkpoint(self, idx: int = 0):
oldest_checkpoint = self.checkpoints.pop(idx)
def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
oldest_checkpoint = checkpoints.pop(idx)
barrier()
if get_fs_local_rank() == 0 and oldest_checkpoint.is_dir():
shutil.rmtree(oldest_checkpoint, ignore_errors=True)
Expand All @@ -315,6 +327,12 @@ def remove_sharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def remove_sharded_checkpoint(self, idx: int = 0):
self._remove_sharded_checkpoint(idx, self.checkpoints)

def remove_ephemeral_checkpoint(self, idx: int = 0):
self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)

def restore_sharded_checkpoint(
self,
load_path: PathOrStr,
Expand Down Expand Up @@ -406,6 +424,8 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec
self.remove_sharded_checkpoint(idx=idx)
elif checkpoint_type == CheckpointType.unsharded:
self.remove_unsharded_checkpoint(idx=idx)
elif checkpoint_type == CheckpointType.sharded_ephemeral:
self.remove_ephemeral_checkpoint(idx=idx)
else:
raise NotImplementedError(checkpoint_type)

Expand Down Expand Up @@ -850,14 +870,28 @@ def on_trace_ready(p):
elif self.global_step % self.cfg.canceled_check_interval == 0:
canceled = self.check_if_cancelled()

# Maybe save sharded checkpoint.
# Maybe save sharded or ephemeral sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
log.info(f"Checkpoint saved to {checkpoint_path}")

# Remove any ephemeral checkpoints.
while self.ephemeral_checkpoints:
self.remove_ephemeral_checkpoint()

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
elif (
self.cfg.save_interval_ephemeral is not None
and self.global_step % self.cfg.save_interval_ephemeral == 0
):
log.info("Saving ephemeral checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
log.info(f"Checkpoint saved to {checkpoint_path}")

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

Expand Down
Loading