Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Llama 3 (and Llama-2-70b-hf) #549

Merged
merged 7 commits into from Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion transformer_lens/components.py
Expand Up @@ -961,7 +961,7 @@ def __init__(
attn_type: str = "global",
layer_id: Union[int, None] = None,
):
"""Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245v2 for details.
"""Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details.
Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head] and W_Q has shape [head_index, d_head, d_model].
However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called.
Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head]
Expand Down
46 changes: 44 additions & 2 deletions transformer_lens/loading_from_pretrained.py
Expand Up @@ -121,7 +121,10 @@
"CodeLlama-7b-hf",
"CodeLlama-7b-Python-hf",
"CodeLlama-7b-Instruct-hf",
# TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
"Baidicoot/Othello-GPT-Transformer-Lens",
"bert-base-cased",
"roneneldan/TinyStories-1M",
Expand Down Expand Up @@ -601,7 +604,7 @@
"llama-30b-hf",
"llama-65b-hf",
]
"""Official model names for models that not hosted on HuggingFace."""
"""Official model names for models not hosted on HuggingFace."""

# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
DEFAULT_MODEL_ALIASES = [
Expand Down Expand Up @@ -665,6 +668,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
**kwargs,
)
architecture = hf_config.architectures[0]

if official_model_name.startswith(
("llama-7b", "meta-llama/Llama-2-7b")
): # same architecture for LLaMA and Llama-2
Expand Down Expand Up @@ -781,6 +785,44 @@ def convert_hf_model_config(model_name: str, **kwargs):
"final_rms": True,
"gated_mlp": True,
}
elif "Meta-Llama-3-8B" in official_model_name:
cfg_dict = {
"d_model": 4096,
"d_head": 128,
"n_heads": 32,
"d_mlp": 14336,
"n_layers": 32,
"n_ctx": 8192,
"eps": 1e-5,
"d_vocab": 128256,
"act_fn": "silu",
"n_key_value_heads": 8,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
}
elif "Meta-Llama-3-70B" in official_model_name:
cfg_dict = {
"d_model": 8192,
"d_head": 128,
"n_heads": 64,
"d_mlp": 28672,
"n_layers": 80,
"n_ctx": 8192,
"eps": 1e-5,
"d_vocab": 128256,
"act_fn": "silu",
"n_key_value_heads": 8,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"rotary_adjacent_pairs": False,
"rotary_dim": 128,
"final_rms": True,
"gated_mlp": True,
}
elif architecture == "GPTNeoForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
Expand Down