In [None]:
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath("/home/shiyu/git/SFM_framework/sfm/tasks/ft_graphormer.py"))))
sys.path.append((os.path.dirname(os.path.abspath("/home/shiyu/git/SFM_framework/sfm/tasks/ft_graphormer.py"))))

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from tqdm import tqdm
import torch
from sfm.models.generalist import GraphormerLlamaModel

In [None]:
from sfm.utils import add_argument

args = add_argument.add_argument()
args.num_classes = 1
args.encoder_attention_heads = 32
args.encoder_layers = 24
args.encoder_ffn_embed_dim = 768
args.encoder_embed_dim = 768
args.droppath_prob = 0.0
args.attn_dropout = 0.1
args.act_dropout = 0.1
args.dropout = 0.0
args.weight_decay = 0.0
args.sandwich_ln = True
args.dataset_names = 'mol-instruction-mol-desc'
args.data_path = '/mnt/shiyu/dataset/chemical-copilot'
args.output_path = '/mnt/shiyu/models/converted/output'
args.pipeline_parallelism = 0
args.seed = 666667
args.ft = True
args.d_tilde = 1
args.num_pred_attn_layer = 4
args.pool_mode = 'full'
args.embedding_length = 1
args.llm_model_name_or_path = '/mnt/shiyu/models/converted/ft_100MMFM_70Bllama2_full_mix1/global_step2000'

with init_empty_weights():
    model = GraphormerLlamaModel(args, 32011)

In [None]:
from accelerate import load_checkpoint_and_dispatch

device_map = {"graphormer_encoder": 0, "decoder.model.embed_tokens": 0, "adaptor": 0}
for i in range(8):
    for j in range(i * 10, i * 10 + 10):
        device_map[f'decoder.model.layers.{j}'] = i
device_map["decoder.model.norm"] = 7
device_map["decoder.lm_head"] = 0

model = load_checkpoint_and_dispatch(
    model, "/mnt/shiyu/models/converted/ft_100MMFM_70Bllama2_full_mix1/global_step2000/", device_map=device_map, no_split_module_classes=["LlamaDecoderLayer"]
)

In [None]:
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
CHEMICAL_TOKENS = [
    "<mol>",
    "</mol>",
    "<material>",
    "</material>",
    "<protein>",
    "</protein>",
    "<dna>",
    "</dna>",
    "<rna>",
    "</rna>",
]

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "/mnt/shiyu/models/converted/llama-2-70b/",
    cache_dir=False,
    model_max_length=512,
    padding_side="right",
    use_fast=False,
)
special_tokens_dict = dict()
if tokenizer.pad_token is None:
    special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
    special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
    special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
    special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

special_tokens_dict["additional_special_tokens"] = CHEMICAL_TOKENS
tokenizer.add_special_tokens(special_tokens_dict)

In [None]:
import re
from rdkit import Chem
from rdkit.Chem.rdmolops import RemoveHs
from sfm.data.mol_data.moltext_dataset import smiles2graph_removeh

molecules = ['C1=CC=CC=C1', 'C1=CC(=C(C=C1F)F)N', 'CC(C)OP(=O)(C)OC(C)C', 
             'N[C@@H](Cc1ccc(O)c(O)c1)C(=O)O',
             'C(CC=O)CC=O', 'CC(N)O', 'C(C(Br)Br)(Br)Br', 'CC(C)(C=NOC(=O)NC)SC', 'CCOCC=C', 'CCn1cc(C(=O)O)c(=O)c2cnc(N3CCNCC3)nc21', 'O=C(O)Cc1ccccc1']
for i, molecule in enumerate(molecules):
        with open(f"results_70b_ft/mol_{i}.txt", "a") as out_file:
                out_file.write(molecule + '\n')
                out_file.write("\n\n\n\n\n\n\n===============================================================================\n\n\n\n\n\n\n")
                mol = smiles2graph_removeh(molecule)
                num_atoms = mol['x'].size()[0]
                for question in [
                        "Please give me some details about this molecule.",
                        "Is this molecule toxic and why?",
                        "Is the molecule easily soluble in water and why?",
                        "Does the molecule has good oral bioavailability and why?",
                        "Can the molecule pass the blood-brain barrier and why?",
                        "Explain whether the molecule satisfy the Lipinski's rule of five."
                ]:
                        tokenized = tokenizer(
                                "Below is an instruction that describes a question, paired with an input that provides further context. Answer the question as detailed as possible.\n\n"
                                f"### Instruction:\n{question}\n\n### Input:\n<mol>{''.join(['<unk>' for _ in range(num_atoms)])}<mol>\n\n### Response:\n",
                                return_tensors="pt",
                                padding="longest",
                                max_length=512,
                                truncation=True)
                        input_ids = tokenized.input_ids[0]
                        input_ids[input_ids == 0] = -1
                        input_ids = input_ids.to('cuda')
                        res = model.generate_with_smiles(input_ids.unsqueeze(0), do_sample=True, temperature=0.7, max_new_tokens=128, output_scores=True, return_dict_in_generate=True, smiles=[molecule])
                        seq = res.sequences[0]
                        seq[seq < 0] = 0
                        out_file.write(tokenizer.decode(seq, skip_special_tokens=False) + "\n")
                        out_file.write("\n\n\n\n\n\n\n===============================================================================\n\n\n\n\n\n\n")
                        out_file.flush()