In [7]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medmax.agent import *
from medmax.tools import *
from medmax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medmax/docs/system_prompts.txt"

"""
1. Change the LLM to public LLM
3. Explore code generation
4. SFT or example database, prompt enhancing to make tool call better
5. Add segmentation tool, radiology report + get the tool call json
6. How to incorporate new tools easily?
"""

'\n1. Change the LLM to public LLM\n3. Explore code generation\n4. SFT or example database, prompt enhancing to make tool call better\n5. Add segmentation tool, radiology report + get the tool call json\n6. How to incorporate new tools easily?\n'

In [2]:
report_tool = RadiologyReportGeneratorTool()
organ_size_tool = OrganSizeMeasurementTool()
xray_classification_tool = ChestXRayClassifierTool()
medical_visual_qa_tool = MedicalVisualQATool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(organ_size_tool)}, Name: {organ_size_tool.name}")
print(f"Type: {type(xray_classification_tool)}, Name: {xray_classification_tool.name}")
print(f"Type: {type(medical_visual_qa_tool)}, Name: {medical_visual_qa_tool.name}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Type: <class 'medmax.tools.tools.RadiologyReportGeneratorTool'>, Name: radiology_report_generator
Type: <class 'medmax.tools.tools.OrganSizeMeasurementTool'>, Name: organ_size_measurement
Type: <class 'medmax.tools.tools.ChestXRayClassifierTool'>, Name: chest_xray_classifier
Type: <class 'medmax.tools.tools.MedicalVisualQATool'>, Name: medical_visual_qa


In [3]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]
print(prompt)

You are a smart medical assistant. Use the tools available to answer questions.
You are allowed to make multiple calls (either together or in sequence).
For best reposense to user, use multiple tools in sequences if needed.
If you need to look up some information before asking a follow up question, you are allowed to do that!


In [10]:
checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(model, [xray_classification_tool, medical_visual_qa_tool], system_prompt=prompt, checkpointer=checkpointer)
thread = {"configurable": {"thread_id": "1"}}

# Image(agent.graph.get_graph().draw_png())

In [11]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
# messages = [HumanMessage(content="Provide a radiology report for the given image `demo/chest/pneumonia2.jpg`.")]
# messages = [HumanMessage(content="Descibe the given image.")]
# messages = [HumanMessage(content="What is the volume of the heart (mL)?")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the name and size of first organ mention in radiology report of `image.png`")]
# messages = [HumanMessage(content="What is the probability of pneumonia in the image?")]
messages = [HumanMessage(content="Does the patient with chest xray given here need to go see doctor? `demo/chest/normal1.jpg`")]

In [13]:
for event in agent.workflow.stream({"messages": messages}, thread):
    for v in event.values():
        print(v)

{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_gdhXo0OjrXoI4TwxihoV4zZ8', 'function': {'arguments': '{"image_path":"demo/chest/normal1.jpg"}', 'name': 'chest_xray_classifier'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 24, 'prompt_tokens': 364, 'total_tokens': 388, 'completion_tokens_details': {'audio_tokens': None, 'reasoning_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': None, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_90354628f2', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-94f65e32-cd23-4a1a-836c-87c43f68752a-0', tool_calls=[{'name': 'chest_xray_classifier', 'args': {'image_path': 'demo/chest/normal1.jpg'}, 'id': 'call_gdhXo0OjrXoI4TwxihoV4zZ8', 'type': 'tool_call'}], usage_metadata={'input_tokens': 364, 'output_tokens': 24, 'total_tokens': 388, 'input_token_details': {'cache_read': 0}, 'output_token_details': {'reasoning': 0}})]}
E