From 740294f78b513331d858c1793dccf743bf972789 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 27 Oct 2023 13:32:20 -0700 Subject: [PATCH] Fix loading unsharded optim state with different block group size --- olmo/checkpoint.py | 65 ++++++++++++++++++++++++++++++++++++++++----- olmo/model.py | 55 +++++++++++++++++++++++++++++++------- tests/model_test.py | 17 +++++++++--- 3 files changed, 116 insertions(+), 21 deletions(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 73231f822..fb0e912ea 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -6,6 +6,7 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager +from copy import deepcopy from dataclasses import dataclass, field, replace from functools import reduce from pathlib import Path @@ -609,19 +610,19 @@ def restore_checkpoint( ): # Load model state. log.info("Loading model state...") - fsdp_model.load_state_dict( - fsdp_model._fsdp_wrapped_module._make_state_dict_compatible( - load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location="cpu") - ) + state_dict_to_load, og_keys_to_new = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible( + load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location="cpu") ) + fsdp_model.load_state_dict(state_dict_to_load) # Load optimizer state. if load_optimizer_state: log.info("Loading optimizer state...") - optim_state_dict = load_state_dict( - load_path, "optim.pt", local_cache=local_cache, map_location="cpu" + optim_state_dict_to_load = self._make_optim_state_dict_compatible( + load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location="cpu"), + og_keys_to_new, ) - load_fsdp_optim_state(fsdp_model, optim, optim_state_dict) + load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load) # Load other state. try: @@ -632,6 +633,56 @@ def restore_checkpoint( barrier() return trainer_state + def _make_optim_state_dict_compatible( + self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]] + ) -> Dict[str, Any]: + # This state dict comes in two forms: one where the state keys are integers and one where the + # keys are fully qualified parameter names. The latter case is easier to deal with here so we + # first transform the integer key form into the FQN key form. + if isinstance(next(iter(optim_state_dict.keys())), int): + id_to_fqn: Dict[int, str] = {} + for group in optim_state_dict["param_groups"]: + new_param_names = [] + for fqn, id in zip(group["param_names"], group["params"]): + fqn = fqn.replace("_fsdp_wrapped_module.", "") + id_to_fqn[id] = fqn + new_param_names.append(fqn) + group["param_names"] = new_param_names + group["params"] = new_param_names + for id in list(optim_state_dict["state"].keys()): + optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id) + else: + # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix. + for group in optim_state_dict["param_groups"]: + group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]] + group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]] + assert group["param_names"] == group["params"] + for key in list(optim_state_dict["state"].keys()): + optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[ + "state" + ].pop(key) + + # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`. + # First fix param names in the state. + for og_key, new_keys in og_keys_to_new.items(): + og_state = optim_state_dict["state"].pop(og_key) + for i, new_key in enumerate(new_keys): + if i == len(new_keys) - 1: + optim_state_dict["state"][new_key] = og_state + else: + optim_state_dict["state"][new_key] = deepcopy(og_state) + # Now fix param names in the param groups. + for group in optim_state_dict["param_groups"]: + og_names = group["params"] + new_names = [] + for og_key in og_names: + for new_key in og_keys_to_new[og_key]: + new_names.append(new_key) + group["params"] = new_names + group["param_names"] = new_names + + return optim_state_dict + def load_checkpoint( self, load_path: PathOrStr, diff --git a/olmo/model.py b/olmo/model.py index be80cefd8..77d638e7c 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -9,6 +9,7 @@ import logging import math from abc import abstractmethod +from collections import defaultdict from collections.abc import MutableMapping from functools import partial from typing import ( @@ -19,6 +20,7 @@ NamedTuple, Optional, Sequence, + Set, Tuple, cast, ) @@ -1297,7 +1299,7 @@ def from_checkpoint( # Load state dict directly to target device. state_dict_path = resource_path(checkpoint_dir, "model.pt") state_dict = torch.load(state_dict_path, map_location="cpu") - model.load_state_dict(model._make_state_dict_compatible(state_dict)) + model.load_state_dict(model._make_state_dict_compatible(state_dict)[0]) model = model.to(torch.device(device)) else: from .checkpoint import load_model_state @@ -1312,27 +1314,46 @@ def from_checkpoint( return model.eval() - def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _make_state_dict_compatible( + self, state_dict: Dict[str, torch.Tensor] + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]: + """ + Handles some cases where the state dict is valid yet may need to be transformed in order to + be loaded. + + This modifies the state dict in-place and also returns it, along with a mapping of original key + names to new key names in cases where the keys were simply renamed. That mapping can be used + to make a corresponding optimizer state dict compatible as well. + """ import re from fnmatch import fnmatch + new_keys_to_og_keys: Dict[str, str] = {} + # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work # fine without the prefixes. This also simplifies the other steps below. for key in list(state_dict.keys()): - state_dict[key.replace("_fsdp_wrapped_module.", "")] = state_dict.pop(key) + state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = key # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222 if self.config.block_type == BlockType.sequential: for key in list(state_dict.keys()): if fnmatch(key, "transformer.*.norm.weight"): tensor = state_dict.pop(key) - state_dict[key.replace("norm.weight", "attn_norm.weight")] = tensor - state_dict[key.replace("norm.weight", "ff_norm.weight")] = tensor.clone() + state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] elif fnmatch(key, "transformer.*.norm.bias"): tensor = state_dict.pop(key) - state_dict[key.replace("norm.bias", "attn_norm.bias")] = tensor - state_dict[key.replace("norm.bias", "ff_norm.bias")] = tensor.clone() + state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] # For loading a state dict that was saved with a different `block_group_size`. if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys(): @@ -1354,8 +1375,13 @@ def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Di group_idx, group_block_idx = int(m.group(1)), int(m.group(2)) block_idx = (group_idx * state_dict_block_group_size) + group_block_idx state_dict[ - key.replace(f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}.") + ( + new_key := key.replace( + f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}." + ) + ) ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) if self.config.block_group_size > 1: # Group the state dict blocks into the right block size. @@ -1367,7 +1393,16 @@ def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Di block_idx % self.config.block_group_size, ) state_dict[ - key.replace(f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}.") + ( + new_key := key.replace( + f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}." + ) + ) ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + og_keys_to_new: Dict[str, Set[str]] = defaultdict(set) + for new_key, og_key in new_keys_to_og_keys.items(): + og_keys_to_new[og_key].add(new_key) - return state_dict + return state_dict, og_keys_to_new diff --git a/tests/model_test.py b/tests/model_test.py index 591278a89..18dd5401f 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -439,12 +439,21 @@ def test_block_groups(): model_without_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=1)).eval() # We should be able to load the state dict from one model into the other, and vice-versa. - model_with_block_groups.load_state_dict( - model_with_block_groups._make_state_dict_compatible(model_without_block_groups.state_dict()) + state_dict_to_load, og_keys_to_new_keys = model_with_block_groups._make_state_dict_compatible( + model_without_block_groups.state_dict() ) - model_without_block_groups.load_state_dict( - model_without_block_groups._make_state_dict_compatible(model_with_block_groups.state_dict()) + assert og_keys_to_new_keys["transformer.blocks.0.attn_out.weight"] == { + "transformer.block_groups.0.0.attn_out.weight" + } + model_with_block_groups.load_state_dict(state_dict_to_load) + + state_dict_to_load, og_keys_to_new_keys = model_without_block_groups._make_state_dict_compatible( + model_with_block_groups.state_dict() ) + assert og_keys_to_new_keys["transformer.block_groups.0.0.attn_out.weight"] == { + "transformer.blocks.0.attn_out.weight" + } + model_without_block_groups.load_state_dict(state_dict_to_load) # Check that output is exactly the same. input_ids = torch.randint(0, model_with_block_groups.config.vocab_size, (2, 16))