In [None]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medixar.agent import *
from medixar.tools import *
from medixar.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medixar/docs/system_prompts.txt"

"""
1. Change the LLM to public LLM
2. Add other tools + integrate tool baackend
4. Choose agent name
5. Explore code generation
6. SFT or example database, prompt enhancing to make tool call better
"""

In [None]:
report_tool = RadiologyReportGeneratorTool()
organ_size_tool = OrganSizeMeasurementTool()
chest_xray_tool = ChestXRayClassifierTool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(organ_size_tool)}, Name: {organ_size_tool.name}")
print(f"Type: {type(chest_xray_tool)}, Name: {chest_xray_tool.name}")

In [None]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]
print(prompt)

In [None]:
checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-mini", temperature=0, top_p=0.95)
agent = Agent(model, [report_tool, organ_size_tool, chest_xray_tool], system=prompt, checkpointer=checkpointer)
thread = {"configurable": {"thread_id": "1"}}

Image(agent.graph.get_graph().draw_png())

In [6]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
messages = [HumanMessage(content="Describe the radiology image `demo/chest/pneumonia4.jpg` in detail.")]
# messages = [HumanMessage(content="What is the volume of the heart (mL)?")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the name and size of first organ mention in radiology report of `image.png`")]

In [None]:
async for event in agent.graph.astream_events({"messages": messages}, thread, version="v1"):
        kind = event["event"]
        if kind == "on_chat_model_stream":
            content = event["data"]["chunk"].content
            if content:
                print(content, end="")

In [None]:
import torch
from PIL import Image
import requests
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

disable_torch_init()

model_path = "microsoft/llava-med-v1.5-mistral-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_path)

In [3]:
image = Image.open("demo/chest/normal1.jpg")
question = "Describe the image in detail."

if model.config.mm_use_im_start_end:
    question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
else:
    question = DEFAULT_IMAGE_TOKEN + '\n' + question

conv = conv_templates["vicuna_v1"].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

image_tensor = process_images([image], image_processor, model.config)[0]

with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor.unsqueeze(0).half().cuda(),
        do_sample=True,
        temperature=0.2,
        max_new_tokens=1000,
        use_cache=True,
    )

output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(output)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


The image is a chest X-ray, which is a common diagnostic imaging technique used to visualize the structures within the chest, including the lungs, heart, and bones of the chest and spine. In this particular image, there are no visible abnormalities, which means that the structures within the chest appear to be normal.
