Skip to content

Commit

Permalink
Merge pull request #372 from allenai/epwalsh/optim-state-fix
Browse files Browse the repository at this point in the history
Fix loading local sharded optim state
  • Loading branch information
AkshitaB committed Nov 8, 2023
2 parents c205912 + 5459c19 commit 38be6a7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
13 changes: 13 additions & 0 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,19 @@ def restore_checkpoint(
optim_state = load_state_dict(
load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
)
# HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
# in every rank, and keep this in the optimizer state. But this causes issues when loading the
# state since torch sees the state is non-empty for some params which would normally be empty,
# and then assumes it should have all of the other state tensors for that param, which is doesn't.
# So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
# Not the end of the world but there's probably a better way around this without resetting
# the metric.
for param_id in list(optim_state["state"].keys()):
state = optim_state["state"][param_id]
if "grad_norm_exp_avg" in state:
del state["grad_norm_exp_avg"]
if len(state) == 0:
del optim_state["state"][param_id]
optim.load_state_dict(optim_state)
del optim_state

Expand Down
6 changes: 6 additions & 0 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def _do_adaptive_clipping(
continue

# Get or initialize the exponential average of grad norm.
# TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter,
# even parameters for which the corresponding local shard is empty. This has the potential to
# cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372.
# So we should consider changing how we do this at some point so that we don't add any state
# to parameters for which the local shard is empty. That would probably add extra distributed
# communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`).
state = self.state[p]
grad_norm_exp_avg = state.get("grad_norm_exp_avg")
if grad_norm_exp_avg is None:
Expand Down

0 comments on commit 38be6a7

Please sign in to comment.