Skip to content

Commit

Permalink
Merge pull request #411 from allenai/epwalsh/lr-schedule-tokens
Browse files Browse the repository at this point in the history
allow specifying LR schedule in terms of tokens
  • Loading branch information
epwalsh committed Jan 18, 2024
2 parents 45ed078 + 9477cfa commit dcae8e8
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 24 deletions.
12 changes: 9 additions & 3 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,20 @@ class SchedulerType(StrEnum):
constant = "constant"


class SchedulerUnits(StrEnum):
steps = "steps"
tokens = "tokens"


@dataclass
class SchedulerConfig(BaseConfig):
name: SchedulerType = SchedulerType.cosine_with_warmup
t_warmup: int = 100
t_max: Optional[int] = None
units: SchedulerUnits = SchedulerUnits.steps
t_warmup: Union[int, float] = 100
t_max: Optional[Union[int, float]] = None
alpha_f: float = 0.1

grad_clip_warmup_steps: Optional[int] = None
grad_clip_warmup_steps: Optional[Union[int, float]] = None
"""
The warmup period for which the max grad norm (or norm ratio) will be set to its
warmup value of `max_grad_norm * grad_clip_warmup_factor`.
Expand Down
30 changes: 20 additions & 10 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,36 +720,46 @@ def build_scheduler(cfg: TrainConfig, sched_cfg: Optional[SchedulerConfig] = Non
sched_cfg = sched_cfg if sched_cfg is not None else cfg.scheduler
if sched_cfg.name == SchedulerType.cosine_with_warmup:
return CosWithWarmup(
grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps,
grad_clip_warmup_steps=None
if sched_cfg.grad_clip_warmup_steps is None
else int(sched_cfg.grad_clip_warmup_steps),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_steps=sched_cfg.t_warmup,
warmup_steps=int(sched_cfg.t_warmup),
alpha_f=sched_cfg.alpha_f,
t_max=sched_cfg.t_max,
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
)
elif sched_cfg.name == SchedulerType.linear_with_warmup:
return LinearWithWarmup(
grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps,
grad_clip_warmup_steps=None
if sched_cfg.grad_clip_warmup_steps is None
else int(sched_cfg.grad_clip_warmup_steps),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_steps=sched_cfg.t_warmup,
warmup_steps=int(sched_cfg.t_warmup),
alpha_f=sched_cfg.alpha_f,
t_max=sched_cfg.t_max,
t_max=None if sched_cfg.t_max is None else int(sched_cfg.t_max),
)
elif sched_cfg.name == SchedulerType.inverse_sqrt_with_warmup:
return InvSqrtWithWarmup(
grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps,
grad_clip_warmup_steps=None
if sched_cfg.grad_clip_warmup_steps is None
else int(sched_cfg.grad_clip_warmup_steps),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
warmup_steps=sched_cfg.t_warmup,
warmup_steps=int(sched_cfg.t_warmup),
)
elif sched_cfg.name == SchedulerType.max_scheduler:
return MaxScheduler(
grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps,
grad_clip_warmup_steps=None
if sched_cfg.grad_clip_warmup_steps is None
else int(sched_cfg.grad_clip_warmup_steps),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
sched1=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.cosine_with_warmup)),
sched2=build_scheduler(cfg, replace(sched_cfg, name=SchedulerType.inverse_sqrt_with_warmup)),
)
elif sched_cfg.name == SchedulerType.constant:
return ConstantScheduler(
grad_clip_warmup_steps=sched_cfg.grad_clip_warmup_steps,
grad_clip_warmup_steps=None
if sched_cfg.grad_clip_warmup_steps is None
else int(sched_cfg.grad_clip_warmup_steps),
grad_clip_warmup_factor=sched_cfg.grad_clip_warmup_factor,
)
else:
Expand Down
69 changes: 59 additions & 10 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
from .config import (
CheckpointType,
SchedulerUnits,
ShardedCheckpointerType,
SpeedMonitorConfig,
TrainConfig,
Expand Down Expand Up @@ -122,6 +123,14 @@ def dataset(self) -> IterableDataset:
assert isinstance(self.train_loader.dataset, IterableDataset)
return self.train_loader.dataset

@property
def tokens_per_batch(self) -> int:
return self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length

@property
def batches_per_epoch(self) -> int:
return self.dataset.total_size // self.cfg.global_train_batch_size

@property
def max_epochs(self) -> int:
if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
Expand All @@ -137,21 +146,59 @@ def max_steps(self) -> int:
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
max_tokens = int(float(self.cfg.max_duration[:-1].strip()))
tokens_remaining = max_tokens - self.global_train_tokens_seen
tokens_per_batch = self.cfg.global_train_batch_size * self.cfg.model.max_sequence_length
steps_remaining = tokens_remaining // tokens_per_batch
tokens_remaining = max(max_tokens - self.global_train_tokens_seen, 0)
steps_remaining = tokens_remaining // self.tokens_per_batch
return self.global_step + steps_remaining
elif self.cfg.max_duration.endswith("ep"):
max_epochs = int(self.cfg.max_duration[:-2].strip())
examples_per_epoch = self.dataset.total_size
steps_per_epoch = examples_per_epoch // self.cfg.global_train_batch_size
return max_epochs * steps_per_epoch
return max_epochs * self.batches_per_epoch
else:
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration))
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

@property
def max_tokens(self) -> int:
if isinstance(self.cfg.max_duration, int):
return (
self.global_train_tokens_seen
+ max(self.cfg.max_duration - self.global_step, 0) * self.tokens_per_batch
)
elif isinstance(self.cfg.max_duration, str):
if self.cfg.max_duration.endswith("T"):
# convert to float *first* to handle scientific notation
return int(float(self.cfg.max_duration[:-1].strip()))
elif self.cfg.max_duration.endswith("ep"):
max_epochs = int(self.cfg.max_duration[:-2].strip())
return max_epochs * self.batches_per_epoch * self.tokens_per_batch
else:
# convert to float *first* to handle scientific notation
return (
self.global_train_tokens_seen
+ max(int(float(self.cfg.max_duration)) - self.global_step, 0) * self.tokens_per_batch
)
else:
raise TypeError(f"expected int or str for 'max_duration', found {type(self.cfg.max_duration)}")

@property
def scheduler_current(self) -> int:
if self.cfg.scheduler.units == SchedulerUnits.steps:
return self.global_step
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
return self.global_train_tokens_seen
else:
raise NotImplementedError(self.cfg.scheduler.units)

@property
def scheduler_max(self) -> int:
if self.cfg.scheduler.units == SchedulerUnits.steps:
return self.max_steps
elif self.cfg.scheduler.units == SchedulerUnits.tokens:
return self.max_tokens
else:
raise NotImplementedError(self.cfg.scheduler.units)

def trainer_state_dict(self) -> Dict[str, Any]:
return {
"epoch": self.epoch,
Expand Down Expand Up @@ -233,7 +280,7 @@ def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Reset learning rate and weight decay to the values from the config, not the checkpoint.
log.info("Resetting learning rate...")
new_learning_rate = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.global_step, self.max_steps
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
)
for group in self.optim.param_groups:
group["lr"] = new_learning_rate
Expand Down Expand Up @@ -572,12 +619,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
# we should pass `group["initial_lr"]` or `group["initial_max_grad_norm"]` here instead of
# the corresponding values from `self.cfg`.
group["lr"] = self.scheduler.get_lr(self.cfg.optimizer.learning_rate, self.global_step, self.max_steps)
group["lr"] = self.scheduler.get_lr(
self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max
)
group["max_grad_norm"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm, self.global_step, self.max_steps
self.cfg.max_grad_norm, self.scheduler_current, self.scheduler_max
)
group["max_grad_norm_ratio"] = self.scheduler.get_max_grad_norm(
self.cfg.max_grad_norm_ratio, self.global_step, self.max_steps
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
)

# Optimizer step.
Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
trainer.scheduler = BoltOnWarmupScheduler.wrap(
trainer.scheduler,
trainer.global_step,
trainer.global_step + cfg.scheduler.t_warmup,
int(trainer.global_step + cfg.scheduler.t_warmup),
)

if cfg.force_save_unsharded:
Expand Down

0 comments on commit dcae8e8

Please sign in to comment.