In [None]:
import sys
import os
from argparse import ArgumentParser

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch
from sfm.models.scigpt.scigpt import ScigptModel
from sfm.models.scigpt.config import ScigptConfig
from sfm.utils import arg_utils
from sfm.utils.science_tokens import SCIENCE_TAG_TOKENS



In [None]:
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def get_args_and_tokenizer(use_llama=False):
    parser = ArgumentParser()
    cfg_classes = [ScigptConfig]
    parser = arg_utils.add_dataclass_to_parser(cfg_classes, parser)
    args = parser.parse_args(args=[])
    args.load_ckpt = False
    args.strategy = "DDP"
    args.encoder_layers = 33
    args.encoder_embed_dim = 1280
    args.encoder_ffn_embed_dim = 5120
    args.encoder_attention_heads = 20
    args.infer = True
    args.bf16 = True
    
    tokenizer = AutoTokenizer.from_pretrained("/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original")
    args.save_dir = "/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original"
    args.llm_model_name_or_path = "/data/peiran/blob/hai1data/sfm/llama/Meta-Llama-3-8B/original"

    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    # special_tokens_dict["additional_special_tokens"] = SCIENCE_TAG_TOKENS
    tokenizer.add_special_tokens(special_tokens_dict)
        

    return args, tokenizer

args, tokenizer = get_args_and_tokenizer()
print(type(tokenizer))

In [None]:
ckpt_dict = {}

model = ScigptModel(args)
# model = AutoModelForCausalLM.from_pretrained(args.save_dir)

model_dict = model.state_dict()
print(f"model_dict: {model_dict.keys()}")

layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer0.items():
    new_k = "decoder.model." + k
    ckpt_dict[new_k] = v

# layer1 = torch.load(os.path.join(args.save_dir, "layer_01-model_states.pt"), map_location=torch.device("cpu"))
# ckpt_dict['embed_tokens.weight'] = layer1['embed_tokens.weight']

# layer2 = torch.load(os.path.join(args.save_dir, "layer_02-model_states.pt"), map_location=torch.device("cpu"))
# for k, v in layer2.items():
#     new_k = "adaptor." + k
#     ckpt_dict[new_k] = v

for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"decoder.model.layers.{l}.{k}"] = layer[k]
    del layer

layer = torch.load(os.path.join(args.save_dir, "layer_33-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(args.save_dir, "layer_34-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.lm_head.weight"] = layer["lm_head.weight"]

print(f"ckpt_dict: {ckpt_dict.keys()}")
model_dict.update(ckpt_dict)
model.load_state_dict(model_dict)



In [None]:
device = torch.device("cuda")
model.decoder.resize_token_embeddings(len(tokenizer))
model = model.to(torch.bfloat16).to(device)

model.eval()

# print(f"input: {text},\n output: {res}")

# # output = model.generate(
# #     input_ids=batched_data['input_ids'],
# #     num_return_sequences=10,
# #     num_beams=20,
# # )
# for i in range(10):
#     print(tokenizer.decode(output[i]))


In [None]:
output = model.decoder.generate(
    input_ids=torch.tensor(tokenizer.encode("Football is a ", return_tensors="pt")).to(device),
    num_beams=5,
    max_new_tokens=512,
    num_return_sequences=1,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.5,
)
res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

In [None]:
import torch
ckpt = torch.load("/data/peiran/blob/hai1data/sfm/pfmexp/output/stageB/global_step12386/layer_01-model_states.pt")

In [None]:
import numpy as np
data = np.load("/data/peiran/v5_train/train1.npy", mmap_mode="r")
data2 = np.load("/data/peiran/v5_train/train2.npy", mmap_mode="r")


In [None]:
print(data.shape, data2.shape)


In [None]:
# concate data and data2
data = np.concatenate([data, data2], axis=0)
print(data.shape)