In [None]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medixar.agent import *
from medixar.tools import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

In [None]:
report_tool = RadiologyReportGeneratorTool()
organ_size_tool = OrganSizeMeasurementTool()

print(type(report_tool))
print(report_tool.name)
print(type(organ_size_tool))
print(organ_size_tool.name)

In [None]:
prompt = """You are a smart medical assistant. Use the tools available to answer questions. \
You are allowed to make multiple calls (either together or in sequence). \
Only look up information when you are sure of what you want. \
If you need to look up some information before asking a follow up question, you are allowed to do that!
Do not use formatting in your response.
"""

checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06")
agent = Agent(model, [report_tool, organ_size_tool], system=prompt, checkpointer=checkpointer)

Image(agent.graph.get_graph().draw_png())

In [None]:
thread = {"configurable": {"thread_id": "1"}}
messages = [HumanMessage(content="What is the report of the radiology image `image.png`")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'])

In [None]:
messages = [HumanMessage(content="What is the volume of the heart (mL)?")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'][-1].content)

In [None]:
messages = [HumanMessage(content="What was the question I asked you?")]

for event in agent.graph.stream({"messages": messages}, thread):
    for v in event.values():
        print(v['messages'][-1].content)

In [None]:
############################################################# LEGACY CODE #############################################################
# web_search_tool = WebSearchTool()

# agent = Agent(
#     model="meta-llama/Llama-3.2-1B-Instruct",
#     tools={"web_search": web_search_tool},
#     tools_json_path="medixar/docs/tools.json",
#     system_prompts_file="medixar/docs/system_prompts.txt",
#     system_prompt_type="MEDICAL_ASSISTANT",
#     device="auto",
#     torch_dtype=torch.float16,
#     max_new_tokens=250,
#     temperature=0.7,
#     top_p=0.95
# )

# response = agent.generate("What is all you need? Respond using web search tool.")
# agent.messages
# agent.generate("what was the answer to what is all you need?")