In [None]:
import operator
import warnings
from typing import Annotated, TypedDict

import torch
from dotenv import load_dotenv
from IPython.display import Image
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from transformers import logging

from medmax.agent import *
from medmax.tools import *
from medmax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()

PROMPT_FILE = "medmax/docs/system_prompts.txt"

"""
1. Change the LLM to public LLM
3. Explore code generation
4. SFT or example database, prompt enhancing to make tool call better
5. Add segmentation tool
6. How to incorporate new tools easily?
7. Implement gradio interface
8. New tool: https://www.nature.com/articles/s41551-024-01246-y

concrete task with dataset and benchmark + compute
mulitmodal still xray report generation is good task to explore -> still unsolved + great potential for commercialization | main problem is hallucination
"""

In [None]:
report_tool = ChestXRayReportGeneratorTool()
xray_classification_tool = ChestXRayClassifierTool()
medical_visual_qa_tool = MedicalVisualQATool()

print(f"Type: {type(report_tool)}, Name: {report_tool.name}")
print(f"Type: {type(xray_classification_tool)}, Name: {xray_classification_tool.name}")
print(f"Type: {type(medical_visual_qa_tool)}, Name: {medical_visual_qa_tool.name}")

In [None]:
prompts = load_prompts_from_file(PROMPT_FILE)
prompt = prompts["MEDICAL_ASSISTANT"]
print(prompt)

In [36]:
checkpointer = MemorySaver()
model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0, top_p=0.95)
agent = Agent(
    model, 
    tools=[xray_classification_tool, medical_visual_qa_tool, report_tool],
    log_tools=True,
    log_dir="logs",
    system_prompt=prompt, 
    checkpointer=checkpointer
)
thread = {"configurable": {"thread_id": "1"}}

# Image(agent.graph.get_graph().draw_png())

In [37]:
# messages = [HumanMessage(content="Does the image `demo/chest/pneumonia4.jpg` show pneumonia?")]
# messages = [HumanMessage(content="Provide a radiology report for the given image `demo/chest/pneumonia2.jpg`.")]
# messages = [HumanMessage(content="Descibe the given image.")]
# messages = [HumanMessage(content="What was the first question I asked?")]
# messages = [HumanMessage(content="What is the probability of pneumonia in the image `demo/chest/normal1.jpg`?")]
# messages = [HumanMessage(content="Does the patient with chest xray given here need to go see doctor? `demo/chest/normal1.jpg`")]

In [38]:
# for event in agent.workflow.stream({"messages": messages}, thread):
#     for v in event.values():
#         print(v, '\n')

In [54]:
import gradio as gr
from pathlib import Path
import time
import shutil
from typing import List, Optional
from gradio import ChatMessage

class ChatInterface:
    def __init__(self, agent):
        self.agent = agent
        self.upload_dir = Path("uploaded_images")
        self.upload_dir.mkdir(exist_ok=True)
        self.current_thread_id = None
        
    def handle_upload(self, image_path: str) -> str:
        """Handle image upload and return the new path"""
        if not image_path:
            return None
            
        source = Path(image_path)
        dest = self.upload_dir / f"upload_{int(time.time())}{source.suffix}"
        shutil.copy2(source, dest)
        return str(dest)

    async def process_message(self, 
                            message: str, 
                            image: Optional[str], 
                            chat_history: List[ChatMessage]) -> List[ChatMessage]:
        chat_history = chat_history or []
        
        # Handle image if present
        if image:
            saved_image = self.handle_upload(image)
            message = f"{message} `{saved_image}`"
        
        # Initialize thread ID if needed
        if not self.current_thread_id:
            self.current_thread_id = str(time.time())
        
        # Add user message to history
        chat_history.append(ChatMessage(role="user", content=message))
        yield chat_history
        
        try:
            # Process through your agent's workflow
            for event in self.agent.workflow.stream(
                {"messages": [{"role": "user", "content": message}]},
                {"configurable": {"thread_id": self.current_thread_id}}
            ):
                if isinstance(event, dict):
                    if 'process' in event:
                        # Handle main response
                        content = event['process']['messages'][-1].content
                        if content:
                            chat_history.append(ChatMessage(
                                role="assistant",
                                content=content
                            ))
                            yield chat_history
                    elif 'execute' in event:
                        # Handle tool execution
                        tool_name = event['execute']['messages'][-1].name
                        tool_result = eval(event['execute']['messages'][-1].content)[0]
                        if tool_result:
                            chat_history.append(ChatMessage(
                                role="assistant",
                                content=str(tool_result),
                                metadata={"title": f"🔧 Using {tool_name}"}
                            ))
                            yield chat_history
                            
        except Exception as e:
            chat_history.append(ChatMessage(
                role="assistant",
                content=f"❌ Error: {str(e)}",
                metadata={"title": "Error"}
            ))
            yield chat_history

def create_demo(agent):
    interface = ChatInterface(agent)
    
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        with gr.Column():
            gr.Markdown("""
            # 🏥 Medical X-Ray Analysis Assistant
            Upload an X-ray image and ask questions about it.
            """)
            
            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(
                        [],
                        height=900,
                        container=True,
                        show_label=False,
                        elem_classes="chat-box",
                        type="messages",
                        avatar_images=(None, "🏥")
                    )
                    with gr.Row():
                        with gr.Column(scale=3):
                            txt = gr.Textbox(
                                show_label=False,
                                placeholder="Ask about the X-ray...",
                                container=False
                            )
                            
                
                with gr.Column(scale=2):
                    image_display = gr.Image(
                        label="Current X-Ray",
                        type="filepath",
                        height=600,
                        container=True
                    )
                    upload_button = gr.UploadButton(
                        "📎 Upload X-Ray",
                        file_types=["image"],
                    )
                    clear_btn = gr.Button("Clear Chat")
                    new_thread_btn = gr.Button("New Thread")

        # Event handlers
        def clear_chat():
            return [], None

        def new_thread():
            interface.current_thread_id = str(time.time())
            return [], None

        txt.submit(
            interface.process_message,
            inputs=[txt, image_display, chatbot],
            outputs=chatbot
        ).then(
            lambda: ("", None),
            None,
            [txt, image_display]
        )

        upload_button.upload(
            lambda x: x,
            inputs=upload_button,
            outputs=image_display
        )

        clear_btn.click(clear_chat, outputs=[chatbot, image_display])
        new_thread_btn.click(new_thread, outputs=[chatbot, image_display])

    return demo

demo = create_demo(agent)
demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7887
* Running on public URL: https://f56e08736b7c47aae3.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


