In [2]:
import os
os.environ["TRANSFORMERS_CACHE"] = os.path.join("/home/", "huggingface")
os.environ["HF_HUB_CACHE"] = os.path.join("/home/", "huggingface")

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.12s/it]


In [4]:
import torch

params = {
    "dim": 4096,
    "n_layers": 32,
    "head_dim": 128,
    "hidden_dim": 14336,
    "n_heads": 32,
    "n_kv_heads": 8,
    "norm_eps": 1e-05,
    "sliding_window": 4096,
    "vocab_size": 32000
}
num_shards = 1
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 100000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
max_position_embeddings = 4096 * 8

if "n_kv_heads" in params:
    num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
    num_local_key_value_heads = num_key_value_heads // num_shards
    key_value_dim = dims_per_head * num_local_key_value_heads
else: 
    num_key_value_heads = n_heads
    num_local_key_value_heads = n_heads_per_shard
    key_value_dim = dim

def permute(w, n_heads=n_heads, dim1, dim2):
    return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

def reverse_permute(w, n_heads, dim1, dim2):
    return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)


In [7]:
new_state_dict = {}

replacements = {
    ".embed_tokens.": ".tok_embeddings.",
    ".self_attn.": ".attention.",
    ".q_proj.": ".wq.",
    ".k_proj.": ".wk.",
    ".v_proj.": ".wv.",
    ".o_proj.": ".wo.",
    ".mlp.": ".feed_forward.",
    ".gate_proj.": ".w1.",
    ".up_proj.": ".w3.",
    ".down_proj.": ".w2.",
    ".input_layernorm.": ".attention_norm.",
    ".post_attention_layernorm.": ".ffn_norm.",
    ".lm_head.": ".output.",
    ".norm.": ".norm."
}

for key, val in model.state_dict().items():
    # change shape for some weights
    if "k_proj" in key:
        val = reverse_permute(
            val,
            num_key_value_heads,
            key_value_dim,
            dim,
        )
    elif "q_proj" in key:
        val = reverse_permute(
            val,
            n_heads,
            dim,
            dim,
        )
    key = key.replace("model.", "")
    key = key.replace("blocks.", "layers.")
    key = "llma." + key
    # replace
    for old, new in replacements.items():
        key = key.replace(old, new)

    new_state_dict[key] = val

new_state_dict = {"model": new_state_dict}

torch.save(new_state_dict, "consolidated.00-of-01.model.pth")

In [8]:
# make sure the model is saved correctly
model_original = torch.load("/home/mistral-src-main/mistral-7B-v0.1/consolidated.00.pth")
model_converted = torch.load("consolidated.00-of-01.model.pth")

In [10]:
# check model weights are the same
for k, v in model_original.items():
    if torch.allclose(v, model_converted['model']['llma.' + k].to(dtype=torch.bfloat16)):
        print(f"{k} is the same")
    else:
        print(f"********************{k} is different********************")

tok_embeddings.weight is the same
norm.weight is the same
output.weight is the same
layers.0.attention.wq.weight is the same
layers.0.attention.wk.weight is the same
layers.0.attention.wv.weight is the same
layers.0.attention.wo.weight is the same
layers.0.feed_forward.w1.weight is the same
layers.0.feed_forward.w2.weight is the same
layers.0.feed_forward.w3.weight is the same
layers.0.attention_norm.weight is the same
layers.0.ffn_norm.weight is the same
layers.1.attention.wq.weight is the same
layers.1.attention.wk.weight is the same
layers.1.attention.wv.weight is the same
layers.1.attention.wo.weight is the same
layers.1.feed_forward.w1.weight is the same
layers.1.feed_forward.w2.weight is the same
layers.1.feed_forward.w3.weight is the same
layers.1.attention_norm.weight is the same
layers.1.ffn_norm.weight is the same
layers.2.attention.wq.weight is the same
layers.2.attention.wk.weight is the same
layers.2.attention.wv.weight is the same
layers.2.attention.wo.weight is the same
