In [2]:
import re
from typing import Dict

import torch

In [3]:
MODEL_CHECKPOINT_PATH = (
    "/home/pranav-pc/projects/OpenTransformer/multiformer/blm-1024/checkpoints/last.ckpt"
)
model_dict = torch.load(MODEL_CHECKPOINT_PATH)

In [4]:
model_dict.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'hparams_name', 'hyper_parameters'])

In [5]:
# print(model_dict['hyper_parameters'])

In [6]:
for key in model_dict["state_dict"]:
    print(key)

tok_embd.weight
layers.0.norms.w
layers.0.attention.wq.weight
layers.0.attention.wk.weight
layers.0.attention.wv.weight
layers.0.attention.wo.weight
layers.0.mlp.linear1.weight
layers.0.mlp.linear2.weight
layers.0.mlp.linear3.weight
layers.1.norms.w
layers.1.attention.wq.weight
layers.1.attention.wk.weight
layers.1.attention.wv.weight
layers.1.attention.wo.weight
layers.1.mlp.linear1.weight
layers.1.mlp.linear2.weight
layers.1.mlp.linear3.weight
layers.2.norms.w
layers.2.attention.wq.weight
layers.2.attention.wk.weight
layers.2.attention.wv.weight
layers.2.attention.wo.weight
layers.2.mlp.linear1.weight
layers.2.mlp.linear2.weight
layers.2.mlp.linear3.weight
layers.3.norms.w
layers.3.attention.wq.weight
layers.3.attention.wk.weight
layers.3.attention.wv.weight
layers.3.attention.wo.weight
layers.3.mlp.linear1.weight
layers.3.mlp.linear2.weight
layers.3.mlp.linear3.weight
norm.w
output.weight


In [7]:
_FROM_HF = {
    "model.embed_tokens.weight": "tok_embd.weight",
    "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
    "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
    "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
    "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
    "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.linear1.weight",
    "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.linear3.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.linear2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.norms.w2",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.norms.w",
    "model.norm.weight": "norm.w",
    "lm_head.weight": "output.weight",
}

In [8]:
_FROM_META = {
    "tok_embeddings.weight": "tok_embd.weight",
    "norm.weight": "norm.w",
    "output.weight": "output.weight",
    "layers.{}.attention.wk.weight": "layers.{}.attention.wq.weight",
    "layers.{}.attention.wq.weight": "layers.{}.attention.wk.weight",
    "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
    "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
    "layers.{}.attention_norm.weight": "layers.{}.norms.w2",
    "layers.{}.ffn_norm.weight": "layers.{}.norms.w",
    "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.linear1.weight",
    "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.linear2.weight",
    "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.linear3.weight",
}

In [9]:
def _get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
    try:
        if "layers" in key:
            # Replace layer number with "{}" to create key for lookup
            abstract_key = re.sub(r"(\.\d+)", ".{}", key)
            layer_num = re.search(r"\d+", key).group(0)

            new_key = mapping_dict[abstract_key]

            new_key = new_key.format(layer_num)
        else:
            new_key = mapping_dict[key]
    except KeyError as e:
        raise Exception(
            f'Error converting the state dict. Found unexpected key: "{key}". '
            "Please make sure you're loading a checkpoint with the right format. "
        ) from e

    return new_key

In [10]:
def blm_to_hf(
    state_dict: Dict[str, torch.Tensor],
    num_heads: int = 12,
    num_kv_heads: int = 12,
    dim: int = 768,
):
    """

    Args:
        state_dict (Dict[str, torch.Tensor]): State dict in blm's format.
        num_heads (int): Number of heads in the model.
        num_kv_heads (int): Number of heads in the key/value projection layers.
        dim (int): Dimension of the model.

    Returns:
        Dict[str, torch.Tensor]: State dict in Meta's format.
    """
    converted_state_dict = {}
    inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()}
    head_dim = dim // num_heads

    def _permute(t, n_heads):
        return (
            t.view(n_heads, head_dim // 2, 2, dim)
            .transpose(1, 2)
            .reshape((head_dim * n_heads), dim)
        )

    for key, value in state_dict.items():
        new_key = _get_mapped_key(key, inverted_mapping_dict)
        if "q_proj" in key:
            value = _permute(value, num_heads)
        elif "k_proj" in key:
            value = _permute(value, num_kv_heads)
        converted_state_dict[new_key] = value

    return converted_state_dict

In [11]:
def meta_to_blm(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Args:
        state_dict (Dict[str, torch.Tensor]): State dict in Meta's format.

    Returns:
        Dict[str, torch.Tensor]: State dict in blm's format.
    """
    converted_state_dict = {}
    for key, value in state_dict.items():
        if key not in ["rope.freqs"]:  # Skip loading the position embeddings
            new_key = _get_mapped_key(key, _FROM_META)
            converted_state_dict[new_key] = value

    return converted_state_dict

In [12]:
# llama_model = torch.load('/home/pranav-pc/projects/OpenTransformer/checkpoints/llama-2-7b/consolidated.00.pth')
# llama_model = meta_to_blm(llama_model)

In [13]:
# PATH = "/home/pranav-pc/projects/OpenTransformer/multiformer/blm-1024/checkpoints/llama/blm-llama-7b.pth"

# llama_model['pytorch-lightning_version'] = '2.3.0.dev20240318'
# llama_model['hparams_name'] = 'kwargs'
# from src.models.blm.config import ModelArgs
# llama_model['hyper_parameters'] = {'args': ModelArgs(vocab_size=32000, embedding_dim=4096, max_seq_len=4096, embedding_dropout=0.0, rms_norm_eps=1e-05, rope_scaling=1.0, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, num_attention_heads=32, num_key_value_heads=32, use_cache=True, use_sliding_window=True, residual_dropout=0.1, mlp_hidden_size=11008, mlp_dropout=0.0, num_layers=32, device='cpu', padding_idx=2),
#  'is_causal': True,
#  'attn_mask': None,
#  'lr': 0.0005,
#  'cosine_t_max': 1000}
# torch.save(llama_model, PATH)

In [14]:
def blm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Args:
        state_dict (Dict[str, torch.Tensor]): State dict in blm's format.

    Returns:
        Dict[str, torch.Tensor]: State dict in Meta's format.
    """
    converted_state_dict = {}
    inverted_mapping_dict = {v: k for k, v in _FROM_META.items()}

    for key, value in state_dict.items():
        new_key = _get_mapped_key(key, inverted_mapping_dict)
        converted_state_dict[new_key] = value

    return converted_state_dict

In [15]:
# blm_to_meta(model_dict['state_dict'])

In [24]:
def hf_to_blm(
    state_dict: Dict[str, torch.Tensor],
    num_heads: int = 12,
    num_kv_heads: int = 12,
    dim: int = 768,
    head_dim: int = None,
) -> Dict[str, torch.Tensor]:
    """
    Args:
        state_dict (Dict[str, torch.Tensor]): State dict in Meta's format.
        num_heads (int): Number of heads in the model.
        num_kv_heads (int): Number of heads in the key/value projection layers.
        dim (int): Dimension of the model.
        head_dim (int): Dimension of the head. If not provided, it will be calculated
            as dim // num_heads.

    Returns:
        Dict[str, torch.Tensor]: State dict in blm's format.
    """
    converted_state_dict = {}
    if head_dim is None:
        head_dim = dim // num_heads

    def _permute(t, n_heads):
        return (
            t.view(n_heads, 2, head_dim // 2, dim)
            .transpose(1, 2)
            .reshape((head_dim * n_heads), dim)
        )

    for key, value in state_dict.items():
        if "rotary_emb.inv_freq" not in key:  # Skip loading the position embeddings
            new_key = _get_mapped_key(key, _FROM_HF)
            if "q_proj" in key:
                value = _permute(value, num_heads)
            elif "k_proj" in key:
                value = _permute(value, num_kv_heads)
            converted_state_dict[new_key] = value
    return converted_state_dict

tensor([[-5.2307e-02, -5.5769e-02, -1.3177e-01,  ...,  2.2935e-01,
         -5.4016e-02,  1.1902e-01],
        [ 3.2817e-02,  1.1855e-04,  3.0871e-02,  ..., -4.8728e-02,
         -4.3831e-02, -1.0355e-01],
        [-1.2030e-03, -2.2414e-01,  5.3518e-02,  ..., -1.0539e-01,
         -1.3322e-01, -8.1382e-02],
        ...,
        [-6.4690e-02,  2.9343e-02, -1.5258e-04,  ..., -5.8976e-02,
          1.8575e-01,  4.4744e-02],
        [ 1.7963e-01,  1.1678e-01, -5.4585e-02,  ..., -5.7763e-02,
         -2.2228e-01,  6.4764e-02],
        [ 7.1337e-02, -2.0702e-01,  2.8848e-02,  ...,  9.6218e-02,
         -2.9840e-02, -1.2775e-01]], device='cuda:0') 12
tensor([[ 0.0006, -0.1292,  0.0835,  ..., -0.1895,  0.0720, -0.1217],
        [ 0.0056, -0.0291, -0.3082,  ...,  0.0387, -0.0865, -0.1162],
        [-0.0627, -0.1507, -0.0728,  ...,  0.2257,  0.1081,  0.2148],
        ...,
        [ 0.1453,  0.0486, -0.1470,  ..., -0.0579,  0.0439,  0.0829],
        [ 0.0662, -0.1418,  0.0414,  ..., -0.0160,  0.1

{'tok_embd.weight': tensor([[-0.0407, -0.0110, -0.0283,  ...,  0.0344, -0.0348, -0.0187],
         [-0.0024,  0.0108, -0.0182,  ...,  0.0005, -0.0436,  0.0270],
         [-0.0827, -0.0160, -0.0518,  ..., -0.1533, -0.0490,  0.0909],
         ...,
         [-0.0407, -0.0110, -0.0283,  ...,  0.0344, -0.0348, -0.0188],
         [-0.0407, -0.0110, -0.0283,  ...,  0.0344, -0.0348, -0.0187],
         [-0.0406, -0.0110, -0.0283,  ...,  0.0344, -0.0347, -0.0188]],
        device='cuda:0'),
 'layers.0.norms.w': tensor([0.1570, 0.1649, 0.1578, 0.1534, 0.1473, 0.1549, 0.1634, 0.2792, 0.1523,
         0.1406, 0.0617, 0.1445, 0.1376, 0.1397, 0.1580, 0.1514, 0.1633, 0.1397,
         0.1655, 0.1390, 0.1695, 0.1402, 0.1481, 0.1264, 0.1438, 0.1410, 0.1371,
         0.2102, 0.1409, 0.1850, 0.1454, 0.1394, 0.1555, 0.1386, 0.1472, 0.1209,
         0.2077, 0.1555, 0.1558, 0.1519, 0.1542, 0.1612, 0.1654, 0.1531, 0.1833,
         0.1591, 0.1632, 0.1340, 0.1568, 0.1453, 0.0869, 0.1625, 0.1303, 0.1675,
        