In [None]:
## Context Retrieval
# Code for DDXPlus only
# All classification context (DDXPlus, SymptomsDisease, Symptom2Disease) already retrieved into context.csv file
# DO NOT RUN CODE BELOW

In [None]:
## Env setup
import getpass
import os
from langchain_community.graphs import Neo4jGraph
from langchain_openai import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from dotenv import load_dotenv

# LLM
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
#os.environ["OPENAI_BASE_URL"] = os.getenv('OPENAI_BASE_URL')

llm = ChatOpenAI(model="gpt-4o-mini")

# Neo4j env (lmkg)
url = "neo4j://43.140.200.9:7687"
username ="neo4j"
password = "20230408"
graph = Neo4jGraph(
    url=url,
    username=username,
    password=password,
    refresh_schema=False
)

# CSV env
import sys
from langchain_community.document_loaders.csv_loader import CSVLoader
from pathlib import Path
from langchain_openai import ChatOpenAI,OpenAIEmbeddings
from dotenv import load_dotenv
import pandas as pd

sys.path.append(r'C:\Users\Sin Yee\Desktop\rag_techniques')

In [None]:
## Query & Knowledge Agent Graph
## Helper utilities
from typing import List, Optional, Literal, TypedDict, Union
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.types import Command
from langchain_core.messages import HumanMessage, trim_messages

# Query agent
class QState(MessagesState):
    next: str
    planner_executed: bool = False

# Knowledge agent
class State(MessagesState):
    next: str
    original_question: str = ""
    references: dict = {}
    completed_agents: list = []
    excluded_agents: list = []

def make_supervisor_node(llm: BaseChatModel, members: list[str], agent_scopes: dict = None) -> callable:
    """
    Create supervisor node that only routes between agents for reference retrieval
    """
    if agent_scopes is None:
        agent_scopes = {
            "lmkg": "General medical queries using knowledge graph: diseases, exams, indicators, symptoms, complications etc.",
            "hkg": "For questions on symptom.",
            "ds": "For questions on symptoms or treatments.",
            "primekg": "To retrieve drugs indicated for diseases.",
            "drugreviews": "To retrieve drugs along with patient reviews.",
            "wiki": "For disease definitions and overviews.",
            "mayoclinic": "For clinical info: causes, treatments, symptoms, complications.",
            "llmself": "Fallback only if others are irrelevant."
        }
    
    def supervisor_node(state: State) -> Command:
        """Routes to agents for reference retrieval only."""
        # Get the original question
        if not state.get("original_question"):
            original_question = next(
                msg.content for msg in state["messages"] 
                if msg.type == "human"
            )
        else:
            original_question = state["original_question"]
        
        completed_agents = state.get("completed_agents", [])
        references = state.get("references", {})
        excluded_agents = state.get("excluded_agents", [])

        # Fallback REPLACEMENT logic
        if "llmself" not in completed_agents and "llmself" in members:
            for agent in completed_agents:
                ref = references.get(agent, "")
                # Check for both: No information or empty cypher retrieved_result
                is_no_info = ref == "No information retrieved"
                is_empty_cypher = (
                    isinstance(ref, dict)
                    and "generated_cypher" in ref  # identify as cypher_reference
                    and (
                        not ref.get("retrieved_result")
                        or (isinstance(ref.get("retrieved_result"), list) and not any(str(r).strip() for r in ref["retrieved_result"]))
                    )
                )
                if is_no_info or is_empty_cypher:
                    print(f"⚠️ Replacing low-quality agent '{agent}' with fallback agent 'llmself'.")
                    
                    # Remove low-quality agent
                    completed_agents.remove(agent)
                    references.pop(agent, None)
                    excluded_agents.append(agent)  # Mark as excluded
                    
                    return Command(goto="llmself", update={
                        "next": "llmself",
                        "original_question": original_question,
                        "completed_agents": completed_agents,
                        "excluded_agents": excluded_agents
                    })

        # Check if we have references from all agents - if so, finish.
        if len(completed_agents) >= 6:
            return Command(goto=END, update={
                "next": "FINISH",
                "original_question": original_question
            })
        
        available_agents = [agent for agent in members]
        remaining_agents = [agent for agent in available_agents 
                            if agent not in completed_agents and agent not in excluded_agents]
        
        # Check if all agents have been completed
        if not remaining_agents:
            return Command(goto=END, update={
                "next": "FINISH",
                "original_question": original_question
            })
        
        # Create agent scope descriptions for remaining agents
        remaining_agent_scopes = {agent: agent_scopes.get(agent, "General purpose agent") 
                                for agent in remaining_agents}
        
        # Create options with only remaining agents
        options = ["FINISH"] + remaining_agents
        
        # System prompt for agent selection
        system_prompt_updated = (
            f"You are a supervisor selecting agents for medical reference retrieval. "
            f"Choose one agent at a time based on query relevance.\n\n"
            f"Agent options and specialties:\n"
        )
        
        for agent, scope in remaining_agent_scopes.items():
            system_prompt_updated += f"- {agent}: {scope}\n"
        
        system_prompt_updated += (
            f"\nCompleted agents: {completed_agents} (Target: 6 agents)\n"
            f"Available agents: {remaining_agents}\n\n"
            f"GOAL: Select up to 6 MOST RELEVANT agents.\n"
            f"Progress: {len(completed_agents)}/6 done\n\n"
            f"Instructions:\n"
            f"1. If 6 agents are done, select FINISH\n"
            f"2. Choose the most relevant agent from remaining\n"
            f"3. Prioritize by question type:\n"
            f"   - Symptoms: hkg > ds > mayoclinic > wiki > lmkg > drugreviews\n"
            f"   - Drugs/Medications: drugreviews > primekg \n"
            f"   - Complications: lmkg > mayoclinic\n"
            f"   - Fallback (only if others unsuitable): llmself\n"
            f"4. Pick the next best from: {remaining_agents}\n\n"
            f"DO NOT select any agents not listed in 'Valid responses' below.\n\n"
            f"Valid responses: {options}\n"
            f"Use 6 best agents, then FINISH."
        )

        messages = [
            {"role": "system", "content": system_prompt_updated},
            {"role": "user", "content": f"Question: {original_question}\n\nSelect the most relevant agent:"}
        ]
        
        class Router(TypedDict):
            """Worker to route to next. Select the most relevant agent for reference retrieval."""
            next: str
            
        response = llm.with_structured_output(Router).invoke(messages)
        goto = response["next"]
        
        print(f"Question: {original_question}")
        print(f"Remaining agents: {remaining_agents}")
        print(f"LLM selected: {goto}")
        print(f"Valid options: {options}")
        
        # Validate that the response is in our options
        if goto not in options:
            print(f"Warning: Invalid selection '{goto}'. Defaulting to FINISH.")
            goto = "FINISH"
            
        if goto == "FINISH":
            goto = END
            
        return Command(goto=goto, update={
            "next": goto, 
            "original_question": original_question
        })
    
    return supervisor_node

## Tools (remain the same for retrieval)
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate 
from langchain_core.prompts import ChatPromptTemplate

def planner_tool():
    planner_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a query planning agent for a medical knowledge system. 
        Determine if the user's question is:
        
        1. "single-step" - Direct factual questions
           e.g., "What is diabetes?", "What are symptoms of flu?", "Which drug treats malaria?"
        
        2. "multi-step" - Complex questions needing multiple facts or reasoning
           e.g., "Compare treatment options for diabetes vs hypertension", 
           "What's the best treatment for hypertension in a diabetic patient with renal impairment?"
        
        Respond with either "single-step" or "multi-step"."""),
        ("human", "{input}")
    ])
    planner_chain = LLMChain(llm=llm, prompt=planner_prompt)

    def planner_agent(question: str):
        return planner_chain.invoke({"input": question})['text']
    return planner_agent

## Tools for reference retrieval only
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain.chains import RetrievalQA # vector
from langchain.chains import GraphCypherQAChain # cypher
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate 

def lmkg_retrieval_tool():
    # A) Tool's helper function
    def build_schema(graph):
        node_info = graph.query("""
            MATCH (n)
            WITH labels(n) AS label_list, keys(n) AS props
            UNWIND label_list AS label
            RETURN DISTINCT label, collect(DISTINCT props) AS property_keys
        """)
        
        rel_info = graph.query("""
            MATCH ()-[r]->()
            RETURN DISTINCT type(r) AS rel_type, keys(r) AS property_keys
        """)

        # Build string
        schema = "Node properties are the following:\n"
        for row in node_info:
            label = row['label']
            props = ', '.join(row['property_keys'][0]) if row['property_keys'] else ''
            schema += f"{label} {{{props}}}\n"

        schema += "\nThe relationships are the following:\n"
        for row in rel_info:
            rel_type = row['rel_type']
            schema += f"(:X)-[:{rel_type}]->(:Y)\n"

        return schema

    # B) Vector search
    lmkg_vector = FAISS.load_local(
        "lmkg_faiss",
        OpenAIEmbeddings(model='text-embedding-3-small'),
        allow_dangerous_deserialization=True)

    from typing import Union, Dict, Any
    
    def vector_search(input_data: Union[str, Dict[str, Any]]) -> dict:
        """Returns retrieved documents and final answer."""
        question = input_data if isinstance(input_data, str) else input_data.get("input", "")
        if not question:
            return {"error": "No question provided"}

        # Retrieve documents
        retrieved_docs = lmkg_vector.similarity_search(question, k=4)
        retrieved_context = [doc.page_content for doc in retrieved_docs]

        return {
            "query": question,
            "retrieved_context": retrieved_context
        }


    # C) Cypher search
    CYPHER_GENERATION_TEMPLATE = """
    Task: Generate precise Cypher query to answer the question:
    {question}

    Requirements:
    1. MUST start from node ID: {node_id} using WHERE id(n) = {node_id}
    2. MUST wrap labels/relationships with spaces in backticks.  
    E.g., [:`Active Ingredient`], [:Complication]
    3. Use only ONE relationship type
    4. Return ONLY what's needed to answer the question
    5. Use this schema:

    Schema:
    {schema}

    Return ONLY the executable Cypher query with no additional text.
    """

    CYPHER_GENERATION_PROMPT = PromptTemplate(
        input_variables=["schema", "question", "node_id"], template=CYPHER_GENERATION_TEMPLATE
    )

    cypher_chain = GraphCypherQAChain.from_llm(
        cypher_llm = llm,
        qa_llm = llm, graph=graph, verbose=True,
        cypher_prompt=CYPHER_GENERATION_PROMPT,
        return_direct=True, # bypass qa_llm
        return_intermediate_steps=True, # return cypher query & retrieved context
        allow_dangerous_requests=True
    )

    # insert graph_schema
    cypher_chain.graph_schema = build_schema(graph)
    schema = build_schema(graph)

    # D) LMKG Agent
    # Use vector search & cypher query to explore KG
    from langchain.agents import Tool, AgentExecutor, create_openai_functions_agent
    from langchain import hub
    from typing import Dict, Any, Union

    # Wrap cypher chain in a function with proper input handling
    def cypher_search(input_data: Union[str, Dict[str, Any]]) -> dict:
        """Executes Cypher and returns query, raw result, and LLM answer."""
        try:
            question = input_data if isinstance(input_data, str) else input_data.get("input", "")
            if not question:
                return {"error": "No question provided"}

            similar_nodes = lmkg_vector.similarity_search(question, k=1)
            if not similar_nodes:
                return {"error": "No matching nodes found"}

            node_id = similar_nodes[0].metadata['node_id']
            print(f"Retrieved node id: {node_id}")

            # Inject into GraphCypherQAChain
            cypher_response = cypher_chain({
                "query": question,
                "schema": schema,
                "node_id": node_id
            })

            print("cypher_response:")
            print(cypher_response)

            intermediate = cypher_response.get("intermediate_steps", [])
            
            generated_cypher = next((step["query"] for step in intermediate if "query" in step), None)
            retrieved_context = cypher_response.get("result", [])

            return {
                "query": question,
                "node_id": node_id,
                "generated_cypher": generated_cypher,
                "retrieved_result": retrieved_context
            }

        except Exception as e:
            return {"error": str(e)}

    tools = [
        Tool(
            name="med_vector_qa",
            func=vector_search,
            description="""Use for open-ended, semantic, or fuzzy medical questions requiring
            contextual understanding from vector-based retrieval.
            Examples:
            "What is ADHD?"
            "Explain the mechanism of diabetes."
            "Long-term effects of COVID-19?"
            Only use this tool once.
            """,
        ),
        Tool(
            name="med_cypher_qa",
            func=cypher_search,
            description=""""Use for specific, factual medical queries (about diseases, symptoms, complications, 
            affected organs, treatments, drugs etc.) using structured graph data. 
            Examples:
            "What is the symptom of ADHF?"
            "Organs affected by HIV?"
            "Complications of diabetes?"
            "Drug for tuberculosis?"
            Only use this tool once.
            """,
        ),
    ]

    prompt = hub.pull("hwchase17/openai-functions-agent")
    agent = create_openai_functions_agent(llm, tools, prompt)

    # Create agent executor
    lmkg_agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
    return lmkg_agent_executor, lmkg_vector

def hkg_retrieval_tool():
    """HKG tool for reference retrieval only"""
    hkg_vector = FAISS.load_local(
        "faiss_hkg",
        OpenAIEmbeddings(model='text-embedding-ada-002'),
        allow_dangerous_deserialization=True)
    return hkg_vector

def ds_retrieval_tool():
    ds_vector = FAISS.load_local(
        "diseases_symptoms_faiss",
        OpenAIEmbeddings(model='text-embedding-ada-002'),
        allow_dangerous_deserialization=True)
    return ds_vector

def primekg_retrieval_tool():
    """PrimeKG tool for reference retrieval only"""
    primekg_vector = FAISS.load_local(
        "primekg_faiss",
        OpenAIEmbeddings(model='text-embedding-ada-002'),
        allow_dangerous_deserialization=True)
    return primekg_vector

def drugreviews_retrieval_tool():
    drugreviews_vector = FAISS.load_local(
        "drug_reviews_faiss",
        OpenAIEmbeddings(model='text-embedding-ada-002'),
        allow_dangerous_deserialization=True)
    return drugreviews_vector

def wiki_retrieval_tool():
    """Wiki tool retrieval"""
    from openai import OpenAI

    client = OpenAI(api_key=os.environ["DEEPSEEK_API_KEY"], base_url=os.environ["DEEPSEEK_BASE_URL"])

    def extract_entity_from_query(query):
        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[{"role": "user", "content": f"What is the main entity or concept in this query: '{query}'? Extract only the core term (e.g., 'hyponatremia', not 'causes of hyponatremia'). Return only the entity name in one line."}]
        )
        entity = response.choices[0].message.content.strip()
        print(f'Extracted entity: {entity}')
        return entity.replace(" ", "_")
    
    return extract_entity_from_query

def mayoclinic_retrieval_tool():
    """MayoClinic tool retrieval"""
    from openai import OpenAI

    client = OpenAI(api_key=os.environ["DEEPSEEK_API_KEY"], base_url=os.environ["DEEPSEEK_BASE_URL"])

    def extract_entity_from_query(query):
        response = client.chat.completions.create(
            model="deepseek-chat",
            messages=[{"role": "user", "content": f"What is the main entity or concept in this query: '{query}'? Extract only the core term (e.g., 'hyponatremia', not 'causes of hyponatremia'). Return only the entity name in one line."}]
        )
        entity = response.choices[0].message.content.strip()
        print(f'Extracted entity: {entity}')
        return entity
    
    return extract_entity_from_query

def llmself_tool():
    knowledge_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a knowledgeable AI. Answer concisely and precisely using only your internal knowledge. No external context."),
        ("human", "{input}")
    ])
    knowledge_chain = LLMChain(llm=llm, prompt=knowledge_prompt)

    def knowledge_agent(question: str):
        """Agent that answers using only its internal knowledge"""
        return knowledge_chain.invoke({"input": question})['text']
    return knowledge_agent

## Worker nodes (Query Agent)
def planner_node(state: QState):
    question = next(
        (msg.content for msg in reversed(state["messages"]) 
         if msg.type == "human" and not hasattr(msg, 'name')), 
        None
    )
    
    if question is None:
        question = next(
            (msg.content for msg in state["messages"] if msg.type == "human"),
            "No question found"
        )
    
    planner_agent_executor = planner_tool()
    result = planner_agent_executor(question)
    return {
        "messages": [
            HumanMessage(content=result, name="planner")
        ],
        "planner_executed": True
    }

def multi_step_node(state: QState):
    original_question = None
    for msg in state["messages"]:
        if msg.type == "human" and not hasattr(msg, 'name'):
            original_question = msg.content
            break
    
    if not original_question:
        for msg in state["messages"]:
            if msg.type == "human":
                original_question = msg.content
                break
    
    if not original_question:
        original_question = "No question found"
    
    subquery_prompt = f"""
    Original query: {original_question}

    \nBreak the medical query into exactly two sub-queries. 
    Each sub-query should: 
    Be SIMPLE and SPECIFIC; 
    Focus on a different, narrow aspect of the original query;
    Avoid long or reasoning-based formulations. 
    
    \nFormat your response as: 
    1. [sub-query 1] 
    2. [sub-query 2]
    """
    
    result = llm.invoke(subquery_prompt)
    
    return {
        "messages": [
            HumanMessage(content=result.content, name="subquery")
        ]
    }

def query_agent_supervisor_node(state: QState):
    return {"messages": state["messages"]}

## Worker nodes (Knowledge Agent) - Modified for reference retrieval only
def lmkg_node(state: State):
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    lmkg_agent_executor, lmkg_retriever = lmkg_retrieval_tool()
    response = lmkg_agent_executor.invoke({"input": question})
    result = response['output']

    # Get intermediate_steps from response
    intermediate_steps = response.get("intermediate_steps", [])
    
    def extract_reference_from_steps(intermediate_steps):
        """Extract reference information from agent steps, prioritizing non-empty results"""
        vector_reference = None
        cypher_reference = None
        cypher_has_results = False
        
        for step in intermediate_steps:
            if isinstance(step, tuple) and len(step) >= 2:
                tool_name = step[0].tool if hasattr(step[0], 'tool') else str(step[0])
                tool_result = step[1]
                
                if isinstance(tool_result, dict):
                    # Handle vector_qa tool
                    if 'med_vector_qa' in tool_name or 'vector' in tool_name.lower():
                        retrieved_context = tool_result.get("retrieved_context", [])
                        if retrieved_context and any(ctx.strip() for ctx in retrieved_context):
                            vector_reference = retrieved_context
                    
                    # Handle cypher_qa tool
                    if 'med_cypher_qa' in tool_name or 'cypher' in tool_name.lower():
                        generated_cypher = tool_result.get("generated_cypher", "")
                        retrieved_result = tool_result.get("retrieved_result", "")
                        
                        # Check if cypher has meaningful results
                        if retrieved_result and retrieved_result != [] and any(str(item).strip() for item in (retrieved_result if isinstance(retrieved_result, list) else [retrieved_result])):
                            cypher_has_results = True
                        
                        # Store cypher information regardless of whether result is empty
                        cypher_reference = {
                            "generated_cypher": generated_cypher or "N/A",
                            "retrieved_result": retrieved_result
                        }

        
        # Priority logic:
        # 1. If both tools used and cypher has results → return cypher reference
        # 2. If both tools used but cypher has no results → return vector reference
        # 3. If only cypher used → return cypher reference (even if empty)
        # 4. If only vector used → return vector reference
        # 5. If neither used → return "No information retrieved"
        
        if cypher_reference and vector_reference:  # Both tools were used
            if cypher_has_results:
                return cypher_reference  # Cypher has results, use it
            else:
                return vector_reference  # Cypher empty, use vector
        elif cypher_reference:  # Only cypher was used
            return cypher_reference
        elif vector_reference:  # Only vector was used
            return vector_reference
        else:  # No tools were used
            return "No information retrieved"
    
    reference = extract_reference_from_steps(intermediate_steps)

    # Update references and mark as completed    
    references = state.get("references", {})
    references["lmkg"] = reference

    completed_agents = state.get("completed_agents", [])
    if "lmkg" not in completed_agents:
        completed_agents.append("lmkg")
    
    return {
        "messages": [
            HumanMessage(content=f"LMKG reference retrieved", name="lmkg")
        ],
        "completed_agents": completed_agents,
        "references": references  # Fixed: was "reference" instead of "references"
    }

def hkg_node(state: State):
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # Only retrieve reference, no LLM processing
    hkg_retriever = hkg_retrieval_tool()
    response = hkg_retriever.similarity_search(question, k=1)
    reference = response[0].page_content if response else "No reference found"
    
    references = state.get("references", {})
    references["hkg"] = reference

    completed_agents = state.get("completed_agents", [])
    if "hkg" not in completed_agents:
        completed_agents.append("hkg")
    
    return {
        "messages": [
            HumanMessage(content=f"HKG reference retrieved", name="hkg")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

def ds_node(state: State):
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # Only retrieve reference, no LLM processing
    ds_retriever = ds_retrieval_tool()
    response = ds_retriever.similarity_search(question, k=1)
    reference = response[0].page_content if response else "No reference found"
    
    references = state.get("references", {})
    references["ds"] = reference

    completed_agents = state.get("completed_agents", [])
    if "ds" not in completed_agents:
        completed_agents.append("ds")
    
    return {
        "messages": [
            HumanMessage(content=f"DS reference retrieved", name="ds")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

def primekg_node(state: State):
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # Only retrieve reference, no LLM processing
    primekg_retriever = primekg_retrieval_tool()
    response = primekg_retriever.similarity_search(question, k=1)
    reference = response[0].page_content if response else "No reference found"
    
    references = state.get("references", {})
    references["primekg"] = reference

    completed_agents = state.get("completed_agents", [])
    if "primekg" not in completed_agents:
        completed_agents.append("primekg")
    
    return {
        "messages": [
            HumanMessage(content=f"PrimeKG reference retrieved", name="primekg")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

def drugreviews_node(state: State):
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # Only retrieve reference, no LLM processing
    drugreviews_retriever = drugreviews_retrieval_tool()
    response = drugreviews_retriever.similarity_search(question, k=1)
    reference = response[0].page_content if response else "No reference found"
    
    references = state.get("references", {})
    references["drugreviews"] = reference

    completed_agents = state.get("completed_agents", [])
    if "drugreviews" not in completed_agents:
        completed_agents.append("drugreviews")
    
    return {
        "messages": [
            HumanMessage(content=f"drugreviews reference retrieved", name="drugreviews")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

def wiki_node(state: State):
    from wiki_crawler import crawl_wikipedia_entity
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # Wiki data retrieval
    query_entity = wiki_retrieval_tool()
    entity = query_entity(question)
    wiki_text = crawl_wikipedia_entity(entity, articles_limit=1)

    # Store as reference
    reference = wiki_text[:1500]
    references = state.get("references", {})
    references["wiki"] = reference

    completed_agents = state.get("completed_agents", [])
    if "wiki" not in completed_agents:
        completed_agents.append("wiki")
    
    return {
        "messages": [
            HumanMessage(content=f"Wiki reference retrieved", name="wiki")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

def mayoclinic_node(state: State):
    from mayoclinic_symptom_crawler import crawl_mayoclinic_entity
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # MayoClinic data retrieval
    query_entity = mayoclinic_retrieval_tool()
    entity = query_entity(question)
    mayoclinic_text = crawl_mayoclinic_entity(entity)

    # Store as reference
    reference = mayoclinic_text
    references = state.get("references", {})
    references["mayoclinic"] = reference

    completed_agents = state.get("completed_agents", [])
    if "mayoclinic" not in completed_agents:
        completed_agents.append("mayoclinic")
    
    return {
        "messages": [
            HumanMessage(content=f"Mayo Clinic reference retrieved", name="mayoclinic")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

def llmself_node(state: State):
    question = state.get("original_question") or next(
        msg.content for msg in state["messages"] 
        if msg.type == "human"
    )
    
    # Generate LLM answer using internal knowledge
    llmself_agent_executor = llmself_tool()
    result = llmself_agent_executor(question)
    
    # Store the LLM-generated answer directly as the reference
    references = state.get("references", {})
    references["llmself"] = result  # Store the actual answer as reference

    completed_agents = state.get("completed_agents", [])
    if "llmself" not in completed_agents:
        completed_agents.append("llmself")
    
    return {
        "messages": [
            HumanMessage(content=f"LLM Self Agent completed", name="llmself")
        ],
        "completed_agents": completed_agents,
        "references": references
    }

# Create supervisor and verifier nodes
knowledge_agent_supervisor_node = make_supervisor_node(
    llm, 
    ["lmkg", "hkg", "ds", "primekg", "drugreviews", "wiki", "mayoclinic", "llmself"],
    agent_scopes={
        "lmkg": "General medical queries using knowledge graph: diseases, exams, indicators, symptoms, complications etc.",
        "hkg": "For questions on symptom.",
        "ds": "For questions on symptoms or treatments.",
        "primekg": "To retrieve drugs indicated for diseases.",
        "drugreviews": "To retrieve drugs along with patient reviews.",
        "wiki": "For disease definitions and overviews.",
        "mayoclinic": "For clinical info: causes, treatments, symptoms, complications.",
        "llmself": "Fallback only if others are irrelevant."
    }
)

# Create graph (Query Agent) - Modified to bypass supervisor
# Modified planner_node to handle single-step directly
def planner_node(state: QState):
    question = next(
        (msg.content for msg in reversed(state["messages"]) 
         if msg.type == "human" and not hasattr(msg, 'name')), 
        None
    )
    
    if question is None:
        question = next(
            (msg.content for msg in state["messages"] if msg.type == "human"),
            "No question found"
        )
    
    planner_agent_executor = planner_tool()
    result = planner_agent_executor(question)
    
    # Check if it's single-step and preserve the original question
    if "single-step" in result.lower():
        return {
            "messages": [
                HumanMessage(content=result, name="planner"),
                HumanMessage(content=question, name="single_step")  # Preserve original question
            ],
            "planner_executed": True
        }
    else:
        return {
            "messages": [
                HumanMessage(content=result, name="planner")
            ],
            "planner_executed": True
        }

query_builder = StateGraph(QState)
query_builder.add_node("planner", planner_node)
query_builder.add_node("multi_step", multi_step_node)

# Start directly with planner
query_builder.add_edge(START, "planner")

# Conditional edges from planner to multi_step or single_step
query_builder.add_conditional_edges(
    "planner",
    lambda state: determine_planner_decision(state),
    {
        "single_step": END,
        "multi_step": "multi_step"
    }
)

query_builder.add_edge("multi_step", END)

def determine_planner_decision(state: QState):
    planner_message = None
    for msg in reversed(state["messages"]):
        if hasattr(msg, 'name') and msg.name == "planner":
            planner_message = msg
            break
    
    if planner_message:
        planner_decision = planner_message.content.strip().lower()
        
        if "single-step" in planner_decision:
            return "single_step"
        elif "multi-step" in planner_decision:
            return "multi_step"
    
    # Default fallback
    return "single_step"

query_agent_graph = query_builder.compile()

## Create graph (Knowledge Agent)
knowledge_builder = StateGraph(State)
knowledge_builder.add_node("knowledge_supervisor", knowledge_agent_supervisor_node)
knowledge_builder.add_node("lmkg", lmkg_node)
knowledge_builder.add_node("hkg", hkg_node)
knowledge_builder.add_node("ds", ds_node)
knowledge_builder.add_node("primekg", primekg_node)
knowledge_builder.add_node("drugreviews", drugreviews_node)
knowledge_builder.add_node("wiki", wiki_node)
knowledge_builder.add_node("mayoclinic", mayoclinic_node)
knowledge_builder.add_node("llmself", llmself_node)

# Start with supervisor
knowledge_builder.add_edge(START, "knowledge_supervisor")

# Conditional edges FROM supervisor to workers and end
knowledge_builder.add_conditional_edges(
    "knowledge_supervisor",
    lambda state: state["next"],
    {
        "lmkg": "lmkg",
        "hkg": "hkg",
        "primekg": "primekg", 
        "wiki": "wiki",
        "mayoclinic": "mayoclinic",
        "llmself": "llmself",
        "ds": "ds",
        "drugreviews": "drugreviews",
        "FINISH": END,
        END: END
    }
)

# Unconditional edges FROM workers back TO supervisor
knowledge_builder.add_edge("lmkg", "knowledge_supervisor")
knowledge_builder.add_edge("hkg", "knowledge_supervisor")
knowledge_builder.add_edge("ds", "knowledge_supervisor")
knowledge_builder.add_edge("primekg", "knowledge_supervisor")
knowledge_builder.add_edge("drugreviews", "knowledge_supervisor")
knowledge_builder.add_edge("wiki", "knowledge_supervisor")
knowledge_builder.add_edge("mayoclinic", "knowledge_supervisor")
knowledge_builder.add_edge("llmself", "knowledge_supervisor")

knowledge_agent_graph = knowledge_builder.compile()

In [None]:
## RAG Chain
## Helper function (linker)
def process_final_message(final_message):
    stripped_message = final_message.strip()
    
    if '\n' in stripped_message:
        return [line.strip() for line in stripped_message.split('\n') if line.strip()]
    else:
        return [stripped_message]

## RAG Function
def rag(test_query):

    print("="*50)
    print(f"[Test query]: {test_query}")

    # Query Agent
    message_contents = []
    for s in query_agent_graph.stream(
        {"messages": [("user", test_query)]},
        {"recursion_limit": 100},
    ):
        
        if 'messages' in str(s):
            for key, value in s.items():
                if isinstance(value, dict) and 'messages' in value:
                    for msg in value['messages']:
                        content = msg.content
                        message_contents.append(content)

    final_message = message_contents[-1]
    message_list = process_final_message(final_message)

    # Knowledge Agent
    retrieved_references = {}

    for msg in message_list:
        final_state = None
        
        for s in knowledge_agent_graph.stream(
            {"messages": [("user", msg)]},
            {"recursion_limit": 100}):
            
            for key, value in s.items():
                if isinstance(value, dict):
                    if 'references' in value:
                        if final_state is None:
                            final_state = {}
                        final_state.update(value)

        # Print final results
        if final_state:
            if 'references' in final_state:
                retrieved_references[msg] = final_state['references']

    print(f"\n[Retrieved reference]: {retrieved_references}")
    return retrieved_references

In [None]:
## format_query
import pandas as pd

df = pd.read_csv("../new_data/ddxplus_context.csv")
ques = df.loc[1]
def format_query(ques):
    question = ques['disease']
    query = f"What are the symptoms of '{question}'?"
    return query

query = format_query(ques)
print(query)

In [None]:
## Retrieve contexts for 10 disease types
import json

df = pd.read_csv("../new_data/ddxplus_context.csv")
for i in range(len(df)):
    try:
        query = format_query(df.loc[i])
        if pd.notna(query):
            result = rag(query)
            result_str = json.dumps(result, ensure_ascii=False) # convert dict to string
            df.loc[i, "reference"] = result_str
            df.to_csv("../new_data/ddxplus_context.csv", index=False, encoding="utf-8-sig")

    except Exception as e:
        print(f"Error at row {i}: {e}")
        continue

In [None]:
## Create knowledge base for 10 disease types (sources separated, filtered, no llmself)
## FAISS vector store
import pandas as pd
import json
from langchain.docstore.document import Document

# Step 1: Load CSV
df = pd.read_csv("../new_data/ddxplus_context.csv")

documents = []

# Step 2: Process each row
for _, row in df.iterrows():
    diagnosis = row['disease']
    reference_json_str = row['reference_no_llm']
    
    try:
        reference_dict = json.loads(reference_json_str)
    except json.JSONDecodeError:
        print(f"Failed to parse JSON for diagnosis: {diagnosis}")
        continue

    # Step 3: Extract nested sources
    for question, sources in reference_dict.items():
        for source_name, content in sources.items():
            if isinstance(content, str):
                documents.append(Document(
                    page_content=content,
                    metadata={"diagnosis": diagnosis, "source": source_name}
                ))
            elif isinstance(content, list):
                for item in content:
                    documents.append(Document(
                        page_content=item,
                        metadata={"diagnosis": diagnosis, "source": source_name}
                    ))
            elif isinstance(content, dict):
                doc_text = "\n".join([f"{k}: {v}" for k, v in content.items()])
                documents.append(Document(
                    page_content=doc_text,
                    metadata={"diagnosis": diagnosis, "source": source_name}
                ))
            else:
                print(f"Unknown content type for {source_name} in {diagnosis}")

## Create knowledge base for 22 disease types
# initiate faiss vector store and openai embedding
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
index = faiss.IndexFlatL2(len(OpenAIEmbeddings().embed_query(" ")))
vector_store = FAISS(
    embedding_function=OpenAIEmbeddings(),
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={}
)

vector_store.add_documents(documents)
save_path = "./ddxplus_faiss"
vector_store.save_local(save_path)

In [None]:
## Load FAISS (sources separated, filtered, no llmself)
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
# old ver: ddxplus_test_2_faiss
vector_store = FAISS.load_local(
    "ddxplus_faiss",
    OpenAIEmbeddings(),
    allow_dangerous_deserialization=True
)

num_documents = len(vector_store.docstore._dict)
print(f"FAISS index loaded with {num_documents} documents")

In [None]:
# Function for Ref + LLMSelf (gpt)
import os
from openai import OpenAI

client = OpenAI()
def reference_llmself(query: str, vector_store):
    # 1 top similar ref from knowledge base + 1 llmself ans
    options = "Anemia, Boerhaave, Cluster headache, GERD, Influenza, Myocarditis, Panic attack, Pericarditis, Pneumonia, Sarcoidosis"
    # Step 1: Run vector similarity search using only the symptom query
    results = vector_store.similarity_search(query, k=1)
    if not results:
        return {query: {"llmself": "No relevant information found."}}

    top_doc = results[0]
    source = top_doc.metadata.get("source", "unknown")
    content = top_doc.page_content.strip()

    # Step 2: Run LLM using query + options
    llm_prompt = (
        "You are a Medical Diagnosis AI. Your task is to determine the most likely diagnosis based on the patient's symptom description.\n\n"
        f"options: {options}\n\n"
        f"Patient Description: {query}\n\n"
        "Select the most appropriate diagnosis from the options provided, and briefly explain your reasoning based on the described symptoms.\n"
        "Be concise, medically accurate, and focus only on the symptoms mentioned."
    )

    print("=" * 50)
    print(f'[Test query]: {query}')

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{
            "role": "user",
            "content": llm_prompt
        }]
    )
    llm_result = response.choices[0].message.content.strip()

    # Step 3: Combine and return
    reference_dict = {
        source: content,
        "llmself": llm_result
    }

    return {query: reference_dict}

In [None]:
# run ref + llmself (gpt)
import json

df = pd.read_csv("../new_data/ddxplus_400_result.csv")
for i in range(len(df)):
    try:
        query = df.loc[i, "EVIDENCES"]
        if pd.notna(query):
            result = reference_llmself(query, vector_store)
            result_str = json.dumps(result, ensure_ascii=False)
            df.loc[i, "gpt_reference"] = result_str
            df.to_csv("../new_data/ddxplus_400_result.csv", index=False, encoding="utf-8-sig")

    except Exception as e:
        print(f"Error at row {i}: {e}")
        continue

In [None]:
# Function for Ref + LLMSelf (ds)
import os
from openai import OpenAI

# Initialize DeepSeek client
os.environ["DEEPSEEK_API_KEY"] = os.getenv('DEEPSEEK_API_KEY')
os.environ["DEEPSEEK_BASE_URL"] = os.getenv('DEEPSEEK_BASE_URL')
client = OpenAI(api_key=os.environ["DEEPSEEK_API_KEY"], base_url=os.environ["DEEPSEEK_BASE_URL"])

def reference_llmself(query: str, vector_store):
    # 1 top similar ref from knowledge base + 1 llmself ans
    options = "Anemia, Boerhaave, Cluster headache, GERD, Influenza, Myocarditis, Panic attack, Pericarditis, Pneumonia, Sarcoidosis"
    # Step 1: Run vector similarity search using only the symptom query
    results = vector_store.similarity_search(query, k=1)
    if not results:
        return {query: {"llmself": "No relevant information found."}}

    top_doc = results[0]
    source = top_doc.metadata.get("source", "unknown")
    content = top_doc.page_content.strip()

    # Step 2: Run LLM using query + options
    llm_prompt = (
        "You are a Medical Diagnosis AI, answer with your internal knowledge.\n\n"
        "Select the most suitable diagnosis based on the symptoms, and briefly explain your reasoning in max 3 sentences.\n\n"
        f"Question: {query}\n\n"
        f"Options: {options}\n\n"
    )

    print("=" * 50)
    print(f'[Test query]: {query}')

    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[{
            "role": "user",
            "content": llm_prompt
        }]
    )
    llm_result = response.choices[0].message.content.strip()

    # Step 3: Combine and return
    reference_dict = {
        source: content,
        "llmself": llm_result
    }
    
    return {query: reference_dict}

In [None]:
# run ref + llmself (ds)
import json

df = pd.read_csv("../new_data/ddxplus_400_result.csv")
for i in range(len(df)):
    try:
        query = df.loc[i, "EVIDENCES"]
        if pd.notna(query):
            result = reference_llmself(query, vector_store)
            result_str = json.dumps(result, ensure_ascii=False)
            df.loc[i, "ds_reference"] = result_str
            df.to_csv("../new_data/ddxplus_400_result.csv", index=False, encoding="utf-8-sig")

    except Exception as e:
        print(f"Error at row {i}: {e}")
        continue