Skip to content

Commit

Permalink
Merge pull request #314 from allenai/ResetOptimizerState
Browse files Browse the repository at this point in the history
Reset optimizer state
  • Loading branch information
dirkgr committed Oct 5, 2023
2 parents e8bd122 + 434dc94 commit 0b5f68d
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 43 deletions.
7 changes: 7 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,13 @@ class TrainConfig(BaseConfig):
Whether to run the PyTorch profiler on batches 6, 7, and 8.
"""

reset_optimizer_state: bool = False
"""
When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
curve (according to the current learning rate schedule settings), and continues from there.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
17 changes: 17 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"LinearWithWarmup",
"InvSqrtWithWarmup",
"MaxScheduler",
"BoltOnWarmupScheduler",
"build_optimizer",
"build_scheduler",
]
Expand Down Expand Up @@ -435,6 +436,22 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
)


@dataclass
class BoltOnWarmupScheduler(Scheduler):
inner: Scheduler
warmup_start: int
warmup_end: int

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
if step < self.warmup_start:
return 0.0
if step < self.warmup_end:
lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
else:
return self.inner.get_lr(initial_lr, step, max_steps)


PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")


Expand Down
99 changes: 59 additions & 40 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def remove_sharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def restore_sharded_checkpoint(
self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -365,21 +367,23 @@ def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional
self.fsdp_model.load_state_dict(model_state["model"])

# Load optim state dict in place.
log.info("Loading optimizer state...")
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=model_state["model"],
optimizer_key="optim",
storage_reader=RemoteFileSystemReader(
f"{load_path}/model_and_optim",
local_cache=None if local_cache is None else local_cache / "model_and_optim",
),
)
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del model_state, optim_state, flattened_osd
if load_optimizer_state:
log.info("Loading optimizer state...")
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=model_state["model"],
optimizer_key="optim",
storage_reader=RemoteFileSystemReader(
f"{load_path}/model_and_optim",
local_cache=None if local_cache is None else local_cache / "model_and_optim",
),
)
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del optim_state, flattened_osd
del model_state

# Load trainer state dict.
log.info("Loading trainer state...")
Expand Down Expand Up @@ -408,7 +412,9 @@ def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional

barrier()

def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def restore_legacy_sharded_checkpoint(
self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -426,18 +432,20 @@ def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr, local_cache: O
# Load model and optimizer state.
log.info("Loading model state...")
self.fsdp_model.load_state_dict(state_dict["model"])
log.info("Loading optimizer state...")
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(state_dict["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, state_dict["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
if load_optimizer_state:
log.info("Loading optimizer state...")
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(state_dict["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, state_dict["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del flattened_osd

# Load trainer state dict.
log.info("Loading trainer state...")
self.load_trainer_state_dict(state_dict)

del state_dict, flattened_osd
del state_dict

barrier()

Expand Down Expand Up @@ -537,7 +545,9 @@ def remove_unsharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def restore_unsharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def restore_unsharded_checkpoint(
self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -556,18 +566,19 @@ def restore_unsharded_checkpoint(self, load_path: PathOrStr, local_cache: Option
)

# Load optimizer state.
log.info("Loading optimizer state...")
optim_state_dict = torch.load(resource_path(load_path, "optim.pt", local_cache=local_cache))
# NOTE: careful, the order of these arguments has changed since the 2.0 release.
if version.parse(torch.__version__) < version.parse("2.1.0"):
# flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(optim_state_dict, self.fsdp_model, self.optim) # type: ignore
else:
# flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state_dict) # type: ignore
del optim_state_dict
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del flattened_osd
if load_optimizer_state:
log.info("Loading optimizer state...")
optim_state_dict = torch.load(resource_path(load_path, "optim.pt", local_cache=local_cache))
# NOTE: careful, the order of these arguments has changed since the 2.0 release.
if version.parse(torch.__version__) < version.parse("2.1.0"):
# flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(optim_state_dict, self.fsdp_model, self.optim) # type: ignore
else:
# flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state_dict) # type: ignore
del optim_state_dict
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del flattened_osd

# Load other state.
try:
Expand Down Expand Up @@ -596,20 +607,28 @@ def restore_checkpoint(
checkpoint_type: Optional[CheckpointType] = None,
local_cache: Optional[PathOrStr] = None,
legacy_mode: bool = False,
*,
load_optimizer_state: bool = True,
):
if checkpoint_type == CheckpointType.unsharded or (
checkpoint_type is None and str(load_path).endswith("-unsharded")
):
self.restore_unsharded_checkpoint(load_path, local_cache=local_cache)
self.restore_unsharded_checkpoint(
load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state
)
elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
try:
legacy_mode = resource_path(load_path, f"rank{get_global_rank()}.pt").is_file()
except FileNotFoundError:
legacy_mode = False
if legacy_mode:
self.restore_legacy_sharded_checkpoint(load_path, local_cache=local_cache)
self.restore_legacy_sharded_checkpoint(
load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state
)
else:
self.restore_sharded_checkpoint(load_path, local_cache=local_cache)
self.restore_sharded_checkpoint(
load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state
)
elif checkpoint_type is not None:
raise NotImplementedError(checkpoint_type)

Expand Down
17 changes: 15 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from olmo.eval import build_evaluators
from olmo.exceptions import OlmoCliError, OlmoConfigurationError
from olmo.model import Olmo
from olmo.optim import build_optimizer, build_scheduler
from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler
from olmo.train import Trainer
from olmo.util import (
barrier,
Expand All @@ -43,6 +43,13 @@ def main(cfg: TrainConfig) -> None:
cfg.run_name = os.environ.get("COMPOSER_RUN_NAME", "train-llm")
log_extra_field("run_name", cfg.run_name)

# Sanity check
if cfg.reset_optimizer_state and cfg.load_path is None:
log.warning(
"You want to reset the optimizer state, but we're not loading from the checkpoint. The"
"setting has no effect."
)

# Initialize process group and set device.
dist.init_process_group(backend="nccl")
barrier()
Expand Down Expand Up @@ -210,9 +217,15 @@ def dummy_init_fn(module: torch.nn.Module) -> None:

if cfg.load_path is not None:
log.info(f"Loading checkpoint from {cfg.load_path}...")
trainer.restore_checkpoint(cfg.load_path)
trainer.restore_checkpoint(cfg.load_path, load_optimizer_state=not cfg.reset_optimizer_state)
log.info("Checkpoint successfully loaded")

# If we have to, set a new scheduler:
if cfg.reset_optimizer_state:
trainer.scheduler = BoltOnWarmupScheduler(
trainer.scheduler, trainer.global_step, trainer.global_step + cfg.scheduler.t_warmup
)

if cfg.force_save_unsharded:
log.info("Saving unsharded checkpoint...")
checkpoint_path, _ = trainer.save_unsharded_checkpoint()
Expand Down
18 changes: 17 additions & 1 deletion tests/optim_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from olmo.optim import LinearWithWarmup
import pytest

from olmo.optim import BoltOnWarmupScheduler, LinearWithWarmup


def test_linear_with_warmup_scheduler():
Expand All @@ -9,3 +11,17 @@ def test_linear_with_warmup_scheduler():
assert scheduler.get_lr(initial_lr, 2000, max_steps) == 1.0
assert scheduler.get_lr(initial_lr, 10_000, max_steps) == 0.1
assert scheduler.get_lr(initial_lr, 3_000, max_steps) > scheduler.get_lr(initial_lr, 5_000, max_steps)


def test_bolt_on_warmup_scheduler():
initial_lr = 1.0
max_steps = 11_000
alpha_f = 0.1
scheduler = LinearWithWarmup(warmup_steps=1000, alpha_f=alpha_f)
scheduler2 = BoltOnWarmupScheduler(scheduler, 5000, 6000)
assert scheduler.get_lr(initial_lr, 100, max_steps) > 0.0
assert scheduler2.get_lr(initial_lr, 100, max_steps) == 0.0
assert scheduler2.get_lr(initial_lr, 5000, max_steps) == 0.0
assert scheduler2.get_lr(initial_lr, 5500, max_steps) == pytest.approx(0.25 * (1 + alpha_f))
assert scheduler2.get_lr(initial_lr, 6000, max_steps) == pytest.approx(0.5 * (1 + alpha_f))
assert scheduler2.get_lr(initial_lr, 7000, max_steps) == scheduler.get_lr(initial_lr, 7000, max_steps)

0 comments on commit 0b5f68d

Please sign in to comment.