In [None]:
import operator
import warnings
from typing import *
import traceback

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging
import matplotlib.pyplot as plt
import numpy as np

from medmax.agent import *
from medmax.tools import *
from medmax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medmax/docs/system_prompts.txt"

"""
0. Fix gradio issues and image visualizition tools
1. New tool: https://www.nature.com/articles/s41551-024-01246-y
2. New tool: https://huggingface.co/microsoft/maira-2
3. Change the LLM to public LLM
4. Explore code generation
5. SFT or example database, prompt enhancing to make tool call better
"""

In [6]:
from typing import Optional, Type, Dict, Tuple
from pydantic import BaseModel, Field
import matplotlib.pyplot as plt
import skimage.io
from pathlib import Path

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool


class ImageVisualizerInput(BaseModel):
    """Input schema for the Image Visualizer Tool."""

    image_path: str = Field(..., description="Path to the image file to display")
    title: Optional[str] = Field(None, description="Optional title to display above the image")
    description: Optional[str] = Field(
        None, description="Optional description to display below the image"
    )
    figsize: Optional[tuple] = Field(
        (10, 10), description="Optional figure size as (width, height) in inches"
    )
    cmap: Optional[str] = Field(
        "rgb", description="Optional colormap to use for displaying the image"
    )


class ImageVisualizerTool(BaseTool):
    """Tool for displaying medical images to users with annotations."""

    name: str = "image_visualizer"
    description: str = (
        "Displays images to users with optional titles and descriptions. "
        "Input: Path to image file and optional display parameters. "
        "Output: Dict with image path and metadata."
    )
    args_schema: Type[BaseModel] = ImageVisualizerInput

    def _display_image(
        self,
        image_path: str,
        title: Optional[str] = None,
        description: Optional[str] = None,
        figsize: tuple = (10, 10),
        cmap: str = "rgb",
    ) -> None:
        """Display an image with optional annotations."""
        plt.figure(figsize=figsize)

        img = skimage.io.imread(image_path)
        if len(img.shape) > 2 and cmap != "rgb":
            img = img[..., 0]

        plt.imshow(img, cmap=None if cmap == "rgb" else cmap)
        plt.axis("off")

        if title:
            plt.title(title, pad=15, fontsize=12)

        # Add description if provided
        if description:
            plt.figtext(
                0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10
            )

        # Adjust margins to minimize whitespace while preventing overlap
        plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
        plt.show()

    def _run(
        self,
        image_path: str,
        title: Optional[str] = None,
        description: Optional[str] = None,
        figsize: tuple = (10, 10),
        cmap: str = "rgb",
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> dict:
        """
        Display an image to the user with optional annotations.

        Args:
            image_path: Path to the image file
            title: Optional title to display above image
            description: Optional description to display below image
            figsize: Optional figure size as (width, height)
            cmap: Optional colormap to use for displaying the image
            run_manager: Optional callback manager

        Returns:
            Dict containing display status and metadata
        """
        try:
            # Verify image path
            if not Path(image_path).is_file():
                raise FileNotFoundError(f"Image file not found: {image_path}")

            # Display image
            self._display_image(image_path, title, description, figsize, cmap)

            output = {"image_path": image_path}
            metadata = {
                "image_path": image_path,
                "title": bool(title),
                "description": bool(description),
                "figsize": figsize,
                "cmap": cmap,
                "analysis_status": "completed",
            }
            return output, metadata

        except Exception as e:
            return (
                {"error": str(e)},
                {
                    "image_path": image_path,
                    "visualization_status": "failed",
                    "note": "An error occurred during image visualization",
                },
            )

    async def _arun(
        self,
        image_path: str,
        title: Optional[str] = None,
        description: Optional[str] = None,
        figsize: tuple = (10, 10),
        cmap: str = "rgb",
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> Tuple[Dict[str, any], Dict]:
        """Async version of _run."""
        return self._run(image_path, title, description, figsize, cmap)


In [None]:
# report_tool = ChestXRayReportGeneratorTool()
# xray_classification_tool = ChestXRayClassifierTool()
# medical_visual_qa_tool = MedicalVisualQATool()
# segmentation_tool = ChestXRaySegmentationTool()
image_visualizer_tool = ImageVisualizerTool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(xray_classification_tool)}, Name: {xray_classification_tool.name}")
print(f"Type: {type(medical_visual_qa_tool)}, Name: {medical_visual_qa_tool.name}")
print(f"Type: {type(segmentation_tool)}, Name: {segmentation_tool.name}")
print(f"Type: {type(image_visualizer_tool)}, Name: {image_visualizer_tool.name}")

In [23]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]

checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(
    model, 
    tools=[xray_classification_tool, medical_visual_qa_tool, report_tool, segmentation_tool, image_visualizer_tool],
    log_tools=True,
    log_dir="logs",
    system_prompt=prompt, 
    checkpointer=checkpointer
)
thread = {"configurable": {"thread_id": "1"}}

# Image(agent.graph.get_graph().draw_png())

In [14]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
# messages = [HumanMessage(content="Provide a radiology report for the given image `demo/chest/pneumonia2.jpg`.")]
# messages = [HumanMessage(content="Descibe the given image.")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the probability of pneumonia in the image `demo/chest/normal1.jpg`?")]
# messages = [HumanMessage(content="Does the patient with chest xray given here need to go see doctor? `demo/chest/normal1.jpg`")]
# messages = [HumanMessage(content="What is the size of heart?")]
messages = [HumanMessage(content="Use visualizer tool. Display the image `demo/chest/normal1.jpg`.")]

In [None]:
for event in agent.workflow.stream({"messages": messages}, thread):
    for v in event.values():
        print(v, '\n')

In [None]:
import gradio as gr
from pathlib import Path
import time
import shutil
from typing import List, Optional
from gradio import ChatMessage

class ChatInterface:
    def __init__(self, agent):
        self.agent = agent
        self.upload_dir = Path("temp")
        self.upload_dir.mkdir(exist_ok=True)
        self.current_thread_id = None
        
    def handle_upload(self, image_path: str) -> str:
        """Handle image upload and return the new path"""
        if not image_path:
            return None
            
        source = Path(image_path)
        dest = self.upload_dir / f"upload_{int(time.time())}{source.suffix}"
        shutil.copy2(source, dest)
        return str(dest)

    async def process_message(self, 
                            message: str, 
                            image: Optional[str], 
                            chat_history: List[ChatMessage]) -> List[ChatMessage]:
        chat_history = chat_history or []
        current_image = image  # Track current image
        
        if image:
            saved_image = self.handle_upload(image)
            message = f"{message} `{saved_image}`"
        
        if not self.current_thread_id:
            self.current_thread_id = str(time.time())
        
        chat_history.append(ChatMessage(role="user", content=message))
        yield chat_history, current_image  # Initial yield with both values
        
        try:
            for event in self.agent.workflow.stream(
                {"messages": [{"role": "user", "content": message}]},
                {"configurable": {"thread_id": self.current_thread_id}}
            ):
                if isinstance(event, dict):
                    if 'process' in event:
                        content = event['process']['messages'][-1].content
                        if content:
                            chat_history.append(ChatMessage(
                                role="assistant",
                                content=content
                            ))
                            yield chat_history, current_image

                    elif 'execute' in event:
                        for message in event['execute']['messages']:
                            tool_name = message.name
                            tool_result = eval(message.content)[0]
                            
                            if tool_name == "image_visualizer":
                                current_image = message.args['image_path']
                            
                            if tool_result:
                                formatted_result = ' '.join(line.strip() for line in str(tool_result).splitlines()).strip()
                                chat_history.append(ChatMessage(
                                    role="assistant",
                                    content=formatted_result,
                                    metadata={"title": f"🔧 Using tool: {tool_name}"},
                                ))
                                yield chat_history, current_image
                            
        except Exception as e:
            chat_history.append(ChatMessage(
                role="assistant",
                content=f"❌ Error: {str(e)}",
                metadata={"title": "Error"}
            ))
            yield chat_history, current_image

def create_demo(agent):
    interface = ChatInterface(agent)
    
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        with gr.Column():
            gr.Markdown("""
            # 🏥 MedMAX
            Multimodal Medical Agent for Chest X-rays
            """)
            
            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(
                        [],
                        height=800,
                        container=True,
                        show_label=True,
                        elem_classes="chat-box",
                        type="messages",
                        label="Agent",
                        avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png")
                    )
                    with gr.Row():
                        with gr.Column(scale=3):
                            txt = gr.Textbox(
                                show_label=False,
                                placeholder="Ask about the X-ray...",
                                container=False
                            )
                            
                with gr.Column(scale=3):
                    image_display = gr.Image(
                        label="Image",
                        type="filepath",
                        height=700,
                        container=True
                    )
                    upload_button = gr.UploadButton(
                        "📎 Upload X-Ray",
                        file_types=["image"],
                    )
                    with gr.Row():
                        clear_btn = gr.Button("Clear Chat")
                        new_thread_btn = gr.Button("New Thread")

        # Event handlers
        def clear_chat():
            return [], None

        def new_thread():
            interface.current_thread_id = str(time.time())
            return [], None

        txt.submit(
            interface.process_message,
            inputs=[txt, image_display, chatbot],
            outputs=[chatbot, image_display]
        ).then(
            lambda: "",
            None,
            [txt]
        )

        upload_button.upload(
            lambda x: x,
            inputs=upload_button,
            outputs=image_display
        )

        clear_btn.click(clear_chat, outputs=[chatbot, image_display])
        new_thread_btn.click(new_thread, outputs=[chatbot, image_display])

    return demo

demo = create_demo(agent)
demo.launch(share=True)