Skip to content

Commit

Permalink
Merge pull request #385 from allenai/epwalsh/max-duration-tokens
Browse files Browse the repository at this point in the history
Allow specify max_duration in terms of tokens
  • Loading branch information
epwalsh committed Nov 27, 2023
2 parents e16e606 + c46a5d8 commit ff883e5
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
8 changes: 6 additions & 2 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,13 @@ class TrainConfig(BaseConfig):
Deprecated. Use ``sharded_checkpointer`` instead.
"""

max_duration: int = 10000
max_duration: Union[int, str] = 10000
"""
Maximum number of batches to train for.
How long to train for.
If specified without a unit (the default), the units are assumed to be steps.
You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
2 trillion tokens.
"""

global_train_batch_size: int = 512
Expand Down
1 change: 0 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
seed=train_config.seed + train_config.epoch,
shuffle=True,
drop_last=train_config.data.drop_last,
max_examples=train_config.global_train_batch_size * train_config.max_duration,
work_dir=work_dir,
),
batch_size=train_config.device_train_batch_size,
Expand Down
30 changes: 23 additions & 7 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,24 @@ class Trainer:
indices_file: Optional[TextIO] = None
_start_time: float = 0.0

@property
def max_steps(self) -> int:
if isinstance(self.cfg.max_duration, int):
return self.cfg.max_duration
elif isinstance(self.cfg.max_duration, str):
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
tokens_remaining = max_tokens - self.global_train_tokens_seen
tokens_per_batch = self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
steps_remaining = tokens_remaining // tokens_per_batch
return self.global_step + steps_remaining
else:
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration))
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

def trainer_state_dict(self) -> Dict[str, Any]:
return {
"epoch": self.epoch,
Expand Down Expand Up @@ -188,7 +206,7 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Reset learning rate and weight decay to the values from the config, not the checkpoint.
log.info("Resetting learning rate...")
new_learning_rate = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.global_step, self.cfg.max_duration
self.cfg.optimizer.learning_rate, self.global_step, self.max_steps
)
for group in self.optim.param_groups:
group["lr"] = new_learning_rate
Expand Down Expand Up @@ -495,14 +513,12 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
# we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
# the corresponding values from `self.cfg`.
group["lr"] = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.global_step, self.cfg.max_duration
)
group["lr"] = self.scheduler.get_lr(self.cfg.optimizer.learning_rate, self.global_step, self.max_steps)
group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm, self.global_step, self.cfg.max_duration
self.cfg.max_grad_norm, self.global_step, self.max_steps
)
group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm_ratio, self.global_step, self.cfg.max_duration
self.cfg.max_grad_norm_ratio, self.global_step, self.max_steps
)

# Optimizer step.
Expand Down Expand Up @@ -818,7 +834,7 @@ def on_trace_ready(p):

# Log metrics to console.
if self.global_step % self.cfg.console_log_interval == 0:
self.log_metrics_to_console(f"[step={self.global_step}/{self.cfg.max_duration}]", metrics)
self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

# Log metrics to W&B.
if (
Expand Down

0 comments on commit ff883e5

Please sign in to comment.