diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index fdb621cee..73e01b435 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1204,6 +1204,19 @@ def restore_checkpoint( optim_state = load_state_dict( load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu" ) + # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param + # in every rank, and keep this in the optimizer state. But this causes issues when loading the + # state since torch sees the state is non-empty for some params which would normally be empty, + # and then assumes it should have all of the other state tensors for that param, which is doesn't. + # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric. + # Not the end of the world but there's probably a better way around this without resetting + # the metric. + for param_id in list(optim_state["state"].keys()): + state = optim_state["state"][param_id] + if "grad_norm_exp_avg" in state: + del state["grad_norm_exp_avg"] + if len(state) == 0: + del optim_state["state"][param_id] optim.load_state_dict(optim_state) del optim_state diff --git a/olmo/optim.py b/olmo/optim.py index 573937caf..7f119eb78 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -251,6 +251,12 @@ def _do_adaptive_clipping( continue # Get or initialize the exponential average of grad norm. + # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter, + # even parameters for which the corresponding local shard is empty. This has the potential to + # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372. + # So we should consider changing how we do this at some point so that we don't add any state + # to parameters for which the local shard is empty. That would probably add extra distributed + # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`). state = self.state[p] grad_norm_exp_avg = state.get("grad_norm_exp_avg") if grad_norm_exp_avg is None: