### "https://go.drugbank.com/structures/small_molecule_drugs/{drugbank_id}.smiles"

In [1]:
import pdb
import os
import json
import pickle
import torch
import pandas as pd

from src.model import BindingModel
from src.inference import BridgeInference
output_dir = "./data/generation_data"

In [2]:
df = pd.read_csv("./data/DrugBank/drugbank.csv")

  df = pd.read_csv("./data/DrugBank/drugbank.csv")


In [3]:
mol_df = df[df["type"] == "SmallMoleculeDrug"]
df_mol = pd.read_csv("/home/ec2-user/data/Processed/drug.csv")
mol_df = mol_df[~mol_df["moldb_smiles"].isin(df_mol["smiles"])]

In [4]:
mol_df = mol_df.drop_duplicates(subset=["moldb_smiles"]).reset_index(drop=True)

In [5]:
{
    "node_type": {
        "biological_process": 0,
        "gene/protein": 1,
        "disease": 2,
        "effect/phenotype": 3,
        "anatomy": 4,
        "molecular_function": 5,
        "drug": 6,
        "cellular_component": 7,
        "pathway": 8,
        "exposure": 9
    },
    "relation_type": {
        "expression present": 0,
        "synergistic interaction": 1,
        "interacts with": 2,
        "ppi": 3,
        "phenotype present": 4,
        "parent-child": 5,
        "associated with": 6,
        "side effect": 7,
        "contraindication": 8,
        "expression absent": 9,
        "target": 10,
        "indication": 11,
        "enzyme": 12,
        "transporter": 13,
        "off-label use": 14,
        "linked to": 15,
        "phenotype absent": 16,
        "carrier": 17
    },
    "emb_dim": {
        "molecular_function": 768,
        "gene/protein": 2560,
        "disease": 768,
        "cellular_component": 768,
        "drug": 512,
        "biological_process": 768
    }
}

{'node_type': {'biological_process': 0,
  'gene/protein': 1,
  'disease': 2,
  'effect/phenotype': 3,
  'anatomy': 4,
  'molecular_function': 5,
  'drug': 6,
  'cellular_component': 7,
  'pathway': 8,
  'exposure': 9},
 'relation_type': {'expression present': 0,
  'synergistic interaction': 1,
  'interacts with': 2,
  'ppi': 3,
  'phenotype present': 4,
  'parent-child': 5,
  'associated with': 6,
  'side effect': 7,
  'contraindication': 8,
  'expression absent': 9,
  'target': 10,
  'indication': 11,
  'enzyme': 12,
  'transporter': 13,
  'off-label use': 14,
  'linked to': 15,
  'phenotype absent': 16,
  'carrier': 17},
 'emb_dim': {'molecular_function': 768,
  'gene/protein': 2560,
  'disease': 768,
  'cellular_component': 768,
  'drug': 512,
  'biological_process': 768}}

## Mol smiles input for Q&A

In [6]:
# encode it using a protein encoder
from src.drug_encoder import load_molecule_model, inference
mol_model, tokenizer = load_molecule_model()

2023-08-30 23:39:22 | unimol_tools/models/unimol.py | 114 | INFO | Uni-Mol(QSAR) | Loading pretrained weights from /home/ec2-user/miniconda3/envs/py39/lib/python3.9/site-packages/unimol_tools-1.0.0-py3.9.egg/unimol_tools/weights/mol_pre_all_h_220816.pt


In [7]:
# load biobridge model
checkpoint_dir = "./checkpoints/bind-openke-benchmark-6-layer-unimol"
with open(os.path.join(checkpoint_dir, "model_config.json"), "r") as f:
    model_config = json.load(f)
model = BindingModel(**model_config)
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, "pytorch_model.bin")))
model = BridgeInference(model)

In [8]:
# load pre-encoded disease embeddings and project for retrieval
with open("./data/embeddings/esm2b_unimo_pubmedbert/disease.pkl", "rb") as f:
    dis_raw = pickle.load(f)
dis_emb = torch.tensor(dis_raw["embedding"], dtype=torch.float32)
dis_emb = model.project(
    x = dis_emb,
    src_type = 2,
)
dis_raw.keys()
dis_idx = torch.tensor(dis_raw["node_index"])
# load disease node index to its descriptions
df_dis = pd.read_csv("/home/ec2-user/data/Processed/disease.csv")

In [9]:
# load pre-encoded protein embeddings and project for retrieval
with open("./data/embeddings/esm2b_unimo_pubmedbert/protein.pkl", "rb") as f:
    pro_raw = pickle.load(f)
pro_emb = torch.tensor(pro_raw["embedding"], dtype=torch.float32)
pro_emb = model.project(
    x = pro_emb,
    src_type = 1,
)
pro_idx = torch.tensor(pro_raw["node_index"])
df_pro = pd.read_csv("/home/ec2-user/data/Processed/protein.csv")

# start encoding and retrieval

In [10]:
# select a sample
mol = mol_df.iloc[15]
name, smiles = mol["title"], mol["moldb_smiles"]
print(name)
print(smiles)
print(mol["drugbank_id"])

Ranitidine
CN\C(NCCSCC1=CC=C(CN(C)C)O1)=C/[N+]([O-])=O
DB00863


In [11]:
mol_raw_emb = inference(mol_model, smiles)
mol_raw_emb = torch.tensor(mol_raw_emb, dtype=torch.float32)
mol_raw_emb[0][:10]

2023-08-30 23:39:33 | unimol_tools/data/conformer.py | 62 | INFO | Uni-Mol(QSAR) | Start generating conformers...
1it [00:00, 13.55it/s]
2023-08-30 23:39:33 | unimol_tools/data/conformer.py | 66 | INFO | Uni-Mol(QSAR) | Failed to generate conformers for 0.00% of molecules.
2023-08-30 23:39:33 | unimol_tools/data/conformer.py | 68 | INFO | Uni-Mol(QSAR) | Failed to generate 3d conformers for 0.00% of molecules.
                                                                                                                                                                                         

tensor([-0.0657, -0.9442, -0.7864, -0.2025, -0.6477, -1.4753,  1.1976,  0.6430,
        -0.4422,  1.5850])

In [12]:
# transform raw seq embeddings to the disease space
tr_dis_emb = model.transform(
    x = mol_raw_emb,
    src_type = 6, # drug
    tgt_type = 2, # disease
    rel_type = 11, # indication
)

# transform raw seq embeddings to the protein space
tr_pro_emb = model.transform(
    x = mol_raw_emb,
    src_type = 6, # drug
    tgt_type = 1, # protein
    rel_type = 10, # target
)

In [13]:
tr_dis_emb[0][:10]

tensor([-0.2112,  0.0892,  0.2597, -1.3143, -0.0588, -0.2926,  0.0403,  1.0763,
        -1.2507,  0.4479])

In [14]:
def retrieve_topk_disease(tgt, topk=10):
    """Args:
    tgt: emb with [1, dim]
    """
    cossim = torch.cosine_similarity(tgt, dis_emb, dim=1)
    top_k = torch.topk(cossim, k=topk, dim=0)
    retrieved_dis = df_dis.set_index('node_index').loc[dis_idx[top_k.indices].numpy()]
    return retrieved_dis


def retrieve_topk_protein(tgt, topk=10):
    """Args:
    tgt: emb with [1, dim]
    """
    cossim = torch.cosine_similarity(tgt, pro_emb, dim=1)
    top_k = torch.topk(cossim, k=topk, dim=0)
    retrieved = df_pro.set_index('node_index').loc[pro_idx[top_k.indices].numpy()]
    return retrieved

In [15]:
res_dis = retrieve_topk_disease(tr_dis_emb, 10)
for v in res_dis["mondo_name"].unique():
    print(v)

conjunctivitis (disease)
duodenal ulcer (disease)
pulmonary embolism (disease)
gastroenteritis
chronic duodenal ileus
peptic ulcer disease
intralobar congenital pulmonary sequestration
Echovirus infectious disease
trachoma
lower respiratory tract disease


In [16]:
res_pro = retrieve_topk_protein(tr_pro_emb, 10)
for v in res_pro["node_name"]:
    print(v)
for v in res_pro["sequence"]:
    print(v)

NOS2
NOS3
GSTP1
GSTA2
GSTA1
GSTM2
GSTM1
GSTA5
GSTM4
PTGS2
MACPWKFLFKTKFHQYAMNGEKDINNNVEKAPCATSSPVTQDDLQYHNLSKQQNESPQPLVETGKKSPESLVKLDATPLSSPRHVRIKNWGSGMTFQDTLHHKAKGILTCRSKSCLGSIMTPKSLTRGPRDKPTPPDELLPQAIEFVNQYYGSFKEAKIEEHLARVEAVTKEIETTGTYQLTGDELIFATKQAWRNAPRCIGRIQWSNLQVFDARSCSTAREMFEHICRHVRYSTNNGNIRSAITVFPQRSDGKHDFRVWNAQLIRYAGYQMPDGSIRGDPANVEFTQLCIDLGWKPKYGRFDVVPLVLQANGRDPELFEIPPDLVLEVAMEHPKYEWFRELELKWYALPAVANMLLEVGGLEFPGCPFNGWYMGTEIGVRDFCDVQRYNILEEVGRRMGLETHKLASLWKDQAVVEINIAVLHSFQKQNVTIMDHHSAAESFMKYMQNEYRSRGGCPADWIWLVPPMSGSITPVFHQEMLNYVLSPFYYYQVEAWKTHVWQDEKRRPKRREIPLKVLVKAVLFACMLMRKTMASRVRVTILFATETGKSEALAWDLGALFSCAFNPKVVCMDKYRLSCLEEERLLLVVTSTFGNGDCPGNGEKLKKSLFMLKELNNKFRYAVFGLGSSMYPRFCAFAHDIDQKLSHLGASQLTPMGEGDELSGQEDAFRSWAVQTFKAACETFDVRGKQHIQIPKLYTSNVTWDPHHYRLVQDSQPLDLSKALSSMHAKNVFTMRLKSRQNLQSPTSSRATILVELSCEDGQGLNYLPGEHLGVCPGNQPALVQGILERVVDGPTPHQTVRLEALDESGSYWVSDKRLPPCSLSQALTYFLDITTPPTQLLLQKLAQVATEEPERQRLEALCQPSEYSKWKFTNSPTFLEVLEEFPSLRVSAGFLLSQLPILKPRFYSISSSRDHTPTEIHLTVAVVTYHTRDGQGPLHHG

## Galactica for generating the answer

In [17]:
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
# tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-120b")
# model = OPTForCausalLM.from_pretrained("facebook/galactica-120b", device_map="auto", load_in_8bit=True, cache_dir="/home/ec2-user/checkpoints")


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /home/ec2-user/miniconda3/envs/py39/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...


In [18]:
tokenizer = AutoTokenizer.from_pretrained("GeorgiaTechResearchInstitute/galactica-30b-evol-instruct-70k")
model = AutoModelForCausalLM.from_pretrained("GeorgiaTechResearchInstitute/galactica-30b-evol-instruct-70k", device_map="auto", torch_dtype=torch.bfloat16, cache_dir="/home/ec2-user/checkpoints")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [19]:
def generate_answer(input_text, max_new_tokens=128, temperature=0.7, top_k=40):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
    outputs = model.generate(input_ids, max_new_tokens=max_new_tokens,
                            top_k=top_k,
                            do_sample=True,
                            temperature=temperature,
                            repetition_penalty=1.0,
    )

    # contrastive search
    # outputs = model.generate(
    #     input_ids,
    #     max_new_tokens=max_new_tokens,
    #     penalty_alpha=0.6, top_k=4,
    # )
    return tokenizer.decode(outputs[0])

In [29]:
prompt_template = """### Instruction:
{instruction}

### Response:"""

prompt = prompt_template.format_map(
    {"instruction":f"""Drug molecule structure: [START_I_SMILES] {smiles} [END_I_SMILES]

Target proteins:
    nitric oxide synthase 2
    nitric oxide synthase 3
    glutathione S-transferase pi 1
    glutathione S-transferase alpha 2
    glutathione S-transferase alpha 1
    glutathione S-transferase mu 2
    glutathione S-transferase mu 1
    glutathione S-transferase alpha 5
    glutathione S-transferase mu 4
    prostaglandin-endoperoxide synthase 2

Associated diseases:
    conjunctivitis (disease)
    pulmonary embolism (disease)
    gastroenteritis
    chronic duodenal ileus
    peptic ulcer disease
    Echovirus infectious disease

Consider the associated diseases and the proteins this molecule targets, what are the main possible associated conditions and main pharmacodynamics of this small molecule drug? """,}
)

print(prompt)


### Instruction:
Drug molecule structure: [START_I_SMILES] CN\C(NCCSCC1=CC=C(CN(C)C)O1)=C/[N+]([O-])=O [END_I_SMILES]

Target proteins:
    nitric oxide synthase 2
    nitric oxide synthase 3
    glutathione S-transferase pi 1
    glutathione S-transferase alpha 2
    glutathione S-transferase alpha 1
    glutathione S-transferase mu 2
    glutathione S-transferase mu 1
    glutathione S-transferase alpha 5
    glutathione S-transferase mu 4
    prostaglandin-endoperoxide synthase 2

Associated diseases:
    conjunctivitis (disease)
    pulmonary embolism (disease)
    gastroenteritis
    chronic duodenal ileus
    peptic ulcer disease
    Echovirus infectious disease

Consider the associated diseases and the proteins this molecule targets, what are the main possible associated conditions and main pharmacodynamics of this small molecule drug? 

### Response:


In [30]:
outputs = generate_answer(prompt, max_new_tokens=256, temperature=1.0, top_k=50)
print(outputs)

### Instruction:
Drug molecule structure: [START_I_SMILES] CN\C(NCCSCC1=CC=C(CN(C)C)O1)=C/[N+]([O-])=O [END_I_SMILES]

Target proteins:
    nitric oxide synthase 2
    nitric oxide synthase 3
    glutathione S-transferase pi 1
    glutathione S-transferase alpha 2
    glutathione S-transferase alpha 1
    glutathione S-transferase mu 2
    glutathione S-transferase mu 1
    glutathione S-transferase alpha 5
    glutathione S-transferase mu 4
    prostaglandin-endoperoxide synthase 2

Associated diseases:
    conjunctivitis (disease)
    pulmonary embolism (disease)
    gastroenteritis
    chronic duodenal ileus
    peptic ulcer disease
    Echovirus infectious disease

Consider the associated diseases and the proteins this molecule targets, what are the main possible associated conditions and main pharmacodynamics of this small molecule drug? 

### Response:Based on the target proteins and associated diseases listed, the main possible associated conditions of this small molecule drug a