In [None]:
import os
os.environ["GEMINI_API_KEY"] = ""
os.environ["TAVILY_API_KEY"] = "tvly-dev-"

In [6]:
!pip install langgraph langchain-core langchain-google-genai tavily requests neo4j

Collecting neo4j
  Downloading neo4j-6.0.2-py3-none-any.whl.metadata (5.2 kB)
Collecting pytz (from neo4j)
  Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading neo4j-6.0.2-py3-none-any.whl (325 kB)
Downloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
Installing collected packages: pytz, neo4j
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [neo4j]32m1/2[0m [neo4j]
[1A[2KSuccessfully installed neo4j-6.0.2 pytz-2025.2


In [20]:
import os
import re
import requests
from typing import TypedDict, List, Dict, Optional
from langgraph.graph import StateGraph, END
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage
from tavily import TavilyClient

# Config
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "")

# State
class DrugState(TypedDict):
    drug_name: str
    normalized_id: Optional[str]
    source: Optional[str]

# RxNorm API (FREE - replaces DrugBank)
def search_rxnorm(drug_name: str) -> Optional[str]:
    """Search RxNorm API for drug identifier"""
    try:
        # Step 1: Get approximate term match
        response = requests.get(
            "https://rxnav.nlm.nih.gov/REST/approximateTerm.json",
            params={"term": drug_name, "maxEntries": 5},
            timeout=10
        )
        
        if response.status_code == 200:
            data = response.json()
            candidates = data.get("approximateGroup", {}).get("candidate", [])
            
            if not candidates:
                return None
            
            # Get first candidate's RxCUI
            if isinstance(candidates, list):
                rxcui = candidates[0].get("rxcui")
            else:
                rxcui = candidates.get("rxcui")
            
            if rxcui:
                # Step 2: Get DrugBank ID from RxNorm properties
                prop_response = requests.get(
                    f"https://rxnav.nlm.nih.gov/REST/rxcui/{rxcui}/allProperties.json",
                    params={"prop": "all"},
                    timeout=10
                )
                
                if prop_response.status_code == 200:
                    prop_data = prop_response.json()
                    properties = prop_data.get("propConceptGroup", {}).get("propConcept", [])
                    
                    # Look for DrugBank ID in properties
                    for prop in properties:
                        if prop.get("propName") == "DRUGBANK":
                            drugbank_id = prop.get("propValue")
                            if drugbank_id and drugbank_id.startswith("DB"):
                                return f"Compound::{drugbank_id}"
                
                # If no DrugBank ID, try to map via ChEMBL using the drug name
                # Or return RxCUI format for now
                return f"Compound::RXCUI{rxcui}"
                
    except Exception as e:
        print(f"RxNorm error: {e}")
    return None

# ChEMBL API (FREE)
def search_chembl(drug_name: str) -> Optional[str]:
    """Search ChEMBL API for drug identifier"""
    try:
        # Try molecule search
        response = requests.get(
            "https://www.ebi.ac.uk/chembl/api/data/molecule/search.json",
            params={"q": drug_name, "limit": 5},
            timeout=10
        )
        
        if response.status_code == 200:
            data = response.json()
            molecules = data.get("molecules", [])
            if molecules:
                chembl_id = molecules[0].get("molecule_chembl_id")
                if chembl_id and chembl_id.startswith("CHEMBL"):
                    return f"Compound::{chembl_id}"
        
        # Try drug-specific endpoint
        response = requests.get(
            "https://www.ebi.ac.uk/chembl/api/data/drug.json",
            params={"pref_name__icontains": drug_name, "limit": 5},
            timeout=10
        )
        
        if response.status_code == 200:
            data = response.json()
            drugs = data.get("drugs", [])
            if drugs:
                chembl_id = drugs[0].get("molecule_chembl_id")
                if chembl_id and chembl_id.startswith("CHEMBL"):
                    return f"Compound::{chembl_id}"
                    
    except Exception as e:
        print(f"ChEMBL error: {e}")
    return None

# Tavily + LLM (Fallback)
def search_tavily(drug_name: str) -> Optional[str]:
    """Use Tavily search + LLM to extract drug identifiers"""
    if not TAVILY_API_KEY or not GEMINI_API_KEY:
        return None
    
    try:
        tavily = TavilyClient(api_key=TAVILY_API_KEY)
        llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash-exp", 
            temperature=0, 
            api_key=GEMINI_API_KEY
        )
        
        # Search for drug info
        response = tavily.search(
            f"{drug_name} DrugBank ID CHEMBL ID drug database",
            search_depth="advanced",
            max_results=5
        )
        
        results = []
        for r in response.get("results", []):
            results.append(f"Title: {r['title']}\nContent: {r['content']}\nURL: {r['url']}\n")
        
        if not results:
            return None
        
        context = "\n---\n".join(results)
        
        prompt = f"""Extract the official drug database ID for: "{drug_name}"

SEARCH RESULTS:
{context}

RULES:
- DrugBank IDs: "DB" + exactly 5 digits (e.g., DB00945)
- CHEMBL IDs: "CHEMBL" + digits (e.g., CHEMBL25)
- Prefer DrugBank over CHEMBL
- Output format: "Compound::DBxxxxx" or "Compound::CHEMBLxxxx"
- If not found: "NOT_FOUND"

OUTPUT (one line only):"""
        
        result = llm.invoke([
            SystemMessage(content="You extract drug database identifiers."),
            HumanMessage(content=prompt)
        ])
        
        extracted = result.content.strip()
        
        if re.match(r"Compound::(DB\d{5}|CHEMBL\d+)", extracted):
            return extracted
            
    except Exception as e:
        print(f"Tavily error: {e}")
    return None

# LangGraph Nodes
def try_rxnorm(state: DrugState) -> DrugState:
    """Try RxNorm API first"""
    print(f"  → Trying RxNorm...")
    result = search_rxnorm(state["drug_name"])
    if result:
        state["normalized_id"] = result
        state["source"] = "RxNorm"
        print(f"    ✓ Found: {result}")
    else:
        print(f"    ✗ Not found")
    return state

def try_chembl(state: DrugState) -> DrugState:
    """Try ChEMBL API"""
    print(f"  → Trying ChEMBL...")
    result = search_chembl(state["drug_name"])
    if result:
        state["normalized_id"] = result
        state["source"] = "ChEMBL"
        print(f"    ✓ Found: {result}")
    else:
        print(f"    ✗ Not found")
    return state

def try_tavily(state: DrugState) -> DrugState:
    """Try Tavily search + LLM"""
    print(f"  → Trying Tavily + LLM...")
    result = search_tavily(state["drug_name"])
    if result:
        state["normalized_id"] = result
        state["source"] = "Tavily+LLM"
        print(f"    ✓ Found: {result}")
    else:
        print(f"    ✗ Not found")
    return state

def finalize(state: DrugState) -> DrugState:
    """Finalize normalization"""
    if state.get("normalized_id"):
        print(f"✅ Normalized: {state['drug_name']} → {state['normalized_id']} (via {state['source']})")
    else:
        print(f"❌ Failed to normalize: {state['drug_name']}")
    return state

# Routing Logic
def route_after_rxnorm(state: DrugState) -> str:
    return "finalize" if state.get("normalized_id") else "try_chembl"

def route_after_chembl(state: DrugState) -> str:
    return "finalize" if state.get("normalized_id") else "try_tavily"

def route_after_tavily(state: DrugState) -> str:
    return "finalize"

# Build LangGraph
def build_normalization_graph():
    """Build the drug normalization LangGraph workflow"""
    workflow = StateGraph(DrugState)
    
    # Add nodes
    workflow.add_node("try_rxnorm", try_rxnorm)
    workflow.add_node("try_chembl", try_chembl)
    workflow.add_node("try_tavily", try_tavily)
    workflow.add_node("finalize", finalize)
    
    # Set entry point
    workflow.set_entry_point("try_rxnorm")
    
    # Add conditional edges
    workflow.add_conditional_edges("try_rxnorm", route_after_rxnorm)
    workflow.add_conditional_edges("try_chembl", route_after_chembl)
    workflow.add_conditional_edges("try_tavily", route_after_tavily)
    workflow.add_edge("finalize", END)
    
    return workflow.compile()

# Main API
def normalize_drug(drug_name: str) -> Dict:
    """Normalize a single drug name"""
    print(f"\n{'='*60}")
    print(f"Normalizing: {drug_name}")
    print(f"{'='*60}")
    
    graph = build_normalization_graph()
    result = graph.invoke({
        "drug_name": drug_name.strip(),
        "normalized_id": None,
        "source": None
    })
    
    return {
        "drug_name": drug_name,
        "normalized_id": result.get("normalized_id"),
        "source": result.get("source"),
        "success": result.get("normalized_id") is not None
    }

def normalize_multiple(drug_names: List[str]) -> Dict[str, Dict]:
    """Normalize multiple drug names"""
    print(f"\n{'='*60}")
    print(f"Normalizing {len(drug_names)} drugs")
    print(f"{'='*60}")
    
    results = {}
    for drug in drug_names:
        results[drug] = normalize_drug(drug)
    
    successful = sum(1 for r in results.values() if r["success"])
    print(f"\n{'='*60}")
    print(f"✅ Success: {successful}/{len(drug_names)}")
    print(f"{'='*60}\n")
    
    return results

In [21]:
from neo4j import GraphDatabase

# Neo4j Config
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"

# DRKG Relationships
DDI_RELS = ["DRUGBANK::ddi-interactor-in::Compound:Compound"]
SIDE_EFFECT_RELS = ["Hetionet::CcSE::Compound:Side Effect", "GNBR::Sa::Compound:Disease"]

def query_graph(drug_ids: List[str]) -> Dict:
    """Query Neo4j for drug interactions and side effects"""
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    results = {}
    
    with driver.session() as session:
        for drug in drug_ids:
            print(f"Querying graph for: {drug}")
            results[drug] = {'interactions': [], 'side_effects': []}
            
            # Drug-Drug Interactions
            ddi_query = """
            MATCH (drug {Entity: $drug})-[r]-(partner)
            WHERE r.Relationship IN $ddi_rels
            RETURN DISTINCT partner.Entity AS interacting_drug
            LIMIT 200
            """
            ddi_records = session.run(ddi_query, drug=drug, ddi_rels=DDI_RELS)
            for rec in ddi_records:
                results[drug]['interactions'].append(rec['interacting_drug'])
            
            print(f"  ✓ Found {len(results[drug]['interactions'])} interactions")
            
            # Side Effects
            se_query = """
            MATCH (drug {Entity: $drug})-[r]-(effect)
            WHERE r.Relationship IN $se_rels
            RETURN DISTINCT effect.Entity AS side_effect
            LIMIT 200
            """
            se_records = session.run(se_query, drug=drug, se_rels=SIDE_EFFECT_RELS)
            for rec in se_records:
                results[drug]['side_effects'].append(rec['side_effect'])
            
            print(f"  ✓ Found {len(results[drug]['side_effects'])} side effects")
    
    driver.close()
    return results

def analyze_with_llm(graph_data: Dict, drug_name_mapping: Dict[str, str]) -> str:
    """
    Analyze drug combination using LLM with proper drug names
    
    Args:
        graph_data: Graph query results
        drug_name_mapping: Dict mapping normalized_id -> original drug name
    """
    llm = ChatGoogleGenerativeAI(
        model="gemini-2.0-flash-exp",
        temperature=0.1,
        api_key=GEMINI_API_KEY
    )
    
    # Create drug name reference for LLM
    drug_names_section = "\n## DRUG NAME REFERENCE:\n"
    for drug_id, drug_name in drug_name_mapping.items():
        drug_names_section += f"- {drug_id} = **{drug_name}**\n"
    
    # Format context
    context = ""
    for drug_id, data in graph_data.items():
        drug_name = drug_name_mapping.get(drug_id, drug_id)
        context += f"\n## {drug_name} ({drug_id})\n"
        
        if data['interactions']:
            context += f"### Drug-Drug Interactions:\n"
            for i, interaction in enumerate(data['interactions'][:10], 1):
                context += f"{i}. {interaction}\n"
        else:
            context += "### Drug-Drug Interactions: None found\n"
        
        if data['side_effects']:
            context += f"### Side Effects:\n"
            for i, effect in enumerate(data['side_effects'][:15], 1):
                context += f"{i}. {effect}\n"
        else:
            context += "### Side Effects: None found\n"
    
    # Check for mutual interactions
    drug_codes = list(graph_data.keys())
    mutual = []
    for d1 in drug_codes:
        for d2 in drug_codes:
            if d1 != d2 and d2 in graph_data[d1]['interactions']:
                name1 = drug_name_mapping.get(d1, d1)
                name2 = drug_name_mapping.get(d2, d2)
                mutual.append(f"{name1} ({d1}) ↔ {name2} ({d2})")
    
    prompt = f"""You are a pharmaceutical expert. Analyze this drug combination for patient safety.

{drug_names_section}

DRUGS IN COMBINATION:
{', '.join([f"{drug_name_mapping.get(code, code)} ({code})" for code in drug_codes])}

KNOWLEDGE GRAPH DATA FROM DRKG:
{context}

DIRECT INTERACTIONS DETECTED BETWEEN INPUT DRUGS:
{chr(10).join(mutual) if mutual else "No direct interactions found between the input drugs"}

IMPORTANT INSTRUCTIONS:
1. Always use the actual drug names (from DRUG NAME REFERENCE above)
2. Convert all DrugBank IDs to drug names when discussing them
3. Explain interactions in plain language
4. For MESH disease codes (e.g., Disease::MESH:D001986), provide the condition name if you know it

Provide your analysis in this format:

## 1. Drug Names & Overview
List each drug with its common name and therapeutic class

## 2. Overall Safety Assessment
Risk level: [Low/Moderate/High/Critical]
Brief explanation (2-3 sentences)

## 3. Direct Drug-Drug Interactions
Explain interactions between the input drugs (if any)
- Which drugs interact
- Mechanism of interaction
- Clinical significance

## 4. Interactions with Other Medications
Explain notable interactions with commonly prescribed drugs found in the data
- Mention the drug classes involved
- Clinical relevance

## 5. Side Effects Profile
Common side effects for each drug
- Convert MESH codes to condition names when possible
- Note overlapping side effects
- Indicate severity

## 6. Clinical Recommendations
- Can these drugs be taken together safely?
- Required monitoring (e.g., INR for anticoagulants, renal function)
- Timing considerations
- Contraindications

ANALYSIS:"""
    
    print("\nGenerating clinical analysis with LLM...")
    response = llm.invoke(prompt)
    return response.content

def analyze_prescription(medicines: str) -> Dict:
    """Complete end-to-end analysis pipeline with proper drug name tracking"""
    print(f"\n{'='*70}")
    print("PRESCRIPTION ANALYSIS PIPELINE")
    print(f"{'='*70}")
    
    # Step 1: Extract medicines
    drug_list = [m.strip() for m in medicines.split(',') if m.strip()]
    print(f"\n📋 Extracted {len(drug_list)} medicines: {drug_list}")
    
    # Step 2: Normalize
    print(f"\n🔄 Step 1: Normalizing drug names...")
    norm_results = normalize_multiple(drug_list)
    
    # Step 3: Build mapping of normalized_id -> original drug name
    drug_name_mapping = {}
    normalized_ids = []
    
    for original_name, result in norm_results.items():
        if result["success"]:
            normalized_ids.append(result["normalized_id"])
            drug_name_mapping[result["normalized_id"]] = original_name
    
    if not normalized_ids:
        return {
            "error": "No drugs could be normalized",
            "normalization_results": norm_results
        }
    
    print(f"\n✅ Successfully normalized {len(normalized_ids)} drugs")
    print(f"   Mapping: {drug_name_mapping}")
    
    # Step 4: Query graph
    print(f"\n📊 Step 2: Querying knowledge graph...")
    graph_data = query_graph(normalized_ids)
    
    # Step 5: Analyze with LLM (pass drug name mapping)
    print(f"\n🤖 Step 3: Generating clinical analysis...")
    analysis = analyze_with_llm(graph_data, drug_name_mapping)
    
    print(f"\n{'='*70}")
    print("CLINICAL ANALYSIS")
    print(f"{'='*70}\n")
    print(analysis)
    print(f"\n{'='*70}\n")
    
    return {
        "input": medicines,
        "extracted_drugs": drug_list,
        "normalization_results": norm_results,
        "normalized_ids": normalized_ids,
        "drug_name_mapping": drug_name_mapping,
        "graph_data": graph_data,
        "analysis": analysis
    }

In [22]:
result = analyze_prescription("Aspirin, Ibuprofen, Warfarin")


PRESCRIPTION ANALYSIS PIPELINE

📋 Extracted 3 medicines: ['Aspirin', 'Ibuprofen', 'Warfarin']

🔄 Step 1: Normalizing drug names...

Normalizing 3 drugs

Normalizing: Aspirin
  → Trying RxNorm...
    ✓ Found: Compound::DB00945
✅ Normalized: Aspirin → Compound::DB00945 (via RxNorm)

Normalizing: Ibuprofen
  → Trying RxNorm...
    ✓ Found: Compound::DB01050
✅ Normalized: Ibuprofen → Compound::DB01050 (via RxNorm)

Normalizing: Warfarin
  → Trying RxNorm...
    ✓ Found: Compound::DB00682
✅ Normalized: Warfarin → Compound::DB00682 (via RxNorm)

✅ Success: 3/3


✅ Successfully normalized 3 drugs
   Mapping: {'Compound::DB00945': 'Aspirin', 'Compound::DB01050': 'Ibuprofen', 'Compound::DB00682': 'Warfarin'}

📊 Step 2: Querying knowledge graph...
Querying graph for: Compound::DB00945
  ✓ Found 200 interactions
  ✓ Found 123 side effects
Querying graph for: Compound::DB01050
  ✓ Found 200 interactions
  ✓ Found 200 side effects
Querying graph for: Compound::DB00682
  ✓ Found 200 interactions
  