Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train a few steps after time limit reached #362

Merged
merged 7 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,13 @@ class TrainConfig(BaseConfig):
to write out a final checkpoint.
"""

extra_steps_after_cancel: int = 10
"""
Under certain conditions when a run is canceled we train for a few extra steps after saving
the final checkpoint so that when the run is restarted from the latest checkpoint we have some
overlap in metrics.
"""

early_stopping_factor: Optional[float] = None

save_data_indices: bool = True
Expand Down
19 changes: 13 additions & 6 deletions olmo/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,17 @@ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
return peak_mb


def synchronize_flag(flag: bool, device: torch.device) -> bool:
if is_distributed():
flag_tensor = torch.tensor(flag, device=device)
dist.broadcast(flag_tensor, 0)
return flag_tensor.item() # type: ignore
V = TypeVar("V", bool, int, float)


def synchronize_value(value: V, device: torch.device) -> V:
if dist.is_available() and dist.is_initialized():
value_tensor = torch.tensor(value, device=device)
dist.broadcast(value_tensor, 0)
return value_tensor.item() # type: ignore
else:
return flag
return value


def synchronize_flag(flag: bool, device: torch.device) -> bool:
return synchronize_value(flag, device)
57 changes: 40 additions & 17 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
move_to_device,
peak_gpu_memory,
synchronize_flag,
synchronize_value,
)
from .util import upload

Expand Down Expand Up @@ -711,14 +712,16 @@ def eval(self) -> Dict[str, Any]:

return eval_metrics

def check_if_cancelled(self) -> bool:
def check_if_cancelled(self) -> Tuple[bool, int]:
should_cancel = False
cancel_reason: Optional[str] = None
extra_steps = 0
if get_global_rank() == 0:
if self.cfg.time_limit is not None and time.time() - self._start_time >= self.cfg.time_limit:
# First check if we've reached the training time limit.
should_cancel = True
cancel_reason = "time limit reached"
extra_steps = self.cfg.extra_steps_after_cancel
elif (
self.cfg.early_stopping_factor is not None
and self.global_step > self.cfg.scheduler.t_warmup
Expand All @@ -739,14 +742,20 @@ def check_if_cancelled(self) -> bool:
if tag.lower() in {"cancel", "canceled", "cancelled"}:
should_cancel = True
cancel_reason = "Weights & Biases tag"
extra_steps = self.cfg.extra_steps_after_cancel
break
except RequestException:
pass

run_canceled = synchronize_flag(should_cancel, self.device)
if run_canceled and cancel_reason is not None:
log.warning(f"Run canceled due to {cancel_reason}")
return run_canceled
extra_steps = synchronize_value(extra_steps, self.device)
if extra_steps > 0:
log.warning(f"Run canceled due to {cancel_reason}, stopping in {extra_steps} more steps...")
else:
log.warning(f"Run canceled due to {cancel_reason}")

return run_canceled, extra_steps

def fit(self):
self._start_time = time.time()
Expand Down Expand Up @@ -818,7 +827,9 @@ def on_trace_ready(p):

# Train.
first_batch: bool = True
canceled: bool = False
cancel_initiated: bool = False
stop_at: Optional[int] = self.cfg.stop_at
save_checkpoints: bool = True

with torch_profiler as p:
for batch in self.train_loader:
Expand Down Expand Up @@ -870,15 +881,23 @@ def on_trace_ready(p):
):
wandb.log(metrics, step=self.global_step)

# Check if run should be canceled.
if self.cfg.stop_at is not None and self.global_step >= self.cfg.stop_at:
canceled = True
elif self.global_step % self.cfg.canceled_check_interval == 0:
canceled = self.check_if_cancelled()

# Maybe save sharded or ephemeral sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
# Check if/when run should be canceled.
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
cancel_initiated, extra_steps = self.check_if_cancelled()
if cancel_initiated:
stop_at = (
self.global_step + extra_steps
if stop_at is None
else min(self.global_step + extra_steps, stop_at)
)

# Maybe save sharded checkpoint.
if save_checkpoints and (
cancel_initiated
or (
self.global_step % self.cfg.save_interval == 0
and self.cfg.save_num_checkpoints_to_keep != 0
)
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
Expand All @@ -890,6 +909,10 @@ def on_trace_ready(p):

# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# If the run was just canceled this will be the final checkpoint.
if cancel_initiated:
save_checkpoints = False
elif (
self.cfg.save_interval_ephemeral is not None
and self.global_step % self.cfg.save_interval_ephemeral == 0
Expand All @@ -903,7 +926,7 @@ def on_trace_ready(p):

# Maybe save unsharded checkpoint.
if (
not canceled # we already save a sharded checkpoint when canceled
save_checkpoints
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
Expand All @@ -916,7 +939,7 @@ def on_trace_ready(p):
speed_monitor.reset()

# Maybe run evaluations.
if not canceled and self.global_step % self.cfg.eval_interval == 0:
if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, you don't want eval metrics if they happen in those extra steps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right... though it's debatable. I think when we cancel we want to stop ASAP, and the eval loop adds time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, no eval loops. This is a sanity check.

eval_metrics = self.eval()

# Log metrics to W&B.
Expand All @@ -934,7 +957,7 @@ def on_trace_ready(p):
if p is not None:
p.step()

if canceled:
if stop_at is not None and self.global_step >= stop_at:
break

# Python Profiler stuff
Expand All @@ -950,7 +973,7 @@ def on_trace_ready(p):
log.info("Training loop complete")

# Save final checkpoint.
if not canceled:
if save_checkpoints:
if self.cfg.save_interval_unsharded is not None:
log.info("Saving final unsharded model checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
Expand Down
Loading