In [None]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medixar.agent import *
from medixar.tools import *
from medixar.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medixar/docs/system_prompts.txt"

"""
1. Change the LLM to public LLM
2. Add other tools + integrate tool baackend
4. Choose agent name
5. Explore code generation
6. SFT or example database, prompt enhancing to make tool call better
"""

In [None]:
report_tool = RadiologyReportGeneratorTool()
organ_size_tool = OrganSizeMeasurementTool()
chest_xray_tool = ChestXRayClassifierTool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(organ_size_tool)}, Name: {organ_size_tool.name}")
print(f"Type: {type(chest_xray_tool)}, Name: {chest_xray_tool.name}")

In [None]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]
print(prompt)

In [None]:
checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(model, [report_tool, organ_size_tool, chest_xray_tool], system=prompt, checkpointer=checkpointer)
thread = {"configurable": {"thread_id": "1"}}

Image(agent.graph.get_graph().draw_png())

In [None]:
messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'])

In [None]:
messages = [HumanMessage(content="Describe the radiology image in detail.")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'])

In [None]:
messages = [HumanMessage(content="What is the volume of the heart (mL)?")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'][-1].content)

In [None]:
messages = [HumanMessage(content="What was the first question I asked?")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'][-1].content)

In [None]:
# Nested tool calls
messages = [HumanMessage(content="What is the name and size of first organ mention in radiology report of `image.png`")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'])

In [None]:
############################################################# LEGACY CODE #############################################################
# web_search_tool = WebSearchTool()

# agent = Agent(
#     model="meta-llama/Llama-3.2-1B-Instruct",
#     tools={"web_search": web_search_tool},
#     tools_json_path="medixar/docs/tools.json",
#     system_prompts_file="medixar/docs/system_prompts.txt",
#     system_prompt_type="MEDICAL_ASSISTANT",
#     device="auto",
#     torch_dtype=torch.float16,
#     max_new_tokens=250,
#     temperature=0.7,
#     top_p=0.95
# )

# response = agent.generate("What is all you need? Respond using web search tool.")
# agent.messages
# agent.generate("what was the answer to what is all you need?")