In [1]:
import sys
import os
from argparse import ArgumentParser

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch
from sfm.models.progpt.progpt import ProGPTModel
from sfm.models.progpt.progpt_config import ProGPTConfig
from sfm.models.pfm.pfm_config import PFMConfig
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
from sfm.utils import arg_utils
from sfm.utils.science_tokens import SCIENCE_TAG_TOKENS



  from .autonotebook import tqdm as notebook_tqdm


[2024-04-20 07:29:51,771] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[[32m2024-04-20 07:29:54.037[0m][[36mINFO[0m]: flash_attn not installed, use default attn


In [2]:
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def get_args_and_tokenizer(use_llama=False):
    parser = ArgumentParser()
    cfg_classes = [PFMConfig, ProGPTConfig]
    parser = arg_utils.add_dataclass_to_parser(cfg_classes, parser)
    args = parser.parse_args(args=[])
    args.load_ckpt = False
    args.strategy = "DDP"
    args.encoder_layers = 33
    args.encoder_embed_dim = 1280
    args.encoder_ffn_embed_dim = 5120
    args.encoder_attention_heads = 20
    # args.fp16 = True
    
    mount_dir = "/data/peiran/blob/msralaphilly2/ml-la"
    if not use_llama:
        args.llm_model_name_or_path = mount_dir+"/v-kehanwu/SFM/scigpt/stageB.prot/global_step224655"
        args.tokenizer_path = mount_dir+"/shufxi/data/scigpt"
        args.save_dir = mount_dir+'/v-kehanwu/nlm/checkpoints/bfm_scigpt_prot/global_step11499'
        # args.save_dir = "/fastdata/peiran/nlm/checkpoints/stageB.prot/global_step1"

        tokenizer = SFMDecTokenizer.from_pretrained(
            args.llm_model_name_or_path,
            prot_spm_path=os.path.join(args.tokenizer_path, "ur50bpe/bpe"),
            dna_spm_path=os.path.join(args.tokenizer_path, "dnabpe/bpe"),
            rna_spm_path=os.path.join(args.tokenizer_path, "rnabpe/bpe"),
        )
        args.vocab_size = len(tokenizer)  # now we have new tokens
        args.pad_token_id = tokenizer.pad_token_id
    else:
        args.llm_model_name_or_path = mount_dir+"/v-kehanwu/SFM/scigpt/stageB.prot/global_step224655"
        args.tokenizer_path = mount_dir+"/shufxi/data/scigpt"
        args.save_dir = mount_dir+'/v-kehanwu/nlm/checkpoints/bfm_llama/global_step11499'
        # args.save_dir = "/fastdata/peiran/nlm/checkpoints/stageB.prot/global_step1"

        tokenizer = AutoTokenizer.from_pretrained(
            args.llm_model_name_or_path,
            model_max_length=args.model_max_length,
            padding_side="right",
            use_fast=False,
        )

        special_tokens_dict = dict()
        if tokenizer.pad_token is None:
            special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
        if tokenizer.eos_token is None:
            special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
        if tokenizer.bos_token is None:
            special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
        if tokenizer.unk_token is None:
            special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

        special_tokens_dict["additional_special_tokens"] = SCIENCE_TAG_TOKENS
        tokenizer.add_special_tokens(special_tokens_dict)
        

    return args, tokenizer

args, tokenizer = get_args_and_tokenizer(use_llama=True)

[[32m2024-04-20 07:29:54.601[0m][[36mINFO[0m]: Trainer args: Namespace(num_classes=1, encoder_attention_heads=20, encoder_ffn_embed_dim=5120, encoder_embed_dim=1280, encoder_layers=33, num_pred_attn_layer=4, num_3d_bias_kernel=128, max_length=1024, pbc_expanded_token_cutoff=512, pbc_expanded_num_cell_per_direction=10, multi_hop_max_dist=20, droppath_prob=0.0, act_dropout=0.0, attn_dropout=0.0, dropout=0.0, sandwich_ln=True, noise_scale=0.2, mask_ratio=0.5, d_tilde=1.0, pbc_cutoff=40.0, data_path='', dataset_names='', loadcheck_path='', add_3d=False, no_2d=False, ft=False, infer=False, use_pbc=False, transformer_m_pretrain=True, mode_prob='0.6,0.2,0.2', num_timesteps=1000, ddpm_beta_start=0.0001, ddpm_beta_end=0.02, ddpm_schedule='linear', noise_mode='const', num_edges=1536, num_atom_features=5120, task_name='', data_basepath='', output_dim=1024, add_rope=True, flash_attn=False, stack_seq=False, num_residues=32, max_num_aa=1024, task='mae', mask_prob=0.15, train_data_path='', valid_

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [3]:
ckpt_dict = {}

model = ProGPTModel(args, len(tokenizer))
model_dict = model.state_dict()

layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer0.items():
    new_k = "pfm_encoder." + k
    ckpt_dict[new_k] = v

layer1 = torch.load(os.path.join(args.save_dir, "layer_01-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict['decoder.model.embed_tokens.weight'] = layer1['embed_tokens.weight']

layer2 = torch.load(os.path.join(args.save_dir, "layer_02-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer2.items():
    new_k = "adaptor." + k
    ckpt_dict[new_k] = v

for l in range(0, 32):
    l_index = str(l + 3).zfill(2)
    layer = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"decoder.model.layers.{l}.{k}"] = layer[k]
    del layer

layer = torch.load(os.path.join(args.save_dir, "layer_35-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(args.save_dir, "layer_36-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.lm_head.weight"] = layer["lm_head.weight"]

model_dict.update(ckpt_dict)
model.decoder.resize_token_embeddings(len(tokenizer))
model.load_state_dict(model_dict)


<All keys matched successfully>

In [4]:
scigpt_vacab = {'L': 33874, 'A': 33875, 'G': 33878, 'V': 33877, 'S': 33876, 'E': 33879, 'R': 33880, 'T': 33881, 'I': 33882, 'D': 33884, 'P': 33886, 'K': 33883, 'Q': 33885, 'N': 33887, 'F': 33888, 'Y': 33890, 'M': 33873, 'H': 33889, 'W': 33891, 'C': 33892, 'X': 34276, 'B': 37965, 'U': 37967, 'Z': 37966, 'O': 0}

vocab = {'<cls>': 0, '<pad>': 1, '<eos>': 2, '<unk>': 3, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, '.': 29, '-': 30, '<mask>': 31}

def protein_process(protein):
    protein_id = [vocab[tok] for tok in protein]
    protein_bpe_id = [scigpt_vacab[tok] for tok in protein]

    return protein_id, protein_bpe_id

def process(text):
    # find the part of protein seq that surrounded by <protein> and </protein> in text
    protein = []
    res = []
    text1 = text.split("<protein>")
    res.append(text1[0])
    for i in range(1, len(text1)):
        text2 = text1[i].split("</protein>")
        protein.append(text2[0])
        res.append(text2[1])

    return res, protein

def tokenize(text):
    # split text with <protein> and </protein>
    text_list, protein = process(text)
    protein_id_list = []
    protein_bpe_id_list = []

    if len(protein) == 0:
        return tokenizer.encode(text), protein_id_list, protein_bpe_id_list
    else:
        for p in protein:
            protein_id, protein_bpe_id = protein_process(p)
            protein_id_list.append(protein_id)
            protein_bpe_id_list.append(protein_bpe_id)

    input_ids = []
    for i in range(len(text_list)):
        if i == 0:
            input_ids.extend(tokenizer.encode(text_list[i] + " <protein>"))
        elif i != len(text_list) - 1:
            input_ids.append(-1)
            input_ids.extend(tokenizer.encode("</protein> " + text_list[i] + " <protein>")[1:])
        else:
            input_ids.append(-1)
            input_ids.extend(tokenizer.encode("</protein> " + text_list[i])[1:])

    return input_ids, protein_id_list, protein_bpe_id_list


def collator(input_ids, protein_id_list, protein_bpe_id_list, device):
    input_ids = torch.tensor(input_ids, dtype=torch.int64)
    for i in range(len(protein_bpe_id_list)):
        protein_bpe_id_list[i] = torch.tensor(protein_bpe_id_list[i], dtype=torch.int64)

    new_input_ids = []
    original_input_ids_len = len(input_ids)
    input_ids_len = len(input_ids)
    mol_pos = torch.nonzero(input_ids < 0).squeeze(-1)
    mol_pos = torch.cat(
        [torch.tensor([0]), mol_pos, torch.tensor([len(input_ids)])]
    )

    for i in range(mol_pos.size(0) - 1):
        if i == 0:
            new_input_ids.extend(input_ids[mol_pos[i] : mol_pos[i + 1]])
        else:
            new_input_ids.extend(input_ids[mol_pos[i] + 1 : mol_pos[i + 1]])

        if i < len(mol_pos) - 2:
            len_protein = len(protein_id_list[i])
            mol_idx = input_ids[mol_pos[i + 1]]
            if len_protein > 1:
                new_input_ids.extend(torch.ones([len_protein]) * mol_idx)
            if mol_pos[i + 1] < original_input_ids_len:
                input_ids_len += len_protein - 1


    input_ids = torch.tensor(new_input_ids).to(dtype=torch.int64).unsqueeze(0)
    if len(protein_id_list) == 0:
        protein = torch.tensor([0, 2]).to(dtype=torch.int64).unsqueeze(0)
    else:
        protein = torch.tensor(protein_id_list[0]).to(dtype=torch.int64)
        protein = torch.cat([torch.tensor([0]).to(dtype=torch.int64), protein, torch.tensor([2]).to(dtype=torch.int64)]).unsqueeze(0)

    return dict(
        input_ids=input_ids.to(device),
        proteins=protein.long().to(device),
        llm_mask=input_ids.ne(tokenizer.pad_token_id).to(device),
    )

In [5]:
import lmdb
from sfm.data.prot_data.util import bstr2obj
import pickle as pkl
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
# device = torch.device("cpu")

data_path = '/fastdata/peiran/nlm/progpt_valid_bpe.lmdb/'
env = lmdb.open(
    str(data_path), subdir=True, readonly=True, lock=False, readahead=False
)
txn = env.begin(write=False)
metadata = bstr2obj(txn.get("metadata".encode()))
size, keys = metadata["size"], metadata["keys"]


In [6]:
key = keys[-1]
value = txn.get(str(key).encode())
input_ids, proteins, proteins_bpeid = pkl.loads(value)

batched_data = collator(input_ids, proteins, proteins_bpeid, device)

print(f"input_ids: {batched_data['input_ids'].shape}")
# batched_data['input_ids']

input_ids: torch.Size([1, 364])


device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
text = "Describe the <protein>AAAGSGAGU</protein> ."
# text = "Hello, what to eat tonight?"

input_ids, protein_id_list, protein_bpe_id_list = tokenize(text)
input_ids = tokenizer.encode(text)
batched_data = collator(input_ids, protein_id_list, protein_bpe_id_list, device)

batched_data

In [7]:
model = model.to(torch.float16).to(device)
print(f"shape of input_ids: {batched_data['input_ids'].shape}")
model.eval()
output = model.generate(
    batched_data,
    num_beams=4,
    max_new_tokens=300,
    num_return_sequences=1,
    return_dict_in_generate=True,
    output_scores=True,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.5
)
res = tokenizer.decode(output.sequences[0])
print(res)

# # output = model.generate(
# #     input_ids=batched_data['input_ids'],
# #     num_return_sequences=10,
# #     num_beams=20,
# # )
# for i in range(10):
#     print(tokenizer.decode(output[i]))


shape of input_ids: torch.Size([1, 364])
torch.Size([220, 1, 1280]) torch.Size([1, 220]) torch.Size([1, 364, 4096]) torch.Size([1, 364]) torch.Size([1, 364])
torch.Size([1, 364, 4096])
The similarity of this protein : Belongs to the muscleblind family. The sequence caution of this protein : Sequence=AACAAGAAGAAGAAGGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAA
