## Load pretrain model use LlamaForCausalLM (file in ml-la container)

#### 100m 6mer 160k
pretrain ckpt  `/ml-la/v-zekunguo/gene/checkpoints/100m6kmer160k`
config path `/ml-la/v-zekunguo/gene/checkpoints/config/config_1b`


#### 1b 6mer 16k
pretrain ckpt  1bm `/ml-la/v-zekunguo/gene/checkpoints/real_1b6kmer16k`
config path `/ml-la/v-zekunguo/gene/checkpoints/config/config_100m`

In [None]:
import os
import torch
from transformers import LlamaConfig, LlamaForCausalLM

In [None]:
def load_model100m(config_path, ckpt_path):
    config = LlamaConfig.from_json_file(config_path)
    model = LlamaForCausalLM(config)
    model_dict = model.state_dict()
    print(model_dict.keys())
    flag = ""
    if not os.path.exists(os.path.join(ckpt_path, "layer_00-model_states.pt")):
        flag = "_00-model"
    # print(model_dict.keys())
    ckpt_dict = {}
    layer0 = torch.load(
        os.path.join(ckpt_path, f"layer_00-model{flag}_states.pt"),
        map_location=torch.device("cpu"),
    )
    print(layer0.keys())
    ckpt_dict["model.embed_tokens.weight"] = layer0["word_embeddings.weight"]

    for l in range(0, config.num_hidden_layers):
        l_index = str(l + 1).zfill(2)
        layer = torch.load(
            os.path.join(ckpt_path, f"layer_{l_index}-model{flag}_states.pt"),
            map_location=torch.device("cpu"),
        )
        for k in layer:
            if "dummy" in k or "rotary_emb" in k:
                continue
            if k == "self_attention.layernorm_qkv.query_weight":
                ckpt_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = layer[k]
            elif k == "self_attention.layernorm_qkv.key_weight":
                ckpt_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = layer[k]
            elif k == "self_attention.layernorm_qkv.value_weight":
                ckpt_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = layer[k]
            elif k == "self_attention.proj.weight":
                ckpt_dict[f"model.layers.{l}.self_attn.o_proj.weight"] = layer[k]
            elif k == "self_attention.layernorm_qkv.layer_norm_weight":
                ckpt_dict[f"model.layers.{l}.input_layernorm.weight"] = layer[k]
            elif k == "layernorm_mlp.layer_norm_weight":
                ckpt_dict[f"model.layers.{l}.post_attention_layernorm.weight"] = layer[
                    k
                ]
            elif k == "self_attention.proj.weight":
                ckpt_dict[f"model.layers.{l}.self_attn.o_proj.weight"] = layer[k]
            elif k == "layernorm_mlp.fc2_weight":
                ckpt_dict[f"model.layers.{l}.mlp.down_proj.weight"] = layer[k]
            elif k == "layernorm_mlp.fc1_weight":
                splits = torch.split(layer[k], int(layer[k].size(0) / 2))
                ckpt_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = splits[0]
                ckpt_dict[f"model.layers.{l}.mlp.up_proj.weight"] = splits[1]
    layer = torch.load(
        os.path.join(
            ckpt_path,
            f"layer_{config.num_hidden_layers+1}-model{flag}_states.pt",
        ),
        map_location=torch.device("cpu"),
    )
    ckpt_dict["model.norm.weight"] = layer["norm.weight"]
    layer = torch.load(
        os.path.join(
            ckpt_path,
            f"layer_{config.num_hidden_layers+2}-model{flag}_states.pt",
        ),
        map_location=torch.device("cpu"),
    )
    ckpt_dict["lm_head.weight"] = layer["lm_head.weight"]
    model_dict.update(ckpt_dict)

    model.load_state_dict(model_dict)
    return model

In [None]:
def load_model1b(config_path, ckpt_path):
    config = LlamaConfig.from_json_file(config_path)
    model = LlamaForCausalLM(config)
    model_dict = model.state_dict()
    flag = ""
    if not os.path.exists(os.path.join(ckpt_path, "layer_00-model_states.pt")):
        flag = "_00-model"
    # print(model_dict.keys())
    ckpt_dict = {}
    layer0 = torch.load(
        os.path.join(ckpt_path, f"layer_00-model{flag}_states.pt"),
        map_location=torch.device("cpu"),
    )
    print(layer0.keys())
    ckpt_dict["model.embed_tokens.weight"] = layer0["embed_tokens.weight"]

    for l in range(0, config.num_hidden_layers):
        l_index = str(l + 1).zfill(2)
        layer = torch.load(
            os.path.join(ckpt_path, f"layer_{l_index}-model{flag}_states.pt"),
            map_location=torch.device("cpu"),
        )
        for k in layer:
            if "dummy" in k or "rotary_emb" in k:
                continue
            ckpt_dict[f"model.layers.{l}.{k}"] = layer[k]
    layer = torch.load(
        os.path.join(
            ckpt_path,
            f"layer_{config.num_hidden_layers+1}-model{flag}_states.pt",
        ),
        map_location=torch.device("cpu"),
    )
    ckpt_dict["model.norm.weight"] = layer["norm.weight"]

    layer = torch.load(
        os.path.join(
            ckpt_path,
            f"layer_{config.num_hidden_layers+2}-model{flag}_states.pt",
        ),
        map_location=torch.device("cpu"),
    )
    ckpt_dict["lm_head.weight"] = layer["lm_head.weight"]
    model_dict.update(ckpt_dict)

    model.load_state_dict(model_dict)
    return model
