In [None]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medmax.agent import *
from medmax.tools import *
from medmax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medmax/docs/system_prompts.txt"

"""
1. Change the LLM to public LLM
3. Explore code generation
4. SFT or example database, prompt enhancing to make tool call better
5. Add segmentation tool
6. How to incorporate new tools easily?
7. Implement gradio interface
8. New tool: https://www.nature.com/articles/s41551-024-01246-y

concrete task with dataset and benchmark + compute
mulitmodal still xray report generation is good task to explore -> still unsolved + great potential for commercialization | main problem is hallucination
"""

In [None]:
report_tool = ChestXRayReportGeneratorTool()
xray_classification_tool = ChestXRayClassifierTool()
medical_visual_qa_tool = MedicalVisualQATool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(xray_classification_tool)}, Name: {xray_classification_tool.name}")
print(f"Type: {type(medical_visual_qa_tool)}, Name: {medical_visual_qa_tool.name}")

In [None]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]
print(prompt)

In [18]:
checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(
    model, 
    tools=[xray_classification_tool, medical_visual_qa_tool, report_tool],
    log_tools=True,
    log_dir="logs",
    system_prompt=prompt, 
    checkpointer=checkpointer
)
thread = {"configurable": {"thread_id": "1"}}

# Image(agent.graph.get_graph().draw_png())

In [13]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
# messages = [HumanMessage(content="Provide a radiology report for the given image `demo/chest/pneumonia2.jpg`.")]
# messages = [HumanMessage(content="Descibe the given image.")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the probability of pneumonia in the image `demo/chest/normal1.jpg`?")]
messages = [HumanMessage(content="Does the patient with chest xray given here need to go see doctor? `demo/chest/normal1.jpg`")]

In [None]:
for event in agent.workflow.stream({"messages": messages}, thread):
    for v in event.values():
        print(v, '\n')

In [None]:
import gradio as gr
from pathlib import Path
import time
import shutil
from typing import List, Dict, Any, Optional
from langchain_core.messages import HumanMessage, AIMessage

class ChatInterface:
    def __init__(self, agent):
        self.agent = agent
        self.upload_dir = Path("uploaded_images")
        self.upload_dir.mkdir(exist_ok=True)
        self.current_thread_id = None
        
    def handle_upload(self, image_path: str) -> str:
        """Handle image upload and return the new path"""
        if not image_path:
            return None
            
        # Create a permanent copy in upload directory
        source = Path(image_path)
        dest = self.upload_dir / f"upload_{int(time.time())}{source.suffix}"
        shutil.copy2(source, dest)
        return str(dest)

    def format_message(self, msg: Any) -> str:
        """Format different message types for display"""
        print(msg)
        if isinstance(msg, dict):
            if 'messages' in msg:
                # Handle agent response message
                for message in msg['messages']:
                    if isinstance(message, AIMessage):
                        return message.content
            elif 'tool_call_id' in msg:
                # Handle tool calls
                return f"🔧 Using {msg['name']}: {msg['content']}"
        return str(msg)

    def process_message(self, 
                       message: str, 
                       image: Optional[str], 
                       history: List[List[str]]) -> List[List[str]]:
        history = history or []
        
        # Handle image if present
        if image:
            saved_image = self.handle_upload(image)
            message = f"{message} `{saved_image}`"
        
        # Initialize thread ID if needed
        if not self.current_thread_id:
            self.current_thread_id = str(time.time())
        
        # Add user message to history
        history.append([message, None])
        
        try:
            # Create message and process through agent
            messages = [HumanMessage(content=message)]
            for event in self.agent.workflow.stream(
                {"messages": messages},
                {"configurable": {"thread_id": self.current_thread_id}}
            ):
                # Update assistant response
                formatted_response = self.format_message(event)
                if formatted_response:
                    history[-1][1] = formatted_response
                    yield history
                time.sleep(0.05)
        except Exception as e:
            history[-1][1] = f"❌ Error: {str(e)}"
            yield history

def create_demo(agent):
    interface = ChatInterface(agent)
    
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        with gr.Column():
            gr.Markdown("""
            # 🏥 Medical X-Ray Analysis Assistant
            Upload an X-ray image and ask questions about it.
            """)
            
            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(
                        [],
                        height=600,
                        container=True,
                        show_label=False,
                        elem_classes="chat-box"
                    )
                    with gr.Row():
                        with gr.Column(scale=4):
                            txt = gr.Textbox(
                                show_label=False,
                                placeholder="Ask about the X-ray...",
                                container=False
                            )
                        with gr.Column(scale=1):
                            upload_button = gr.UploadButton(
                                "📎 Upload X-Ray",
                                file_types=["image"],
                            )
                
                with gr.Column(scale=2):
                    image_display = gr.Image(
                        label="Current X-Ray",
                        type="filepath",
                        height=400,
                        container=True
                    )
            
            with gr.Row():
                clear_btn = gr.Button("Clear Chat")
                new_thread_btn = gr.Button("New Thread")

        # Event handlers
        def clear_chat():
            return [], None

        def new_thread():
            interface.current_thread_id = str(time.time())
            return [], None

        txt.submit(
            interface.process_message,
            inputs=[txt, image_display, chatbot],
            outputs=chatbot
        ).then(
            lambda: ("", None),
            None,
            [txt, image_display]
        )

        upload_button.upload(
            lambda x: x,
            inputs=upload_button,
            outputs=image_display
        )

        clear_btn.click(clear_chat, outputs=[chatbot, image_display])
        new_thread_btn.click(new_thread, outputs=[chatbot, image_display])

    return demo



demo = create_demo(agent)
demo.launch(share=True)