In [None]:
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
import torch
import os
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from copy import deepcopy

In [None]:
def show_ckpt(name, ckpt):
    for k, v in ckpt.items():
        if 'dummy' not in k:
            print(name, k, v.shape)

def process_protein(output):
    if '</protein>' not in output:
        return None
    m = re.search(r'<protein>(.*?)</protein>', output)
    if m:
        s = m.group(1)
        s = s.replace('<a>', '')
        s = s.replace(' ', '')
        return s.strip()
    return None

In [None]:
tokenizer_home = '/hai1/mfm/ds_dataset/llama2/llama-2-7b'
tokenizer = SFMDecTokenizer.from_pretrained(
    tokenizer_home,
    prot_spm_path='/blob/shufxi/data/scigpt/ur50bpe/bpe',
    dna_spm_path='/blob/shufxi/data/scigpt/dnabpe/bpe',
    rna_spm_path='/blob/shufxi/data/scigpt/rnabpe/bpe',
)
print(len(tokenizer))
print(tokenizer.tokenize('<protein>AABBCCDD</protein>'))
llama_tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_home,
)
print(len(llama_tokenizer))

In [None]:
model = AutoModelForCausalLM.from_pretrained(tokenizer_home)

In [None]:
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_200k/global_step140999/'
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_200k/global_step999/'
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_prot_e10_bs256/global_step19999/' # full finetune
ckpt_home = r"/hai1/mfm/shufxi/scigpt/7bv3/stageA_prot_e10_bs512_emb_8xG8H100/global_step5781/" # emb finetune, load llama emb

In [None]:
test_model = deepcopy(model)
model_dict = test_model.state_dict()
ckpt_dict = {}

layer0 = torch.load(os.path.join(ckpt_home, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict['model.embed_tokens.weight'] = layer0['embed_tokens.weight']#[:32000]
show_ckpt('layer0', layer0)
for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer = torch.load(os.path.join(ckpt_home, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    show_ckpt(l_index, layer)
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"model.layers.{l}.{k}"] = layer[k]

layer = torch.load(os.path.join(ckpt_home, "layer_33-model_states.pt"), map_location=torch.device("cpu"))
show_ckpt(33, layer)
ckpt_dict["model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(ckpt_home, "layer_34-model_states.pt"), map_location=torch.device("cpu"))
show_ckpt(34, layer)
ckpt_dict["lm_head.weight"] = layer["lm_head.weight"]#[:32000]
model_dict.update(ckpt_dict)

test_model.resize_token_embeddings(len(tokenizer))
test_model.load_state_dict(model_dict)
test_model = test_model.cuda()
test_model.eval()

In [None]:
print(torch.sum(torch.abs(model.state_dict()['model.embed_tokens.weight'] - test_model.state_dict()['model.embed_tokens.weight'][:32000].cpu())))
print(torch.sum(torch.abs(model.state_dict()['model.layers.10.self_attn.k_proj.weight'] - test_model.state_dict()['model.layers.10.self_attn.k_proj.weight'].cpu())))
print(torch.sum(torch.abs(model.state_dict()['lm_head.weight'] - test_model.state_dict()['lm_head.weight'][:32000].cpu())))

In [None]:
print(model.state_dict()['model.embed_tokens.weight'].shape)
print(model.state_dict()['lm_head.weight'].shape)
print(ckpt_dict['model.embed_tokens.weight'].shape)
print(ckpt_dict['lm_head.weight'].shape)
print(test_model.state_dict()['model.embed_tokens.weight'].shape)
print(test_model.state_dict()['lm_head.weight'].shape)

In [None]:
#test_model = model.cuda()
encodings = tokenizer(["An apple a day", "An apple a day keeps the doctor away."], padding=True, return_tensors='pt')
input_ids = encodings.input_ids.cuda()
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = -100
with torch.no_grad():  
    outputs = test_model(input_ids, labels=input_ids)
    neg_log_likelihood = outputs.loss
    perplexity = torch.exp(neg_log_likelihood)
print(perplexity.item())

In [None]:
input_ids = tokenizer('An apple a day', return_tensors="pt").input_ids.cuda()
# input_ids = tokenizer('<protein>', return_tensors="pt").input_ids.cuda()
output = test_model.generate(
    input_ids,
    num_beams=4,
    max_new_tokens=100,
    num_return_sequences=4,
    return_dict_in_generate=True,
    # output_scores=True,
    #do_sample=True,
    # repetition_penalty=1.2,
    # num_beams=5,
    # max_new_tokens=512,
    # num_return_sequences=1,
    # return_dict_in_generate=True,
    # output_scores=True,
    # do_sample=True,
    # top_p=0.95,
    # repetition_penalty=1.5,
)

#res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
#print(res)
for i in range(len(output.sequences)):
    # print(s, output.sequences_scores[i].item())
    s = tokenizer.decode(output.sequences[i])
    print(s)
    # print(s, output.sequences_scores[i].item())

In [None]:
from sfm.data.sci_data.dataset import ProcessedSciDataset
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
from sfm.logging import logger
from sfm.models.scigpt.config import (
    ScigptConfig,
    scigpt_7b_config,
    scigpt_13b_config,
    scigpt_350m_config,
    scigpt_shallow_config,
    scigpt_tiny_config,
)
from sfm.models.scigpt.scigpt import ScigptModel
from sfm.pipeline.accelerator.trainer import Trainer
from sfm.utils import arg_utils
from sfm.utils.cli_utils import cli

config_registry = {
    "scigpt_tiny": scigpt_tiny_config,
    "scigpt_shallow": scigpt_shallow_config,
    "scigpt_350m": scigpt_350m_config,
    "scigpt": scigpt_shallow_config,
    "scigpt_7b": scigpt_7b_config,
    "scigpt_13b": scigpt_13b_config,
}


from argparse import ArgumentParser
args = ArgumentParser()
args.model_type="scigpt_7b_config"
args.vocab_size=40014
args.pad_token_id=32000
args.max_position_embeddings=4096
args.bf16=True
args.strategy="Pipeline"
args.pipeline_model_parallel_size=1
args.pp_partition_layer_name="LlamaDecoderLayerPP"
args.load_ckpt=True
args.pretrained_ckpt_path="/hai1/mfm/ds_dataset/llama2/llama-2-7b/"
args.unfreeze_param_list="lm_head.weight,embed_tokens.weight"
args.learnable_cutoff=32000
args.infer=True
args.llm_model_name_or_path="/hai1/mfm/ds_dataset/llama2/llama-2-7b/"



config = arg_utils.from_args(args, ScigptConfig)
config = config_registry.get(config.model_type, scigpt_tiny_config)(config)
config.llm_model_name_or_path="/hai1/mfm/ds_dataset/llama2/llama-2-7b/"

model = ScigptModel(config)

In [None]:
tokenizer_home = '/hai1/mfm/ds_dataset/llama2/llama-2-7b'
tokenizer = SFMDecTokenizer.from_pretrained(
    tokenizer_home,
    prot_spm_path='/blob/shufxi/data/scigpt/ur50bpe/bpe',
    dna_spm_path='/blob/shufxi/data/scigpt/dnabpe/bpe',
    rna_spm_path='/blob/shufxi/data/scigpt/rnabpe/bpe',
)
print(len(tokenizer))
print(tokenizer.tokenize('<protein>AABBCCDD</protein>'))
print(tokenizer.tokenize('An apple  a day\nYes\nThis is\n\nan applenade\n'))

In [None]:
model.eval()
model=model.cuda()

input_ids = tokenizer('An apple a day', return_tensors="pt").input_ids.cuda()
#input_ids = tokenizer('<protein>AA', return_tensors="pt").input_ids.cuda()
output = model.decoder.generate(
    input_ids,
    num_beams=4,
    max_new_tokens=100,
    num_return_sequences=4,
    return_dict_in_generate=True,
    do_sample=False,
)

res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

In [None]:
with open("/blob/renqian/data/sfm/ur90/valid.uniref90.shuf.10k", "r") as f:
    lines = [line.strip() for line in f.readlines()]

In [None]:
llama_lengths = []
sfm_lengths = []
for line in lines:
    input_ids = llama_tokenizer(line, return_tensors="pt").input_ids
    llama_lengths.append(input_ids.shape[1])
    input_ids = tokenizer(line, return_tensors="pt").input_ids
    sfm_lengths.append(input_ids.shape[1])
print(f"llama: max {max(llama_lengths)}, min {min(llama_lengths)}, avg {sum(llama_lengths) / len(llama_lengths)}")
print(f"sfm: max {max(sfm_lengths)}, min {min(sfm_lengths)}, avg {sum(sfm_lengths) / len(sfm_lengths)}")