Skip to content

Commit

Permalink
Ephemeral checkpoints (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Dec 13, 2023
1 parent 6f2abfb commit e2d77c4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
19 changes: 16 additions & 3 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ class FSDPConfig(BaseConfig):
class CheckpointType(StrEnum):
sharded = "sharded"
unsharded = "unsharded"
sharded_ephemeral = "sharded_ephemeral"


class ShardedCheckpointerType(StrEnum):
Expand Down Expand Up @@ -757,19 +758,31 @@ class TrainConfig(BaseConfig):

save_interval: int = 1000
"""
How often (in terms of batches) to save training state checkpoints that can be used for restarts.
How often (in terms of steps) to save sharded training state checkpoints.
"""

save_interval_unsharded: Optional[int] = None
"""
How often (if at all) to save the unsharded state to a single file.
How often (if at all) to save unsharded training state checkpoint.
For large models it can be costly to save these, so it usually makes sense to save
these less often than regular (sharded) training checkpoints.
"""

save_interval_ephemeral: Optional[int] = None
"""
How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
as those saved every `save_interval` except that at most only the most recent one of these is kept.
This is useful when you want to checkpoint often for restarts in case of failures, but don't
want to keep the majority of these checkpoints.
For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
a temporary checkpoint every 100 steps in case your job fails. In that case you would
set `save_interval=1000` and `save_interval_ephemeral=100`.
"""

save_num_checkpoints_to_keep: int = -1
"""
How many checkpoints to keep.
How many sharded checkpoints to keep.
"""

save_num_unsharded_checkpoints_to_keep: int = -1
Expand Down
46 changes: 43 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class Trainer:
"""Tracks the global total number of tokens trained on."""
checkpoints: List[Path] = field(default_factory=list)
unsharded_checkpoints: List[Path] = field(default_factory=list)
ephemeral_checkpoints: List[Path] = field(default_factory=list)
min_train_loss: float = float("inf")
cur_train_loss: float = float("inf")
indices_file: Optional[TextIO] = None
Expand Down Expand Up @@ -142,6 +143,7 @@ def trainer_state_dict(self) -> Dict[str, Any]:
"world_size": get_world_size(),
"checkpoints": self.checkpoints,
"unsharded_checkpoints": self.unsharded_checkpoints,
"ephemeral_checkpoints": self.ephemeral_checkpoints,
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
Expand All @@ -162,6 +164,11 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
for path in state_dict["unsharded_checkpoints"]
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]
self.ephemeral_checkpoints = [
path
for path in state_dict["ephemeral_checkpoints"]
if path.is_dir() and path.resolve().parent == Path(self.cfg.save_folder).resolve()
]

# Dataset / dataloader position.
checkpoint_epoch = state_dict.get("epoch", 0)
Expand Down Expand Up @@ -245,6 +252,11 @@ def _save_checkpoint(
current_checkpoints = self.unsharded_checkpoints
link_latest = get_global_rank() == 0
num_checkpoints_to_keep = self.cfg.save_num_unsharded_checkpoints_to_keep
elif checkpoint_type == CheckpointType.sharded_ephemeral:
suffix = ""
current_checkpoints = self.ephemeral_checkpoints
link_latest = get_fs_local_rank() == 0
num_checkpoints_to_keep = 1
else:
raise NotImplementedError(checkpoint_type)

Expand Down Expand Up @@ -305,8 +317,12 @@ def save_sharded_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
checkpointer = build_sharded_checkpointer(self.cfg)
return self._save_checkpoint(checkpointer, CheckpointType.sharded)

def remove_sharded_checkpoint(self, idx: int = 0):
oldest_checkpoint = self.checkpoints.pop(idx)
def save_ephemeral_checkpoint(self) -> Tuple[PathOrStr, Optional[PathOrStr]]:
checkpointer = build_sharded_checkpointer(self.cfg)
return self._save_checkpoint(checkpointer, CheckpointType.sharded_ephemeral)

def _remove_sharded_checkpoint(self, idx: int, checkpoints: List[Path]):
oldest_checkpoint = checkpoints.pop(idx)
barrier()
if get_fs_local_rank() == 0 and oldest_checkpoint.is_dir():
shutil.rmtree(oldest_checkpoint, ignore_errors=True)
Expand All @@ -315,6 +331,12 @@ def remove_sharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def remove_sharded_checkpoint(self, idx: int = 0):
self._remove_sharded_checkpoint(idx, self.checkpoints)

def remove_ephemeral_checkpoint(self, idx: int = 0):
self._remove_sharded_checkpoint(idx, self.ephemeral_checkpoints)

def restore_sharded_checkpoint(
self,
load_path: PathOrStr,
Expand Down Expand Up @@ -373,6 +395,8 @@ def save_checkpoint(
return self.save_sharded_checkpoint()
elif checkpoint_type == CheckpointType.unsharded:
return self.save_unsharded_checkpoint()
elif checkpoint_type == CheckpointType.sharded_ephemeral:
return self.save_ephemeral_checkpoint()
else:
raise NotImplementedError(checkpoint_type)

Expand Down Expand Up @@ -406,6 +430,8 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec
self.remove_sharded_checkpoint(idx=idx)
elif checkpoint_type == CheckpointType.unsharded:
self.remove_unsharded_checkpoint(idx=idx)
elif checkpoint_type == CheckpointType.sharded_ephemeral:
self.remove_ephemeral_checkpoint(idx=idx)
else:
raise NotImplementedError(checkpoint_type)

Expand Down Expand Up @@ -850,14 +876,28 @@ def on_trace_ready(p):
elif self.global_step % self.cfg.canceled_check_interval == 0:
canceled = self.check_if_cancelled()

# Maybe save sharded checkpoint.
# Maybe save sharded or ephemeral sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
log.info(f"Checkpoint saved to {checkpoint_path}")

# Remove any ephemeral checkpoints.
while self.ephemeral_checkpoints:
self.remove_ephemeral_checkpoint()

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()
elif (
self.cfg.save_interval_ephemeral is not None
and self.global_step % self.cfg.save_interval_ephemeral == 0
):
log.info("Saving ephemeral checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
log.info(f"Checkpoint saved to {checkpoint_path}")

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

Expand Down

0 comments on commit e2d77c4

Please sign in to comment.