In [7]:
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
import torch
import os
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm


[2024-05-07 09:15:16,872] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [8]:
def show_ckpt(name, ckpt):
    for k, v in ckpt.items():
        if 'dummy' not in k:
            print(name, k, v.shape)

def process_protein(output):
    if '</protein>' not in output:
        return None
    m = re.search(r'<protein>(.*?)</protein>', output)
    if m:
        s = m.group(1)
        s = s.replace('<a>', '')
        s = s.replace(' ', '')
        return s.strip()
    return None

In [9]:
tokenizer_home = '/hai1/mfm/ds_dataset/llama2/llama-2-7b'
tokenizer = SFMDecTokenizer.from_pretrained(
    tokenizer_home,
    prot_spm_path='/blob/shufxi/data/scigpt/ur50bpe/bpe',
    dna_spm_path='/blob/shufxi/data/scigpt/dnabpe/bpe',
    rna_spm_path='/blob/shufxi/data/scigpt/rnabpe/bpe',
)
print(len(tokenizer))
seq="<protein>MAQVAFGRILGVDNAANPRAGRPQPGSAGDAEDQILDLILGVVI</protein>"
tokens = tokenizer._tokenize(seq)
tokens = ['<s>'] + tokens + ['</s>']
print(tokens)
print(tokenizer.convert_tokens_to_ids(tokens))


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'SFMDecTokenizer'.


40014
['<protein>', '▁A', 'AB', 'B', 'CC', 'DD', '</protein>']
32000


In [10]:
llama_tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_home,
)
print(len(llama_tokenizer))

Loading checkpoint shards: 100%|██████████| 2/2 [00:29<00:00, 14.82s/it]


In [5]:
model = AutoModelForCausalLM.from_pretrained(tokenizer_home)

Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.38s/it]


In [11]:
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_200k/global_step140999/'
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_200k/global_step999/'
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_prot_e10_bs256/global_step19999/' # full finetune
ckpt_home = r"/hai1/mfm/shufxi/scigpt/7bv3/stageA_prot_e10_bs512_emb_8xG8H100/global_step5781/" # emb finetune, load llama emb

In [12]:
test_model = deepcopy(model)
model_dict = test_model.state_dict()
ckpt_dict = {}

layer0 = torch.load(os.path.join(ckpt_home, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict['model.embed_tokens.weight'] = layer0['embed_tokens.weight']#[:32000]
show_ckpt('layer0', layer0)
for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer = torch.load(os.path.join(ckpt_home, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    show_ckpt(l_index, layer)
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"model.layers.{l}.{k}"] = layer[k]

layer = torch.load(os.path.join(ckpt_home, "layer_33-model_states.pt"), map_location=torch.device("cpu"))
show_ckpt(33, layer)
ckpt_dict["model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(ckpt_home, "layer_34-model_states.pt"), map_location=torch.device("cpu"))
show_ckpt(34, layer)
ckpt_dict["lm_head.weight"] = layer["lm_head.weight"]#[:32000]
model_dict.update(ckpt_dict)

test_model.resize_token_embeddings(len(tokenizer))
test_model.load_state_dict(model_dict)
test_model = test_model.cuda()
test_model.eval()

layer0 embed_tokens.weight torch.Size([40014, 4096])
01 self_attn.q_proj.weight torch.Size([4096, 4096])
01 self_attn.k_proj.weight torch.Size([4096, 4096])
01 self_attn.v_proj.weight torch.Size([4096, 4096])
01 self_attn.o_proj.weight torch.Size([4096, 4096])
01 mlp.gate_proj.weight torch.Size([11008, 4096])
01 mlp.up_proj.weight torch.Size([11008, 4096])
01 mlp.down_proj.weight torch.Size([4096, 11008])
01 input_layernorm.weight torch.Size([4096])
01 post_attention_layernorm.weight torch.Size([4096])
02 self_attn.q_proj.weight torch.Size([4096, 4096])
02 self_attn.k_proj.weight torch.Size([4096, 4096])
02 self_attn.v_proj.weight torch.Size([4096, 4096])
02 self_attn.o_proj.weight torch.Size([4096, 4096])
02 mlp.gate_proj.weight torch.Size([11008, 4096])
02 mlp.up_proj.weight torch.Size([11008, 4096])
02 mlp.down_proj.weight torch.Size([4096, 11008])
02 input_layernorm.weight torch.Size([4096])
02 post_attention_layernorm.weight torch.Size([4096])
03 self_attn.q_proj.weight torch.Size

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(40014, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_

In [13]:
print(torch.sum(torch.abs(model.state_dict()['model.embed_tokens.weight'] - test_model.state_dict()['model.embed_tokens.weight'][:32000].cpu())))
print(torch.sum(torch.abs(model.state_dict()['model.layers.10.self_attn.k_proj.weight'] - test_model.state_dict()['model.layers.10.self_attn.k_proj.weight'].cpu())))
print(torch.sum(torch.abs(model.state_dict()['lm_head.weight'] - test_model.state_dict()['lm_head.weight'][:32000].cpu())))

tensor(0.0063)
tensor(6.8878e-05)
tensor(0.0007)


In [14]:
print(model.state_dict()['model.embed_tokens.weight'].shape)
print(model.state_dict()['lm_head.weight'].shape)
print(ckpt_dict['model.embed_tokens.weight'].shape)
print(ckpt_dict['lm_head.weight'].shape)
print(test_model.state_dict()['model.embed_tokens.weight'].shape)
print(test_model.state_dict()['lm_head.weight'].shape)

torch.Size([32000, 4096])
torch.Size([32000, 4096])
torch.Size([40014, 4096])
torch.Size([40014, 4096])
torch.Size([40014, 4096])
torch.Size([40014, 4096])


In [None]:
#test_model = model.cuda()
encodings = tokenizer(["An apple a day", "An apple a day keeps the doctor away."], padding=True, return_tensors='pt')
input_ids = encodings.input_ids.cuda()
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = -100
with torch.no_grad():  
    outputs = test_model(input_ids, labels=input_ids)
    neg_log_likelihood = outputs.loss
    perplexity = torch.exp(neg_log_likelihood)
print(perplexity.item())

In [None]:
input_ids = tokenizer('An apple a day', return_tensors="pt").input_ids.cuda()
# input_ids = tokenizer('<protein>', return_tensors="pt").input_ids.cuda()
output = test_model.generate(
    input_ids,
    num_beams=4,
    max_new_tokens=100,
    num_return_sequences=4,
    return_dict_in_generate=True,
    # output_scores=True,
    #do_sample=True,
    # repetition_penalty=1.2,
    # num_beams=5,
    # max_new_tokens=512,
    # num_return_sequences=1,
    # return_dict_in_generate=True,
    # output_scores=True,
    # do_sample=True,
    # top_p=0.95,
    # repetition_penalty=1.5,
)

#res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
#print(res)
for i in range(len(output.sequences)):
    # print(s, output.sequences_scores[i].item())
    s = tokenizer.decode(output.sequences[i])
    print(s)
    # print(s, output.sequences_scores[i].item())

In [None]:
model.eval()
model=model.cuda()

input_ids = tokenizer('An apple a day', return_tensors="pt").input_ids.cuda()
#input_ids = tokenizer('<protein>AA', return_tensors="pt").input_ids.cuda()
output = model.decoder.generate(
    input_ids,
    num_beams=4,
    max_new_tokens=100,
    num_return_sequences=4,
    return_dict_in_generate=True,
    do_sample=False,
)

res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

In [None]:
with open("/blob/renqian/data/sfm/ur90/valid.uniref90.shuf.10k", "r") as f:
    lines = [line.strip() for line in f.readlines()]

In [None]:
llama_lengths = []
sfm_lengths = []
for line in lines:
    input_ids = llama_tokenizer(line, return_tensors="pt").input_ids
    llama_lengths.append(input_ids.shape[1])
    input_ids = tokenizer(line, return_tensors="pt").input_ids
    sfm_lengths.append(input_ids.shape[1])
print(f"llama: max {max(llama_lengths)}, min {min(llama_lengths)}, avg {sum(llama_lengths) / len(llama_lengths)}")
print(f"sfm: max {max(sfm_lengths)}, min {min(sfm_lengths)}, avg {sum(sfm_lengths) / len(sfm_lengths)}")

In [21]:
input_ids = tokenizer('Generate a protein: <protein>AA', return_tensors="pt").input_ids.cuda()
output = test_model.generate(
    input_ids,
    num_beams=1,
    max_new_tokens=100,
    num_return_sequences=1,
    return_dict_in_generate=True,
    do_sample=True,
    temperature=0.7,
)

res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)

<s>Generate a protein: <protein> AAARA:


### Step 1:


### Step 2:


### Step 3:


### Step 4:


### Step 5:


### Step 6:


### Step 7:


### Step 8:


### Step 9:


### Step 10:


### Step
