From 6245b03c56fa2cc275b7d151dbc6f99a6622558a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 8 Nov 2023 11:20:06 -0800 Subject: [PATCH 1/4] fix loading optim state w/ adaptive clipping --- olmo/checkpoint.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index fdb621cee..bb06b9fda 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1204,6 +1204,16 @@ def restore_checkpoint( optim_state = load_state_dict( load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu" ) + # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param + # in every rank, and keep this in the optimizer state. But this causes issues when loading the + # state since torch sees the state is non-empty for some params which would normally be empty, + # and then assumes it should have all of the other state tensors for that param, which is doesn't. + # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric. + # Not the end of the world but there's probably a better way around this without resetting + # the metric. + for state in optim_state["state"].values(): + if "grad_norm_exp_avg" in state: + del state["grad_norm_exp_avg]"] optim.load_state_dict(optim_state) del optim_state From ff76a564190068e0ab4a9f6d90edac77b46c7aa6 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 8 Nov 2023 11:24:53 -0800 Subject: [PATCH 2/4] Copilot can't be trusted --- olmo/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index bb06b9fda..6301e3e94 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1213,7 +1213,7 @@ def restore_checkpoint( # the metric. for state in optim_state["state"].values(): if "grad_norm_exp_avg" in state: - del state["grad_norm_exp_avg]"] + del state["grad_norm_exp_avg"] optim.load_state_dict(optim_state) del optim_state From 9bf13019731c340078b517067a868670234d9ebf Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 8 Nov 2023 11:30:50 -0800 Subject: [PATCH 3/4] Remove empty state --- olmo/checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 6301e3e94..73e01b435 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1211,9 +1211,12 @@ def restore_checkpoint( # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric. # Not the end of the world but there's probably a better way around this without resetting # the metric. - for state in optim_state["state"].values(): + for param_id in list(optim_state["state"].keys()): + state = optim_state["state"][param_id] if "grad_norm_exp_avg" in state: del state["grad_norm_exp_avg"] + if len(state) == 0: + del optim_state["state"][param_id] optim.load_state_dict(optim_state) del optim_state From 5459c197b9866d41a25e6ad71dc461d1d1f9728f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 8 Nov 2023 12:41:16 -0800 Subject: [PATCH 4/4] Add comment about @2015aroras's suggestion --- olmo/optim.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/olmo/optim.py b/olmo/optim.py index 573937caf..7f119eb78 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -251,6 +251,12 @@ def _do_adaptive_clipping( continue # Get or initialize the exponential average of grad norm. + # TODO: The way we have it right now, every rank tracks the `grad_norm_exp_avg` of every parameter, + # even parameters for which the corresponding local shard is empty. This has the potential to + # cause some issues with the optimizer, as we ran into with https://github.com/allenai/LLM/pull/372. + # So we should consider changing how we do this at some point so that we don't add any state + # to parameters for which the local shard is empty. That would probably add extra distributed + # communication, at least on steps where we have to log (i.e. when `collect_param_metrics=True`). state = self.state[p] grad_norm_exp_avg = state.get("grad_norm_exp_avg") if grad_norm_exp_avg is None: