diff --git a/olmo/config.py b/olmo/config.py index c8258c54d..c635185f9 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -797,6 +797,13 @@ class TrainConfig(BaseConfig): Whether to run the PyTorch profiler on batches 6, 7, and 8. """ + reset_optimizer_state: bool = False + """ + When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized. + We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning + curve (according to the current learning rate schedule settings), and continues from there. + """ + @property def autocast_precision(self) -> torch.dtype: if self.precision == "amp_bf16": diff --git a/olmo/optim.py b/olmo/optim.py index f34ef11a4..b34c837ae 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -22,6 +22,7 @@ "LinearWithWarmup", "InvSqrtWithWarmup", "MaxScheduler", + "BoltOnWarmupScheduler", "build_optimizer", "build_scheduler", ] @@ -435,6 +436,22 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: ) +@dataclass +class BoltOnWarmupScheduler(Scheduler): + inner: Scheduler + warmup_start: int + warmup_end: int + + def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float: + if step < self.warmup_start: + return 0.0 + if step < self.warmup_end: + lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps) + return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start) + else: + return self.inner.get_lr(initial_lr, step, max_steps) + + PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names") diff --git a/olmo/train.py b/olmo/train.py index bb796529e..db756b8d1 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -339,7 +339,9 @@ def remove_sharded_checkpoint(self, idx: int = 0): latest_path.unlink() barrier() - def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None): + def restore_sharded_checkpoint( + self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True + ): # Zero-gradients to avoid gathering them. self.optim.zero_grad(set_to_none=True) @@ -365,21 +367,23 @@ def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional self.fsdp_model.load_state_dict(model_state["model"]) # Load optim state dict in place. - log.info("Loading optimizer state...") - optim_state = load_sharded_optimizer_state_dict( - model_state_dict=model_state["model"], - optimizer_key="optim", - storage_reader=RemoteFileSystemReader( - f"{load_path}/model_and_optim", - local_cache=None if local_cache is None else local_cache / "model_and_optim", - ), - ) - if version.parse(torch.__version__) < version.parse("2.1.0"): - flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore - else: - flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore - self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd)) - del model_state, optim_state, flattened_osd + if load_optimizer_state: + log.info("Loading optimizer state...") + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=model_state["model"], + optimizer_key="optim", + storage_reader=RemoteFileSystemReader( + f"{load_path}/model_and_optim", + local_cache=None if local_cache is None else local_cache / "model_and_optim", + ), + ) + if version.parse(torch.__version__) < version.parse("2.1.0"): + flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore + else: + flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore + self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd)) + del optim_state, flattened_osd + del model_state # Load trainer state dict. log.info("Loading trainer state...") @@ -408,7 +412,9 @@ def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional barrier() - def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None): + def restore_legacy_sharded_checkpoint( + self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True + ): # Zero-gradients to avoid gathering them. self.optim.zero_grad(set_to_none=True) @@ -426,18 +432,20 @@ def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr, local_cache: O # Load model and optimizer state. log.info("Loading model state...") self.fsdp_model.load_state_dict(state_dict["model"]) - log.info("Loading optimizer state...") - if version.parse(torch.__version__) < version.parse("2.1.0"): - flattened_osd = FSDP.optim_state_dict_to_load(state_dict["optim"], self.fsdp_model, self.optim) # type: ignore - else: - flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, state_dict["optim"]) # type: ignore - self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd)) + if load_optimizer_state: + log.info("Loading optimizer state...") + if version.parse(torch.__version__) < version.parse("2.1.0"): + flattened_osd = FSDP.optim_state_dict_to_load(state_dict["optim"], self.fsdp_model, self.optim) # type: ignore + else: + flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, state_dict["optim"]) # type: ignore + self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd)) + del flattened_osd # Load trainer state dict. log.info("Loading trainer state...") self.load_trainer_state_dict(state_dict) - del state_dict, flattened_osd + del state_dict barrier() @@ -537,7 +545,9 @@ def remove_unsharded_checkpoint(self, idx: int = 0): latest_path.unlink() barrier() - def restore_unsharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None): + def restore_unsharded_checkpoint( + self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True + ): # Zero-gradients to avoid gathering them. self.optim.zero_grad(set_to_none=True) @@ -556,18 +566,19 @@ def restore_unsharded_checkpoint(self, load_path: PathOrStr, local_cache: Option ) # Load optimizer state. - log.info("Loading optimizer state...") - optim_state_dict = torch.load(resource_path(load_path, "optim.pt", local_cache=local_cache)) - # NOTE: careful, the order of these arguments has changed since the 2.0 release. - if version.parse(torch.__version__) < version.parse("2.1.0"): - # flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore - flattened_osd = FSDP.optim_state_dict_to_load(optim_state_dict, self.fsdp_model, self.optim) # type: ignore - else: - # flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore - flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state_dict) # type: ignore - del optim_state_dict - self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd)) - del flattened_osd + if load_optimizer_state: + log.info("Loading optimizer state...") + optim_state_dict = torch.load(resource_path(load_path, "optim.pt", local_cache=local_cache)) + # NOTE: careful, the order of these arguments has changed since the 2.0 release. + if version.parse(torch.__version__) < version.parse("2.1.0"): + # flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore + flattened_osd = FSDP.optim_state_dict_to_load(optim_state_dict, self.fsdp_model, self.optim) # type: ignore + else: + # flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore + flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state_dict) # type: ignore + del optim_state_dict + self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd)) + del flattened_osd # Load other state. try: @@ -596,20 +607,28 @@ def restore_checkpoint( checkpoint_type: Optional[CheckpointType] = None, local_cache: Optional[PathOrStr] = None, legacy_mode: bool = False, + *, + load_optimizer_state: bool = True, ): if checkpoint_type == CheckpointType.unsharded or ( checkpoint_type is None and str(load_path).endswith("-unsharded") ): - self.restore_unsharded_checkpoint(load_path, local_cache=local_cache) + self.restore_unsharded_checkpoint( + load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state + ) elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None: try: legacy_mode = resource_path(load_path, f"rank{get_global_rank()}.pt").is_file() except FileNotFoundError: legacy_mode = False if legacy_mode: - self.restore_legacy_sharded_checkpoint(load_path, local_cache=local_cache) + self.restore_legacy_sharded_checkpoint( + load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state + ) else: - self.restore_sharded_checkpoint(load_path, local_cache=local_cache) + self.restore_sharded_checkpoint( + load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state + ) elif checkpoint_type is not None: raise NotImplementedError(checkpoint_type) diff --git a/scripts/train.py b/scripts/train.py index 24364ed9c..39e264736 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -20,7 +20,7 @@ from olmo.eval import build_evaluators from olmo.exceptions import OlmoCliError, OlmoConfigurationError from olmo.model import Olmo -from olmo.optim import build_optimizer, build_scheduler +from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler from olmo.train import Trainer from olmo.util import ( barrier, @@ -43,6 +43,13 @@ def main(cfg: TrainConfig) -> None: cfg.run_name = os.environ.get("COMPOSER_RUN_NAME", "train-llm") log_extra_field("run_name", cfg.run_name) + # Sanity check + if cfg.reset_optimizer_state and cfg.load_path is None: + log.warning( + "You want to reset the optimizer state, but we're not loading from the checkpoint. The" + "setting has no effect." + ) + # Initialize process group and set device. dist.init_process_group(backend="nccl") barrier() @@ -210,9 +217,15 @@ def dummy_init_fn(module: torch.nn.Module) -> None: if cfg.load_path is not None: log.info(f"Loading checkpoint from {cfg.load_path}...") - trainer.restore_checkpoint(cfg.load_path) + trainer.restore_checkpoint(cfg.load_path, load_optimizer_state=not cfg.reset_optimizer_state) log.info("Checkpoint successfully loaded") + # If we have to, set a new scheduler: + if cfg.reset_optimizer_state: + trainer.scheduler = BoltOnWarmupScheduler( + trainer.scheduler, trainer.global_step, trainer.global_step + cfg.scheduler.t_warmup + ) + if cfg.force_save_unsharded: log.info("Saving unsharded checkpoint...") checkpoint_path, _ = trainer.save_unsharded_checkpoint() diff --git a/tests/optim_test.py b/tests/optim_test.py index 73945e28d..38928dbe8 100644 --- a/tests/optim_test.py +++ b/tests/optim_test.py @@ -1,4 +1,6 @@ -from olmo.optim import LinearWithWarmup +import pytest + +from olmo.optim import BoltOnWarmupScheduler, LinearWithWarmup def test_linear_with_warmup_scheduler(): @@ -9,3 +11,17 @@ def test_linear_with_warmup_scheduler(): assert scheduler.get_lr(initial_lr, 2000, max_steps) == 1.0 assert scheduler.get_lr(initial_lr, 10_000, max_steps) == 0.1 assert scheduler.get_lr(initial_lr, 3_000, max_steps) > scheduler.get_lr(initial_lr, 5_000, max_steps) + + +def test_bolt_on_warmup_scheduler(): + initial_lr = 1.0 + max_steps = 11_000 + alpha_f = 0.1 + scheduler = LinearWithWarmup(warmup_steps=1000, alpha_f=alpha_f) + scheduler2 = BoltOnWarmupScheduler(scheduler, 5000, 6000) + assert scheduler.get_lr(initial_lr, 100, max_steps) > 0.0 + assert scheduler2.get_lr(initial_lr, 100, max_steps) == 0.0 + assert scheduler2.get_lr(initial_lr, 5000, max_steps) == 0.0 + assert scheduler2.get_lr(initial_lr, 5500, max_steps) == pytest.approx(0.25 * (1 + alpha_f)) + assert scheduler2.get_lr(initial_lr, 6000, max_steps) == pytest.approx(0.5 * (1 + alpha_f)) + assert scheduler2.get_lr(initial_lr, 7000, max_steps) == scheduler.get_lr(initial_lr, 7000, max_steps)