Import necessary libraries

In [3]:
%pip install langgraph langchain langchain-google-genai langchain-community faiss-cpu python-dotenv pypdf

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [21]:
import os
from typing import TypedDict, Annotated, Sequence
from langgraph.graph import Graph, StateGraph
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.tools import Tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.agents import AgentExecutor, create_react_agent
from langchain_core.messages import HumanMessage
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import json
import csv
from dotenv import load_dotenv
from langgraph.graph import StateGraph, END
from graph_visualizer import save_graph_visualization

# Load environment variables from .env file
load_dotenv()  

True

Initialize LLM and tavily

In [5]:
# Initialize Google Generative AI
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", api_key=os.getenv("GEMINI_API_KEY"))
# llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=google_api_key)

# Initialize Tavily Search
tavily_tool = TavilySearchResults(api_key=os.getenv("TAVILY_API_KEY"))

In [6]:
# Create a vector store
def create_vector_store():
    """
    Create a vector store from PDF documents.
    
    This function loads PDF documents, splits the text into chunks, and creates embeddings using Google Generative AI.
    
    Returns:
        FAISS: A vector store containing the text and corresponding embeddings.
    """
    
    try:
        # Try to load the existing vector store first
        embeddings = GoogleGenerativeAIEmbeddings(
            model="models/gemini-embedding-exp-03-07", 
            google_api_key=os.getenv("GEMINI_API_KEY")
        )
        vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
        print("Loaded existing vector store.")
        return vector_store
    except Exception as e:
        print(f"Could not load existing vector store: {e}. Creating new vector store...")
    
    # Loading documents
    documents = []
    try:
        for i in range(1, 4):
            try:
                loader = PyPDFLoader(f"document{i}.pdf")
                documents.extend(loader.load())
            except FileNotFoundError:
                print(f"File document{i}.pdf not found, skipping.")
    except Exception as e:
        print(f"Error loading documents: {e}")
        # Fallback to loading just the first document
        try:
            loader = PyPDFLoader("document1.pdf")
            documents.extend(loader.load())
        except FileNotFoundError:
            print("Critical: No documents found")
            return None
    
    # Splitting documents and creating text representations
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    texts = text_splitter.split_documents(documents)
    print(texts)
    print(f"Number of text chunks: {len(texts)}")
    
    # Create embeddings using a Google Generative AI model
    embeddings = GoogleGenerativeAIEmbeddings(
        model="models/gemini-embedding-exp-03-07", 
        google_api_key=os.getenv("GEMINI_API_KEY")
    )
    
    # Create a vector store using FAISS from the provided text chunks and embeddings
    vector_store = FAISS.from_documents(texts, embedding=embeddings)
    
    # Save the vector store locally with the name "faiss_index"
    vector_store.save_local("faiss_index")
    return vector_store


# Initialize vector store
try:
    vector_store = create_vector_store()
    if vector_store is None:
        raise ValueError("Failed to create vector store")
except Exception as e:
    print(f"Error creating vector store: {e}")

Loaded existing vector store.


In [8]:
# Create a search function for the vector store
def search_vector_store(query: str) -> str:
    try:
        print(f"Starting vector search for query: {query}")
        results = vector_store.similarity_search(query, k=1)
        print(f"Vector search results: {results}")
        return results[0].page_content
    except Exception as e:
        print(f"Error searching vector store: {e}")
        return "No information found in the vector store."
    
# Create a Tavily search function
def execute_tavily_tool(query: str):
    try:
        print(f"Starting Tavily search for query: {query}")
        results = tavily_tool.run(query)
        print(f"Tavily search results: {results}")
        return results
    except Exception as e:
        print(f"Error in Tavily search: {e}")
        return "No information found using Tavily."
    
# Create a tool for the vector store
vector_search_tool = Tool(
    name="VectorSearch",
    func=search_vector_store,
    description="Searches the vector store for relevant information about diseases."
)

# Create a tool for Tavily search
tavily_web_search_tool =  Tool(
        name="TavilySearch",
        func=execute_tavily_tool,
        description="Searches the web for relevant information about diseases."
    )

Create the Agent

In [13]:
# Store tools in a list
tools = [vector_search_tool, tavily_web_search_tool]

In [27]:
# Prepare the prompt
prompt = PromptTemplate.from_template(
    """You are a medical information retrieval agent. Your task is to find information about diseases and their ICD codes.

    Tools you can use:
    {tools}

    Tool name: {tool_names}

    Human: {human_input}
    You must respond in the following format:
        Thought: <Your thought process>
        Action: <The action you will take>
        Action Input: <The input for the action>
        Observation: <The result of the action>
        ... (this Thought/Action/Action Input/Observation can repeat N times)
        Thought: I have gathered all the necessary information and will now provide the final answer.
        Final Answer: <The final answer in JSON format>

    Always use the following steps:
    1. **Retrieve Disease Description:** Use VectorSearch to retrieve the disease description.
    2. **Retrieve ICD Codes:**
        a. First, attempt to use VectorSearch to find the ICD codes for the disease.
        b. If VectorSearch does not provide the ICD codes, then use TavilySearch to retrieve them.
    3. Compile the results into the specified JSON format.
    4. Once you have compiled the JSON response, you MUST output it using the "Final Answer:" prefix.

    The final answer MUST be a JSON string with the following structure:
    {{
        "disease": "<disease_name>",
        "description": "<description>",
        "icd_codes": ["<code1>", "<code2>", ...]
    }}
    If VectorSearch does not provide a description, clearly state that in the "description" field of the final JSON.
    If VectorSearch does not provide ICD codes, clearly state that in the "icd_codes" field (e.g., ["Not found in vector store, retrieved from web search"]).

    {agent_scratchpad}
    """
)

In [28]:
# Create agent and agent_executor 
agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(
        agent=agent,
        tools=tools,
        handle_parsing_errors=True,
        max_iterations=7, # Increased max_iterations to allow for more steps (VectorSearch for desc, VectorSearch for ICD, then Tavily for ICD if needed)
        verbose=True,
        return_intermediate_steps=True,
        max_execution_time=120.0)

In [None]:
# Define the state
class AgentState(TypedDict):
    human_input: str
    messages: Sequence[HumanMessage]
    results: dict


In [None]:
# Define the nodes
def agent_node(state: AgentState):
    """Process the agent's response"""
    try:
        # Execute the agent
        result = agent_executor.invoke({
            "human_input": state["human_input"]
        })
        
        print("Raw result from agent:", result)
        
        # Check if the result contains the final JSON output
        if isinstance(result, dict) and "disease" in result and "description" in result and "icd_codes" in result:
            print("Final result obtained. Terminating chain.")
            return {
                "messages": state.get("messages", []),  # Keep existing messages
                "results": result  # Return the final result
            }
        
        # Continue processing if the result is incomplete
        return {
            "messages": state.get("messages", []),
            "results": result.get("output", "No output from agent")
        }
    except Exception as e:
        print(f"Error in agent_node: {e}")
        return {
            "messages": state.get("messages", []),
            "results": f"Error processing request: {str(e)}"
        }

In [31]:
# Define vector search node
def vector_search_node(state: AgentState):
    query = state["human_input"]
    result = vector_search_tool.func(query)
    # Assuming the agent node can handle this result as part of its state progression
    state["results"]["vector_search_result"] = result
    return state

In [32]:
#Define tavily search node
def tavily_search_node(state: AgentState):
    query = state["human_input"]
    result = tavily_web_search_tool.run(query)
    # Assuming the agent node can handle this result as part of its state progression
    state["results"]["tavily_search_result"] = result
    return state

In [33]:
# Create the graph
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent_node)

workflow.set_entry_point("agent")
workflow.set_finish_point("agent")

graph = workflow.compile()
print(graph.input_schema)


<class 'langchain_core.utils.pydantic.LangGraphInput'>


In [34]:
# Save visualization
save_graph_visualization(graph)

Output directory checked or created.
Graph object retrieved.
Attempting to draw PNG...
PNG drawn successfully.


In [35]:
# Run the graph
def run_graph(query: str):
    """Run the agent graph with the given query"""
    inputs = {
        "human_input": query,
        "messages": [],
        "results": {}
    }
    try:
        result = graph.invoke(inputs)
        
        # Check if the result contains the final JSON output
        if isinstance(result, dict) and "disease" in result["results"] and "description" in result["results"] and "icd_codes" in result["results"]:
            print("Final result obtained. Stopping execution.")
            return result["results"]
        
        return result["results"]
    except Exception as e:
        print(f"Error running graph: {e}")
        return f"Error: {str(e)}"


In [36]:
# function to generate output in json and csv formats
def generate_output(query: str, result: str):
    """
    Parses the result from a graph query to generate JSON and CSV outputs.

    Args:
        query (str): The original query string used to obtain the result.
        result (str): The output string from the graph query to be parsed.

    Returns:
        tuple: A tuple containing the JSON and CSV outputs based on the parsed result.
               Returns (None, None) if an error occurs.
    """
    try:
        # Check the input isn't empty and contains expected separator
        if not result:
            raise ValueError("Invalid input format: Result is empty")
        
        # print("Result: ", result)

        # Using more reliable methods to parse and extract information
        lines = result.strip().split('.')
        if len(lines) < 2:
            raise ValueError("Insufficient information in the result")

        # Assumed formats from logs
        disease = lines[0].strip()
        description = lines[1].strip()
        icd_line = [line for line in lines if "ICD-10" in line]
        icd_code = icd_line[0].strip().replace("The ICD-10 code is ", "") if icd_line else "N/A"
        print(disease)
        print(description)
        print(icd_line)
        print(icd_code)

        # Prepare JSON and CSV output
        json_output = {
            query: {
                "disease": disease,
                "description": description,
                "icd_code": icd_code
            }
        }
        print(json_output)

        csv_output = [["Query", "Disease", "Description", "ICD code"]]
        csv_output.append([query, disease, description, icd_code])

        return json_output, csv_output

    except Exception as e:
        print(f"Error processing output: {str(e)}")
        return None, None

In [37]:
# main function for execution
def main(query: str):

    """
    Executes query and generates outputs in JSON and CSV formats.

    Args:
        query (str): The query string to be executed by the graph processing function.

    Returns:
        None

    Raises:
        Exception: If an error occurs during query execution or output generation.
    """
    try:
        print(f"Processing query: {query}")
        result = run_graph(query)
        print(f"Result from graph: {result}")
        
        json_output, csv_output = generate_output(query, result)
        
        # Save JSON output
        with open('output.json', 'w') as f:
            json.dump(json_output, f, indent=2)
        
        # Save CSV output
        with open('output.csv', 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerows(csv_output)
        
        print("JSON output:", json_output)
        print("CSV output has been saved to output.csv")
        return json_output

    except Exception as e:
        print(f"Error in main function: {str(e)}")

In [40]:
# main execution
if __name__ == "__main__":
    main("fever")

Processing query: fever


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: VectorSearch
Action Input: fever description[0mStarting vector search for query: fever description
Vector search results: [Document(id='720cc71a-4af5-46fb-9f55-855653856ed6', metadata={'source': 'document1.pdf', 'page': 0, 'page_label': '1'}, page_content="Disease Information - Document1\nDisease Name: Malaria\nDescription: Malaria is a serious and sometimes fatal disease caused by a parasite that commonly\ninfects a certain type of mosquito which feeds on humans. People who get malaria are typically very\nsick with high fevers, shaking chills, and flu-like illness.\nICD Code: B54\nDisease Name: Tuberculosis\nDescription: Tuberculosis (TB) is a potentially serious infectious disease that mainly affects your\nlungs. The bacteria that cause tuberculosis are spread from person to person through tiny droplets\nreleased into the air via coughs and sneezes.\nICD Code: A15.0\nDisease Name: Pneumonia\