In [13]:
import os
from typing import TypedDict, Annotated, Sequence
from langgraph.graph import Graph, StateGraph
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.tools import Tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.agents import AgentExecutor, create_react_agent
from langchain_core.messages import HumanMessage
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_google_genai import GoogleGenerativeAIEmbeddings
# import google.generativeai as genai

# from graph_visualizer import save_graph_visualization

import json
import csv
from dotenv import load_dotenv

In [14]:
# Load environment variables from .env file
load_dotenv()

# genai.configure(api_key=os.getenv("GEMINI_API_KEY")) # Loads API key

# Initialize Google Generative AI
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", api_key="AIzaSyC8R-EfvfAMKItpj7isfvgBUmRZFSeLu90")

# Initialize Tavily Search
tavily_tool = TavilySearchResults(api_key=os.getenv("TAVILY_API_KEY"))

In [15]:
# Create a vector store
def create_vector_store():
    """
    Create a vector store from PDF documents.
    
    This function loads PDF documents, splits the text into chunks, and creates embeddings using the OpenAI API.
    
    Returns:
        FAISS: A vector store containing the text and corresponding embeddings.
    """
    
    # Loading documents
    documents = []
    for i in range(1, 4):
        loader = PyPDFLoader(f"malraia_1.pdf")
        documents.extend(loader.load())
    
    # Splitting documents and creating text representations
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) # 2000 = 3 chunks => 1st 0-1000 (0),2nd 900-1900, 3rd 1800-2000
    texts = text_splitter.split_documents(documents)
    print(f"Texts: {texts}")
    
    # Print attributes of a single document to verify attribute names
    if texts:
        print(f"attributes of first doc: {vars(texts[0])}") # To help determine the correct attribute name
    
    # Create embeddings using a Google Generative AI model
    embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-exp-03-07", google_api_key="AIzaSyC8R-EfvfAMKItpj7isfvgBUmRZFSeLu90") # TODO added models/
    
    # print(embeddings)
    
    print(f"Number of text chunks: {len(texts)}")
    print(f"embeddings object: {embeddings}") #TODO Remove len()
    # print("Example embedding shape:", dir(embeddings[0][0]))
    
    # Create a vector store using FAISS from the provided text chunks and embeddings
    vector_store = FAISS.from_documents(texts, embedding=embeddings) #TODO uncomment to create new embeddings
    # vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
    
    # Save the vector store locally with the name "faiss_index"
    vector_store.save_local("faiss_index") # TODO uncomment to replace embeddings
    return vector_store

vector_store = create_vector_store()

Texts: [Document(metadata={'source': 'malraia_1.pdf', 'page': 1}, page_content='WHO guidelines for malaria - 30 November 2024 This document is a PDF generated from the WHO guidelines for malaria hosted on the MAGICapp online \nplatform: https://app.magicapp.org/#/guideline/LwRMXj. Each time the content of the platform is updated, a new PDF version of the \nGuidelines will be downloadable on the WHO Global Malaria Programme website to facilitate access where the Internet is not \navailable. Users should note the downloaded PDFs of the Guidelines may be outdated and not contain the latest recommendations. \nPlease consult with the website for the most up-to-date version of the Guidelines (https://www.who.int/teams/global-malaria-\nprogramme). \nContact \nWHO Global Malaria Programme \nAppia Avenue 20, 1202 Geneva, Switzerland \ngmpfeedback@who.int \nhttps://www.who.int/teams/global-malaria-programme \nSponsors/Funding \nFunding for the development and publication of the Guidelines was gr

In [16]:
# Create a search tool for the vector store
def search_vector_store(query: str) -> str:
    results = vector_store.similarity_search(query, k=1)
    return results[0].page_content

vector_search_tool = Tool(
    name="VectorSearch",
    func=search_vector_store,
    description="Searches the vector store for relevant information about diseases."
)

# Create an agent
tools = [vector_search_tool, tavily_tool]

In [17]:
# prompt = PromptTemplate.from_template(
#     """You are a medical information retrieval agent. Your task is to find information about diseases and their ICD codes.
    
#     Tools you can use:
#     {tools}
    
#     Tool name: {tool_names}
    
#     Human: {human_input}
#     Agent: Let's approach this step-by-step:
#     1) First, I'll search for the disease in our vector store.
#     2) Then, I'll use the web search tool to find the ICD code for the disease.
#     3) Finally, I'll compile the information and return it.
    
#     Agent Scratchpad:
#     {agent_scratchpad}
#     """
# )

prompt = PromptTemplate.from_template(
    """You are a medical information retrieval agent. Your task is to find information about diseases and their ICD codes.
    
    Tools you can use:
    {tools}
    
    Tool name: {tool_names}
    
    Human: {human_input}
    Agent: Let's approach this step-by-step:
    1) First, I'll search for the disease in our vector store.
    2) Then, I'll use the web search tool to find the ICD code for the disease.
    3) Finally, I'll compile the information and return it in the following JSON format:
    
    {{
        "disease": "<disease_name>",
        "description": "<description>",
        "icd_codes": ["<code1>", "<code2>", ...]
    }}
    
    Agent Scratchpad:
    {agent_scratchpad}
    """
)

agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)

In [18]:
# Define the state
class AgentState(TypedDict):
    human_input: str
    messages: Sequence[HumanMessage]
    results: dict

# Define the nodes
def agent_node(state: AgentState):
    # Invoke the agent and capture the raw response
    raw_result = agent_executor.invoke(state, handle_parsing_errors=True)
    
    # Log the raw response for debugging
    print("Raw LLM Response:", raw_result)
    
    # Process the response
    return {
        "messages": raw_result["intermediate_steps"],
        "results": raw_result["output"]
    }

# Create the graph
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent_node)

workflow.set_entry_point("agent")
workflow.set_finish_point("agent")

graph = workflow.compile()


In [19]:
# Run the graph
def run_graph(query: str):
    inputs = {
        "human_input": query,
        "messages": [],
        "results": {}
    }
    result = graph.invoke(inputs)
    return result["results"]

# # Generate output in JSON and CSV formats
# def generate_output(query: str, result: str):
#     # Parse the result
#     lines = result.split('\n')
#     document = lines[0].split(': ')[1]
#     disease = lines[1].split(': ')[1]
#     icd_code = lines[2].split(': ')[1]
    
#     # Generate JSON output
#     json_output = {
#         "document": {
#             "disease": icd_code
#         }
#     }
    
#     # Generate CSV output
#     csv_output = [["Document name", "Disease name", "ICD code"]]
#     csv_output.append([document, disease, icd_code])
    
#     return json_output, csv_output

def generate_output(query: str, result: str):
    try:
        # Parse the result as JSON
        output = json.loads(result)
        
        # Extract relevant fields
        disease = output.get("disease", "Unknown")
        description = output.get("description", "No description available")
        icd_codes = ", ".join(output.get("icd_codes", []))
        
        # Generate JSON output
        json_output = {
            "disease": disease,
            "description": description,
            "icd_codes": icd_codes
        }
        
        # Generate CSV output
        csv_output = [["Disease", "Description", "ICD Codes"]]
        csv_output.append([disease, description, icd_codes])
        
        return json_output, csv_output
    except json.JSONDecodeError:
        raise ValueError("The agent's output is not in the expected JSON format.")

In [21]:
# Main function
def main(query: str):
    result = run_graph(query)
    json_output, csv_output = generate_output(query, result)
    
    # Save JSON output
    with open('output.json', 'w') as f:
        json.dump(json_output, f, indent=2)
    
    # Save CSV output
    with open('output.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(csv_output)
    
    print("JSON output:", json_output)
    print("CSV output has been saved to output.csv")

# Example usage
main("Malaria")
# TODO only search vector_search
# res = search_vector_store("How many types of Cancer are there?")
# print(res)

Raw LLM Response: {'human_input': 'Malaria', 'messages': [], 'results': {}, 'output': 'Agent stopped due to iteration limit or time limit.'}


KeyError: 'intermediate_steps'