In [1]:
import operator
import warnings
from typing import *
import traceback

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging
import matplotlib.pyplot as plt
import numpy as np

from medmax.agent import *
from medmax.tools import *
from medmax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medmax/docs/system_prompts.txt"
SCRATCH_DIR = "/scratch/ssd004/scratch/afallah"

"""
1. SFT or example database, prompt enhancing to make tool call better
2. show segmentation result overlay on the image
3. use ppx instead of cm2, support dicom
4. look agent benchmark + build benchmark
"""

'\n1. SFT or example database, prompt enhancing to make tool call better\n2. show segmentation result overlay on the image\n3. use ppx instead of cm2, support dicom\n4. look agent benchmark + build benchmark\n'

In [2]:
# !pip install -U git+https://github.com/huggingface/transformers.git@88d960937c81a32bfb63356a2e8ecf7999619681
# !pip install -U pydicom gdcm  pylibjpeg 

In [2]:
report_tool = ChestXRayReportGeneratorTool()
xray_classification_tool = ChestXRayClassifierTool()
# medical_visual_qa_tool = MedicalVisualQATool()
segmentation_tool = ChestXRaySegmentationTool()
image_visualizer_tool = ImageVisualizerTool()
grounding_tool = XRayPhraseGroundingTool(cache_dir=SCRATCH_DIR, temp_dir="temp")
generation_tool = ChestXRayGeneratorTool(model_path=f"{SCRATCH_DIR}/roentgen", temp_dir="temp")
dicom_tool = DicomProcessorTool(temp_dir="temp")

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(xray_classification_tool)}, Name: {xray_classification_tool.name}")
# print(f"Type: {type(medical_visual_qa_tool)}, Name: {medical_visual_qa_tool.name}")
print(f"Type: {type(segmentation_tool)}, Name: {segmentation_tool.name}")
print(f"Type: {type(image_visualizer_tool)}, Name: {image_visualizer_tool.name}")
print(f"Type: {type(grounding_tool)}, Name: {grounding_tool.name}")
print(f"Type: {type(generation_tool)}, Name: {generation_tool.name}")
print(f"Type: {type(dicom_tool)}, Name: {dicom_tool.name}")

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

An error occurred while trying to fetch /scratch/ssd004/scratch/afallah/roentgen/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /scratch/ssd004/scratch/afallah/roentgen/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /scratch/ssd004/scratch/afallah/roentgen/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /scratch/ssd004/scratch/afallah/roentgen/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


Type: <class 'medmax.tools.report_generation.ChestXRayReportGeneratorTool'>, Name: chest_xray_report_generator
Type: <class 'medmax.tools.classification.ChestXRayClassifierTool'>, Name: chest_xray_classifier
Type: <class 'medmax.tools.segmentation.ChestXRaySegmentationTool'>, Name: chest_xray_segmentation
Type: <class 'medmax.tools.utils.ImageVisualizerTool'>, Name: image_visualizer
Type: <class 'medmax.tools.grounding.XRayPhraseGroundingTool'>, Name: xray_phrase_grounding
Type: <class 'medmax.tools.generation.ChestXRayGeneratorTool'>, Name: chest_xray_generator
Type: <class 'medmax.tools.dicom.DicomProcessorTool'>, Name: dicom_processor


In [3]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]

checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(
    model, 
    tools=[xray_classification_tool,
           report_tool, 
           segmentation_tool, 
           image_visualizer_tool, 
           grounding_tool, 
           generation_tool],
    log_tools=True,
    log_dir="logs",
    system_prompt=prompt, 
    checkpointer=checkpointer
)
thread = {"configurable": {"thread_id": "1"}}

# Image(agent.graph.get_graph().draw_png())

In [6]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
# messages = [HumanMessage(content="Provide a radiology report for the given image `demo/chest/pneumonia2.jpg`.")]
# messages = [HumanMessage(content="Descibe the given image.")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the probability of pneumonia in the image `demo/chest/normal1.jpg`?")]
# messages = [HumanMessage(content="Does the patient with chest xray given here need to go see doctor? `demo/chest/normal1.jpg`")]
# messages = [HumanMessage(content="What is the size of heart?")]
messages = [HumanMessage(content="Ground pleural effusion in the image `demo/chest/effusion1.png`.")]

In [None]:
for event in agent.workflow.stream({"messages": messages}, thread):
    for v in event.values():
        print(v, '\n')

In [33]:
output, _ = dicom_tool._run("temp/upload_1732291135.dcm")
print(output)

{'error': "Unable to decompress 'JPEG Lossless, Non-Hierarchical, First-Order Prediction (Process 14 [Selection Value 1])' pixel data because all plugins are missing dependencies:\n\tgdcm - requires gdcm>=3.0.10\n\tpylibjpeg - requires pylibjpeg>=2.0 and pylibjpeg-libjpeg>=2.1"}


In [35]:
import gradio as gr
from pathlib import Path
import time
import shutil
from typing import List, Optional, Tuple
from gradio import ChatMessage

class ChatInterface:
    def __init__(self, agent):
        self.agent = agent
        self.upload_dir = Path("temp")
        self.upload_dir.mkdir(exist_ok=True)
        self.current_thread_id = None
        self.uploaded_image = None
        self.current_display_image = None
        self.dicom_tool = DicomProcessorTool(temp_dir="temp")

    def handle_upload(self, file_path: str) -> Tuple[str, str]:
        """
        Handles file uploads, converting DICOM to PNG for display if needed
        Returns: Tuple[display_path, prompt_path]
        """
        if not file_path:
            return None, None
            
        source = Path(file_path)
        timestamp = int(time.time())
        prompt_path = self.upload_dir / f"upload_{timestamp}{source.suffix}"
        shutil.copy2(source, prompt_path)
        
        # If DICOM file, convert to PNG for display
        if source.suffix.lower() == '.dcm':
            print(f"Processing DICOM file: {prompt_path}")
            output, _ = self.dicom_tool._run(str(prompt_path))
            return output['image_path'], str(prompt_path)
        
        return str(prompt_path), str(prompt_path)

    async def process_message(self, 
                            message: str, 
                            image: Optional[str], 
                            chat_history: List[ChatMessage]) -> List[ChatMessage]:
        chat_history = chat_history or []
        
        # Handle new image upload
        if image and (not self.current_display_image or image != self.current_display_image):
            display_path, prompt_path = self.handle_upload(image)
            self.current_display_image = display_path
            self.uploaded_image = prompt_path
            
        # Append image path to message if exists
        if self.uploaded_image:
            message = f"{message} `{self.uploaded_image}`"
        
        # Initialize thread if needed
        if not self.current_thread_id:
            self.current_thread_id = str(time.time())
        
        # Add user message to history
        chat_history.append(ChatMessage(role="user", content=message))
        yield chat_history, self.current_display_image
        
        try:
            for event in self.agent.workflow.stream(
                {"messages": [{"role": "user", "content": message}]},
                {"configurable": {"thread_id": self.current_thread_id}}
            ):
                if isinstance(event, dict):
                    if 'process' in event:
                        content = event['process']['messages'][-1].content
                        if content:
                            chat_history.append(ChatMessage(
                                role="assistant",
                                content=content
                            ))
                            yield chat_history, self.current_display_image

                    elif 'execute' in event:
                        for message in event['execute']['messages']:
                            tool_name = message.name
                            tool_result = eval(message.content)[0]
                            
                            if tool_name == "image_visualizer":
                                self.current_display_image = message.args['image_path']
                            
                            if tool_result:
                                formatted_result = ' '.join(line.strip() for line in str(tool_result).splitlines()).strip()
                                chat_history.append(ChatMessage(
                                    role="assistant",
                                    content=formatted_result,
                                    metadata={"title": f"🔧 Using tool: {tool_name}"},
                                ))
                                yield chat_history, self.current_display_image
                            
        except Exception as e:
            chat_history.append(ChatMessage(
                role="assistant",
                content=f"❌ Error: {str(e)}",
                metadata={"title": "Error"}
            ))
            yield chat_history, self.current_display_image

def create_demo(agent):
    interface = ChatInterface(agent)
    
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        with gr.Column():
            gr.Markdown("""
            # 🏥 MedMAX
            Multimodal Medical Agent for Chest X-rays
            """)
            
            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(
                        [],
                        height=800,
                        container=True,
                        show_label=True,
                        elem_classes="chat-box",
                        type="messages",
                        label="Agent",
                        avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png")
                    )
                    with gr.Row():
                        with gr.Column(scale=3):
                            txt = gr.Textbox(
                                show_label=False,
                                placeholder="Ask about the X-ray...",
                                container=False
                            )
                            
                with gr.Column(scale=3):
                    image_display = gr.Image(
                        label="Image",
                        type="filepath",
                        height=700,
                        container=True
                    )
                    with gr.Row():
                        upload_button = gr.UploadButton(
                            "📎 Upload X-Ray",
                            file_types=["image"],
                        )
                        dicom_upload = gr.UploadButton(
                            "📄 Upload DICOM",
                            file_types=["file"],
                        )
                    with gr.Row():
                        clear_btn = gr.Button("Clear Chat")
                        new_thread_btn = gr.Button("New Thread")

        # Event handlers
        def clear_chat():
            interface.uploaded_image = None
            interface.current_display_image = None
            return [], None

        def new_thread():
            interface.current_thread_id = str(time.time())
            interface.uploaded_image = None
            interface.current_display_image = None
            return [], None

        def handle_file_upload(file):
            display_path, _ = interface.handle_upload(file.name)
            return display_path

        txt.submit(
            interface.process_message,
            inputs=[txt, image_display, chatbot],
            outputs=[chatbot, image_display]
        ).then(
            lambda: "",
            None,
            [txt]
        )

        upload_button.upload(
            handle_file_upload,
            inputs=upload_button,
            outputs=image_display
        )
        
        dicom_upload.upload(
            handle_file_upload,
            inputs=dicom_upload,
            outputs=image_display
        )

        clear_btn.click(clear_chat, outputs=[chatbot, image_display])
        new_thread_btn.click(new_thread, outputs=[chatbot, image_display])

    return demo

demo = create_demo(agent)
demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7867
* Running on public URL: https://4d8bea6d35b828e066.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Processing DICOM file: temp/upload_1732291334.dcm
