<a href="https://colab.research.google.com/github/IyadSultan/IyadSultan/blob/main/Multiagent_ICD10_extractor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Environment Setup

First, we need to install the necessary libraries: `langchain`, `langgraph`, `langchain-openai`, and `langchain-community`.

In [18]:
%pip install -qU langchain langgraph langchain-openai langchain-community faiss-cpu rank_bm25

Next, we'll set up the API keys and environment variables. You'll need to add your OpenAI API key to the Colab secrets manager under the name `OPENAI_API_KEY`.

In [19]:
import os
from google.colab import userdata

os.environ["OPENAI_API_KEY"] = userdata.get("OPENAI_API_KEY")

Now, let's import the necessary modules for state management, agents, and visualization.

In [20]:
from typing import Annotated, Sequence, TypedDict

from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool

from langgraph.graph import END, StateGraph

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever # Import BM25Retriever

Finally, we'll define the default model and set up the vector database.

In [21]:
llm = ChatOpenAI(model="gpt-4o-mini")
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(["", ""], embeddings) # Initialize with dummy data
retriever = vectorstore.as_retriever()

### Step 2: State Graph Schema Definition

Now, we'll define the state of our graph. This `TypedDict` will be the schema for the state that is passed between the nodes in the graph.

In [22]:
from typing import List, Annotated, TypedDict, Literal
from langchain_core.messages import BaseMessage
from langgraph.graph import StateGraph

# Define the reducer function for messages
def add_messages(left: List[BaseMessage], right: List[BaseMessage]) -> List[BaseMessage]:
    """Reducer for concatenating lists of messages."""
    return left + right

class AgentState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        messages: A list of messages exchanged in the conversation.
        user_input: The initial input from the user. #user_input example: "What is the weather like today?"
        agent_response: The response generated by an agent. #agent_response example: "The weather is sunny."
        tool_calls: A list of tool calls made by agents. #tool_calls example: [{'tool_name': 'weather_tool', 'tool_args': {'location': 'London'}}]
        final_output: The compiled final response. #final_output example: "The final answer is: The final answer is: The weather in London is sunny."
        debug_info: Information for debugging at different logging levels. #debug_info example: {'level': 'info', 'message': 'Processing user input.'}
        patient_note: str # The patient's medical note.
        gpt4_1_icd10: List[dict] # List of ICD-10 codes and descriptions extracted by gpt4.1. #gpt4_1_icd10 example: [{'code': 'J45.901', 'description': 'Unspecified asthma, intermittent, uncomplicated'}]
        gpt4o_mini_icd10: List[dict] #gpt4o_mini_icd10 example: [{'code': 'I10', 'description': 'Essential (primary) hypertension'}]
        gpt4_1_mini_icd10: List[dict] # Renamed and adjusted as per previous discussions. #gpt4_1_mini_icd10 example: [{'code': 'E11.9', 'description': 'Type 2 diabetes mellitus without complications'}]
        rag_icd10: List[dict] #rag_icd10 example: [{'code': 'N18.6', 'description': 'End stage renal disease'}]
        rag_confidence: List[float] #rag_confidence example: [0.95, 0.88]
        bm25_icd10: List[dict] # Add field for BM25 results. #bm25_icd10 example: [{'code': 'J44.9', 'description': 'Chronic obstructive pulmonary disease, unspecified'}]
        retries_per_node: dict
        tokens_per_node: dict
    """
    messages: Annotated[List[BaseMessage], add_messages] # Fixed reducer signature
    user_input: str
    agent_response: str
    tool_calls: List[dict]
    final_output: str
    debug_info: dict
    patient_note: str
    gpt4_1_icd10: List[dict]
    gpt4o_mini_icd10: List[dict]
    gpt4_1_mini_icd10: List[dict]
    rag_icd10: List[dict]
    rag_confidence: List[float]
    bm25_icd10: List[dict] # Added BM25 field
    retries_per_node: dict
    tokens_per_node: dict



### Step 3: Define the Start Agent and Synthetic Data

Now, let's create a synthetic patient note and define the starting agent. This agent will be the entry point for our graph and will receive the patient note.

In [73]:
# Generate a synthetic patient note (approximately 200 words)
# synthetic_patient_note = """
# Patient is a 68-year-old male presenting with a chief complaint of increasing shortness of breath over the past two weeks. Symptoms are worse with exertion and improve slightly with rest. He reports a history of chronic obstructive pulmonary disease (COPD), diagnosed 10 years ago, managed with inhaled bronchodilators as needed. He denies fever, chills, or cough with sputum production. He has a history of hypertension, controlled with lisinopril. No known allergies. Social history includes a 40 pack-year smoking history, quit 5 years ago. He lives with his wife and is retired. Physical examination reveals a thin male in mild respiratory distress. Vital signs: BP 140/85, HR 98, RR 22, Temp 98.6 F, SpO2 90% on room air. Auscultation of the lungs reveals diminished breath sounds bilaterally with scattered expiratory wheezes. Cardiac exam is regular rate and rhythm with no murmurs, rubs, or gallops. Extremities show no edema or clubbing. Assessment: Acute exacerbation of COPD. Plan: Administer nebulized albuterol and ipratropium. Start oral prednisone 40mg daily for 5 days. Obtain a chest X-ray and arterial blood gas. Continue home inhalers. Follow up in 1 week or sooner if symptoms worsen. Educate patient on symptom management and when to seek urgent care.
# """
# synthetic_patient_note = import from /content/patient_note.txt
synthetic_patient_note=open('/content/patient_note.txt','r').read()

# Define the start agent node - Modified to return a dictionary
def start_node(state: AgentState) -> dict:
    """
    The starting node of the graph. Receives the patient note.
    Returns a dictionary of state updates.
    """
    print("---START NODE---")
    # Assuming the patient note is passed in the initial state
    # If invoked automatically, you might fetch it from a source here
    # Initialize retries and tokens dictionaries in the initial state if they are not present
    initial_state_updates = {
        "patient_note": state.get("patient_note", synthetic_patient_note),
        "retries_per_node": state.get("retries_per_node", {}),
        "tokens_per_node": state.get("tokens_per_node", {})
    }
    return initial_state_updates

### Step 4: Build Agents

Now, we'll create parameter dictionaries for each agent. These dictionaries will hold all the necessary configurations for each agent, such as the model to use, system prompts, tool assignments, and other settings. We'll start with the `gpt4_1_icd10` agent.

In [74]:
# Agent parameters for the gpt4_1_icd10 agent
gpt4_1_icd10_agent_params = {
    "model": "gpt-4.1",  # Using a relevant GPT-4 model available via langchain-openai
    "system_prompt": """You are an expert medical coder. Your task is to extract all relevant ICD-10 codes and their descriptions from the provided patient note.

    Present the extracted information as a list of dictionaries, where each dictionary has 'code' and 'description' keys.
    """,
    "tools": [],  # Add relevant tools if this agent needs to use any
    "temperature": 0.1,
    "role": "ICD-10 Extractor (GPT-4)",
    "debug_logging": "info", # Set logging level (e.g., 'info', 'debug')
    # Include vector database configuration if needed for this agent (e.g., for RAG)
    # "vectorstore": vectorstore,
    # "retriever": retriever,
}

Now, let's define the parameters for the `gpt4.1mini_icd10` agent.

In [75]:
# Agent parameters for the gpt4.1mini_icd10 agent
gpt4_1_mini_icd10_agent_params = {
    "model": "gpt-4o-mini",  # Using the default model
    "system_prompt": """You are an expert medical coder. Your task is to extract all relevant ICD-10 codes and their descriptions from the provided patient note.

    Present the extracted information as a list of dictionaries, where each dictionary has 'code' and 'description' keys.
    """,
    "tools": [],  # Add relevant tools if this agent needs to use any
    "temperature": 0.1,
    "role": "ICD-10 Extractor (GPT-4o-mini)",
    "debug_logging": "info", # Set logging level (e.g., 'info', 'debug')
    # Include vector database configuration if needed for this agent (e.g., for RAG)
    # "vectorstore": vectorstore,
    # "retriever": retriever,
}

Next, let's define the parameters for the `gpt4omini_icd10` agent.

In [76]:
# Agent parameters for the gpt4omini_icd10 agent
gpt4omini_icd10_agent_params = {
    "model": "gpt-4o-mini",  # Using the default model
    "system_prompt": """You are an expert medical coder. Your task is to extract all relevant ICD-10 codes and their descriptions from the provided patient note.

    Present the extracted information as a list of dictionaries, where each dictionary has 'code' and 'description' keys.
    """,
    "tools": [],  # Add relevant tools if this agent needs to use any
    "temperature": 0.1,
    "role": "ICD-10 Extractor (GPT-4o-mini - separate instance)",
    "debug_logging": "info", # Set logging level (e.g., 'info', 'debug')
    # Include vector database configuration if needed for this agent (e.g., for RAG)
    # "vectorstore": vectorstore,
    # "retriever": retriever,
}

In [86]:
import pandas as pd
import os  # ADD THIS LINE
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document

# Initialize OpenAI Embeddings FIRST (needed for both create and load)
embeddings = OpenAIEmbeddings()

# ADD THIS SECTION - Check if FAISS database already exists
faiss_db_path = "icd10_faiss_db"
vectorstore = None
retriever = None

if os.path.exists(faiss_db_path):
    print(f"Found existing FAISS database at '{faiss_db_path}'")
    try:
        # Try to load existing FAISS database
        vectorstore = FAISS.load_local(faiss_db_path, embeddings, allow_dangerous_deserialization=True)
        print("✅ Successfully loaded existing FAISS database!")

        # Define the retriever
        retriever = vectorstore.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 10}
        )

        # Quick test to verify it works
        test_results = retriever.invoke("diabetes")
        print(f"✅ Database test successful - found {len(test_results)} results")

    except Exception as load_error:
        print(f"❌ Failed to load existing database: {load_error}")
        print("Will create a new database...")
        vectorstore = None

# CREATE NEW DATABASE if loading failed or doesn't exist
if vectorstore is None:
    print("Creating new FAISS database...")

    # Load the ICD-10 data from the CSV file
    try:
        icd10_df = pd.read_csv('/content/icd10_2019.csv')
        # Display the head of the DataFrame to understand the structure
        print("Head of icd10_2019.csv:")
        display(icd10_df.head())
        print(f"\nTotal rows: {len(icd10_df)}")

        # Create a separate Document for each row
        documents = []
        for index, row in icd10_df.iterrows():
            # Create meaningful content for each ICD-10 code
            code = row.get('sub-code', '')
            definition = row.get('definition', '')

            # Create rich content for better semantic search
            page_content = f"ICD-10 Code: {code}\nDescription: {definition}"

            # Add metadata for filtering and retrieval
            metadata = {
                'code': code,
                'definition': definition,
                'row_index': index
            }

            # Create Document object for this specific ICD-10 code
            doc = Document(
                page_content=page_content,
                metadata=metadata
            )
            documents.append(doc)

        print(f"Created {len(documents)} individual documents")

        # Create a FAISS vector store from all documents
        # Each document will get its own embedding
        vectorstore = FAISS.from_documents(documents, embeddings)
        print("\nFAISS vector database created successfully!")
        print("Each ICD-10 code is now a separate searchable chunk.")

        # Save the vector store for future use
        try:
            vectorstore.save_local(faiss_db_path)
            print(f"\n✅ Vector store saved to '{faiss_db_path}' directory")
        except Exception as save_error:
            print(f"❌ Could not save vector store: {save_error}")

        # Define the retriever with customizable parameters
        retriever = vectorstore.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 10}
        )

    except FileNotFoundError:
        print("Error: icd10_2019.csv not found. Please upload the file.")
    except Exception as e:
        print(f"An error occurred: {e}")

# CONTINUE WITH TESTING (only if vectorstore and retriever are available)
if vectorstore is not None and retriever is not None:
    # Test the retriever with a sample query
    print("\n" + "="*50)
    print("Testing retriever with sample query...")
    test_query = "diabetes mellitus"
    results = retriever.invoke(test_query)  # Updated method name

    print(f"\nTop {len(results)} results for '{test_query}':")
    for i, doc in enumerate(results, 1):
        print(f"\n{i}. {doc.metadata['code']}: {doc.metadata['definition']}")

    # Optional: Function to search ICD-10 codes
    def search_icd10_codes(query, k=5):
        """Search for ICD-10 codes based on a text query"""
        results = retriever.invoke(query)  # Updated method name

        print(f"\nSearch results for: '{query}'")
        print("-" * 40)
        for i, doc in enumerate(results[:k], 1):
            code = doc.metadata['code']
            definition = doc.metadata['definition']
            print(f"{i}. Code: {code}")
            print(f"   Description: {definition}")
            print()

    # Example usage
    print("\n" + "="*50)
    print("Example searches:")
    search_icd10_codes("heart attack", 3)
    search_icd10_codes("broken bone", 3)
else:
    print("❌ FAISS database not available. Cannot perform searches.")
    retriever = None

Creating new FAISS database...
Head of icd10_2019.csv:


Unnamed: 0.1,Unnamed: 0,url,chapter,domain,sub-code,definition
0,1,https://icd.who.int/browse10/2019/en#/A00-A09,Chapter I\r\nCertain infectious and parasitic ...,Intestinal infectious diseases\r\n(A00-A09),A00,Cholera
1,2,https://icd.who.int/browse10/2019/en#/A00-A09,Chapter I\r\nCertain infectious and parasitic ...,Intestinal infectious diseases\r\n(A00-A09),A00.0,"Cholera due to Vibrio cholerae 01, biovar chol..."
2,3,https://icd.who.int/browse10/2019/en#/A00-A09,Chapter I\r\nCertain infectious and parasitic ...,Intestinal infectious diseases\r\n(A00-A09),A00.1,"Cholera due to Vibrio cholerae 01, biovar eltor"
3,4,https://icd.who.int/browse10/2019/en#/A00-A09,Chapter I\r\nCertain infectious and parasitic ...,Intestinal infectious diseases\r\n(A00-A09),A00.9,"Cholera, unspecified"
4,5,https://icd.who.int/browse10/2019/en#/A00-A09,Chapter I\r\nCertain infectious and parasitic ...,Intestinal infectious diseases\r\n(A00-A09),A01,Typhoid and paratyphoid fevers



Total rows: 11243
Created 11243 individual documents

FAISS vector database created successfully!
Each ICD-10 code is now a separate searchable chunk.

✅ Vector store saved to 'icd10_faiss_db' directory

Testing retriever with sample query...

Top 10 results for 'diabetes mellitus':

1. E12: Malnutrition-related diabetes mellitus

2. E14: Unspecified diabetes mellitus

3. O24.3: Pre-existing diabetes mellitus, unspecified

4. E11: Type 2 diabetes mellitus

5. E13: Other specified diabetes mellitus

6. O24.1: Pre-existing type 2 diabetes mellitus

7. E10: Type 1 diabetes mellitus

8. Z83.3: Family history of diabetes mellitus

9. O24.2: Pre-existing malnutrition-related diabetes mellitus

10. O24.9: Diabetes mellitus in pregnancy, unspecified

Example searches:

Search results for: 'heart attack'
----------------------------------------
1. Code: I21
   Description: Acute myocardial infarction

2. Code: I25.2
   Description: Old myocardial infarction

3. Code: I25.1
   Description: Athe

### Step 4 (Continued): Build RAG Agent - Agent Parameters

Now, let's define the parameters for the RAG agent that will use the FAISS vector database we created to extract ICD-10 codes.

In [87]:
# Agent parameters for the RAG agent
rag_icd10_agent_params = {
    "model": "gpt-4o-mini",  # Using the default model for the RAG agent
    "system_prompt": """You are an expert medical coder assisting with ICD-10 code extraction.
    Use the provided context from the ICD-10 database to identify the most relevant codes and descriptions for the patient note.
    The note may have multiple problems so you need to present each problem with its context to RAG to retrieve the codes.
    Present all medical problems even if you think they are trivial.
    Present the extracted information as a list of dictionaries, where each dictionary has 'code' and 'description' keys, and include a confidence score for each extraction based on the relevance of the retrieved information.
    """,
    "tools": [],  # Add relevant tools if this agent needs to use any
    "temperature": 0.1,
    "role": "ICD-10 Extractor (RAG)",
    "debug_logging": "info", # Set logging level (e.g., 'info', 'debug')
    # Include the retriever for this agent to use the vector database
    "retriever": retriever,
}

### Step 4 (Continued): Build BM25 Agent - Retriever Creation

Now, we'll create a BM25 retriever from the loaded ICD-10 data.

In [88]:
from langchain_community.retrievers import BM25Retriever

# Create a BM25 retriever from the ICD-10 data
# We'll use the combined text of 'sub-code' and 'definition' for the BM25 retriever as well
bm25_retriever = BM25Retriever.from_texts(
    icd10_df.apply(lambda row: f"Code: {row.get('sub-code', '')}, Description: {row.get('definition', '')}", axis=1).tolist()
)

print("BM25 retriever created successfully.")

BM25 retriever created successfully.


# Task
Create a comprehensive Jupyter notebook that implements a multi-agent system using LangGraph and LangChain 1.0. The system should extract ICD-10 codes and descriptions from a synthetic patient note using multiple methods: gpt-4o-mini, gpt-4.1 (simulated), gpt-4.1-mini (simulated), RAG with a provided CSV file ("icd10_2019.csv") using FAISS and OpenAI embeddings, and BM25 retrieval. The notebook should include: environment setup, state graph schema definition with detailed examples and validation, agent parameter definitions for each extraction method, a reusable function to build LangGraph nodes with logging and error handling, node creation for each agent, graph construction with defined edges, and invocation of the graph with a synthetic patient note. The state should track the patient note, extracted codes and descriptions from each method, RAG confidence scores, number of retries per node, and token counts per node. The notebook should include markdown descriptions before each code chunk, and after defining each node or set of connected nodes, the current state of the graph should be invoked with the previous output, the current output printed, and the output saved to a variable named `[agent_name]_output`. The notebook should use the `gpt-4o-mini` model by default unless specified otherwise for a particular agent. Ensure all necessary libraries are installed and API keys are set up. The RAG agent should use the "sub-code" and "definition" columns from the "icd10_2019.csv" file for embeddings and build a single chunk for ingestion. The BM25 node should also be added to the graph.

## Define bm25 agent parameters

### Subtask:
Create the parameter dictionary for the BM25 agent, similar to the other agents, including a system prompt and assigning the `bm25_retriever`.


**Reasoning**:
Define the parameters for the BM25 agent including the model, system prompt, tools, temperature, role, debug logging, and assign the BM25 retriever.



In [89]:
# Agent parameters for the BM25 agent
bm25_icd10_agent_params = {
    "model": "gpt-4o-mini",  # Using the default model for the BM25 agent
    "system_prompt": """You are an expert medical coder assisting with ICD-10 code extraction. Use the provided context from the BM25 retriever to identify the most relevant codes and descriptions for the patient note.

    Present the extracted information as a list of dictionaries, where each dictionary has 'code' and 'description' keys.
    """,
    "tools": [],  # This agent won't be using external tools beyond the retriever
    "temperature": 0.1,
    "role": "ICD-10 Extractor (BM25)",
    "debug_logging": "info", # Set logging level (e.g., 'info', 'debug')
    # Include the BM25 retriever for this agent
    "retriever": bm25_retriever,
}

## Create agent nodes

### Subtask:
Use the node creation function to create a node for each agent (`start_node`, `gpt4_1_icd10`, `gpt4.1mini_icd10`, `gpt4omini_icd10`, `rag_icd10`, `bm25_icd10`).


**Reasoning**:
Use the `create_agent_node` function to create LangGraph nodes for each defined agent and assign them to the specified variables.



## Build node creation function

### Subtask:
Create a reusable Python function that takes agent parameters (like those defined in Step 4) and constructs a LangGraph node (a runnable or a function) with built-in logging, error handling, and state updates (including token counts and retries).

**Reasoning**:
Define a reusable function to create LangGraph nodes with logging, error handling, retry counting, and token counting.

In [90]:
import logging
import time
import json # Import the json module
import ast  # ADD THIS LINE - Import ast for parsing Python dict syntax
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document # Import Document for handling retrieved docs

# Set up basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Re-define the create_agent_node function to handle RAG/BM25 context retrieval explicitly and use json.loads for parsing
def create_agent_node(agent_params: dict, llm):
    """
    Creates a LangGraph node function from agent parameters.

    Args:
        agent_params: A dictionary containing agent configuration (model, system_prompt, etc.).
        llm: The language model object.

    Returns:
        A function that acts as a LangGraph node, returning a dictionary of state updates.
    """
    model_name = agent_params.get("model", "gpt-4o-mini")
    system_prompt_template = agent_params.get("system_prompt", "")
    tools = agent_params.get("tools", [])
    temperature = agent_params.get("temperature", 0.1)
    role = agent_params.get("role", "Unnamed Agent")
    debug_logging_level = agent_params.get("debug_logging", "info").upper()
    retriever = agent_params.get("retriever", None)
    retriever_k = agent_params.get("retriever_k", 10) # Default to retrieving top 10 documents

    agent_logger = logging.getLogger(role)
    agent_logger.setLevel(getattr(logging, debug_logging_level, logging.INFO))

    # Add explicit JSON instruction and "return ONLY the JSON" phrasing, escaping curly braces
    system_prompt_template_json = system_prompt_template + "\n\nReturn ONLY a valid JSON list of dictionaries, like this: {{'code': 'ABC.123', 'description': 'Example Description'}}. Do not include any other text or formatting."


    def agent_node(state: AgentState) -> dict:
        """
        LangGraph node function for an agent.
        Processes AgentState and returns a dictionary of state updates.
        """
        agent_logger.info(f"---Executing {role}---")
        start_time = time.time()
        current_state = state.copy()

        state_updates = {}

        retries = current_state.get("retries_per_node", {}).get(role, 0) + 1
        retries_per_node_update = current_state.get("retries_per_node", {}).copy()
        retries_per_node_update[role] = retries
        state_updates["retries_per_node"] = retries_per_node_update
        agent_logger.debug(f"Attempt number {retries} for {role}")

        try:
            context_for_formatting = ""
            if retriever:
                 try:
                    # Explicitly retrieve documents before invoking the chain
                    retrieved_docs = retriever.invoke(current_state["patient_note"])
                    # Limit the number of documents based on retriever_k
                    retrieved_docs = retrieved_docs[:retriever_k]
                    context_for_formatting = "\n".join([doc.page_content for doc in retrieved_docs])
                    agent_logger.debug(f"Retrieved {len(retrieved_docs)} documents for {role}.")
                 except Exception as retrieve_error:
                     agent_logger.warning(f"Could not retrieve docs for {role}: {retrieve_error}")
                     context_for_formatting = "Retrieval failed."


            # Construct the prompt template
            input_variables = ["patient_note"]
            if retriever:
                input_variables.append("context")

            prompt = ChatPromptTemplate.from_messages([
                ("system", system_prompt_template_json), # Use prompt with JSON instruction
                ("human", "{patient_note}" + ("\nContext: {context}" if retriever else "")),
            ])

            # Prepare the input dictionary for the chain
            chain_input = {"patient_note": current_state["patient_note"]}
            if retriever:
                chain_input["context"] = context_for_formatting # Pass the formatted context


            # Build the runnable chain (simplified since context is prepared before)
            chain = prompt | llm.bind(temperature=temperature)


            # Invoke the chain with the prepared input dictionary
            response = chain.invoke(chain_input)

            # Estimate token usage
            # Use the actual formatted prompt content for token counting
            formatted_prompt = prompt.format(**chain_input)

            input_tokens = len(formatted_prompt.split())
            output_tokens = len(response.content.split())
            total_tokens = input_tokens + output_tokens

            tokens_per_node_update = current_state.get("tokens_per_node", {}).copy()
            tokens_per_node_update[role] = tokens_per_node_update.get(role, 0) + total_tokens
            state_updates["tokens_per_node"] = tokens_per_node_update


            # Update the state with the agent's response
            try:
                # REPLACE THIS SECTION - ROBUST PARSING FOR RAG/BM25
                # OLD CODE: extracted_data = json.loads(response.content.strip())

                # NEW CODE: Handle both JSON and Python dict syntax
                def safe_parse_response(content):
                    """Parse response handling both JSON and Python dict syntax"""
                    content = content.strip()

                    # Method 1: Try ast.literal_eval (handles Python dict syntax with single quotes)
                    try:
                        if content.startswith('[') and content.endswith(']'):
                            return ast.literal_eval(content)
                    except (ValueError, SyntaxError):
                        pass

                    # Method 2: Try direct JSON parsing
                    try:
                        return json.loads(content)
                    except json.JSONDecodeError:
                        pass

                    # Method 3: Convert single quotes to double quotes and retry
                    try:
                        corrected = content.replace("'", '"')
                        return json.loads(corrected)
                    except json.JSONDecodeError:
                        pass

                    # If all fails, raise an error
                    raise ValueError(f"Could not parse response: {content[:100]}...")

                # Use the robust parser
                extracted_data = safe_parse_response(response.content)
                # END OF REPLACEMENT SECTION

                if not isinstance(extracted_data, list):
                     raise ValueError("Expected a JSON list.")
                for item in extracted_data:
                    if not isinstance(item, dict) or 'code' not in item or 'description' not in item:
                         raise ValueError("Expected list of dictionaries with 'code' and 'description'.")

                # Add extracted data to state updates based on the agent's role
                if role == "ICD-10 Extractor (GPT-4)":
                    state_updates["gpt4_1_icd10"] = extracted_data
                elif role == "ICD-10 Extractor (GPT-4o-mini)":
                    state_updates["gpt4o_mini_icd10"] = extracted_data
                elif role == "ICD-10 Extractor (GPT-4o-mini - separate instance)":
                    state_updates["gpt4_1_mini_icd10"] = extracted_data
                elif role == "ICD-10 Extractor (RAG)":
                    state_updates["rag_icd10"] = extracted_data
                    # Extract confidence scores if available, otherwise use dummy values
                    if extracted_data and isinstance(extracted_data[0], dict) and 'confidence' in extracted_data[0]:
                        state_updates["rag_confidence"] = [item.get('confidence', 1.0) for item in extracted_data]
                    else:
                        state_updates["rag_confidence"] = [1.0] * len(extracted_data) # Dummy confidence
                elif role == "ICD-10 Extractor (BM25)":
                    state_updates["bm25_icd10"] = extracted_data

                state_updates["agent_response"] = response.content # Store raw response too
                agent_logger.info(f"Successfully extracted data for {role}.")

            except json.JSONDecodeError as json_error:
                agent_logger.error(f"JSON parsing error for {role}: {json_error}. Content: {response.content.strip()}")
                state_updates["debug_info"] = {"level": "error", "message": f"JSON parsing error in {role}: {json_error}"}
            except Exception as parse_error:
                agent_logger.error(f"Error parsing response for {role}: {parse_error}. Content: {response.content.strip()}")
                state_updates["debug_info"] = {"level": "error", "message": f"Parsing error in {role}: {parse_error}"}


        except Exception as e:
            agent_logger.error(f"Error executing {role}: {e}")
            state_updates["debug_info"] = {"level": "error", "message": f"Execution error in {role}: {e}"}

        end_time = time.time()
        duration = end_time - start_time
        agent_logger.info(f"---Finished {role} in {duration:.2f} seconds---")
        return state_updates

    return agent_node

## Create agent nodes

### Subtask:
Use the node creation function to create a node for each agent (`start_node`, `gpt4_1_icd10`, `gpt4.1mini_icd10`, `gpt4omini_icd10`, `rag_icd10`, `bm25_icd10`).

**Reasoning**:
Use the `create_agent_node` function to create LangGraph nodes for each defined agent and assign them to the specified variables.

In [91]:
# Create nodes for each agent using the create_agent_node function
gpt4_1_icd10_node = create_agent_node(gpt4_1_icd10_agent_params, llm)
gpt4_1_mini_icd10_node = create_agent_node(gpt4_1_mini_icd10_agent_params, llm)
gpt4omini_icd10_node = create_agent_node(gpt4omini_icd10_agent_params, llm)
rag_icd10_node = create_agent_node(rag_icd10_agent_params, llm)
bm25_icd10_node = create_agent_node(bm25_icd10_agent_params, llm)

print("Agent nodes created successfully.")

Agent nodes created successfully.


## Build and Compile the LangGraph

### Subtask:
Initialize the `StateGraph`, add all the created agent nodes, define the entry point, add the edges to define the workflow, and compile the graph.

**Reasoning**:
Construct the LangGraph by adding the defined nodes and specifying the transitions between them to create the desired multi-agent workflow. Then, compile the graph into a runnable application.

In [92]:
from langgraph.graph import StateGraph, END

# Initialize the StateGraph with the defined AgentState
workflow = StateGraph(AgentState)

# Add all the agent nodes to the workflow
workflow.add_node("start", start_node)
workflow.add_node("gpt4_1_icd10", gpt4_1_icd10_node)
workflow.add_node("gpt4_1_mini_icd10", gpt4_1_mini_icd10_node)
workflow.add_node("gpt4omini_icd10", gpt4omini_icd10_node)
workflow.add_node("rag_icd10", rag_icd10_node)
workflow.add_node("bm25_icd10", bm25_icd10_node)

# Set the entry point of the graph
workflow.set_entry_point("start")

# Define the edges (transitions) between the nodes
# This defines the sequence in which the agents will be executed
workflow.add_edge("start", "gpt4_1_icd10")
workflow.add_edge("gpt4_1_icd10", "gpt4_1_mini_icd10")
workflow.add_edge("gpt4_1_mini_icd10", "gpt4omini_icd10")
workflow.add_edge("gpt4omini_icd10", "rag_icd10")
workflow.add_edge("rag_icd10", "bm25_icd10")

# Define the end point of the graph
workflow.add_edge("bm25_icd10", END)

# Compile the workflow into a runnable application
app = workflow.compile()

print("LangGraph workflow built and compiled successfully.")

LangGraph workflow built and compiled successfully.


## Invoke the LangGraph and Display Results

### Subtask:
Define the initial state with the synthetic patient note and invoke the compiled LangGraph. Print and save the final state output.

**Reasoning**:
Execute the complete multi-agent workflow by invoking the compiled graph with the initial input and capture the final state containing results from all agents.

In [93]:
from pprint import pprint

# Define the initial state for graph execution
initial_state = {
    "patient_note": synthetic_patient_note, # Use the synthetic patient note defined earlier
    "retries_per_node": {},
    "tokens_per_node": {},
    "messages": [], # Initialize with empty lists/dicts as per AgentState definition
    "user_input": "",
    "agent_response": "",
    "tool_calls": [],
    "final_output": "",
    "debug_info": {},
    "gpt4_1_icd10": [],
    "gpt4o_mini_icd10": [],
    "gpt4_1_mini_icd10": [],
    "rag_icd10": [],
    "rag_confidence": [],
    "bm25_icd10": [],
}

# Invoke the compiled workflow with the initial state
final_graph_output = app.invoke(initial_state)

# Print the final output state
print("\n---Final Graph Output State---")
pprint(final_graph_output)

# Save the output to a variable (already done above as final_graph_output)
# This variable now holds the complete state after the graph execution.

INFO:ICD-10 Extractor (GPT-4):---Executing ICD-10 Extractor (GPT-4)---


---START NODE---


INFO:ICD-10 Extractor (GPT-4):Successfully extracted data for ICD-10 Extractor (GPT-4).
INFO:ICD-10 Extractor (GPT-4):---Finished ICD-10 Extractor (GPT-4) in 4.65 seconds---
INFO:ICD-10 Extractor (GPT-4o-mini):---Executing ICD-10 Extractor (GPT-4o-mini)---
INFO:ICD-10 Extractor (GPT-4o-mini):Successfully extracted data for ICD-10 Extractor (GPT-4o-mini).
INFO:ICD-10 Extractor (GPT-4o-mini):---Finished ICD-10 Extractor (GPT-4o-mini) in 4.40 seconds---
INFO:ICD-10 Extractor (GPT-4o-mini - separate instance):---Executing ICD-10 Extractor (GPT-4o-mini - separate instance)---
INFO:ICD-10 Extractor (GPT-4o-mini - separate instance):Successfully extracted data for ICD-10 Extractor (GPT-4o-mini - separate instance).
INFO:ICD-10 Extractor (GPT-4o-mini - separate instance):---Finished ICD-10 Extractor (GPT-4o-mini - separate instance) in 5.32 seconds---
INFO:ICD-10 Extractor (RAG):---Executing ICD-10 Extractor (RAG)---
INFO:ICD-10 Extractor (RAG):Successfully extracted data for ICD-10 Extractor 


---Final Graph Output State---
{'agent_response': '[\n'
                   '    {"code": "J44.1", "description": "Chronic obstructive '
                   'pulmonary disease with (acute) exacerbation"},\n'
                   '    {"code": "J18.9", "description": "Pneumonia, '
                   'unspecified organism"},\n'
                   '    {"code": "I50.33", "description": "Acute on chronic '
                   'heart failure, unspecified"},\n'
                   '    {"code": "I48.91", "description": "Unspecified atrial '
                   'fibrillation"},\n'
                   '    {"code": "N18.3", "description": "Chronic kidney '
                   'disease, stage 3"},\n'
                   '    {"code": "E11.9", "description": "Type 2 diabetes '
                   'mellitus without complications"},\n'
                   '    {"code": "E87.1", "description": "Hypo-osmolarity and '
                   'hyponatremia"},\n'
                   '    {"code": "L97.521", "descriptio

## Export Results to CSV

Let's create a function to extract relevant data from the final graph output and export it to a single CSV file, appending a new row for each graph execution.

In [109]:
import pandas as pd
import os

def flatten_results_for_csv(graph_output: dict) -> dict:
    """
    Flattens relevant data from the graph output dictionary into a single-level dictionary
    suitable for a CSV row.
    """
    flattened_data = {}

    # Add performance metrics
    retries = graph_output.get("retries_per_node", {})
    tokens = graph_output.get("tokens_per_node", {})
    for agent_name in set(retries.keys()) | set(tokens.keys()):
        flattened_data[f"{agent_name}_retries"] = retries.get(agent_name, 0)
        flattened_data[f"{agent_name}_tokens"] = tokens.get(agent_name, 0)

    # Add extracted ICD-10 codes and descriptions from each agent
    # This part can be complex as lists need to be handled.
    # We'll flatten the lists into strings or multiple columns.
    # For simplicity, let's join codes and descriptions into strings for now.
    # A more sophisticated approach might use multiple columns or a separate CSV for codes.

    agent_code_fields = {
        "gpt4_1_icd10": "gpt4_1_icd10",
        "gpt4o_mini_icd10": "gpt4o_mini_icd10",
        "gpt4_1_mini_icd10": "gpt4_1_mini_icd10",
        "rag_icd10": "rag_icd10",
        "bm25_icd10": "bm25_icd10",
    }

    for agent_key, state_key in agent_code_fields.items():
        codes_list = graph_output.get(state_key, [])
        codes_str = "; ".join([item.get('code', '') for item in codes_list])
        descriptions_str = "; ".join([item.get('description', '') for item in codes_list])
        flattened_data[f"{agent_key}_codes"] = codes_str
        flattened_data[f"{agent_key}_descriptions"] = descriptions_str

    # Add RAG confidence scores (joined as a string)
    rag_confidence_list = graph_output.get("rag_confidence", [])
    confidence_str = "; ".join(map(str, rag_confidence_list))
    flattened_data["rag_confidence_scores"] = confidence_str

    # Add other relevant fields if needed
    # flattened_data["patient_note_snippet"] = graph_output.get("patient_note", "")[:100] + "..." # Example

    return flattened_data

def export_to_csv(data: dict, filename: str):
    """
    Exports a single row of data to a CSV file, creating the file if it doesn't exist
    or appending to it if it does.
    """
    df = pd.DataFrame([data])
    if not os.path.exists(filename):
        df.to_csv(filename, index=False, mode='w')
        print(f"Created and wrote to new CSV file: {filename}")
    else:
        df.to_csv(filename, index=False, mode='a', header=False)
        print(f"Appended data to existing CSV file: {filename}")

# Example usage (after graph invocation):
# Assuming final_graph_output contains the results of one run
# csv_filename = "multi_agent_results.csv"
# flattened_output = flatten_results_for_csv(final_graph_output)
# export_to_csv(flattened_output, csv_filename)

print("CSV export functions defined.")

CSV export functions defined.


## Export the Final Graph Output to CSV

Now, we will use the defined functions to export the data from the `final_graph_output` to a CSV file.

In [111]:
# Define the filename for the CSV
csv_filename = "multi_agent_results.csv"

# Flatten the final graph output
flattened_output = flatten_results_for_csv(final_graph_output)

# Export the flattened data to the CSV file
export_to_csv(flattened_output, csv_filename)

print(f"Final graph output exported to {csv_filename}")

Appended data to existing CSV file: multi_agent_results.csv
Final graph output exported to multi_agent_results.csv
