In [3]:
import sys
import os
from argparse import ArgumentParser

sys.path.extend([".", "..", "../.."])

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch
from sfm.models.scigpt.scigpt import ScigptModel
from sfm.models.scigpt.config import ScigptConfig
from sfm.utils import arg_utils
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from sfm.utils.science_tokens import SCIENCE_TAG_TOKENS

from sfm.models.llama2.llama_modules_3dmp_te import TELlamaModel


  from .autonotebook import tqdm as notebook_tqdm


[2024-05-09 18:59:15,320] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[[32m2024-05-09 18:59:15.682[0m][[36mINFO[0m]: apex is installed, using FusedAdam with fp16 optimizer states
[[32m2024-05-09 18:59:16.236[0m][[36mINFO[0m]: Using TEColumnParallelLinear and TERowParallelLinear in tensor parallel


In [4]:
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def get_args_and_tokenizer(use_llama=False):
    parser = ArgumentParser()
    cfg_classes = [ScigptConfig]
    parser = arg_utils.add_dataclass_to_parser(cfg_classes, parser)
    args = parser.parse_args(args=[])
    args.load_ckpt = False
    args.strategy = "DDP"
    args.encoder_layers = 33
    args.encoder_embed_dim = 1280
    args.encoder_ffn_embed_dim = 5120
    args.encoder_attention_heads = 20
    args.infer = True
    args.bf16 = True
    
    tokenizer = AutoTokenizer.from_pretrained("/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original")
    args.save_dir = "/data/peiran/expresult/llama3_8B_stageB/global_step16999"
    args.llm_model_name_or_path = "/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original"

    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    special_tokens_dict["additional_special_tokens"] = SCIENCE_TAG_TOKENS
    tokenizer.add_special_tokens(special_tokens_dict)
        

    return args, tokenizer

args, tokenizer = get_args_and_tokenizer()
print(type(tokenizer), len(tokenizer))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<class 'transformers.tokenization_utils_fast.PreTrainedTokenizerFast'> 128384


In [10]:
ckpt_dict = {}
llama_config = LlamaConfig.from_pretrained("/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original")
model = TELlamaModel(args, llama_config)




In [14]:
print(model.layers)

[TELlamaDecoderLayer(
  (self_attention): MultiheadAttention(
    (layernorm_qkv): LayerNormLinear()
    (core_attention): DotProductAttention(
      (flash_attention): FlashAttention()
      (fused_attention): FusedAttention()
      (unfused_attention): UnfusedDotProductAttention(
        (scale_mask_softmax): FusedScaleMaskSoftmax()
        (attention_dropout): Dropout(p=0, inplace=False)
      )
    )
    (proj): Linear()
  )
  (layernorm_mlp): LayerNormMLP()
), TELlamaDecoderLayer(
  (self_attention): MultiheadAttention(
    (layernorm_qkv): LayerNormLinear()
    (core_attention): DotProductAttention(
      (flash_attention): FlashAttention()
      (fused_attention): FusedAttention()
      (unfused_attention): UnfusedDotProductAttention(
        (scale_mask_softmax): FusedScaleMaskSoftmax()
        (attention_dropout): Dropout(p=0, inplace=False)
      )
    )
    (proj): Linear()
  )
  (layernorm_mlp): LayerNormMLP()
), TELlamaDecoderLayer(
  (self_attention): MultiheadAttention(


In [12]:
model_dict = model.state_dict()
print(model_dict.keys())

# model_dict["model.embed_tokens.weight"].shape

odict_keys(['dummy', 'word_embeddings.weight', 'norm.weight', 'lm_head.weight'])


In [29]:
ckpt_dict = {}

layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_00-model_states.pt"), map_location=torch.device("cpu"))
layer1 = torch.load(os.path.join(args.save_dir, "layer_00-model_01-model_states.pt"), map_location=torch.device("cpu"))

for k, _ in layer0.items():
    new_k = "model.embed_tokens.weight"
    v = torch.cat([layer0[k], layer1[k]], dim=0)
    ckpt_dict[new_k] = v

del layer0, layer1

for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer0 = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_00-model_states.pt"), map_location=torch.device("cpu"))
    layer1 = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_01-model_states.pt"), map_location=torch.device("cpu"))

    for k in layer0:
        if k.find("norm") != -1:
            v = layer0[k]
        elif layer0[k].shape[0] == 4096:
            v = torch.cat([layer0[k], layer1[k]], dim=1)
        elif layer0[k].shape[1] == 4096:
            v = torch.cat([layer0[k], layer1[k]], dim=0)

        ckpt_dict[f"model.layers.{l}.{k}"] = v

    del layer0, layer1

layer0 = torch.load(os.path.join(args.save_dir, "layer_33-model_00-model_states.pt"), map_location=torch.device("cpu"))
layer1 = torch.load(os.path.join(args.save_dir, "layer_33-model_01-model_states.pt"), map_location=torch.device("cpu"))

for k, _ in layer0.items():
    new_k = "model." + k
    v = torch.cat([layer0[k], layer1[k]], dim=0)
    ckpt_dict[new_k] = v

del layer0, layer1

layer0 = torch.load(os.path.join(args.save_dir, "layer_34-model_00-model_states.pt"), map_location=torch.device("cpu"))
layer1 = torch.load(os.path.join(args.save_dir, "layer_34-model_01-model_states.pt"), map_location=torch.device("cpu"))

for k, _ in layer0.items():
    new_k = "model." + k
    v = torch.cat([layer0[k], layer1[k]], dim=0)
    ckpt_dict[new_k] = v

del layer0, layer1

print(f"ckpt_dict: {ckpt_dict.keys()}")
model_dict.update(ckpt_dict)
model.load_state_dict(model_dict)

ckpt_dict: dict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attention.layernorm_qkv.layer_norm_weight', 'model.layers.0.self_attention.layernorm_qkv.query_weight', 'model.layers.0.self_attention.layernorm_qkv.key_weight', 'model.layers.0.self_attention.layernorm_qkv.value_weight', 'model.layers.0.self_attention.proj.weight', 'model.layers.0.layernorm_mlp.layer_norm_weight', 'model.layers.0.layernorm_mlp.fc1_weight', 'model.layers.0.layernorm_mlp.fc2_weight', 'model.layers.1.self_attention.layernorm_qkv.layer_norm_weight', 'model.layers.1.self_attention.layernorm_qkv.query_weight', 'model.layers.1.self_attention.layernorm_qkv.key_weight', 'model.layers.1.self_attention.layernorm_qkv.value_weight', 'model.layers.1.self_attention.proj.weight', 'model.layers.1.layernorm_mlp.layer_norm_weight', 'model.layers.1.layernorm_mlp.fc1_weight', 'model.layers.1.layernorm_mlp.fc2_weight', 'model.layers.2.self_attention.layernorm_qkv.layer_norm_weight', 'model.layers.2.self_attention.layer

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
	Unexpected key(s) in state_dict: "model.word_embeddings.weight", "model.dummy.weight", "model.dummy.bias", "model.lm_head.weight", "model.num_head.fc1.weight", "model.num_head.fc1.bias", "model.num_head.fc2.weight", "model.num_head.fc2.bias", "model.layers.0.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.0.self_attention.layernorm_qkv.query_weight", "model.layers.0.self_attention.layernorm_qkv.key_weight", "model.layers.0.self_attention.layernorm_qkv.value_weight", "model.layers.0.self_attention.proj.weight", "model.layers.0.layernorm_mlp.layer_norm_weight", "model.layers.0.layernorm_mlp.fc1_weight", "model.layers.0.layernorm_mlp.fc2_weight", "model.layers.1.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.1.self_attention.layernorm_qkv.query_weight", "model.layers.1.self_attention.layernorm_qkv.key_weight", "model.layers.1.self_attention.layernorm_qkv.value_weight", "model.layers.1.self_attention.proj.weight", "model.layers.1.layernorm_mlp.layer_norm_weight", "model.layers.1.layernorm_mlp.fc1_weight", "model.layers.1.layernorm_mlp.fc2_weight", "model.layers.2.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.2.self_attention.layernorm_qkv.query_weight", "model.layers.2.self_attention.layernorm_qkv.key_weight", "model.layers.2.self_attention.layernorm_qkv.value_weight", "model.layers.2.self_attention.proj.weight", "model.layers.2.layernorm_mlp.layer_norm_weight", "model.layers.2.layernorm_mlp.fc1_weight", "model.layers.2.layernorm_mlp.fc2_weight", "model.layers.3.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.3.self_attention.layernorm_qkv.query_weight", "model.layers.3.self_attention.layernorm_qkv.key_weight", "model.layers.3.self_attention.layernorm_qkv.value_weight", "model.layers.3.self_attention.proj.weight", "model.layers.3.layernorm_mlp.layer_norm_weight", "model.layers.3.layernorm_mlp.fc1_weight", "model.layers.3.layernorm_mlp.fc2_weight", "model.layers.4.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.4.self_attention.layernorm_qkv.query_weight", "model.layers.4.self_attention.layernorm_qkv.key_weight", "model.layers.4.self_attention.layernorm_qkv.value_weight", "model.layers.4.self_attention.proj.weight", "model.layers.4.layernorm_mlp.layer_norm_weight", "model.layers.4.layernorm_mlp.fc1_weight", "model.layers.4.layernorm_mlp.fc2_weight", "model.layers.5.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.5.self_attention.layernorm_qkv.query_weight", "model.layers.5.self_attention.layernorm_qkv.key_weight", "model.layers.5.self_attention.layernorm_qkv.value_weight", "model.layers.5.self_attention.proj.weight", "model.layers.5.layernorm_mlp.layer_norm_weight", "model.layers.5.layernorm_mlp.fc1_weight", "model.layers.5.layernorm_mlp.fc2_weight", "model.layers.6.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.6.self_attention.layernorm_qkv.query_weight", "model.layers.6.self_attention.layernorm_qkv.key_weight", "model.layers.6.self_attention.layernorm_qkv.value_weight", "model.layers.6.self_attention.proj.weight", "model.layers.6.layernorm_mlp.layer_norm_weight", "model.layers.6.layernorm_mlp.fc1_weight", "model.layers.6.layernorm_mlp.fc2_weight", "model.layers.7.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.7.self_attention.layernorm_qkv.query_weight", "model.layers.7.self_attention.layernorm_qkv.key_weight", "model.layers.7.self_attention.layernorm_qkv.value_weight", "model.layers.7.self_attention.proj.weight", "model.layers.7.layernorm_mlp.layer_norm_weight", "model.layers.7.layernorm_mlp.fc1_weight", "model.layers.7.layernorm_mlp.fc2_weight", "model.layers.8.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.8.self_attention.layernorm_qkv.query_weight", "model.layers.8.self_attention.layernorm_qkv.key_weight", "model.layers.8.self_attention.layernorm_qkv.value_weight", "model.layers.8.self_attention.proj.weight", "model.layers.8.layernorm_mlp.layer_norm_weight", "model.layers.8.layernorm_mlp.fc1_weight", "model.layers.8.layernorm_mlp.fc2_weight", "model.layers.9.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.9.self_attention.layernorm_qkv.query_weight", "model.layers.9.self_attention.layernorm_qkv.key_weight", "model.layers.9.self_attention.layernorm_qkv.value_weight", "model.layers.9.self_attention.proj.weight", "model.layers.9.layernorm_mlp.layer_norm_weight", "model.layers.9.layernorm_mlp.fc1_weight", "model.layers.9.layernorm_mlp.fc2_weight", "model.layers.10.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.10.self_attention.layernorm_qkv.query_weight", "model.layers.10.self_attention.layernorm_qkv.key_weight", "model.layers.10.self_attention.layernorm_qkv.value_weight", "model.layers.10.self_attention.proj.weight", "model.layers.10.layernorm_mlp.layer_norm_weight", "model.layers.10.layernorm_mlp.fc1_weight", "model.layers.10.layernorm_mlp.fc2_weight", "model.layers.11.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.11.self_attention.layernorm_qkv.query_weight", "model.layers.11.self_attention.layernorm_qkv.key_weight", "model.layers.11.self_attention.layernorm_qkv.value_weight", "model.layers.11.self_attention.proj.weight", "model.layers.11.layernorm_mlp.layer_norm_weight", "model.layers.11.layernorm_mlp.fc1_weight", "model.layers.11.layernorm_mlp.fc2_weight", "model.layers.12.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.12.self_attention.layernorm_qkv.query_weight", "model.layers.12.self_attention.layernorm_qkv.key_weight", "model.layers.12.self_attention.layernorm_qkv.value_weight", "model.layers.12.self_attention.proj.weight", "model.layers.12.layernorm_mlp.layer_norm_weight", "model.layers.12.layernorm_mlp.fc1_weight", "model.layers.12.layernorm_mlp.fc2_weight", "model.layers.13.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.13.self_attention.layernorm_qkv.query_weight", "model.layers.13.self_attention.layernorm_qkv.key_weight", "model.layers.13.self_attention.layernorm_qkv.value_weight", "model.layers.13.self_attention.proj.weight", "model.layers.13.layernorm_mlp.layer_norm_weight", "model.layers.13.layernorm_mlp.fc1_weight", "model.layers.13.layernorm_mlp.fc2_weight", "model.layers.14.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.14.self_attention.layernorm_qkv.query_weight", "model.layers.14.self_attention.layernorm_qkv.key_weight", "model.layers.14.self_attention.layernorm_qkv.value_weight", "model.layers.14.self_attention.proj.weight", "model.layers.14.layernorm_mlp.layer_norm_weight", "model.layers.14.layernorm_mlp.fc1_weight", "model.layers.14.layernorm_mlp.fc2_weight", "model.layers.15.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.15.self_attention.layernorm_qkv.query_weight", "model.layers.15.self_attention.layernorm_qkv.key_weight", "model.layers.15.self_attention.layernorm_qkv.value_weight", "model.layers.15.self_attention.proj.weight", "model.layers.15.layernorm_mlp.layer_norm_weight", "model.layers.15.layernorm_mlp.fc1_weight", "model.layers.15.layernorm_mlp.fc2_weight", "model.layers.16.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.16.self_attention.layernorm_qkv.query_weight", "model.layers.16.self_attention.layernorm_qkv.key_weight", "model.layers.16.self_attention.layernorm_qkv.value_weight", "model.layers.16.self_attention.proj.weight", "model.layers.16.layernorm_mlp.layer_norm_weight", "model.layers.16.layernorm_mlp.fc1_weight", "model.layers.16.layernorm_mlp.fc2_weight", "model.layers.17.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.17.self_attention.layernorm_qkv.query_weight", "model.layers.17.self_attention.layernorm_qkv.key_weight", "model.layers.17.self_attention.layernorm_qkv.value_weight", "model.layers.17.self_attention.proj.weight", "model.layers.17.layernorm_mlp.layer_norm_weight", "model.layers.17.layernorm_mlp.fc1_weight", "model.layers.17.layernorm_mlp.fc2_weight", "model.layers.18.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.18.self_attention.layernorm_qkv.query_weight", "model.layers.18.self_attention.layernorm_qkv.key_weight", "model.layers.18.self_attention.layernorm_qkv.value_weight", "model.layers.18.self_attention.proj.weight", "model.layers.18.layernorm_mlp.layer_norm_weight", "model.layers.18.layernorm_mlp.fc1_weight", "model.layers.18.layernorm_mlp.fc2_weight", "model.layers.19.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.19.self_attention.layernorm_qkv.query_weight", "model.layers.19.self_attention.layernorm_qkv.key_weight", "model.layers.19.self_attention.layernorm_qkv.value_weight", "model.layers.19.self_attention.proj.weight", "model.layers.19.layernorm_mlp.layer_norm_weight", "model.layers.19.layernorm_mlp.fc1_weight", "model.layers.19.layernorm_mlp.fc2_weight", "model.layers.20.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.20.self_attention.layernorm_qkv.query_weight", "model.layers.20.self_attention.layernorm_qkv.key_weight", "model.layers.20.self_attention.layernorm_qkv.value_weight", "model.layers.20.self_attention.proj.weight", "model.layers.20.layernorm_mlp.layer_norm_weight", "model.layers.20.layernorm_mlp.fc1_weight", "model.layers.20.layernorm_mlp.fc2_weight", "model.layers.21.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.21.self_attention.layernorm_qkv.query_weight", "model.layers.21.self_attention.layernorm_qkv.key_weight", "model.layers.21.self_attention.layernorm_qkv.value_weight", "model.layers.21.self_attention.proj.weight", "model.layers.21.layernorm_mlp.layer_norm_weight", "model.layers.21.layernorm_mlp.fc1_weight", "model.layers.21.layernorm_mlp.fc2_weight", "model.layers.22.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.22.self_attention.layernorm_qkv.query_weight", "model.layers.22.self_attention.layernorm_qkv.key_weight", "model.layers.22.self_attention.layernorm_qkv.value_weight", "model.layers.22.self_attention.proj.weight", "model.layers.22.layernorm_mlp.layer_norm_weight", "model.layers.22.layernorm_mlp.fc1_weight", "model.layers.22.layernorm_mlp.fc2_weight", "model.layers.23.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.23.self_attention.layernorm_qkv.query_weight", "model.layers.23.self_attention.layernorm_qkv.key_weight", "model.layers.23.self_attention.layernorm_qkv.value_weight", "model.layers.23.self_attention.proj.weight", "model.layers.23.layernorm_mlp.layer_norm_weight", "model.layers.23.layernorm_mlp.fc1_weight", "model.layers.23.layernorm_mlp.fc2_weight", "model.layers.24.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.24.self_attention.layernorm_qkv.query_weight", "model.layers.24.self_attention.layernorm_qkv.key_weight", "model.layers.24.self_attention.layernorm_qkv.value_weight", "model.layers.24.self_attention.proj.weight", "model.layers.24.layernorm_mlp.layer_norm_weight", "model.layers.24.layernorm_mlp.fc1_weight", "model.layers.24.layernorm_mlp.fc2_weight", "model.layers.25.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.25.self_attention.layernorm_qkv.query_weight", "model.layers.25.self_attention.layernorm_qkv.key_weight", "model.layers.25.self_attention.layernorm_qkv.value_weight", "model.layers.25.self_attention.proj.weight", "model.layers.25.layernorm_mlp.layer_norm_weight", "model.layers.25.layernorm_mlp.fc1_weight", "model.layers.25.layernorm_mlp.fc2_weight", "model.layers.26.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.26.self_attention.layernorm_qkv.query_weight", "model.layers.26.self_attention.layernorm_qkv.key_weight", "model.layers.26.self_attention.layernorm_qkv.value_weight", "model.layers.26.self_attention.proj.weight", "model.layers.26.layernorm_mlp.layer_norm_weight", "model.layers.26.layernorm_mlp.fc1_weight", "model.layers.26.layernorm_mlp.fc2_weight", "model.layers.27.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.27.self_attention.layernorm_qkv.query_weight", "model.layers.27.self_attention.layernorm_qkv.key_weight", "model.layers.27.self_attention.layernorm_qkv.value_weight", "model.layers.27.self_attention.proj.weight", "model.layers.27.layernorm_mlp.layer_norm_weight", "model.layers.27.layernorm_mlp.fc1_weight", "model.layers.27.layernorm_mlp.fc2_weight", "model.layers.28.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.28.self_attention.layernorm_qkv.query_weight", "model.layers.28.self_attention.layernorm_qkv.key_weight", "model.layers.28.self_attention.layernorm_qkv.value_weight", "model.layers.28.self_attention.proj.weight", "model.layers.28.layernorm_mlp.layer_norm_weight", "model.layers.28.layernorm_mlp.fc1_weight", "model.layers.28.layernorm_mlp.fc2_weight", "model.layers.29.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.29.self_attention.layernorm_qkv.query_weight", "model.layers.29.self_attention.layernorm_qkv.key_weight", "model.layers.29.self_attention.layernorm_qkv.value_weight", "model.layers.29.self_attention.proj.weight", "model.layers.29.layernorm_mlp.layer_norm_weight", "model.layers.29.layernorm_mlp.fc1_weight", "model.layers.29.layernorm_mlp.fc2_weight", "model.layers.30.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.30.self_attention.layernorm_qkv.query_weight", "model.layers.30.self_attention.layernorm_qkv.key_weight", "model.layers.30.self_attention.layernorm_qkv.value_weight", "model.layers.30.self_attention.proj.weight", "model.layers.30.layernorm_mlp.layer_norm_weight", "model.layers.30.layernorm_mlp.fc1_weight", "model.layers.30.layernorm_mlp.fc2_weight", "model.layers.31.self_attention.layernorm_qkv.layer_norm_weight", "model.layers.31.self_attention.layernorm_qkv.query_weight", "model.layers.31.self_attention.layernorm_qkv.key_weight", "model.layers.31.self_attention.layernorm_qkv.value_weight", "model.layers.31.self_attention.proj.weight", "model.layers.31.layernorm_mlp.layer_norm_weight", "model.layers.31.layernorm_mlp.fc1_weight", "model.layers.31.layernorm_mlp.fc2_weight". 
	size mismatch for model.norm.weight: copying a param with shape torch.Size([8192]) from checkpoint, the shape in current model is torch.Size([4096]).

In [None]:
device = torch.device("cuda")
model.decoder.resize_token_embeddings(len(tokenizer))
model = model.to(torch.bfloat16).to(device)

model.eval()

# print(f"input: {text},\n output: {res}")

# # output = model.generate(
# #     input_ids=batched_data['input_ids'],
# #     num_return_sequences=10,
# #     num_beams=20,
# # )
# for i in range(10):
#     print(tokenizer.decode(output[i]))


In [None]:
output = model.decoder.generate(
    input_ids=torch.tensor(tokenizer.encode("Football is a ", return_tensors="pt")).to(device),
    num_beams=5,
    max_new_tokens=512,
    num_return_sequences=1,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.5,
)
res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

In [1]:
import torch
ckpt1 = torch.load("/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original/layer_34-model_states.pt", map_location=torch.device("cpu"))

In [None]:
ckpt2 = torch.load("/data/peiran/blob/hai1data/sfm/nlm/output/llama3_stageA_tp2/backup/global_step1999/layer_34-model_00-model_states.pt", map_location=torch.device("cpu"))

In [None]:
print(ckpt1.keys(), ckpt2.keys())
# print(ckpt1["embed_tokens.weight"].shape)
print(ckpt2["word_embeddings.weight"].shape)

In [None]:

print(torch.sum(torch.abs(ckpt1["embed_tokens.weight"][65152:,:]-ckpt2["word_embeddings.weight"][:128256-65152, :])))

In [None]:
torch.set_printoptions(threshold=10000000)
print(torch.mean(torch.abs(ckpt1["embed_tokens.weight"][:65152,:].to(torch.bfloat16)-ckpt2["word_embeddings.weight"][:65152, :].to(torch.bfloat16))))

In [None]:
key1 = "self_attn.k_proj.weight"
key2 = "self_attention.layernorm_qkv.key_weight"
print(ckpt1[key1].shape, ckpt2[key2].shape)
print(torch.sum(torch.abs(ckpt1["self_attn.k_proj.weight"][:512,:]-ckpt2["self_attention.layernorm_qkv.key_weight"][:512,:])))

In [None]:
print(torch.sum(torch.abs(ckpt1["lm_head.weight"][:65152, :]-ckpt2["lm_head.weight"][:65152, :])))

In [14]:
ckpt3 = ckpt2
# print(ckpt1["lm_head.weight"].shape, ckpt1["lm_head.weight"].dtype)
ckpt3["lm_head.weight"][:, :] = ckpt1["lm_head.weight"][:65152, :]
print(ckpt3["lm_head.weight"][:128256-65152, :].shape, ckpt1["lm_head.weight"][65152:, :].shape)
ckpt3["lm_head.weight"][:128256-65152, :] = ckpt1["lm_head.weight"][65152:, :]
ckpt3["lm_head.weight"][128256-65152:, :] = ckpt2["lm_head.weight"][128256-65152:, :]
torch.save(ckpt3, "/data/peiran/blob/hai1data/sfm/nlm/output/llama3_stageA_tp2/backup/global_step1999/layer_34-model_01-model_states.pt")

torch.Size([63104, 4096]) torch.Size([63104, 4096])
