Skip to content

Commit

Permalink
Train a few steps after time limit reached (#362)
Browse files Browse the repository at this point in the history
Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
epwalsh and dirkgr committed Jan 4, 2024
1 parent ac1aee1 commit 23eb949
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
7 changes: 7 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,13 @@ class TrainConfig(BaseConfig):
to write out a final checkpoint.
"""

extra_steps_after_cancel: int = 10
"""
Under certain conditions when a run is canceled we train for a few extra steps after saving
the final checkpoint so that when the run is restarted from the latest checkpoint we have some
overlap in metrics.
"""

early_stopping_factor: Optional[float] = None

save_data_indices: bool = True
Expand Down
19 changes: 13 additions & 6 deletions olmo/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,17 @@ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
return peak_mb


def synchronize_flag(flag: bool, device: torch.device) -> bool:
if is_distributed():
flag_tensor = torch.tensor(flag, device=device)
dist.broadcast(flag_tensor, 0)
return flag_tensor.item() # type: ignore
V = TypeVar("V", bool, int, float)


def synchronize_value(value: V, device: torch.device) -> V:
if dist.is_available() and dist.is_initialized():
value_tensor = torch.tensor(value, device=device)
dist.broadcast(value_tensor, 0)
return value_tensor.item() # type: ignore
else:
return flag
return value


def synchronize_flag(flag: bool, device: torch.device) -> bool:
return synchronize_value(flag, device)
57 changes: 40 additions & 17 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
move_to_device,
peak_gpu_memory,
synchronize_flag,
synchronize_value,
)
from .util import upload

Expand Down Expand Up @@ -711,14 +712,16 @@ def eval(self) -> Dict[str, Any]:

return eval_metrics

def check_if_cancelled(self) -> bool:
def check_if_cancelled(self) -> Tuple[bool, int]:
should_cancel = False
cancel_reason: Optional[str] = None
extra_steps = 0
if get_global_rank() == 0:
if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
# First check if we've reached the training time limit.
should_cancel = True
cancel_reason = "time limit reached"
extra_steps = self.cfg.extra_steps_after_cancel
elif (
self.cfg.early_stopping_factor is not None
and self.global_step > self.cfg.scheduler.t_warmup
Expand All @@ -739,14 +742,20 @@ def check_if_cancelled(self) -> bool:
if tag.lower() in {"cancel", "canceled", "cancelled"}:
should_cancel = True
cancel_reason = "Weights & Biases tag"
extra_steps = self.cfg.extra_steps_after_cancel
break
except RequestException:
pass

run_canceled = synchronize_flag(should_cancel, self.device)
if run_canceled and cancel_reason is not None:
log.warning(f"Run canceled due to {cancel_reason}")
return run_canceled
extra_steps = synchronize_value(extra_steps, self.device)
if extra_steps > 0:
log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
else:
log.warning(f"Run canceled due to {cancel_reason}")

return run_canceled, extra_steps

def fit(self):
self._start_time = time.time()
Expand Down Expand Up @@ -818,7 +827,9 @@ def on_trace_ready(p):

# Train.
first_batch: bool = True
canceled: bool = False
cancel_initiated: bool = False
stop_at: Optional[int] = self.cfg.stop_at
save_checkpoints: bool = True

with torch_profiler as p:
for batch in self.train_loader:
Expand Down Expand Up @@ -870,15 +881,23 @@ def on_trace_ready(p):
):
wandb.log(metrics, step=self.global_step)

# Check if run should be canceled.
if self.cfg.stop_at is not None and self.global_step >= self.cfg.stop_at:
canceled = True
elif self.global_step % self.cfg.canceled_check_interval == 0:
canceled = self.check_if_cancelled()

# Maybe save sharded or ephemeral sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
# Check if/when run should be canceled.
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
cancel_initiated, extra_steps = self.check_if_cancelled()
if cancel_initiated:
stop_at = (
self.global_step + extra_steps
if stop_at is None
else min(self.global_step + extra_steps, stop_at)
)

# Maybe save sharded checkpoint.
if save_checkpoints and (
cancel_initiated
or (
self.global_step % self.cfg.save_interval == 0
and self.cfg.save_num_checkpoints_to_keep != 0
)
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
Expand All @@ -890,6 +909,10 @@ def on_trace_ready(p):

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# If the run was just canceled this will be the final checkpoint.
if cancel_initiated:
save_checkpoints = False
elif (
self.cfg.save_interval_ephemeral is not None
and self.global_step % self.cfg.save_interval_ephemeral == 0
Expand All @@ -903,7 +926,7 @@ def on_trace_ready(p):

# Maybe save unsharded checkpoint.
if (
not canceled # we already save a sharded checkpoint when canceled
save_checkpoints
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
Expand All @@ -916,7 +939,7 @@ def on_trace_ready(p):
speed_monitor.reset()

# Maybe run evaluations.
if not canceled and self.global_step % self.cfg.eval_interval == 0:
if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()

# Log metrics to W&B.
Expand All @@ -934,7 +957,7 @@ def on_trace_ready(p):
if p is not None:
p.step()

if canceled:
if stop_at is not None and self.global_step >= stop_at:
break

# Python Profiler stuff
Expand All @@ -950,7 +973,7 @@ def on_trace_ready(p):
log.info("Training loop complete")

# Save final checkpoint.
if not canceled:
if save_checkpoints:
if self.cfg.save_interval_unsharded is not None:
log.info("Saving final unsharded model checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
Expand Down

0 comments on commit 23eb949

Please sign in to comment.