Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset optimizer state #314

Merged
merged 5 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,13 @@ class TrainConfig(BaseConfig):
Whether to run the PyTorch profiler on batches 6, 7, and 8.
"""

reset_optimizer_state: bool = False
"""
When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
curve (according to the current learning rate schedule settings), and continues from there.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
17 changes: 17 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"LinearWithWarmup",
"InvSqrtWithWarmup",
"MaxScheduler",
"BoltOnWarmupScheduler",
"build_optimizer",
"build_scheduler",
]
Expand Down Expand Up @@ -435,6 +436,22 @@ def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
)


@dataclass
class BoltOnWarmupScheduler(Scheduler):
inner: Scheduler
warmup_start: int
warmup_end: int

def get_lr(self, initial_lr: float, step: int, max_steps: int) -> float:
if step < self.warmup_start:
return 0.0
if step < self.warmup_end:
lr_at_intercept = self.inner.get_lr(initial_lr, self.warmup_end, max_steps)
return lr_at_intercept * (step - self.warmup_start) / (self.warmup_end - self.warmup_start)
else:
return self.inner.get_lr(initial_lr, step, max_steps)


PARAM_GROUP_FIELDS = ("sharded", "max_grad_norm", "max_grad_norm_ratio", "param_names")


Expand Down
99 changes: 59 additions & 40 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def remove_sharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def restore_sharded_checkpoint(
self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -365,21 +367,23 @@ def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional
self.fsdp_model.load_state_dict(model_state["model"])

# Load optim state dict in place.
log.info("Loading optimizer state...")
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=model_state["model"],
optimizer_key="optim",
storage_reader=RemoteFileSystemReader(
f"{load_path}/model_and_optim",
local_cache=None if local_cache is None else local_cache / "model_and_optim",
),
)
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del model_state, optim_state, flattened_osd
if load_optimizer_state:
log.info("Loading optimizer state...")
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=model_state["model"],
optimizer_key="optim",
storage_reader=RemoteFileSystemReader(
f"{load_path}/model_and_optim",
local_cache=None if local_cache is None else local_cache / "model_and_optim",
),
)
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del optim_state, flattened_osd
del model_state

# Load trainer state dict.
log.info("Loading trainer state...")
Expand Down Expand Up @@ -408,7 +412,9 @@ def restore_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional

barrier()

def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def restore_legacy_sharded_checkpoint(
self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -426,18 +432,20 @@ def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr, local_cache: O
# Load model and optimizer state.
log.info("Loading model state...")
self.fsdp_model.load_state_dict(state_dict["model"])
log.info("Loading optimizer state...")
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(state_dict["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, state_dict["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
if load_optimizer_state:
log.info("Loading optimizer state...")
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(state_dict["optim"], self.fsdp_model, self.optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, state_dict["optim"]) # type: ignore
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del flattened_osd

# Load trainer state dict.
log.info("Loading trainer state...")
self.load_trainer_state_dict(state_dict)

del state_dict, flattened_osd
del state_dict

barrier()

Expand Down Expand Up @@ -537,7 +545,9 @@ def remove_unsharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def restore_unsharded_checkpoint(self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def restore_unsharded_checkpoint(
self, load_path: PathOrStr, local_cache: Optional[PathOrStr] = None, *, load_optimizer_state: bool = True
):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -556,18 +566,19 @@ def restore_unsharded_checkpoint(self, load_path: PathOrStr, local_cache: Option
)

# Load optimizer state.
log.info("Loading optimizer state...")
optim_state_dict = torch.load(resource_path(load_path, "optim.pt", local_cache=local_cache))
# NOTE: careful, the order of these arguments has changed since the 2.0 release.
if version.parse(torch.__version__) < version.parse("2.1.0"):
# flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(optim_state_dict, self.fsdp_model, self.optim) # type: ignore
else:
# flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state_dict) # type: ignore
del optim_state_dict
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del flattened_osd
if load_optimizer_state:
log.info("Loading optimizer state...")
optim_state_dict = torch.load(resource_path(load_path, "optim.pt", local_cache=local_cache))
# NOTE: careful, the order of these arguments has changed since the 2.0 release.
if version.parse(torch.__version__) < version.parse("2.1.0"):
# flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(optim_state_dict, self.fsdp_model, self.optim) # type: ignore
else:
# flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state["optim"]) # type: ignore
flattened_osd = FSDP.optim_state_dict_to_load(self.fsdp_model, self.optim, optim_state_dict) # type: ignore
del optim_state_dict
self.optim.load_state_dict(fix_optim_state_dict(self.optim, flattened_osd))
del flattened_osd

# Load other state.
try:
Expand Down Expand Up @@ -596,20 +607,28 @@ def restore_checkpoint(
checkpoint_type: Optional[CheckpointType] = None,
local_cache: Optional[PathOrStr] = None,
legacy_mode: bool = False,
*,
load_optimizer_state: bool = True,
):
if checkpoint_type == CheckpointType.unsharded or (
checkpoint_type is None and str(load_path).endswith("-unsharded")
):
self.restore_unsharded_checkpoint(load_path, local_cache=local_cache)
self.restore_unsharded_checkpoint(
load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state
)
elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
try:
legacy_mode = resource_path(load_path, f"rank{get_global_rank()}.pt").is_file()
except FileNotFoundError:
legacy_mode = False
if legacy_mode:
self.restore_legacy_sharded_checkpoint(load_path, local_cache=local_cache)
self.restore_legacy_sharded_checkpoint(
load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state
)
else:
self.restore_sharded_checkpoint(load_path, local_cache=local_cache)
self.restore_sharded_checkpoint(
load_path, local_cache=local_cache, load_optimizer_state=load_optimizer_state
)
elif checkpoint_type is not None:
raise NotImplementedError(checkpoint_type)

Expand Down
17 changes: 15 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from olmo.eval import build_evaluators
from olmo.exceptions import OlmoCliError, OlmoConfigurationError
from olmo.model import Olmo
from olmo.optim import build_optimizer, build_scheduler
from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler
from olmo.train import Trainer
from olmo.util import (
barrier,
Expand All @@ -43,6 +43,13 @@ def main(cfg: TrainConfig) -> None:
cfg.run_name = os.environ.get("COMPOSER_RUN_NAME", "train-llm")
log_extra_field("run_name", cfg.run_name)

# Sanity check
if cfg.reset_optimizer_state and cfg.load_path is None:
log.warning(
"You want to reset the optimizer state, but we're not loading from the checkpoint. The"
"setting has no effect."
)

# Initialize process group and set device.
dist.init_process_group(backend="nccl")
barrier()
Expand Down Expand Up @@ -210,9 +217,15 @@ def dummy_init_fn(module: torch.nn.Module) -> None:

if cfg.load_path is not None:
log.info(f"Loading checkpoint from {cfg.load_path}...")
trainer.restore_checkpoint(cfg.load_path)
trainer.restore_checkpoint(cfg.load_path, load_optimizer_state=not cfg.reset_optimizer_state)
log.info("Checkpoint successfully loaded")

# If we have to, set a new scheduler:
if cfg.reset_optimizer_state:
trainer.scheduler = BoltOnWarmupScheduler(
trainer.scheduler, trainer.global_step, trainer.global_step + cfg.scheduler.t_warmup
)

if cfg.force_save_unsharded:
log.info("Saving unsharded checkpoint...")
checkpoint_path, _ = trainer.save_unsharded_checkpoint()
Expand Down
18 changes: 17 additions & 1 deletion tests/optim_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from olmo.optim import LinearWithWarmup
import pytest

from olmo.optim import BoltOnWarmupScheduler, LinearWithWarmup


def test_linear_with_warmup_scheduler():
Expand All @@ -9,3 +11,17 @@ def test_linear_with_warmup_scheduler():
assert scheduler.get_lr(initial_lr, 2000, max_steps) == 1.0
assert scheduler.get_lr(initial_lr, 10_000, max_steps) == 0.1
assert scheduler.get_lr(initial_lr, 3_000, max_steps) > scheduler.get_lr(initial_lr, 5_000, max_steps)


def test_bolt_on_warmup_scheduler():
initial_lr = 1.0
max_steps = 11_000
alpha_f = 0.1
scheduler = LinearWithWarmup(warmup_steps=1000, alpha_f=alpha_f)
scheduler2 = BoltOnWarmupScheduler(scheduler, 5000, 6000)
assert scheduler.get_lr(initial_lr, 100, max_steps) > 0.0
assert scheduler2.get_lr(initial_lr, 100, max_steps) == 0.0
assert scheduler2.get_lr(initial_lr, 5000, max_steps) == 0.0
assert scheduler2.get_lr(initial_lr, 5500, max_steps) == pytest.approx(0.25 * (1 + alpha_f))
assert scheduler2.get_lr(initial_lr, 6000, max_steps) == pytest.approx(0.5 * (1 + alpha_f))
assert scheduler2.get_lr(initial_lr, 7000, max_steps) == scheduler.get_lr(initial_lr, 7000, max_steps)
Loading