Skip to content

Commit

Permalink
Merge pull request #347 from allenai/epwalsh/block-groups-load-fix
Browse files Browse the repository at this point in the history
epwalsh/block groups load fix
  • Loading branch information
epwalsh committed Nov 2, 2023
2 parents 62fc2fe + 740294f commit 026793e
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 21 deletions.
65 changes: 58 additions & 7 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from functools import reduce
from pathlib import Path
Expand Down Expand Up @@ -609,19 +610,19 @@ def restore_checkpoint(
):
# Load model state.
log.info("Loading model state...")
fsdp_model.load_state_dict(
fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(
load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location="cpu")
)
state_dict_to_load, og_keys_to_new = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(
load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location="cpu")
)
fsdp_model.load_state_dict(state_dict_to_load)

# Load optimizer state.
if load_optimizer_state:
log.info("Loading optimizer state...")
optim_state_dict = load_state_dict(
load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
optim_state_dict_to_load = self._make_optim_state_dict_compatible(
load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location="cpu"),
og_keys_to_new,
)
load_fsdp_optim_state(fsdp_model, optim, optim_state_dict)
load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)

# Load other state.
try:
Expand All @@ -632,6 +633,56 @@ def restore_checkpoint(
barrier()
return trainer_state

def _make_optim_state_dict_compatible(
self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
) -> Dict[str, Any]:
# This state dict comes in two forms: one where the state keys are integers and one where the
# keys are fully qualified parameter names. The latter case is easier to deal with here so we
# first transform the integer key form into the FQN key form.
if isinstance(next(iter(optim_state_dict.keys())), int):
id_to_fqn: Dict[int, str] = {}
for group in optim_state_dict["param_groups"]:
new_param_names = []
for fqn, id in zip(group["param_names"], group["params"]):
fqn = fqn.replace("_fsdp_wrapped_module.", "")
id_to_fqn[id] = fqn
new_param_names.append(fqn)
group["param_names"] = new_param_names
group["params"] = new_param_names
for id in list(optim_state_dict["state"].keys()):
optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
else:
# Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
for group in optim_state_dict["param_groups"]:
group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
assert group["param_names"] == group["params"]
for key in list(optim_state_dict["state"].keys()):
optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
"state"
].pop(key)

# Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
# First fix param names in the state.
for og_key, new_keys in og_keys_to_new.items():
og_state = optim_state_dict["state"].pop(og_key)
for i, new_key in enumerate(new_keys):
if i == len(new_keys) - 1:
optim_state_dict["state"][new_key] = og_state
else:
optim_state_dict["state"][new_key] = deepcopy(og_state)
# Now fix param names in the param groups.
for group in optim_state_dict["param_groups"]:
og_names = group["params"]
new_names = []
for og_key in og_names:
for new_key in og_keys_to_new[og_key]:
new_names.append(new_key)
group["params"] = new_names
group["param_names"] = new_names

return optim_state_dict

def load_checkpoint(
self,
load_path: PathOrStr,
Expand Down
55 changes: 45 additions & 10 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import math
from abc import abstractmethod
from collections import defaultdict
from collections.abc import MutableMapping
from functools import partial
from typing import (
Expand All @@ -19,6 +20,7 @@
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
cast,
)
Expand Down Expand Up @@ -1530,7 +1532,7 @@ def from_checkpoint(
# Load state dict directly to target device.
state_dict_path = resource_path(checkpoint_dir, "model.pt")
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(model._make_state_dict_compatible(state_dict))
model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
model = model.to(torch.device(device))
else:
from .checkpoint import load_model_state
Expand All @@ -1545,27 +1547,46 @@ def from_checkpoint(

return model.eval()

def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def _make_state_dict_compatible(
self, state_dict: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
"""
Handles some cases where the state dict is valid yet may need to be transformed in order to
be loaded.
This modifies the state dict in-place and also returns it, along with a mapping of original key
names to new key names in cases where the keys were simply renamed. That mapping can be used
to make a corresponding optimizer state dict compatible as well.
"""
import re
from fnmatch import fnmatch

new_keys_to_og_keys: Dict[str, str] = {}

# Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
# not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
# fine without the prefixes. This also simplifies the other steps below.
for key in list(state_dict.keys()):
state_dict[key.replace("_fsdp_wrapped_module.", "")] = state_dict.pop(key)
state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
new_keys_to_og_keys[new_key] = key

# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
if self.config.block_type == BlockType.sequential:
for key in list(state_dict.keys()):
if fnmatch(key, "transformer.*.norm.weight"):
tensor = state_dict.pop(key)
state_dict[key.replace("norm.weight", "attn_norm.weight")] = tensor
state_dict[key.replace("norm.weight", "ff_norm.weight")] = tensor.clone()
state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
del new_keys_to_og_keys[key]
elif fnmatch(key, "transformer.*.norm.bias"):
tensor = state_dict.pop(key)
state_dict[key.replace("norm.bias", "attn_norm.bias")] = tensor
state_dict[key.replace("norm.bias", "ff_norm.bias")] = tensor.clone()
state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
del new_keys_to_og_keys[key]

# For loading a state dict that was saved with a different `block_group_size`.
if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
Expand All @@ -1587,8 +1608,13 @@ def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Di
group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
state_dict[
key.replace(f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}.")
(
new_key := key.replace(
f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
)
)
] = state_dict.pop(key)
new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)

if self.config.block_group_size > 1:
# Group the state dict blocks into the right block size.
Expand All @@ -1600,7 +1626,16 @@ def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Di
block_idx % self.config.block_group_size,
)
state_dict[
key.replace(f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}.")
(
new_key := key.replace(
f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
)
)
] = state_dict.pop(key)
new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)

og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
for new_key, og_key in new_keys_to_og_keys.items():
og_keys_to_new[og_key].add(new_key)

return state_dict
return state_dict, og_keys_to_new
17 changes: 13 additions & 4 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,21 @@ def test_block_groups():
model_without_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=1)).eval()

# We should be able to load the state dict from one model into the other, and vice-versa.
model_with_block_groups.load_state_dict(
model_with_block_groups._make_state_dict_compatible(model_without_block_groups.state_dict())
state_dict_to_load, og_keys_to_new_keys = model_with_block_groups._make_state_dict_compatible(
model_without_block_groups.state_dict()
)
model_without_block_groups.load_state_dict(
model_without_block_groups._make_state_dict_compatible(model_with_block_groups.state_dict())
assert og_keys_to_new_keys["transformer.blocks.0.attn_out.weight"] == {
"transformer.block_groups.0.0.attn_out.weight"
}
model_with_block_groups.load_state_dict(state_dict_to_load)

state_dict_to_load, og_keys_to_new_keys = model_without_block_groups._make_state_dict_compatible(
model_with_block_groups.state_dict()
)
assert og_keys_to_new_keys["transformer.block_groups.0.0.attn_out.weight"] == {
"transformer.blocks.0.attn_out.weight"
}
model_without_block_groups.load_state_dict(state_dict_to_load)

# Check that output is exactly the same.
input_ids = torch.randint(0, model_with_block_groups.config.vocab_size, (2, 16))
Expand Down

0 comments on commit 026793e

Please sign in to comment.