Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Phi-3-mini-4k-instruct checkpoint #1341

Merged
merged 36 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
93f3024
Add phi-3 checkpoint
rasbt Apr 23, 2024
1012eaf
progress
rasbt Apr 23, 2024
581e27f
weight loading works
rasbt Apr 23, 2024
40fe01f
Convert Phi3 qkv into an interleaved one
Andrei-Aksionov Apr 25, 2024
0322ecd
Config: Phi3 doesn't use parallel residual
Andrei-Aksionov Apr 25, 2024
7f33850
Fix layer shapes in Phi3MLP
Andrei-Aksionov Apr 25, 2024
1b217ba
Config: update vocab size
Andrei-Aksionov Apr 25, 2024
6fc4a7c
Add prompt
Andrei-Aksionov Apr 25, 2024
ba1c930
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Apr 25, 2024
2ee1e0d
Add test for Phi3 model
Andrei-Aksionov Apr 25, 2024
29760ab
Update litgpt/prompts.py
rasbt May 3, 2024
6c4cd25
Merge branch 'main' into phi-3-checkpoint
rasbt May 3, 2024
fbc45b4
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Jun 26, 2024
efb8388
Fix prompt
Andrei-Aksionov Jun 26, 2024
7f092fa
The prompt has been changed. Update it
Andrei-Aksionov Jun 26, 2024
a2acd37
A workaround for a Phi-3 tokenizer
Andrei-Aksionov Jun 27, 2024
ef21d37
Convert in copy_weihght_phi without Phi3MLP
Andrei-Aksionov Jun 27, 2024
aa184e7
Config: Phi3MLP -> LlaMAMLP
Andrei-Aksionov Jun 27, 2024
4f941bb
test_model.py: add test for Phi-3
Andrei-Aksionov Jun 27, 2024
3bd0692
model.py: drop Phi3MLP
Andrei-Aksionov Jun 27, 2024
9583cd7
convert_hf: copy_weight_llama without Phi3 related code
Andrei-Aksionov Jun 27, 2024
1c661be
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Jun 27, 2024
c25e533
test_model.py: update test for Phi-3
Andrei-Aksionov Jun 27, 2024
6e484eb
test_covert_hf: add test for qkv_reassemble
Andrei-Aksionov Jun 27, 2024
c8a1e03
Update test_tokenzer to match AutoTokenizers
Andrei-Aksionov Jun 27, 2024
39614e7
Merge branch 'main' into phi-3-checkpoint
Andrei-Aksionov Jun 28, 2024
81e56b6
convert_lit: add Phi-3 code
Andrei-Aksionov Jun 28, 2024
b483ca2
test_convert_lit: prettify test for qkv_split
Andrei-Aksionov Jun 28, 2024
82b4124
Update test_prompts.py
Andrei-Aksionov Jun 28, 2024
d252505
Add Phi-3-mini to the list of supported models
Andrei-Aksionov Jun 28, 2024
0eb288e
Update README.md
rasbt Jun 28, 2024
a2683d7
Merge branch 'main' into phi-3-checkpoint
rasbt Jul 1, 2024
a923c5e
Update tutorials/download_model_weights.md
rasbt Jul 1, 2024
bcbddec
Update tutorials/download_model_weights.md
rasbt Jul 1, 2024
9dc0330
Apply Sebastian's suggestion: model_name.lower()...
Andrei-Aksionov Jul 1, 2024
ceae946
Merge branch 'main' into phi-3-checkpoint
rasbt Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ Every model is written from scratch to maximize performance and remove layers of
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
Expand Down
16 changes: 16 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,22 @@ def norm_class(self) -> Type:
lm_head_bias=True,
gelu_approximate="tanh",
),
# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
dict(
name="Phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="Phi-3-mini-4k-instruct"),
vocab_size=32000,
padded_vocab_size=32064,
block_size=4096,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
]
configs.extend(phi)

Expand Down
9 changes: 9 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"Instruct: {prompt}\nOutput:"


class Phi3(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f'<s><|user|>\n{prompt}<|end|>\n<|assistant|>\n'



class TinyLlama(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return (
Expand Down Expand Up @@ -352,6 +358,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"codellama": CodeLlama,
"phi-1": Phi1,
"phi-2": Phi2,
"phi-3": Phi3,
"tinyllama": TinyLlama,
"gemma": Gemma,
"h2oai": H2Oai,
Expand Down Expand Up @@ -392,6 +399,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Phi1()
if re.search("phi-2", model_name):
return Phi2()
if re.search("Phi-3", model_name):
return Phi3()
if re.search(r"tiny-llama.*chat", model_name):
return TinyLlama()
if re.search(r"(Code)?Gemma.*-it", model_name):
Expand Down
56 changes: 46 additions & 10 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor

from litgpt import Config
from litgpt.utils import (
extend_checkpoint_dir,
lazy_load,
incremental_save,
save_config
)
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config


def copy_weights_gpt_neox(
Expand Down Expand Up @@ -235,13 +230,36 @@ def copy_weights_phi(
"lm_head.bias": "lm_head.bias",
}

if config.name.startswith("Phi-3"):
weight_map.update(
{
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.attn.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
'model.layers.{}.post_attention_layernorm.weight': "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
}
)

for name, param in hf_weights.items():
if name.startswith("model.layers."):
from_name, l = layer_template(name, 2)
qkv = qkv_weights.setdefault(l, defaultdict(dict))
if "qkv_proj" in from_name:
weight = load_param(param, f"layer {l} qkv", dtype)
weight = qkv_reassemble(weight, config)
to_name = weight_map[from_name].format(l)
state_dict[to_name] = weight
continue
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
elif from_name.endswith("gate_up_proj.weight"):
weight = load_param(param, f"layer {l} gate_up_proj", dtype)
fc_1, fc_2 = weight.chunk(2, dim=0)
state_dict[f"transformer.h.{l}.mlp.fc_1.weight"] = fc_1
state_dict[f"transformer.h.{l}.mlp.fc_2.weight"] = fc_2
continue
to_name = weight_map[from_name]
if to_name is None:
continue
Expand Down Expand Up @@ -272,6 +290,24 @@ def copy_weights_phi(
del qkv_weights[i][weight_type]


def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
rasbt marked this conversation as resolved.
Show resolved Hide resolved
rasbt marked this conversation as resolved.
Show resolved Hide resolved
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
"""
q, k, v = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
qs = q.split(config.n_head // config.n_query_groups * config.head_size)
ks = k.split(config.head_size)
vs = v.split(config.head_size)
interleaved = [t for group in zip(qs, ks, vs) for t in group]
return torch.cat(interleaved)


def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
split = layer_name.split(".")
number = int(split[idx])
Expand Down Expand Up @@ -321,14 +357,14 @@ def convert_hf_checkpoint(

if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
elif model_name.lower().startswith("phi"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
copy_fn = partial(copy_weights_phi, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_phi, config, qkv_weights)
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
65 changes: 48 additions & 17 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import gc
from collections import defaultdict
from functools import partial
from pathlib import Path
from pprint import pprint
Expand All @@ -11,11 +12,7 @@

from litgpt import Config
from litgpt.scripts.convert_hf_checkpoint import layer_template, load_param
from litgpt.utils import (
extend_checkpoint_dir,
incremental_save,
lazy_load
)
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load


def copy_weights_falcon(
Expand Down Expand Up @@ -192,31 +189,65 @@ def copy_weights_phi(
"lm_head.bias": "lm_head.bias",
}

if config.name.startswith("Phi-3"):
weight_map.update(
{
"transformer.h.{}.attn.attn.weight": "model.layers.{}.self_attn.qkv_proj.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.norm_2.weight": 'model.layers.{}.post_attention_layernorm.weight',
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.ln_f.weight": "model.norm.weight",
}
)
gate_up_proj_weights = defaultdict(dict)


for name, param in lit_weights.items():
if name.endswith((".attn.attn.weight", ".attn.attn.bias")):
from_name, l = layer_template(name, 2)
weight_type = name.split(".")[-1] # weight or bias
q = f"model.layers.{l}.self_attn.q_proj.{weight_type}"
k = f"model.layers.{l}.self_attn.k_proj.{weight_type}"
v = f"model.layers.{l}.self_attn.v_proj.{weight_type}"
from_name, l_idx = layer_template(name, 2)
qkv = load_param(param, name, None)
qp, kp, vp = qkv_split(qkv, config)
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if config.name.startswith("Phi-3"):
qkv_reassembled = torch.concat([qp, kp, vp], dim=0)
to_name = weight_map[from_name].format(l_idx)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
qkv_reassembled = saver.store_early(qkv_reassembled)
state_dict[to_name] = qkv_reassembled
else:
weight_type = name.split(".")[-1] # weight or bias
q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}"
k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}"
v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}"
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
elif name.endswith((".fc_1.weight", ".fc_2.weight")):
from_name, l_idx = layer_template(name, 2)
weight = load_param(param, name, None)
weight_name = name.split(".")[-2]
gate_up_proj_weights[l_idx][weight_name] = weight
else:
if "transformer.h" in name:
from_name, l = layer_template(name, 2)
from_name, l_idx = layer_template(name, 2)
to_name = weight_map[from_name]
to_name = to_name.format(l)
to_name = to_name.format(l_idx)
else:
to_name = weight_map[name]
param = load_param(param, name, None)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if config.name.startswith("Phi-3"):
for i in list(gate_up_proj_weights):
fc_1_weight = gate_up_proj_weights[i]["fc_1"]
fc_2_weight = gate_up_proj_weights[i]["fc_2"]
weight = torch.concat([fc_1_weight, fc_2_weight], dim=0)
layer_name = f"model.layers.{i}.mlp.gate_up_proj.weight"
state_dict[layer_name] = weight
del gate_up_proj_weights[i]


def qkv_split(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
Expand Down Expand Up @@ -256,11 +287,11 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:

if "falcon" in config.name:
copy_fn = partial(copy_weights_falcon, config.name)
elif config.name.lower().startswith("phi"):
copy_fn = partial(copy_weights_phi, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)
elif "phi" in config.name:
copy_fn = partial(copy_weights_phi, config)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
7 changes: 7 additions & 0 deletions litgpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
if not checkpoint_dir.exists():
raise NotADirectoryError(f"The checkpoint directory does not exist: {str(checkpoint_dir)}")

self.model_name = checkpoint_dir.stem
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
self.bos_id = None
self.eos_id = None
Expand Down Expand Up @@ -114,4 +115,10 @@ def encode(

def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
# Phi-3 tokenizer strips any spaces if to decode a single token at a time.
# https://github.com/huggingface/transformers/issues/31643
if self.model_name.startswith("Phi-3") and len(tokens) == 1:
dummy_token_id = 33 # \x1e
dummy_token = self.processor.decode([dummy_token_id])
return self.processor.decode([dummy_token_id] + tokens).replace(dummy_token, "")
return self.processor.decode(tokens)
102 changes: 102 additions & 0 deletions tests/test_convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,105 @@ def test_convert_hf_checkpoint(tmp_path):
# ensure that the config dict can be loaded
config = Config.from_file(tmp_path / "model_config.yaml")
assert isinstance(config, Config)


def test_qkv_reassemble():
from litgpt import Config
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble

# MHA
config = Config(n_embd=4, n_head=4)
qkv = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # key
[28, 29, 30, 31], # key
[32, 33, 34, 35], # value
[36, 37, 38, 39], # value
[40, 41, 42, 43], # value
[44, 45, 46, 47], # value
]
)
qkv_interleaved = qkv_reassemble(qkv, config)
torch.testing.assert_close(
qkv_interleaved,
torch.tensor(
[
[0, 1, 2, 3], # query
[16, 17, 18, 19], # key
[32, 33, 34, 35], # value
[4, 5, 6, 7], # query
[20, 21, 22, 23], # key
[36, 37, 38, 39], # value
[8, 9, 10, 11], # query
[24, 25, 26, 27], # key
[40, 41, 42, 43], # value
[12, 13, 14, 15], # query
[28, 29, 30, 31], # key
[44, 45, 46, 47], # value
]
),
)

# GQA
config = Config(n_embd=4, n_head=4, n_query_groups=2)
qkv = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # key
[24, 25, 26, 27], # value
[28, 29, 30, 31], # value
]
)
qkv_interleaved = qkv_reassemble(qkv, config)
torch.testing.assert_close(
qkv_interleaved,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[16, 17, 18, 19], # key
[24, 25, 26, 27], # value
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[20, 21, 22, 23], # key
[28, 29, 30, 31], # value
]
),
)

# MQA
config = Config(n_embd=4, n_head=4, n_query_groups=1)
qkv = torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
)
qkv_interleaved = qkv_reassemble(qkv, config)
torch.testing.assert_close(
qkv_interleaved,
torch.tensor(
[
[0, 1, 2, 3], # query
[4, 5, 6, 7], # query
[8, 9, 10, 11], # query
[12, 13, 14, 15], # query
[16, 17, 18, 19], # key
[20, 21, 22, 23], # value
]
),
)
Loading