In [None]:
# Install required libraries
!pip install unsloth datasets langchain sentence-transformers gradio chromadb langchain-community TTS -q

# Import necessary libraries
import torch
import gradio as gr
from datasets import load_dataset
from unsloth import FastLanguageModel, FastVisionModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from typing import List, Any, Optional
from pydantic import PrivateAttr

# Mount Google Drive for the models
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Paths to models
phi4_model_path = "/content/drive/My Drive/MedQA-Phi4_LoRA_Model/lora_model"
xray_model_path = "/content/drive/My Drive/Llama3.2-Vision-radiology"

# Load Phi4 Model (Agent 1)
phi4_model, phi4_tokenizer = FastLanguageModel.from_pretrained(
    model_name=phi4_model_path,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
    device_map="auto"
)
FastLanguageModel.for_inference(phi4_model)

from langchain.llms.base import LLM
from pydantic import PrivateAttr

class LangChainPhi4LLM(LLM):
    """
    Custom wrapper to make the Phi4 model compatible with LangChain.
    """
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()
    max_new_tokens: int
    device: str

    def __init__(self, model, tokenizer, max_new_tokens=256, device="cuda"):
        super().__init__(max_new_tokens=max_new_tokens, device=device)
        self._model = model
        self._tokenizer = tokenizer

    @property
    def _llm_type(self) -> str:
        return "custom_phi4_llm"

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        """
        Generates a response for the given prompt using the Phi4 model.
        """
        inputs = self._tokenizer([prompt], return_tensors="pt", truncation=True, padding=True, max_length=2048).to(self.device)
        with torch.no_grad():
            outputs = self._model.generate(**inputs, max_new_tokens=self.max_new_tokens, use_cache=True)
        return self._tokenizer.decode(outputs[0], skip_special_tokens=True)

wrapped_phi4_llm = LangChainPhi4LLM(model=phi4_model, tokenizer=phi4_tokenizer)

# Load Vision Model (Agent 2)
xray_model, xray_tokenizer = FastVisionModel.from_pretrained(
    xray_model_path,
    load_in_4bit=True,
    use_gradient_checkpointing="unsloth"
)
FastVisionModel.for_inference(xray_model)

# Load PubMedQA dataset and create vector store
dataset = load_dataset("bigbio/pubmed_qa", name="pubmed_qa_labeled_fold0_source", split="train", trust_remote_code=True)
docs = [{"id": f"pubmed_qa_train_{i}", "text": ex["LONG_ANSWER"] or ""} for i, ex in enumerate(dataset)]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = [{"id": f"{d['id']}_chunk_{idx}", "text": chunk} for d in docs for idx, chunk in enumerate(text_splitter.split_text(d["text"]))]
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
texts = [x["text"] for x in chunks]
metadatas = [{"source": x["id"]} for x in chunks]
db = Chroma.from_texts(texts=texts, embedding=embeddings, metadatas=metadatas, persist_directory="chroma_store")

# Define Prompt and Retrieval Chain for Phi4 (Agent 1)
prompt_template = """You are a medical QA system.
Use the following context to answer the question concisely without repeating the context or any lines in the final answer:
{context}

Question: {question}

Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})

# Initialize RetrievalQA with the wrapped Phi4 model
phi4_rag = RetrievalQA.from_chain_type(
    llm=wrapped_phi4_llm,
    chain_type="stuff",
    retriever=retriever,
    chain_type_kwargs={"prompt": prompt}
)

# Define Vision Agent (Agent 2)
def infer_xray(image):
    instruction = "You are an expert radiographer. Describe accurately and with details what you see in this image."
    messages = [
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": instruction}
        ]}
    ]
    input_text = xray_tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    inputs = xray_tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors="pt",
    ).to("cuda")
    with torch.no_grad():
        output = xray_model.generate(
            **inputs, max_new_tokens=256, use_cache=True, eos_token_id=xray_tokenizer.eos_token_id
        )
    # Extract and clean the relevant text
    xray_text = xray_tokenizer.decode(output[0], skip_special_tokens=True)
    return xray_text.strip()

# Define Multi-Agent Logic
def multi_agent_system(input_type, query=None, image=None):
    if input_type == "Chat" and query:
        # Agent 1 handles text-based queries
        response = phi4_rag.invoke({"query": query})
        raw_answer = response["result"]
        final_answer = raw_answer.split("Answer:")[-1].strip() if "Answer:" in raw_answer else raw_answer.strip()
        return f"Chat Answer:\n{final_answer}"
    elif input_type == "X-ray" and image is not None:
        # Agent 2 processes the image and passes output to Agent 1
        xray_text = infer_xray(image)
        response = phi4_rag.invoke({"query": xray_text})
        raw_answer = response["result"]
        final_answer = raw_answer.split("Answer:")[-1].strip() if "Answer:" in raw_answer else raw_answer.strip()
        return f"X-ray Analysis:\n{xray_text}\n\nGenerated Answer:\n{final_answer}"
    else:
        # Handle invalid inputs
        return "Please provide a valid query or image based on the selected input type."

# Gradio Interface
interface = gr.Interface(
    fn=multi_agent_system,
    inputs=[
        gr.Radio(["Chat", "X-ray"], label="Select Input Type"),
        gr.Textbox(label="Enter your query (for Chat)", placeholder="Type your query here..."),
        gr.Image(label="Upload X-ray Image (for X-ray)", type="numpy")
    ],
    outputs=gr.Textbox(label="Response"),
    title="Multi-Agent Medical System",
    description="A system where one agent handles text-based medical QA and another processes X-ray images to generate insights."
)

# Launch the app
interface.launch(share=True)
