In [None]:
# !pip install langchain openai
# !pip install python-arango
# !pip install langchain-community
# !pip install langchain-openai
# !pip install --upgrade langchain langchain-community langchain-openai langgraph
# !pip install langgraph
# !pip install biomart

# !pip install DeepPurpose 
# !pip install torch torchvision torchaudio

# !pip install git+https://github.com/bp-kelley/descriptastorus
# !pip install pandas-flavor

In [1]:
import os
import sys
import requests
import ast
import json
import hashlib
from datetime import datetime
from glob import glob
from io import StringIO

import pandas as pd
import numpy as np

from dotenv import load_dotenv
from arango import ArangoClient
from biomart import BiomartServer

from transformers import AutoTokenizer, AutoModel
import torch

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_community.graphs import ArangoGraph
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_core.tools import tool

from DeepPurpose import utils
from DeepPurpose import DTI as models

from rdkit import Chem, DataStructs
from rdkit.Chem import MACCSkeys
from rdkit.Chem import Draw, AllChem

from Bio.PDB import MMCIFParser

import faiss

In [2]:
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")

In [3]:
db = ArangoClient(hosts="http://localhost:8529").db('NeuThera', username='root', password='openSesame')
arango_graph = ArangoGraph(db)

drug_collection = db.collection('drug')
link_collection = db.collection('drug-protein') 

In [39]:
cursor = db.aql.execute("FOR doc IN drug RETURN {key: doc._key, embedding: doc.embedding}")

drug_keys = []
embeddings = []

for doc in cursor:
    if doc and "embedding" in doc and "key" in doc:
        drug_keys.append(doc["key"])
        embeddings.append(doc["embedding"])

embeddings = np.array(embeddings, dtype=np.float32)

print("Embeddings shape:", embeddings.shape)
print("Number of compounds:", len(drug_keys))

Embeddings shape: (9010, 768)
Number of compounds: 9010


## Tooling

In [5]:
@tool
def text_to_aql(query: str):
    """Execute a Natural Language Query in ArangoDB, and return the result as text."""
    
    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

    chain = ArangoGraphQAChain.from_llm(
        llm=llm,
        graph=arango_graph,  # Assuming arango_graph is already initialized
        verbose=True,
        allow_dangerous_requests=True
    )
    
    result = chain.invoke(query)

    return str(result["result"])

In [6]:
@tool
def predict_binding_affinity(X_drug, X_target, y=[7.635]):
    """
    Predicts the binding affinity for given drug and target sequences.

    Parameters:
    X_drug (list): List containing the SMILES representation of the drug.
    X_target (list): List containing the amino acid sequence of the protein target.

    Returns:
    float: Predicted binding affinity (log(Kd) or log(Ki)).
    """

    print("Predicting binding affinity: ", X_drug, X_target)
    
    model = models.model_pretrained(path_dir='DTI_model')

    X_pred = utils.data_process(X_drug, X_target, y,
                                drug_encoding='CNN', 
                                target_encoding='CNN', 
                                split_method='no_split')
   
    predictions = model.predict(X_pred)

    return predictions[0]


In [7]:
@tool
def get_amino_acid_sequence_from_pdb(pdb_id):    
    """
    Extracts amino acid sequences from a given PDB structure file in CIF format.

    Args:
        pdb_id (str): pdb id of the protein.

    Returns:
        dict: A dictionary where keys are chain IDs and values are amino acid sequences.
    """

    print("Getting Amino Acid sequence for ", pdb_id)

    cif_file_path = f"./database/PDBlib/{pdb_id.lower()}.cif"

    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure("protein", cif_file_path)
    
    sequences = {}
    for model in structure:
        for chain in model:
            seq = "".join(residue.resname for residue in chain if residue.id[0] == " ")
            sequences[chain.id] = seq 
            
    return sequences

In [8]:
sys.path.append(os.path.abspath("./TamGen"))

In [None]:
# Helper Functions for TamGen

tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

def get_chemberta_embedding(smiles):
    """
    Generate a ChemBERTa vector embedding for a given molecule represented as a SMILES string.

    Args:
        smiles (str): A valid SMILES representation of a molecule.

    Returns:
        List[float] or None: A 768-dimensional vector as a list of floats if successful, 
                             otherwise None if the input is invalid.
    """
    
    print("Getting vector embedding")

    if not isinstance(smiles, str) or not smiles.strip():
        return None 

    inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).tolist()[0]

def generate_key(smiles):
    """Generate a unique _key for the compound using SMILES hash."""
    hash_value = hashlib.sha256(smiles.encode()).hexdigest()[:8]
    return f"GEN:{hash_value}"

In [10]:
from TamGen_custom import TamGenCustom

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

worker = TamGenCustom(
    data="./TamGen_Demo_Data",
    ckpt="checkpoints/crossdock_pdb_A10/checkpoint_best.pt",
    use_conditional=True
)

@tool
def prepare_pdb_data(pdb_id):
    """
    Checks if the PDB data for the given PDB ID is available.  
    If not, downloads and processes the data.

    ALWAYS RUN THIS FUNCTION BEFORE WORKING WITH PDB

    Args:
        pdb_id (str): PDB ID of the target structure.

    """

    DemoDataFolder="TamGen_Demo_Data"
    ligand_inchi=None
    thr=10

    out_split = pdb_id.lower()
    FF = glob(f"{DemoDataFolder}/*")
    for ff in FF:
        if f"gen_{out_split}" in ff:
            print(f"{pdb_id} is downloaded")
            return
    
    os.makedirs(DemoDataFolder, exist_ok=True)
    
    with open("tmp_pdb.csv", "w") as fw:
        if ligand_inchi is None:
            print("pdb_id", file=fw)
            print(f"{pdb_id}", file=fw)
        else:
            print("pdb_id,ligand_inchi", file=fw)
            print(f"{pdb_id},{ligand_inchi}", file=fw)

    script_path = os.path.abspath("TamGen/scripts/build_data/prepare_pdb_ids.py")
    os.system(f"python {script_path} tmp_pdb.csv gen_{out_split} -o {DemoDataFolder} -t {thr}")
    os.remove("tmp_pdb.csv")

@tool
def generate_compounds(pdb_id, num_samples=10, max_seed=30):
    """
    Generates and sorts compounds based on similarity to a reference molecule, 
    all generated compounds are added back to the database for futher inference.

    Parameters:
    - pdb_id (str): The PDB ID of the target protein.
    - num_samples (int): Number of compounds to generate. (DEFAULT=500)
    - max_seed (int): Maximum seed variations. (DEFAULT=30)

    Returns:
    - dict: {
        'generated': [list of rdkit Mol objects],
        'reference': rdkit Mol object,
        'reference_smile': SMILE string of the reference compound
        'generated_smiles': [list of SMILES strings, sorted by similarity to reference]
      }
    """

    print("Generating Compounds for PDB ", pdb_id)
    try:
        # Ensure the required PDB data is prepared
        # prepare_pdb_data(pdb_id)

        worker.reload_data(subset=f"gen_{pdb_id.lower()}")

        print(f"Generating {num_samples} compounds...")
        generated_mols, reference_mol = worker.sample(
            m_sample=num_samples, 
            maxseed=max_seed
        )

        if reference_mol:
            # Ensure reference_mol is an RDKit Mol object
            if isinstance(reference_mol, str):
                reference_mol = Chem.MolFromSmiles(reference_mol)

            fp_ref = MACCSkeys.GenMACCSKeys(reference_mol)

            gens = []
            for mol in generated_mols:
                if isinstance(mol, str):  # Convert string SMILES to Mol
                    mol = Chem.MolFromSmiles(mol)
                if mol:  # Ensure conversion was successful
                    fp = MACCSkeys.GenMACCSKeys(mol)
                    similarity = DataStructs.FingerprintSimilarity(fp_ref, fp, metric=DataStructs.TanimotoSimilarity)
                    gens.append((mol, similarity))

            sorted_mols = [mol for mol, _ in sorted(gens, key=lambda e: e[1], reverse=True)]
        
        else:
            sorted_mols = generated_mols

        generated_smiles = [Chem.MolToSmiles(mol) for mol in sorted_mols if mol]

        reference_smile = Chem.MolToSmiles(reference_mol)
        
        print("Inserting to ArangoDB...")
        for smiles in generated_smiles:
            _key = generate_key(smiles) 
            drug_id = f"drug/{_key}"
            protein_id = f"protein/{pdb_id}"

            if drug_collection.has(_key):
                continue

            embedding = get_chemberta_embedding(smiles)
            doc = {
                "_key": _key,
                "_id": drug_id, 
                "accession": "NaN",
                "drug_name": "NaN",
                "cas": "NaN",
                "unii": "NaN",
                "synonym": "NaN",
                "key": "NaN",
                "chembl": "NaN",
                "smiles": smiles,
                "inchi": "NaN",
                "generated": True,
                "embedding": embedding
            }
            drug_collection.insert(doc)

            existing_links = list(db.aql.execute(f'''
                FOR link IN `drug-protein` 
                FILTER link._from == "{drug_id}" AND link._to == "{protein_id}" 
                RETURN link
            '''))

            if not existing_links:
                link_doc = {
                    "_from": drug_id,
                    "_to": protein_id,
                    "generated": True
                }
                link_collection.insert(link_doc)

        return {
            "generated": sorted_mols,
            "reference": reference_mol,
            "reference_smile": reference_smile,
            "generated_smiles": generated_smiles
        }

    except Exception as e:
        print(f"Error in compound generation: {str(e)}")
        return {"error": str(e)}

Namespace(no_progress_bar=False, log_interval=1000, log_format=None, tensorboard_logdir='', tbmf_wrapper=False, seed=1, cpu=False, fp16=False, memory_efficient_fp16=False, fp16_init_scale=128, fp16_scale_window=None, fp16_scale_tolerance=0.0, min_loss_scale=0.0001, threshold_loss_scale=None, user_dir=None, criterion='cross_entropy', tokenizer=None, bpe=None, optimizer='nag', lr_scheduler='fixed', task='translation_coord', num_workers=1, skip_invalid_size_inputs_valid_test=False, max_tokens=1024, max_sentences=None, required_batch_size_multiple=8, dataset_impl=None, gen_subset='gen_8fln', num_shards=1, shard_id=0, path='checkpoints/crossdock_pdb_A10/checkpoint_best.pt', remove_bpe=None, quiet=False, model_overrides='{}', results_path=None, beam=20, nbest=20, max_len_a=0, max_len_b=200, min_len=1, match_source_len=False, no_early_stop=False, unnormalized=False, no_beamable_mm=False, lenpen=1, unkpen=0, replace_unk=None, sacrebleu=False, score_reference=False, prefix_size=0, prefix_string

In [11]:
@tool
def generate_report(columns, rows):
    """
    Generate a report in CSV format with a timestamped filename. This function uses pandas to create a CSV.
    
    Parameters:
    columns (list): List of column names.
    rows (list of lists): Data rows corresponding to the columns.
    
    Returns:
    str: Path of the generated CSV report.
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"report_{timestamp}.csv"
    
    df = pd.DataFrame(rows, columns=columns)
    df.to_csv(filename, index=False)
    
    return filename

In [None]:
# Tools for vector embeddings

def find_similar_drugs(embedding, top_k=5):
    """
    Finds the top K most similar drugs based on given vector embeddings of query molecule.

    Args:
        embedding (List[float]): The ChemBERTa embedding of the query molecule.
        top_k (int, optional): Number of most similar drugs to retrieve. Default is 5.

    Returns:
        List[Tuple[str, float]]: A list of (drug_name, similarity_score) sorted by similarity.
    """
    print("Finding similar drugs...")

    embedding = np.array(embedding, dtype=np.float32).reshape(1, -1)

    # Initialize FAISS index
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)

    distances, indices = index.search(embedding, top_k)

    top_similar_drugs = [(drug_keys[i], 1 - distances[0][j]) for j, i in enumerate(indices[0])]

    return top_similar_drugs

### Agentic RAG

In [None]:
tools = [ text_to_aql, get_amino_acid_sequence_from_pdb, prepare_pdb_data, generate_compounds, predict_binding_affinity, generate_report ]

def query_graph(query):
    
    query_template = f"""
    USER INPUT: {query}

    You are an advanced drug discovery assistant with multiple tools.

    - Use your tools as needed to assist in end-to-end drug discovery and answer user queries.
    - Always structure your output as valid JSON string so it can be parsed in python.
    - If possible, always try to generate reports for whatever output you get. Don't generate reports for errors.
    - Do not add explanations or any extra text.
    - When working with multiple outputs, run functions one by one for everything unless stated otherwise by the user.
    - When working with drugs, ignore embeddings.
    """
    
    llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
    app = create_react_agent(llm, tools)    
    final_state = app.invoke({"messages": [{"role": "user", "content": query_template}]})
    return final_state

In [57]:
    # query_template = f"""
    # user input: {query}

    # You are an advanced drug discovery assistant with multiple tools.

    # - Use your tools as needed to assist in end-to-end drug discovery and answer user queries.
    # - Always structure your output as valid JSON string so it can be parsed in python.
    # - If possible, always try to generate reports for whatever output you get. Don't generate reports for errors
    # - Do not add explanations or any extra text.

    # !WARNING!
    # YOU RUN FOR DEMO PURPOSES ONLY

    # - DO NOT RUN TOOLS ESPECIALLY generate_compounds AND predict_binding_affinity SIMULTANEOUly
    # - IF MULTIPLE OUTPUTS ARE FOUND FOR PROCESSING, JUST PICK THE FIRST ONE AND RUN THE PROCEEDING FUNCTIONS WITH THAT.
    # """

In [63]:
# output = query_graph("What proteins can you find me related to mitochondrial ribosomal protein L36? Generate some compounds and test their binding affinity")
# output = query_graph("Given pdb 5ool, generate some compounds")
# output = query_graph("Find proteins related to disease Anaphylaxis, generate compounds and test their binding affinity")
output = query_graph("Take a random drug from the database and find top 10 most similar drugs to it")

output



[1m> Entering new ArangoGraphQAChain chain...[0m
AQL Query (1):[32;1m[1;3m
WITH drug
FOR d IN drug
RETURN d
LIMIT 1
[0m
AQL Query Execution Error: 
[33;1m[1;3msyntax error, unexpected LIMIT declaration, expecting end of query string near 'LIMIT 1
' at position 5:1[0m

AQL Query (2):[32;1m[1;3m
WITH drug
FOR d IN drug
LIMIT 1
RETURN d
[0m
AQL Result:
[32;1m[1;3m[{'_key': 'DB00014', '_id': 'drug/DB00014', '_rev': '_jUuy9Iy---', 'accession': 'BTD00113 | BIOD00113', 'drug_name': 'Goserelin', 'cas': '65807-02-5', 'unii': '0F65R8P09N', 'synonym': 'Goserelin | Goserelina', 'key': 'BLCLNMBMMGCOAS-URPVMXJPSA-N', 'chembl': 'CHEMBL1201247', 'smiles': 'CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H](Cc1ccc(O)cc1)NC(=O)[C@H](CO)NC(=O)[C@H](Cc1c[nH]c2ccccc12)NC(=O)[C@H](Cc1cnc[nH]1)NC(=O)[C@@H]1CCC(=O)N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N1CCC[C@H]1C(=O)NNC(N)=O', 'inchi': 'InChI=1S/C59H84N18O14/c1-31(2)22-40(49(82)68-39(12-8-20-64-57(60)61)56(89)77-21-9-13-46(77)55(88)75-76-58(62)90)69-54

{'messages': [HumanMessage(content="\n    USER INPUT: Take a random drug from the database and find top 10 most similar drugs to it\n\n    You are an advanced drug discovery assistant with multiple tools.\n\n    - Use your tools as needed to assist in end-to-end drug discovery and answer user queries.\n    - Always structure your output as valid JSON string so it can be parsed in python.\n    - If possible, always try to generate reports for whatever output you get. Don't generate reports for errors.\n    - Do not add explanations or any extra text.\n    - When working with multiple outputs, run functions one by one for everything unless stated otherwise by the user.\n    - When working with drugs, ignore embeddings.\n    ", additional_kwargs={}, response_metadata={}, id='180b2202-0128-47cd-9d0e-48b83c65966a'),
  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_NLJ2WFtdkHq49wCCjjQ9uAbi', 'function': {'arguments': '{"query":"FOR d IN drugs RETURN d LIMIT 1"}', 'name'

In [143]:
message = json.loads(output["messages"][-1].content)
message

{'error': 'An error occurred while finding similar drugs. Please try again later.'}

In [19]:
some = get_chemberta_embedding('CC(C)(COP(=O)(O)OP(=O)(O)OC[C@H]1O[C@@H](n2ccc(=N)nc2O)[C@H](O)[C@@H]1OP(=O)(O)O)[C@@H](O)C(O)=NCCC(O)=NCCS')

Getting vector embedding
