In [None]:
import sys
import os
from argparse import ArgumentParser

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from tqdm import tqdm
import torch
from sfm.models.progpt.progpt import ProGPTModel
from sfm.models.progpt.progpt_config import ProGPTConfig
from sfm.models.pfm.pfm_config import PFMConfig
from sfm.data.sci_data.SFMDecTokenizer import SFMDecTokenizer
from sfm.utils import arg_utils


In [None]:
def get_args_and_tokenizer():
    parser = ArgumentParser()
    cfg_classes = [PFMConfig, ProGPTConfig]
    parser = arg_utils.add_dataclass_to_parser(cfg_classes, parser)
    args = parser.parse_args(args=[])
    args.llm_model_name_or_path = "/fastdata/peiran/scigpt/ckpt/stageB.prot/global_step224655"
    args.tokenizer_path = "/fastdata/peiran/scigpt"
    args.save_dir = '/fastdata/peiran/nlm/checkpoints/stageB/global_step12386/'
    args.load_ckpt = False
    args.strategy = "DDP"
    args.encoder_layers = 33
    args.encoder_embed_dim = 1280
    args.encoder_ffn_embed_dim = 5120
    args.encoder_attention_heads = 20

    tokenizer = SFMDecTokenizer.from_pretrained(
        args.llm_model_name_or_path,
        prot_spm_path=os.path.join(args.tokenizer_path, "ur50bpe/bpe"),
        dna_spm_path=os.path.join(args.tokenizer_path, "dnabpe/bpe"),
        rna_spm_path=os.path.join(args.tokenizer_path, "rnabpe/bpe"),
    )
    args.vocab_size = len(tokenizer)  # now we have new tokens
    args.pad_token_id = tokenizer.pad_token_id

    return args, tokenizer

In [None]:
args, tokenizer = get_args_and_tokenizer()

# with init_empty_weights():
model = ProGPTModel(args, len(tokenizer))
# print(model.state_dict().keys())


In [None]:
ckpt_dict = {}

model_dict = model.state_dict()

layer0 = torch.load(os.path.join(args.save_dir, "layer_00-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer0.items():
    new_k = "pfm_encoder." + k
    ckpt_dict[new_k] = v

layer1 = torch.load(os.path.join(args.save_dir, "layer_01-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict['decoder.model.embed_tokens.weight'] = layer1['embed_tokens.weight']

layer2 = torch.load(os.path.join(args.save_dir, "layer_02-model_states.pt"), map_location=torch.device("cpu"))
for k, v in layer2.items():
    ckpt_dict[k] = v

for l in range(0, 32):
    l_index = str(l + 3).zfill(2)
    layer = torch.load(os.path.join(args.save_dir, f"layer_{l_index}-model_states.pt"), map_location=torch.device("cpu"))
    for k in layer:
        if "dummy" in k or 'rotary_emb' in k:
            continue
        ckpt_dict[f"decoder.model.layers.{l}.{k}"] = layer[k]
    del layer

layer = torch.load(os.path.join(args.save_dir, "layer_35-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.model.norm.weight"] = layer["norm.weight"]

layer = torch.load(os.path.join(args.save_dir, "layer_36-model_states.pt"), map_location=torch.device("cpu"))
ckpt_dict["decoder.lm_head.weight"] = layer["lm_head.weight"]

# model_dict.update(ckpt_dict)
model.load_state_dict(model_dict)


In [None]:
scigpt_vacab = {'L': 33874, 'A': 33875, 'G': 33878, 'V': 33877, 'S': 33876, 'E': 33879, 'R': 33880, 'T': 33881, 'I': 33882, 'D': 33884, 'P': 33886, 'K': 33883, 'Q': 33885, 'N': 33887, 'F': 33888, 'Y': 33890, 'M': 33873, 'H': 33889, 'W': 33891, 'C': 33892, 'X': 34276, 'B': 37965, 'U': 37967, 'Z': 37966, 'O': 0}

vocab = {'<cls>': 0, '<pad>': 1, '<eos>': 2, '<unk>': 3, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, '.': 29, '-': 30, '<mask>': 31}

def protein_process(protein):
    protein_id = [vocab[tok] for tok in protein]
    protein_bpe_id = [scigpt_vacab[tok] for tok in protein]

    return protein_id, protein_bpe_id

def process(text):
    # find the part of protein seq that surrounded by <protein> and </protein> in text
    protein = []
    res = []
    text1 = text.split("<protein>")
    res.append(text1[0])
    for i in range(1, len(text1)):
        text2 = text1[i].split("</protein>")
        protein.append(text2[0])
        res.append(text2[1])

    return res, protein

def tokenize(text):
    # split text with <protein> and </protein>
    text_list, protein = process(text)
    protein_id_list = []
    protein_bpe_id_list = []

    if len(protein) == 0:
        return tokenizer.encode(text), protein_id_list, protein_bpe_id_list
    else:
        for p in protein:
            protein_id, protein_bpe_id = protein_process(p)
            protein_id_list.append([0] + protein_id + [2])
            protein_bpe_id_list.append(protein_bpe_id)

    input_ids = []
    for i in range(len(text_list)):
        if i == 0:
            input_ids.extend(tokenizer.encode(text_list[i] + " <protein>"))
        elif i != len(text_list) - 1:
            input_ids.append(-1)
            input_ids.extend(tokenizer.encode("</protein> " + text_list[i] + " <protein>")[1:])
        else:
            input_ids.append(-1)
            input_ids.extend(tokenizer.encode("</protein> " + text_list[i])[1:])

    return input_ids, protein_id_list, protein_bpe_id_list


def collator(input_ids, protein_id_list, protein_bpe_id_list, device):
    input_ids = torch.tensor(input_ids, dtype=torch.int64)
    for i in range(len(protein_bpe_id_list)):
        protein_bpe_id_list[i] = torch.tensor(protein_bpe_id_list[i], dtype=torch.int64)

    new_input_ids = []
    mol_pos = torch.nonzero(input_ids < 0).squeeze(-1)
    mol_pos = torch.cat(
        [torch.tensor([0]), mol_pos, torch.tensor([len(input_ids)])]
    )

    for i in range(mol_pos.size(0) - 1):
        if i == 0:
            new_input_ids.extend(input_ids[mol_pos[i] : mol_pos[i + 1]])
        else:
            new_input_ids.extend(input_ids[mol_pos[i] + 1 : mol_pos[i + 1]])

        if i < len(mol_pos) - 2:
            len_protein = len(protein_id_list[i])
            mol_idx = input_ids[mol_pos[i + 1]]
            if len_protein > 1:
                new_input_ids.extend(torch.ones([len_protein]) * mol_idx)


    input_ids = torch.tensor(new_input_ids).to(dtype=torch.int64)
    if len(protein_id_list) == 0:
        protein = torch.tensor([0, 2]).to(dtype=torch.int64).unsqueeze(0)
    else:
        protein = torch.tensor(protein_id_list[0]).to(dtype=torch.int64).unsqueeze(0)

    return dict(
        input_ids=input_ids.to(device),
        proteins=protein.long().to(device),
        llm_mask=input_ids.ne(tokenizer.pad_token_id).to(device),
    )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
text = "Describe the <protein>AAAGSGAGU</protein> ."
input_ids, protein_id_list, protein_bpe_id_list = tokenize(text)
batched_data = collator(input_ids, protein_id_list, protein_bpe_id_list, device)
# residue_seq = batched_data["proteins"]
# print(batched_data["input_ids"])

model = model.to(device)
model.eval()
model.generate(batched_data)
