<a href="https://colab.research.google.com/github/ashraqat03/Mini-Drug-Repurposing-Agent-Planner-Executor-Workflow-using-LangGraph/blob/main/Drug_repurposing_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Testing individual parts

In [None]:
import pandas as pd

#Load the data
gene_disease_df = pd.read_csv('gene_disease.csv')

def find_targets(disease_query):
    # Simple filtering: find rows where disease_name matches the query
    results_df = gene_disease_df[gene_disease_df['disease_name'] == disease_query]
    # Convert the result to a list of dictionaries for easy use
    results_list = results_df.to_dict('records')
    return results_list

targets = find_targets("Parkinson's disease")
print("Found targets:", targets)

Found targets: [{'gene_symbol': 'LRRK2', 'disease_name': "Parkinson's disease", 'association_score': 0.95}, {'gene_symbol': 'SNCA', 'disease_name': "Parkinson's disease", 'association_score': 0.89}, {'gene_symbol': 'PINK1', 'disease_name': "Parkinson's disease", 'association_score': 0.87}]


In [None]:
drug_target_df = pd.read_csv('drug_target.csv')

def find_compounds(target_genes_list):
    # Filter the dataframe: find rows where 'target_gene' is IN the list provided
    results_df = drug_target_df[drug_target_df['target_gene'].isin(target_genes_list)]
    return results_df.to_dict('records')

compounds = find_compounds(["LRRK2", "SNCA"])
print("Found compounds:", compounds)

Found compounds: [{'drug_name': 'Rapamycin', 'target_gene': 'LRRK2', 'mechanism': 'inhibitor'}, {'drug_name': 'Nilotinib', 'target_gene': 'SNCA', 'mechanism': 'inhibitor'}]


PROJECT TRIAL 1

In [None]:
!pip install -qU langgraph langchain-google-genai google-generativeai pandas scikit-learn requests streamlit

In [None]:
import pandas as pd
import requests
import google.generativeai as genai
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import joblib
import re

The librarian (finder)

In [None]:
def find_targets(disease_name):
    """
    Looks up a disease in the gene_disease.csv file and returns associated targets.
    """
    try:
        df = pd.read_csv('gene_disease.csv')
        results_df = df[df['disease_name'].str.contains(disease_name, case=False, na=False)]

        if results_df.empty:
            return {"error": f"No targets found for disease: '{disease_name}'."}

        return {"targets": results_df.sort_values('association_score', ascending=False).to_dict('records')}

    except Exception as e:
        return {"error": f"Failed to read gene-disease data: {str(e)}"}

test_result = find_targets("Parkinson's")

print("Test find_targets result:", test_result)

Test find_targets result: {'targets': [{'gene_symbol': 'LRRK2', 'disease_name': "Parkinson's disease", 'association_score': 0.95}, {'gene_symbol': 'SNCA', 'disease_name': "Parkinson's disease", 'association_score': 0.89}, {'gene_symbol': 'PINK1', 'disease_name': "Parkinson's disease", 'association_score': 0.87}]}


The pharmacist

In [None]:
def find_compounds(target_genes):
  """"
  Looks up one or more target genes in the drug_target.csv file and returns associated compounds.
  kinda the pharmacist.
  """

  try:
    df_data = pd.read_csv('drug_target.csv')
    if isinstance(target_genes,str):
      target_genes = [target_genes]

    results_df = df_data[df_data['target_gene'].isin(target_genes)]

    if results_df.empty:
      return {"error": f"No compounds found for target genes: {target_genes}"}

    # Merge with compounds_df_corrected to get SMILES
    compounds_with_smiles = results_df.merge(compounds_df_corrected[['compound_id', 'SMILES']],
                                              left_on='drug_name',
                                              right_on='compound_id',
                                              how='left')
    # Drop the redundant 'compound_id' column
    compounds_with_smiles = compounds_with_smiles.drop('compound_id', axis=1)

    results_list=compounds_with_smiles.to_dict('records')
    return {"compounds": results_list}

  except Exception as e:
    return {"error": f"Failed to read drug-target data: {str(e)}"}

#TESTINGG
#CASE 1
print("test1: list input:", find_compounds(["LRRK2","SNCA"]))

#CASE 2
print("test2: string input:", find_compounds("EGFR"))

#CASE 3 (check bs en el exceptions are working)
print("test3: error test:", find_compounds("blablabla"))

test1: list input: {'compounds': [{'drug_name': 'Rapamycin', 'target_gene': 'LRRK2', 'mechanism': 'inhibitor', 'SMILES': 'C1CCCCC1'}, {'drug_name': 'Nilotinib', 'target_gene': 'SNCA', 'mechanism': 'inhibitor', 'SMILES': 'CC1=CC=CC(=C1)NC(=O)C2=CN=C3N2C=CC(=C3)NC(=O)C4=CC=C(C=C4)N5CCN(CC5)C'}]}
test2: string input: {'compounds': [{'drug_name': 'Gefitinib', 'target_gene': 'EGFR', 'mechanism': 'inhibitor', 'SMILES': 'COC1=CC2=C(C=C1OCCCN3CCOCC3)N=CN=C2NC4=CC=C(C=C4)Cl'}, {'drug_name': 'Erlotinib', 'target_gene': 'EGFR', 'mechanism': 'inhibitor', 'SMILES': 'COC1=CC2=C(C=C1OC)N=CN=C2NC3=CC=C(C=C3)OCCOC'}]}
test3: error test: {'error': "No compounds found for target genes: ['blablabla']"}


QSAR trial

In [None]:
!pip install --force-reinstall rdkit

Collecting rdkit
  Downloading rdkit-2025.3.6-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Collecting numpy (from rdkit)
  Downloading numpy-2.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Pillow (from rdkit)
  Downloading pillow-11.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (9.0 kB)
Downloading rdkit-2025.3.6-cp312-cp312-manylinux_2_28_x86_64.whl (36.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.1/36.1 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-2.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.6/16.6 MB[0m [31m62.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pillow-11.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.w

In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import joblib
import numpy as np

In [None]:
descriptor_list = [
    'MolWt',
    'MolLogP',
    'NumHDonors',
    'NumHAcceptors',
    'NumRotatableBonds',
    'TPSA', # Topological Polar Surface Area (critical for permeability)
    'HeavyAtomCount',
    'RingCount',
    'FractionCSP3' # Measures carbon saturation (related to compound quality)
]

In [None]:
compounds_data_corrected = {
    'compound_id': ['Rapamycin', 'Nilotinib', 'Gefitinib', 'Erlotinib', 'Olaparib', 'Trametinib', 'Metformin', 'Aspirin', 'Ibuprofen'],
    'SMILES': [
        'C1CCCCC1', # Simple cyclohexane stand-in for Rapamycin
        'CC1=CC=CC(=C1)NC(=O)C2=CN=C3N2C=CC(=C3)NC(=O)C4=CC=C(C=C4)N5CCN(CC5)C', # Valid Nilotinib
        'COC1=CC2=C(C=C1OCCCN3CCOCC3)N=CN=C2NC4=CC=C(C=C4)Cl', # Valid Gefitinib
        'COC1=CC2=C(C=C1OC)N=CN=C2NC3=CC=C(C=C3)OCCOC', # Valid Erlotinib
        'O=C(C1CCCCN1)NC2=CC=CC3=C2N=CN=C3N4CCOCC4', # Valid Olaparib
        'CNC(=O)C1=CC2=C(C=C1C3=CC=CC=C3)S(=O)(=O)C4=CC=C(C=C4)N5CCN(CCO)CC5', # Valid Trametinib
        'CN(C)C(=N)N=C(N)N', # Metformin
        'CC(=O)OC1=CC=CC=C1C(=O)O', # Aspirin
        'CC(C)CC1=CC=C(C=C1)C(C)C(=O)O' # Ibuprofen
    ],
    'activity_label': [1, 1, 1, 1, 1, 1, 0, 0, 0]
}
compounds_df_corrected = pd.DataFrame(compounds_data_corrected)
compounds_df_corrected.to_csv('compounds.csv', index=False)
print("Created NEW compounds.csv with VALID SMILES strings.")

Created NEW compounds.csv with VALID SMILES strings.


In [None]:
calculator = MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_list)

def compute_descriptors(smiles):
    """
    Calculates a set of 9 key physicochemical properties for a molecule.
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    descriptors = calculator.CalcDescriptors(mol)
    return np.array(descriptors)


In [None]:
def train_qsar_model():
  """
  using logistic regression to predict compound activity based on molecular features.
  use the compounds.csv file to train the model.
  """

  try:
    qsar_df= pd.read_csv("compounds.csv")
    X=[]
    y=[]

    for index, row in qsar_df.iterrows():
      smiles= row['SMILES']
      label= row['activity_label']

      mol=Chem.MolFromSmiles(smiles)
      desc_vector = compute_descriptors(smiles)
      if desc_vector is not None:
                X.append(desc_vector)
                y.append(label)

    X = np.array(X)
    y = np.array(y)

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
        # Save the scaler
    joblib.dump(scaler, 'scaler.joblib')

    model=LogisticRegression(random_state=42, class_weight='balanced')
    model.fit(X_scaled,y)
    joblib.dump(model, 'qsar_model.joblib')
    print({"message": f"Model trained successfully on {X.shape[0]} molecules with {X.shape[1]}."})

  except Exception as e:
    print({f"error: Failed to train QSAR model: {str(e)}"})

In [None]:
def predict_activity(smiles_list):
    """
    Predicts activity.
    """
    try:
        model = joblib.load('qsar_model.joblib')
        scaler = joblib.load('scaler.joblib')
        predictions = []
        for smiles in smiles_list:
            desc_vector = compute_descriptors(smiles)
            if desc_vector is not None:
                desc_vector_scaled = scaler.transform(desc_vector.reshape(1, -1))
                proba = model.predict_proba(desc_vector_scaled)[0][1]
                predictions.append(float(round(proba, 2)))
            else:
                # Handle invalid SMILES by appending a default value and printing a warning
                print(f"Warning: Could not parse SMILES '{smiles}'. Setting probability to 0.")
                predictions.append(0.0)
        return {"predictions": predictions}
    except Exception as e:
        return {"error": f"Prediction failed: {str(e)}"}

In [None]:
train_qsar_model()

test_smiles_inactive = ["CN(C)C(=N)N=C(N)N", "CC(=O)OC1=CC=CC=C1C(=O)O"]
test_smiles_active = ["CC1=CC=CC(=C1)NC(=O)C2=CN=C3N2C=CC(=C3)NC(=O)C4=CC=C(C=C4)N5CCN(CC5)C"]

print("Test Inactive Compounds:", predict_activity(test_smiles_inactive))
print("Test Active Compound:", predict_activity(test_smiles_active))

{'message': 'Model trained successfully on 8 molecules with 9.'}
Test Inactive Compounds: {'predictions': [0.05, 0.14]}
Test Active Compound: {'predictions': [0.9]}


[18:48:24] SMILES Parse Error: unclosed ring for input: 'CNC(=O)C1=CC2=C(C=C1C3=CC=CC=C3)S(=O)(=O)C4=CC=C(C=C4)N5CCN(CCO)CC5'
[18:48:24] SMILES Parse Error: unclosed ring for input: 'CNC(=O)C1=CC2=C(C=C1C3=CC=CC=C3)S(=O)(=O)C4=CC=C(C=C4)N5CCN(CCO)CC5'


**Integrating Gemeni**

In [None]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    google_api_key=userdata.get('API')
)

Planner Node

In [None]:
# --- DEFINE THE PLANNER NODE ---
def planner_node(state):
    """
    This is the Planner Agent. It uses the Gemini LLM to analyze the user input
    and decide the first step in the workflow.
    It returns a decision that LangGraph will use to route the workflow.
    """
    print(f"Planner analyzing input: '{state['input']}'")

    # The system prompt
    planner_prompt = """
    You are an expert planner for a drug discovery AI. Your only task is to analyze the user's input and decide the first step.

    The rules are strict:
    1. If the user asks about a *disease* (e.g., "Parkinson's", "cancer", "Alzheimer's"), output: 'find_targets'
    2. If the user asks about a *protein*, *gene*, or *target* (e.g., "EGFR", "BRCA1", "LRRK2"), output: 'find_compounds'
    3. If you are unsure, output 'find_targets'

    Do not output anything else. No explanations. Just one word: either 'find_targets' or 'find_compounds'.
    """

    # Create the message for the LLM
    messages = [
        {"role": "user", "content": planner_prompt},
        {"role": "user", "content": f"User input: {state['input']}"}
    ]

    # Call the Gemini API
    response = llm.invoke(messages)
    decision = response.content.strip().lower()

    print(f"Planner decision: '{decision}'")

    # Return the decision to the LangGraph state
    return {"next_step": decision}

# TESTING
# W/out full graph
test_state = {"input": "Parkinson's disease"}
result = planner_node(test_state)
print("Test Planner Output:", result)

test_state2 = {"input": "EGFR"}
result2 = planner_node(test_state2)
print("Test Planner Output 2:", result2)

Planner analyzing input: 'Parkinson's disease'
Planner decision: 'find_targets'
Test Planner Output: {'next_step': 'find_targets'}
Planner analyzing input: 'EGFR'
Planner decision: 'find_compounds'
Test Planner Output 2: {'next_step': 'find_compounds'}


Langchain

In [None]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Annotated
import operator

In [None]:
# Define the structure
class AgentState(TypedDict):
    # The user's original query(no change)
    input: str

    # The decision from Planner node - will determine our path
    next_step: str

    # Results from find_targets node - starts as None
    targets: List[dict]

    # Results from find_compounds node - starts as None
    compounds: List[dict]

    # Results from predict_activity node - starts as None
    predictions: List[dict]

    # Final report - will be added at the very end
    report: str

In [None]:
# workflow graph
workflow = StateGraph(AgentState)
print("initialized")

initialized


In [None]:
# find_targets
def find_targets_node(state):
    """Node version of our find_targets tool"""
    print("Librarian finding targets...")
    result = find_targets(state['input'])  # Use input from state
    if "error" in result:
        return {"error": result["error"]}
    return {"targets": result["targets"]}  # Update state with targets

# find_compounds Node
def find_compounds_node(state):
    """Node version of our find_compounds tool"""
    print("Pharmacist finding compounds...")
    # Get targets from state or use input directly
    if state.get('targets'):
        target_genes = [t['gene_symbol'] for t in state['targets']]
    else:
        target_genes = [state['input']]  # Input was probably a gene

    result = find_compounds(target_genes)
    if "error" in result:
        return {"error": result["error"]}
    return {"compounds": result["compounds"]}  # Update state with compounds

# predict_activity Node
def predict_activity_node(state):
    """Node version of our predict_activity tool"""
    print("Chemist predicting activity...")
    smiles_list = [c['SMILES'] for c in state['compounds']]
    result = predict_activity(smiles_list)
    if "error" in result:
        return {"error": result["error"]}
    return {"predictions": result["predictions"]}  # Update state with predictions

print("All node functions defined")

All node functions defined


In [None]:
# Add all nodes
workflow.add_node("planner", planner_node)
workflow.add_node("find_targets", find_targets_node)
workflow.add_node("find_compounds", find_compounds_node)
workflow.add_node("predict_activity", predict_activity_node)

print("All nodes added to graph")
print("Current nodes:", workflow.nodes)

All nodes added to graph
Current nodes: {'planner': StateNodeSpec(runnable=planner(tags=None, recurse=True, explode_args=False, func_accepts={}), metadata=None, input_schema=<class '__main__.AgentState'>, retry_policy=None, cache_policy=None, ends=(), defer=False), 'find_targets': StateNodeSpec(runnable=find_targets(tags=None, recurse=True, explode_args=False, func_accepts={}), metadata=None, input_schema=<class '__main__.AgentState'>, retry_policy=None, cache_policy=None, ends=(), defer=False), 'find_compounds': StateNodeSpec(runnable=find_compounds(tags=None, recurse=True, explode_args=False, func_accepts={}), metadata=None, input_schema=<class '__main__.AgentState'>, retry_policy=None, cache_policy=None, ends=(), defer=False), 'predict_activity': StateNodeSpec(runnable=predict_activity(tags=None, recurse=True, explode_args=False, func_accepts={}), metadata=None, input_schema=<class '__main__.AgentState'>, retry_policy=None, cache_policy=None, ends=(), defer=False)}


In [None]:
def decide_next_step(state):
    """Reads the planner's decision and routes accordingly"""
    next_step = state['next_step']
    print(f"Routing to: {next_step}")
    return next_step

# Set the starting point
workflow.set_entry_point("planner")

# Add conditional routing after planner
workflow.add_conditional_edges(
    "planner",
    decide_next_step,
    {
        "find_targets": "find_targets",
        "find_compounds": "find_compounds"
    }
)

ValueError: Branch with name `decide_next_step` already exists for node `planner`

In [None]:
# Connect the nodes in sequence
workflow.add_edge("find_targets", "find_compounds")
workflow.add_edge("find_compounds", "predict_activity")
workflow.add_edge("predict_activity", END)


<langgraph.graph.state.StateGraph at 0x785d68447710>

In [None]:
app = workflow.compile()

In [None]:
# Test the complete workflow!
print("Testing complete workflow for Parkinson's disease...")
result = app.invoke({"input": "Parkinson's disease"})
print("\n=== FINAL RESULTS ===")
print("Targets found:", len(result['targets']))
print("Compounds found:", len(result['compounds']))
print("Predictions made:", len(result['predictions']))
print("\nSample prediction:", result['predictions'][0] if result['predictions'] else "None")

Testing complete workflow for Parkinson's disease...
Planner analyzing input: 'Parkinson's disease'
Planner decision: 'find_targets'
🔄 Routing to: find_targets
Librarian finding targets...
Pharmacist finding compounds...
Chemist predicting activity...

=== FINAL RESULTS ===
Targets found: 3
Compounds found: 2
Predictions made: 2

Sample prediction: 0.61
