In [10]:
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
import torch
import os
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from copy import deepcopy

In [12]:
def show_ckpt(name, ckpt):
    for k, v in ckpt.items():
        if 'dummy' not in k:
            print(name, k, v.shape)

def process_protein(output):
    if '</protein>' not in output:
        return None
    m = re.search(r'<protein>(.*?)</protein>', output)
    if m:
        s = m.group(1)
        s = s.replace('<a>', '')
        s = s.replace(' ', '')
        return s.strip()
    return None

In [14]:
tokenizer_home = '/hai1/mfm/ds_dataset/llama2/llama-2-7b'
tokenizer = SFMDecTokenizer.from_pretrained(
    tokenizer_home,
    prot_spm_path='/blob/shufxi/data/scigpt/ur50bpe/bpe',
    dna_spm_path='/blob/shufxi/data/scigpt/dnabpe/bpe',
    rna_spm_path='/blob/shufxi/data/scigpt/rnabpe/bpe',
)
print(len(tokenizer))
print(tokenizer.tokenize('<protein>AABBCCDD</protein>'))
llama_tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_home,
)
print(len(llama_tokenizer))

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'SFMDecTokenizer'.


40014
['<protein>', '▁A', 'AB', 'B', 'CC', 'DD', '</protein>']
32000


In [4]:
model = AutoModelForCausalLM.from_pretrained(tokenizer_home)

Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.36s/it]


In [15]:
# ckpt_home = '/hai1/shufxi/scigpt/7bv2/stageB/global_step26999/'
# ckpt_home = r'/home/yinxia/hai1/shufxi/scigpt/7bv3/stageB/global_step32999'
# ckpt_home = r'/home/yinxia/blob1.v2/shufxi/scigpt/7bv3/inst/20240227121523/global_step3585'
# ckpt_home = r'/home/yinxia//blob1.v2/shufxi/scigpt/7bv3/inst/20240301131100/global_step4753'
# ckpt_home = r'/home/yinxia/hai1/shufxi/scigpt/7bv3/stageB.prot/global_step224655'

# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_200k/global_step140999/'
# ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_200k/global_step999/'
ckpt_home = r'/hai1/mfm/shufxi/scigpt/7bv3/stageA_prot_e10_bs256_emb/global_step29999/'

In [16]:
test_model = deepcopy(model)
model_dict = test_model.state_dict()
ckpt_dict = {}

layer0 = torch.load(os.path.join(ckpt_home, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict['model.embed_tokens.weight'] = layer0['embed_tokens.weight'][:32000]
show_ckpt('layer0', layer0)
for l in range(0, 32):
    l_index = str(l + 1).zfill(2)
    layer = torch.load(os.path.join(ckpt_home, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    show_ckpt(l_index, layer)
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"model.layers.{l}.{k}"] = layer[k]

layer = torch.load(os.path.join(ckpt_home, "layer_33-model_states.pt"), map_location=torch.device("cpu"))
show_ckpt(33, layer)
ckpt_dict["model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(ckpt_home, "layer_34-model_states.pt"), map_location=torch.device("cpu"))
show_ckpt(34, layer)
ckpt_dict["lm_head.weight"] = layer["lm_head.weight"][:32000]
model_dict.update(ckpt_dict)

#test_model.resize_token_embeddings(len(tokenizer))
test_model.load_state_dict(model_dict)
test_model = test_model.cuda()
test_model.eval()

layer0 embed_tokens.weight torch.Size([40014, 4096])
01 self_attn.q_proj.weight torch.Size([4096, 4096])
01 self_attn.k_proj.weight torch.Size([4096, 4096])
01 self_attn.v_proj.weight torch.Size([4096, 4096])
01 self_attn.o_proj.weight torch.Size([4096, 4096])
01 mlp.gate_proj.weight torch.Size([11008, 4096])
01 mlp.up_proj.weight torch.Size([11008, 4096])
01 mlp.down_proj.weight torch.Size([4096, 11008])
01 input_layernorm.weight torch.Size([4096])
01 post_attention_layernorm.weight torch.Size([4096])
02 self_attn.q_proj.weight torch.Size([4096, 4096])
02 self_attn.k_proj.weight torch.Size([4096, 4096])
02 self_attn.v_proj.weight torch.Size([4096, 4096])
02 self_attn.o_proj.weight torch.Size([4096, 4096])
02 mlp.gate_proj.weight torch.Size([11008, 4096])
02 mlp.up_proj.weight torch.Size([11008, 4096])
02 mlp.down_proj.weight torch.Size([4096, 11008])
02 input_layernorm.weight torch.Size([4096])
02 post_attention_layernorm.weight torch.Size([4096])
03 self_attn.q_proj.weight torch.Size

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [18]:
print(torch.sum(torch.abs(model.state_dict()['model.embed_tokens.weight'] - test_model.state_dict()['model.embed_tokens.weight'][:32000].cpu())))
print(torch.sum(torch.abs(model.state_dict()['model.layers.10.self_attn.k_proj.weight'] - test_model.state_dict()['model.layers.10.self_attn.k_proj.weight'].cpu())))
print(torch.sum(torch.abs(model.state_dict()['lm_head.weight'] - test_model.state_dict()['lm_head.weight'][:32000].cpu())))

tensor(0.0063)
tensor(6.8878e-05)
tensor(0.0007)


In [40]:
model_dict.update(ckpt_dict)
print(model_dict['model.embed_tokens.weight'].shape, model_dict['lm_head.weight'].shape)
emb=model.state_dict()['model.embed_tokens.weight']
lm_head=model.state_dict()['lm_head.weight']
print(emb.shape, lm_head.shape)
model_dict['model.embed_tokens.weight'][:emb.shape[0]] = emb[:]
#model_dict['lm_head.weight'][:lm_head.shape[0]] = lm_head[:]
# model_dict['model.embed_tokens.weight'] = emb[:]
# model_dict['lm_head.weight'] = lm_head[:]

#test_model.resize_token_embeddings(emb.shape[0])
#test_model.resize_token_embeddings(len(tokenizer))
test_model.load_state_dict(model_dict)

print(model_dict['model.embed_tokens.weight'].shape, model_dict['lm_head.weight'].shape)

torch.Size([40014, 4096]) torch.Size([40014, 4096])
torch.Size([32000, 4096]) torch.Size([32000, 4096])


torch.Size([40014, 4096]) torch.Size([40014, 4096])


In [20]:
print(model.state_dict()['model.embed_tokens.weight'].shape)
print(model.state_dict()['lm_head.weight'].shape)
print(ckpt_dict['model.embed_tokens.weight'].shape)
print(ckpt_dict['lm_head.weight'].shape)
print(test_model.state_dict()['model.embed_tokens.weight'].shape)
print(test_model.state_dict()['lm_head.weight'].shape)

torch.Size([32000, 4096])
torch.Size([32000, 4096])
torch.Size([32000, 4096])
torch.Size([32000, 4096])
torch.Size([32000, 4096])
torch.Size([32000, 4096])


In [21]:
#input_ids = llama_tokenizer('An apple a day', return_tensors="pt").input_ids.cuda()
input_ids = tokenizer('An apple a day', return_tensors="pt").input_ids.cuda()
#input_ids = tokenizer('<protein>AA', return_tensors="pt").input_ids.cuda()
output = test_model.generate(
    input_ids,
    num_beams=4,
    max_new_tokens=100,
    num_return_sequences=4,
    return_dict_in_generate=True,
    # output_scores=True,
    do_sample=False,
    # repetition_penalty=1.2,
    # num_beams=5,
    # max_new_tokens=512,
    # num_return_sequences=1,
    # return_dict_in_generate=True,
    # output_scores=True,
    # do_sample=True,
    # top_p=0.95,
    # repetition_penalty=1.5,
)

res = tokenizer.decode(output.sequences[0], skip_special_tokens=False)
print(res)
# for i in range(len(output)):
#     s = llama_tokenizer.decode(output[i])
#     print(s)
#     # print(s, output.sequences_scores[i].item())
#     # s = tokenizer.decode(output.sequences[i])
#     # print(s, output.sequences_scores[i].item())

<s>An apple a day keeps the doctor away, right? Well, that’s what we’ve always been told, but is it really true?
Apples are one of the most popular fruits in the world, and for good reason. Not only are they delicious, but they’re also packed with nutrients that can benefit your health in a variety of ways. In this blog post, we’ll take a closer look at the health benefits of apples, and why you should consider


Bad pipe message: %s [b'\xee\\\x856[K\xe3D\x01%L\x03\xf9w\xa7\xaf\xff\x80 x2\xed\x0c4\xa849\xcf\xfb\x8b\x05\x84\xe6\x88\xa1j\xe7\xbf\x8f\xe7\xad\xae\xa6\xb0\xed\xb8\x8d\xca\xe4\x1e\x0f\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r']
Bad pipe message: %s [b'?\xe1E\x85\x1b2\x19)\x91\xb3\x82\xea\x9b\xe4S\xc2\x81\x83 \xcb\x8ev\x88N@4!\xefhe\xc2\xd6\x8e\xcf\xb1\x1e\x90JMtt\xb6v\xa2\xe4a\xb5\x0e+\x9d\x13\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x0

In [33]:
print(type(llama_tokenizer), type(tokenizer))
print(llama_tokenizer.tokenize('An apple  a day\nYes\nThis is\nan apple\n'))
print(tokenizer.tokenize('An apple  a day\nYes\nThis is\n\nan apple\n'))

<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'> <class 'sfm.data.sci_data.SFMDecTokenizer.SFMDecTokenizer'>
['▁An', '▁apple', '▁', '▁a', '▁day', '<0x0A>', 'Yes', '<0x0A>', 'This', '▁is', '<0x0A>', 'an', '▁apple', '<0x0A>']
['▁An', '▁apple', '▁', '▁a', '▁day', '<0x0A>', 'Yes', '<0x0A>', 'This', '▁is', '<0x0A>', '<0x0A>', 'an', '▁apple']
