In [None]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medmax.agent import *
from medmax.tools import *
from medmax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medmax/docs/system_prompts.txt"

"""
1. Change the LLM to public LLM
3. Explore code generation
4. SFT or example database, prompt enhancing to make tool call better
5. Add segmentation tool
6. How to incorporate new tools easily?
"""

In [None]:
report_tool = ChestXRayReportGeneratorTool()
xray_classification_tool = ChestXRayClassifierTool()
medical_visual_qa_tool = MedicalVisualQATool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(xray_classification_tool)}, Name: {xray_classification_tool.name}")
print(f"Type: {type(medical_visual_qa_tool)}, Name: {medical_visual_qa_tool.name}")

In [None]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]
print(prompt)

In [146]:
checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(
    model, 
    tools=[xray_classification_tool, medical_visual_qa_tool, report_tool],
    log_tools=True,
    log_dir="logs",
    system_prompt=prompt, 
    checkpointer=checkpointer
)
thread = {"configurable": {"thread_id": "1"}}

# Image(agent.graph.get_graph().draw_png())

In [147]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
messages = [HumanMessage(content="Provide a radiology report for the given image `demo/chest/pneumonia2.jpg`.")]
# messages = [HumanMessage(content="Descibe the given image.")]
# messages = [HumanMessage(content="What is the volume of the heart (mL)?")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the probability of pneumonia in the image `demo/chest/normal1.jpg`?")]
# messages = [HumanMessage(content="Does the patient with chest xray given here need to go see doctor? `demo/chest/normal1.jpg`")]

In [None]:
for event in agent.workflow.stream({"messages": messages}, thread):
    for v in event.values():
        print(v)