# MedGemma client & runner

In [1]:
!pip install -U transformers
!pip install -U huggingface_hub transformers
!pip install transformers psutil torch datasets

Collecting huggingface_hub
  Downloading huggingface_hub-1.1.4-py3-none-any.whl.metadata (13 kB)


In [2]:
import os
import sys

google_colab = "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT")

if google_colab:
    # Use secret if running in Google Colab
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
else:
    # Store Hugging Face data under `/content` if running in Colab Enterprise
    if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
        os.environ["HF_HOME"] = "/content/hf"
    # Authenticate with Hugging Face
    from huggingface_hub import get_token
    if get_token() is None:
        from huggingface_hub import notebook_login
        notebook_login()

In [3]:
!pip install --upgrade --quiet accelerate bitsandbytes transformers qdrant-client fastembed qdrant-client[fastembed]

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m328.6/328.6 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.3/105.3 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.3/103.3 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.8/324.8 kB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
from transformers import BitsAndBytesConfig
import torch

model_variant = "4b-it"  # @param ["4b-it", "4b-pt"]
model_id = f"google/medgemma-{model_variant}"

use_quantization = True  # @param {type: "boolean"}


# If running a 27B variant in Google Colab, check if the runtime satisfies
# memory requirements
if "27b" in model_variant and google_colab:
    if not ("A100" in torch.cuda.get_device_name(0) and use_quantization):
        raise ValueError(
            "Runtime has insufficient memory to run a 27B variant. "
            "Please select an A100 GPU and use 4-bit quantization."
        )

model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

if use_quantization:
    model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)

In [7]:
# @title
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List, Optional
from pathlib import Path
from qdrant_client import QdrantClient, models

from fastembed import TextEmbedding


@dataclass
class RagSettings:
    qdrant_host: str = os.getenv("QDRANT_HOST", "165.22.56.15")
    qdrant_port: int = int(os.getenv("QDRANT_PORT", 6333))
    qdrant_api_key: str | None = os.getenv("QDRANT_API_KEY")
    qdrant_collection: str = os.getenv("QDRANT_COLLECTION", "case_text_2")
    vector_name: str = os.getenv("QDRANT_VECTOR_NAME", "text")
    embed_model: str = os.getenv("EMBED_MODEL", "qdrant/fastembed-multilingual-v3")
    embed_fallbacks: str = os.getenv("EMBED_MODEL_FALLBACKS", "")
    rag_top_k: int = int(os.getenv("RAG_TOP_K", 4))
    context_prefix: str = os.getenv(
        "RAG_CONTEXT_PREFIX",
        "Use the following reference material when reasoning about the case:",
    )


rag_settings = RagSettings()


In [8]:
# @title
@lru_cache(maxsize=1)
def get_qdrant_client() -> QdrantClient:
    return QdrantClient(
        host=rag_settings.qdrant_host,
        port=rag_settings.qdrant_port,
        api_key=rag_settings.qdrant_api_key,
        prefer_grpc=False,
        timeout=120.0,
    )


def _candidate_embed_models() -> List[str]:
    env_fallbacks = []
    if rag_settings.embed_fallbacks:
        env_fallbacks = [model.strip() for model in rag_settings.embed_fallbacks.split(",") if model.strip()]
    hardcoded = [
        "qdrant/fastembed-multilingual-v1",
        "qdrant/fastembed-multilingual",
        "BAAI/bge-m3",
        "BAAI/bge-base-en-v1.5",
    ]
    candidates: List[str] = [rag_settings.embed_model, *env_fallbacks, *hardcoded]
    seen: set[str] = set()
    unique_candidates: List[str] = []
    for model_name in candidates:
        if not model_name or model_name in seen:
            continue
        seen.add(model_name)
        unique_candidates.append(model_name)
    return unique_candidates


@lru_cache(maxsize=1)
def get_embedder() -> TextEmbedding:
    errors: List[str] = []
    for model_name in _candidate_embed_models():
        try:
            return TextEmbedding(model_name=model_name)
        except ValueError as exc:
            errors.append(f"{model_name}: {exc}")
        except Exception as exc:
            errors.append(f"{model_name}: {exc}")
    error_msg = "Failed to initialise embedding model. "
    if errors:
        error_msg += "; ".join(errors)
    raise RuntimeError(error_msg)


def embed_query(query: str) -> List[float]:
    embedder = get_embedder()
    vectors = list(embedder.embed([query]))
    if not vectors:
        raise ValueError("Failed to compute embedding for query.")

    base_vector  = vectors[0]
    target_dim = 3072
    base_len = len(base_vector)

    # Lặp lại vector cho đến khi đủ 3072
    repeated_vector = []
    while len(repeated_vector) < target_dim:
        remaining = target_dim - len(repeated_vector)
        repeated_vector.extend(base_vector[:min(remaining, base_len)])

    return repeated_vector[:target_dim]

STRUCTURED_DATA_ROOT = Path("/content/drive/MyDrive/Project-AI/structured-data-json")

def load_structured_json(case_id: str) -> Dict[str, Any]:
    """Load full structured JSON for a case."""
    json_path = STRUCTURED_DATA_ROOT / f"{case_id}.json"
    if not json_path.exists():
        return {}
    try:
        return yaml.safe_load(json_path.read_text(encoding="utf-8"))
    except Exception:
        return {}

def retrieve_references(query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
    top_k = top_k or rag_settings.rag_top_k
    vector = embed_query(query)
    client = get_qdrant_client()

    search_kwargs = {
        "collection_name": rag_settings.qdrant_collection,
        "limit": top_k,
        "with_payload": True,
    }

    try:
         # Sử dụng API đơn giản không dùng NamedVector
        if hasattr(client, 'search'):
            # Cách 1: Dùng search với vector thường
            hits = client.search(
                collection_name=rag_settings.qdrant_collection,
                query_vector=vector,
                limit=top_k,
            )
        elif hasattr(client, 'search_points'):
            # Cách 2: Dùng search_points với vector thường
            result = client.search_points(
                collection_name=rag_settings.qdrant_collection,
                query_vector=vector,
                limit=top_k,
                with_payload=True,
            )
            hits = result.points
        else:
            # Cách 3: Dùng query_points với vector thường
            result = client.query_points(
                collection_name=rag_settings.qdrant_collection,
                query=vector,
                limit=top_k,
                with_payload=True,
            )
            hits = result.points
    except Exception as e:
        print(f"Search error: {e}")
    return [
        {
            "id": hit.id,
            "score": hit.score,
            "payload": hit.payload or {},
        }
        for hit in hits
    ]



def format_references(references: List[Dict[str, Any]]) -> str:
    if not references:
        return ""

    chunks: List[str] = []

    for idx, ref in enumerate(references, start=1):
        payload = ref.get("payload", {})
        case_id = payload.get("case_id")
        score = ref.get("score")

        # 1. Load structured JSON
        case_json = load_structured_json(case_id) if case_id else {}

        # 2. Format JSON thành text
        if case_json:
            fields = []
            for k, v in case_json.items():
                fields.append(f"**{k.replace('_', ' ').title()}**: {v}")
            summary = "\n".join(fields)
        else:
            summary = "No JSON data available."

        # 3. Title
        title = f"Case {case_id}" if case_id else f"Reference {idx}"

        # 4. Prefix with score if available
        if score is not None:
            prefix = f"## {title} (score={score:.3f})\n"
        else:
            prefix = f"## {title}\n"

        chunks.append(prefix + summary)

    return "\n\n".join(chunks)



def gather_rag_context(query: Optional[str] = None, top_k: Optional[int] = None) -> tuple[str, List[Dict[str, Any]]]:
    if not query:
        return "", []
    references = retrieve_references(query, top_k=top_k)
    return format_references(references), references


In [9]:
# @title
def build_system_content(base_instruction: str, *, context: str | None = None) -> List[Dict[str, str]]:
    content: List[Dict[str, str]] = [{"type": "text", "text": base_instruction}]
    if context:
        content.append(
            {
                "type": "text",
                "text": f"{rag_settings.context_prefix}\n\n{context}",
            }
        )
    return content


def build_user_content(prompt: str, *, image: Any | None = None) -> List[Dict[str, Any]]:
    content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
    if image is not None:
        content.append({"type": "image", "image": image})
    return content


In [10]:
!pip install transformers psutil torch datasets



In [None]:
from google.colab import files
from PIL import Image
uploaded = files.upload()  # Upload the image file from local system

# Get the uploaded file name
image_path = next(iter(uploaded))
image = Image.open(image_path)


#Prompt Template

In [11]:
# Install required packages
!pip install pyyaml pytest pytest-snapshot

import yaml
import os
from pathlib import Path

# Create directory structure
Path("prompts/medgemma").mkdir(parents=True, exist_ok=True)
# medgemma_prompt.yaml

# Template 1: General_Explain
general_explain = """
name: "general_explain"
description: "Template for general medical explanations"
system_prompt: |
  You are a medical educator providing explanatory information.
  - Explain medical concepts in accessible language. cite sources from the provided context using markers like [1], [2] at the end of relevant paragraphs.
  - Use analogies and examples when helpful
  - Cite sources for factual information, Provide accurate source citations in [number] format.
  - Ensure all references include clickable URLs
  - Distinguish between established facts and areas of uncertainty

  ### Format your response as:
  - **Question**: {question}
  - **Answer**: [Direct answer to the question, answer with citations format [number] at the end of cited/relevant paragraphs.]
  - **References**: [list all references and Another relevant source at the end, include available URL link, example format: [1] source.[https://example.com](https://example.com)

prompt_template: |
  <start_of_turn>user
  Explain the concept: {question}<end_of_turn>
  """
# Template 2: Clinical Q&A
clinical_qa = """
name: "clinical_qa"
description: "Answer a clinical question strictly using retrieved/attached sources."
system_prompt: |
  You are an expert in {specialty} communicating with {audience}. Read the clinical case and provide the single most likely diagnosis.
  ### Format your response as:
  **Predicted disease:**[Respond with most accurate diagnosis name]
  **Reason:**[Short explanation]

prompt_template: |
  You are diagnosing tropical diseases for a patient:
  Question: {user_prompt}
  Patient age: {age}
  Primary symptoms: {symptoms}
"""
# Template 3: Drug Info

drug_info= """
name: "drug_info"
description: "Summarize drug information grounded in the supplied sources."
system_prompt: |
  You are a pharmacology information specialist providing drug-related information.
  - Provide information only from verified sources, 1. Provide accurate inline citations in [number] format.
  - Include mechanism, indications, dosing, interactions, warnings
  - State if information is incomplete or from limited sources
  - Do not make clinical recommendations beyond documented guidelines


  ### Format your response as:
  -**Mechanism**: [How the drug works at molecular/physiological level]
  -**Approved Indications**: [FDA/regulatory approved uses]
  -**Dosing Guidance**: [Recommended dosing, adjustments, administration]
  -**Contraindications**: [Absolute and relative contraindications]
  -**References**: [list all References at the end]

prompt_template: |
  <start_of_turn>user
    Provide drug information for:{question}<end_of_turn>
"""

# Save all templates
templates = {
    "general_explain.yaml": general_explain,
    "drug_info.yaml": drug_info,
    "clinical_qa.yaml": clinical_qa
}

for filename, content in templates.items():
    with open(f"prompts/medgemma/{filename}", "w") as f:
        f.write(content)
    print(f"Created: prompts/medgemma/{filename}")

print("All templates created successfully!")

Collecting pytest-snapshot
  Downloading pytest_snapshot-0.9.0-py3-none-any.whl.metadata (9.5 kB)
Downloading pytest_snapshot-0.9.0-py3-none-any.whl (10 kB)
Installing collected packages: pytest-snapshot
Successfully installed pytest-snapshot-0.9.0
Created: prompts/medgemma/general_explain.yaml
Created: prompts/medgemma/drug_info.yaml
Created: prompts/medgemma/clinical_qa.yaml
All templates created successfully!


#Medgemma without RAG

In [12]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [24]:
import os
import re
import time
import warnings
from functools import lru_cache
from typing import Any, Dict, List, Optional
import requests
import torch
from transformers import pipeline
from PIL import Image
from io import BytesIO
from pathlib import Path

# Initialize the MedGemma 4B model with optional quantization
pipe = pipeline(
    "image-text-to-text",
    model="google/medgemma-4b-it",
    dtype=torch.bfloat16,
    device="cuda" if torch.cuda.is_available() else "cpu",
)
# Configure dataset path and inference options
TEST_DATA_ROOT = Path("/content/drive/MyDrive/Data/Test-Data/extracted-data")
USE_RAG = False  # Set True to enable retrieval with the original case text
MAX_CASES = None  # Set to an integer to cap processed cases

KEYWORD_MODEL_DEFAULT = os.getenv("RAG_KEYWORD_MODEL", "google/flan-t5-small")
KEYWORD_PROMPT_DEFAULT = os.getenv(
    "RAG_KEYWORD_PROMPT",
    (
        "Extract up to {max_keywords} concise clinical keywords or short phrases "
        "from the following request. Return only a comma-separated list without numbering.\n"
        "Request: {prompt}"
    ),
)

@lru_cache(maxsize=4)
def get_keyword_extractor(model_id: str = KEYWORD_MODEL_DEFAULT):
    device = 0 if torch.cuda.is_available() else -1
    return pipeline("text2text-generation", model=model_id, device=device)

def extract_keywords_with_llm(
    prompt_text: str,
    *,
    max_keywords: int = 6,
    model_id: Optional[str] = None,
    instruction_template: Optional[str] = None,
    **generate_kwargs: Any,
 ) -> List[str]:
    """Use a small seq2seq model to extract search-oriented keywords."""
    if not prompt_text or not prompt_text.strip():
        return []
    chosen_model = model_id or KEYWORD_MODEL_DEFAULT
    instruction = (
        instruction_template or KEYWORD_PROMPT_DEFAULT
    ).format(
        max_keywords=max_keywords,
        prompt=prompt_text.strip(),
        user_prompt=prompt_text.strip(),
    )
    extractor = get_keyword_extractor(chosen_model)
    response = extractor(
        instruction,
        max_new_tokens=64,
        num_return_sequences=1,
        **generate_kwargs,
    )
    generated = ""
    if response and isinstance(response, list):
        candidate = response[0]
        generated = candidate.get("generated_text") or candidate.get("text") or ""
    raw_tokens = re.split(r"[\n,;]|\band\b", generated)
    seen: set[str] = set()
    keywords: List[str] = []
    for token in raw_tokens:
        cleaned = token.strip("\t :.-\u2022\u2023\uf0b7")
        if not cleaned:
            continue
        lowered = cleaned.lower()
        if lowered in seen:
            continue
        seen.add(lowered)
        keywords.append(cleaned)
        if len(keywords) >= max_keywords:
            break
    return keywords

def render_text_template(
    base_text: Optional[str],
    *,
    template: Optional[str] = None,
    variables: Optional[Dict[str, Any]] = None,
    field_name: str = "prompt",
    defaults: Optional[Dict[str, Any]] = None,
 ) -> str:
    """Render free-form text using a format template when provided."""
    text_value = base_text or ""
    if template:
        template_vars: Dict[str, Any] = {
            "prompt": text_value,
            "user_prompt": text_value,
        }
        if defaults:
            template_vars.update(defaults)
        if variables:
            template_vars.update(variables)
        try:
            rendered = template.format(**template_vars)
        except KeyError as exc:
            missing_key = exc.args[0] if exc.args else "unknown"
            raise ValueError(f"{field_name} template expects variable '{missing_key}'.") from exc
        if rendered.strip():
            return rendered
        if text_value.strip():
            return text_value
        raise ValueError(f"{field_name.title()} template rendered an empty string.")
    if text_value.strip():
        return text_value
    raise ValueError(f"Provide a {field_name} or {field_name} template.")

def load_image(image_url: Optional[str] = None, image_path: Optional[str] = None):
    """Load an image from a URL or disk, returning the PIL image or an error dict."""
    if image_url:
        try:
            response = requests.get(image_url)
            response.raise_for_status()
            image_obj = Image.open(BytesIO(response.content))
            image_obj.verify()  # quick integrity check
            image_obj = Image.open(BytesIO(response.content))  # reopen after verify
            return image_obj
        except Exception as exc:
            return {"error": "Failed to load image", "message": str(exc)}
    if image_path:
        try:
            image_obj = Image.open(image_path)
            image_obj.verify()
            image_obj = Image.open(image_path)
            return image_obj
        except Exception as exc:
            return {"error": "Failed to load image", "message": str(exc)}
    return None

def find_case_images(case_dir: Path) -> list[Path]:
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
    image_files = []

    # find in the main folder
    for ext in image_extensions:
        image_files.extend(case_dir.glob(f"*{ext}"))
        image_files.extend(case_dir.glob(f"*{ext.upper()}"))

    # find in 'images' folder (If available)
    images_dir = case_dir / "images"
    if images_dir.exists() and images_dir.is_dir():
        for ext in image_extensions:
            image_files.extend(images_dir.glob(f"*{ext}"))
            image_files.extend(images_dir.glob(f"*{ext.upper()}"))

    return sorted(image_files)


def load_case_images(case_dir: Path) -> list[Image.Image]:
    image_paths = find_case_images(case_dir)
    loaded_images = []

    for img_path in image_paths:
        try:
            image = Image.open(img_path)
            # Verify image integrity
            image.verify()
            # Reopen after verify
            image = Image.open(img_path)
            loaded_images.append(image)
        except Exception as e:
            print(f"Failed to load image {img_path.name}: {e}")
    return loaded_images

def build_system_content(system_instruction: str, context: str = "") -> List[Dict[str, Any]]:
    """Build system content for MedGemma."""
    content = [{"type": "text", "text": system_instruction}]
    if context:
        content.append({"type": "text", "text": f"Context: {context}"})
    return content

def build_user_content(prompt: str, image: Optional[Image.Image] = None) -> List[Dict[str, Any]]:
    """Build user content for MedGemma."""
    content = [{"type": "text", "text": prompt}]
    if image:
        content.append({"type": "image", "image": image})
    return content

def extract_answer_from_response(response) -> str:
    """Extract the answer text from MedGemma's response structure."""
    if not response:
        return "No answer."

    # MedGemma typically returns a list with the generated text
    if isinstance(response, list) and len(response) > 0:
        generated_item = response[0]

        # The structure might vary - try different access patterns
        if isinstance(generated_item, dict) and "generated_text" in generated_item:
            generated_text = generated_item["generated_text"]

            # If generated_text is a list, extract the content from the last message
            if isinstance(generated_text, list):
                # Find the last assistant message
                for msg in reversed(generated_text):
                    if isinstance(msg, dict) and msg.get("role") == "assistant":
                        return msg.get("content", "No content found.")
                # If no assistant message found, return the last content
                last_msg = generated_text[-1]
                if isinstance(last_msg, dict) and "content" in last_msg:
                    return last_msg["content"]

            # If it's a string, return it directly
            elif isinstance(generated_text, str):
                return generated_text

        # Try direct string access
        elif isinstance(generated_item, str):
            return generated_item

    # Fallback: convert to string
    return str(response)

def _extract_predicted_disease(answer: str) -> str:
    if not answer:
        return ""
    match = re.search(r"Predicted disease:\s*(.+)", answer, flags=re.IGNORECASE)
    if match:
        return match.group(1).strip()
    return answer.strip()

def run_medgemma(
    prompt: Optional[str] = None,
    *,
    prompt_template: Optional[str] = None,
    prompt_variables: Optional[Dict[str, Any]] = None,
    image: Optional[Image.Image] = None,
    system_instruction: str = "You are an expert in tropical diseases.",
    system_instruction_template: Optional[str] = None,
    system_instruction_variables: Optional[Dict[str, Any]] = None,
    max_new_tokens: int = 2000,
    temperature: float = 0.0,
    **generate_kwargs: Any,
 ) -> Dict[str, Any]:
    """Generate an answer with MedGemma without RAG context."""
    start_time = time.time()
    output: Dict[str, Any] = {
        "answer": None,
        "latency": None,
        "error": None,
        "resolved_prompt": None,
        "resolved_system_instruction": None,
    }

    # Render final prompt
    final_prompt = render_text_template(
        prompt,
        template=prompt_template,
        variables=prompt_variables,
        field_name="prompt",
    )
    output["resolved_prompt"] = final_prompt

    # Render system instruction
    final_system_instruction = render_text_template(
        system_instruction,
        template=system_instruction_template,
        variables=system_instruction_variables,
        field_name="system instruction",
        defaults={"final_prompt": final_prompt},
    )
    output["resolved_system_instruction"] = final_system_instruction

    # Build messages for MedGemma
    system_content = build_system_content(final_system_instruction)
    user_content = build_user_content(
        final_prompt,
        image=image if isinstance(image, Image.Image) else None,
    )
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": user_content},
    ]

    try:
        response = pipe(
            messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            **generate_kwargs,
        )

        # Extract the answer text properly
        output["model_response"] = response[0]["generated_text"][-1]["content"] if response else "No answer."
        generated = response[0].get("generated_text")
        extracted_anw = extract_answer_from_response(response)
        output["answer"] = generated
        # Extract prediction
        output["predicted"] = _extract_predicted_disease(extracted_anw)

        output["system_content"]=system_content
        output["user_content"]

    except Exception as exc:
        output["error"] = f"Model inference error: {exc}"
    finally:
        output["latency"] = time.time() - start_time

    return output

def _format_chat_like_answer(answer: Any) -> str:
    """Convert chat-style answers into a readable text block."""
    if answer is None:
        return ""
    if isinstance(answer, str):
        return answer
    if not isinstance(answer, list):
        return str(answer)

    formatted_lines: List[str] = []
    for message in answer:
        if not isinstance(message, dict):
            formatted_lines.append(str(message))
            continue
        role = str(message.get("role", "assistant")).strip() or "assistant"
        content = message.get("content")
        segments: List[str] = []
        if isinstance(content, list):
            for chunk in content:
                if not isinstance(chunk, dict):
                    segments.append(str(chunk))
                    continue
                chunk_type = chunk.get("type")
                if chunk_type == "text":
                    text = chunk.get("text")
                    if text:
                        segments.append(str(text))
                elif chunk_type == "image":
                    segments.append("[Image provided]")
        elif content is not None:
            segments.append(str(content))

        body = "\n".join(seg for seg in segments if seg).strip()
        if body:
            header = role.title()
            formatted_lines.append(f"{header}:\n{body}")
    return "\n\n".join(formatted_lines).strip()

def pretty_print_result(result: Dict[str, Any] | List[Dict[str, Any]]) -> None:
    """Utility to display the result in a readable form."""
    # If list of result
    if isinstance(result, list):
        for i, case_result in enumerate(result):
            print(f"\n{'='*50}")
            print(f"Case ID: {case_result.get('case_id', 'Unknown')}")
            print(f"{'='*50}")
            _print_single_result(case_result)
        return

    # single result
    _print_single_result(result)

def _print_single_result(result: Dict[str, Any]) -> None:
    """Print a single case result."""
    if result.get("error"):
        print("Error:", result["error"])
        return
    if result.get("resolved_prompt"):
        print("Prompt (resolved):\n", result["resolved_prompt"])
    answer_text = _format_chat_like_answer(result.get("answer"))
    print("\nAnswer:\n", answer_text)
    print(f"\nLatency: {result.get('latency', 0.0):.2f}s")
    if result.get("predicted"):
        print(f"\nPredicted disease: {result.get('predicted')}")


def derive_prompt_variables(
    *,
    user_prompt: str,
    symptom_limit: int = 12,
    age_model_id: Optional[str] = None,
    symptom_model_id: Optional[str] = None,
 ) -> Dict[str, Any]:
    """Build template variables via the keyword LLM to keep prompts DRY. Missing values fall back to placeholders."""
    if not user_prompt or not user_prompt.strip():
        return {}
    variables: Dict[str, Any] = {}

# CÁCH 1: Dùng regex đơn giản để extract tuổi
    age_match = re.search(r'(\d{1,3})\s*[- ]?\s*(year|years|yr|yo|y/o)[- ]?(old)?', user_prompt.lower())
    if age_match:
        try:
            variables["age"] = int(age_match.group(1))
        except ValueError:
            variables["age"] = "unknown"
    else:
        # CÁCH 2: Fallback dùng LLM nếu regex không work
        age_instruction = (
            "Extract the patient's age from the statement. "
            "Respond with a single number if the age is given, otherwise respond with NONE.\n"
            "Patient statement: {prompt}"
        )
        age_candidates = extract_keywords_with_llm(
            user_prompt,
            max_keywords=1,
            model_id=age_model_id,
            instruction_template=age_instruction,
        )
        for token in age_candidates:
            digit_match = re.search(r"\d{1,3}", token)
            if digit_match:
                try:
                    variables["age"] = int(digit_match.group())
                    break
                except ValueError:
                    continue

    symptom_instruction = (
        "List up to {max_keywords} concise clinical symptoms or findings mentioned by the patient. "
        "Return a comma-separated list without numbering.\n"
        "Patient statement: {prompt}"
    )
    symptoms = extract_keywords_with_llm(
        user_prompt,
        max_keywords=symptom_limit,
        model_id=symptom_model_id,
        instruction_template=symptom_instruction,
    )
    if symptoms:
        variables["symptoms"] = ", ".join(symptoms)
    if "age" not in variables:
        variables["age"] = "unknown"
    if "symptoms" not in variables:
        variables["symptoms"] = "unspecified"
    return variables

def run_diagnosis_testcase(*, use_rag: bool = USE_RAG, max_cases: int | None = MAX_CASES):
    if not TEST_DATA_ROOT.exists():
        raise FileNotFoundError(f"Test data directory not found: {TEST_DATA_ROOT}")

    results: list[dict[str, Any]] = []
    case_counter = 0

    print(f"Scanning for cases in: {TEST_DATA_ROOT}")

    for case_dir in sorted(TEST_DATA_ROOT.iterdir()):
        if max_cases is not None and case_counter >= max_cases:
            break
        if not case_dir.is_dir():
            continue

        query_path = case_dir / "query.txt"
        if not query_path.exists():
            continue

        # Create prompt
        user_prompt = query_path.read_text(encoding="utf-8").strip()
        prompt_variables = derive_prompt_variables(
            user_prompt=user_prompt,
            symptom_limit=12,

        )
        prompt_type = "clinical_qa"
        # print(f"Auto-selected prompt type: {prompt_type}")
        #Load the prompt template
        with open(f'/content/prompts/medgemma/{prompt_type}.yaml', 'r') as file:
            config = yaml.safe_load(file)
        # Format a prompt
        prompt_template = config['prompt_template']
        system_instruction = config['system_prompt']

        # Load images (take 1st image only)
        imgs = load_case_images(case_dir)
        image_url = imgs[0] if imgs else None

        image = globals().get("image", None)
        if image is None and image_url:
            image = load_image(image_url=image_url)


        # Run MedGemma without RAG
        response = run_medgemma(
            prompt=user_prompt,
            prompt_template=prompt_template,
            prompt_variables=prompt_variables,
            image=image if isinstance(image, Image.Image) else None,
            system_instruction_template= system_instruction,
            system_instruction_variables={
                "specialty": "tropical pulmonary diseases",
                "audience": "clinicians",
                "test_id": case_dir.name
            },
        )

        # Read ground truth
        ground_truth_path = case_dir / "ground_truth.txt"
        ground_truth = (
            ground_truth_path.read_text(encoding="utf-8").strip()
            if ground_truth_path.exists()
            else ""
        )

        # Store results
        results.append(
            {
                "case_id": case_dir.name,
                "predicted_disease": response.get("predicted"),
                "model_response":response.get("model_response"),
                "answer": response.get("answer"),
                #"latency": response.get("latency"),
                "ground_truth": ground_truth,
                #"resolved_prompt": response.get("resolved_prompt"),
                #"resolved_system_instruction": response.get("resolved_system_instruction"),
            }
        )

        case_counter += 1

    print(f"\nProcessed {case_counter} cases.")
    return results


## Main
testcase_results = run_diagnosis_testcase()
pretty_print_result(testcase_results)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda


Scanning for cases in: /content/drive/MyDrive/Data/Test-Data/extracted-data


Device set to use cuda:0
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset



Processed 33 cases.

Case ID: 001

Answer:
 System:
You are an expert in tropical pulmonary diseases communicating with clinicians. Read the clinical case and provide the single most likely diagnosis.
### Format your response as:
**Predicted disease:**[Respond with most accurate diagnosis name]
**Reason:**[Short explanation]

User:
You are diagnosing tropical diseases for a patient:
Question: A 53-year-old woman presented with fever, cough, and malaise after returning from a visit to Lahore. On examination, her temperature was 38°C and she had a rash on her upper chest. A chest X-ray showed patchy basal consolidation and a full blood count revealed a relative lymphocytosis. Malaria films were negative. Blood cultures were drawn and later grew gram-negative bacilli.
Patient age: 53
Primary symptoms: List of symptoms of fever, cough, malaise

Assistant:
**Predicted disease:** Tuberculosis
**Reason:** The patient's symptoms (fever, cough, malaise), chest X-ray findings (patchy basal cons

In [25]:
# SAVE TO EXCEL
import pandas as pd
df = pd.DataFrame(testcase_results)
output_path = "/content/testcase_medgemma_without_rag.xlsx"
df.to_excel(output_path, index=False)
print(f"\nExcel saved to: {output_path}")


Excel saved to: /content/testcase_medgemma_without_rag.xlsx


# MedGemma with RAG

In [None]:
import os
import re
import time
import warnings
import yaml
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
import requests
import torch
from transformers import pipeline
from PIL import Image
from io import BytesIO
from pathlib import Path

# Initialize the MedGemma 4B model with optional quantization
pipe = pipeline(
    "image-text-to-text",
    model="google/medgemma-4b-it",
    dtype=torch.bfloat16,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

# Configure dataset path and inference options
TEST_DATA_ROOT = Path("/content/drive/MyDrive/Data/Test-Data/extracted-data")
USE_RAG = False  # Set True to enable retrieval with the original case text
MAX_CASES = None  # Set to an integer to cap processed cases

KEYWORD_MODEL_DEFAULT = os.getenv("RAG_KEYWORD_MODEL", "google/flan-t5-small")
KEYWORD_PROMPT_DEFAULT = os.getenv(
    "RAG_KEYWORD_PROMPT",
    (
        "Extract up to {max_keywords} concise clinical keywords or short phrases "
        "from the following request. Return only a comma-separated list without numbering.\n"
        "Request: {prompt}"
    ),
)

@lru_cache(maxsize=4)
def get_keyword_extractor(model_id: str = KEYWORD_MODEL_DEFAULT):
    device = 0 if torch.cuda.is_available() else -1
    return pipeline("text2text-generation", model=model_id, device=device)

def extract_keywords_with_llm(
    prompt_text: str,
    *,
    max_keywords: int = 6,
    model_id: Optional[str] = None,
    instruction_template: Optional[str] = None,
    **generate_kwargs: Any,
 ) -> List[str]:
    """Use a small seq2seq model to extract search-oriented keywords."""
    if not prompt_text or not prompt_text.strip():
        return []
    chosen_model = model_id or KEYWORD_MODEL_DEFAULT
    instruction = (
        instruction_template or KEYWORD_PROMPT_DEFAULT
    ).format(
        max_keywords=max_keywords,
        prompt=prompt_text.strip(),
        user_prompt=prompt_text.strip(),
    )
    extractor = get_keyword_extractor(chosen_model)
    response = extractor(
        instruction,
        max_new_tokens=64,
        num_return_sequences=1,
        **generate_kwargs,
    )
    generated = ""
    if response and isinstance(response, list):
        candidate = response[0]
        generated = candidate.get("generated_text") or candidate.get("text") or ""
    raw_tokens = re.split(r"[\n,;]|\band\b", generated)
    seen: set[str] = set()
    keywords: List[str] = []
    for token in raw_tokens:
        cleaned = token.strip("\t :.-\u2022\u2023\uf0b7")
        if not cleaned:
            continue
        lowered = cleaned.lower()
        if lowered in seen:
            continue
        seen.add(lowered)
        keywords.append(cleaned)
        if len(keywords) >= max_keywords:
            break
    return keywords

def render_text_template(
    base_text: Optional[str],
    *,
    template: Optional[str] = None,
    variables: Optional[Dict[str, Any]] = None,
    field_name: str = "prompt",
    defaults: Optional[Dict[str, Any]] = None,
 ) -> str:
    """Render free-form text using a format template when provided."""
    text_value = base_text or ""
    if template:
        template_vars: Dict[str, Any] = {
            "prompt": text_value,
            "user_prompt": text_value,
        }
        if defaults:
            template_vars.update(defaults)
        if variables:
            template_vars.update(variables)
        try:
            rendered = template.format(**template_vars)
        except KeyError as exc:
            missing_key = exc.args[0] if exc.args else "unknown"
            raise ValueError(f"{field_name} template expects variable '{missing_key}'.") from exc
        if rendered.strip():
            return rendered
        if text_value.strip():
            return text_value
        raise ValueError(f"{field_name.title()} template rendered an empty string.")
    if text_value.strip():
        return text_value
    raise ValueError(f"Provide a {field_name} or {field_name} template.")

def load_image(image_url: Optional[str] = None, image_path: Optional[str] = None):
    """Load an image from a URL or disk, returning the PIL image or an error dict."""
    if image_url:
        try:
            response = requests.get(image_url)
            response.raise_for_status()
            image_obj = Image.open(BytesIO(response.content))
            image_obj.verify()
            image_obj = Image.open(BytesIO(response.content))
            return image_obj
        except Exception as exc:
            return {"error": "Failed to load image", "message": str(exc)}
    if image_path:
        try:
            image_obj = Image.open(image_path)
            image_obj.verify()
            image_obj = Image.open(image_path)
            return image_obj
        except Exception as exc:
            return {"error": "Failed to load image", "message": str(exc)}
    return None

def find_case_images(case_dir: Path) -> list[Path]:
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}
    image_files = []

    for ext in image_extensions:
        image_files.extend(case_dir.glob(f"*{ext}"))
        image_files.extend(case_dir.glob(f"*{ext.upper()}"))

    images_dir = case_dir / "images"
    if images_dir.exists() and images_dir.is_dir():
        for ext in image_extensions:
            image_files.extend(images_dir.glob(f"*{ext}"))
            image_files.extend(images_dir.glob(f"*{ext.upper()}"))

    return sorted(image_files)

def load_case_images(case_dir: Path) -> list[Image.Image]:
    image_paths = find_case_images(case_dir)
    loaded_images = []

    for img_path in image_paths:
        try:
            image = Image.open(img_path)
            image.verify()
            image = Image.open(img_path)
            loaded_images.append(image)
        except Exception as e:
            print(f"Failed to load image {img_path.name}: {e}")
    return loaded_images

def build_system_content(system_instruction: str, context: str = "") -> List[Dict[str, Any]]:
    """Build system content for MedGemma."""
    content = [{"type": "text", "text": system_instruction}]
    if context:
        content.append({"type": "text", "text": f"Context: {context}"})
    return content

def build_user_content(prompt: str, image: Optional[Image.Image] = None) -> List[Dict[str, Any]]:
    """Build user content for MedGemma."""
    content = [{"type": "text", "text": prompt}]
    if image:
        content.append({"type": "image", "image": image})
    return content

def extract_answer_from_response(response) -> str:
    """Extract the answer text from MedGemma's response structure."""
    if not response:
        return "No answer."

    if isinstance(response, list) and len(response) > 0:
        generated_item = response[0]

        if isinstance(generated_item, dict) and "generated_text" in generated_item:
            generated_text = generated_item["generated_text"]

            if isinstance(generated_text, list):
                for msg in reversed(generated_text):
                    if isinstance(msg, dict) and msg.get("role") == "assistant":
                        content = msg.get("content", "")
                        if isinstance(content, list):
                            for chunk in content:
                                if isinstance(chunk, dict) and chunk.get("type") == "text":
                                    return chunk.get("text", "No text found")
                        elif isinstance(content, str):
                            return content
                        return str(content)

                last_msg = generated_text[-1]
                if isinstance(last_msg, dict) and "content" in last_msg:
                    return last_msg["content"]

            elif isinstance(generated_text, str):
                return generated_text

        elif isinstance(generated_item, str):
            return generated_item

    return str(response)

def _extract_predicted_disease(answer: str) -> str:
    """Extract predicted disease from answer text."""
    if not answer:
        return ""

    patterns = [
        r"Predicted disease:\s*(.+)",
        r"Diagnosis:\s*(.+)",
        r"Condition:\s*(.+)",
        r"Disease:\s*(.+)"
    ]

    for pattern in patterns:
        match = re.search(pattern, answer, flags=re.IGNORECASE)
        if match:
            return match.group(1).strip()

    return answer.strip()

def run_medgemma(
    prompt: Optional[str] = None,
    *,
    prompt_template: Optional[str] = None,
    prompt_variables: Optional[Dict[str, Any]] = None,
    image: Optional[Image.Image] = None,
    system_instruction: str = "You are an expert in tropical diseases.",
    system_instruction_template: Optional[str] = None,
    system_instruction_variables: Optional[Dict[str, Any]] = None,
    max_new_tokens: int = 2000,
    temperature: float = 0.0,
    **generate_kwargs: Any,
 ) -> Dict[str, Any]:
    """Generate an answer with MedGemma without RAG context."""
    start_time = time.time()
    output: Dict[str, Any] = {
        "answer": None,
        "latency": None,
        "error": None,
        "resolved_prompt": None,
        "resolved_system_instruction": None,
        "predicted": None,
    }

    # Render final prompt
    try:
        final_prompt = render_text_template(
            prompt,
            template=prompt_template,
            variables=prompt_variables,
            field_name="prompt",
        )
        output["resolved_prompt"] = final_prompt
    except Exception as e:
        output["error"] = f"Prompt rendering error: {e}"
        output["latency"] = time.time() - start_time
        return output

    # Render system instruction
    try:
        final_system_instruction = render_text_template(
            system_instruction,
            template=system_instruction_template,
            variables=system_instruction_variables,
            field_name="system instruction",
            defaults={"final_prompt": final_prompt},
        )
        output["resolved_system_instruction"] = final_system_instruction
    except Exception as e:
        output["error"] = f"System instruction rendering error: {e}"
        output["latency"] = time.time() - start_time
        return output

    # Build messages for MedGemma
    system_content = build_system_content(final_system_instruction)
    user_content = build_user_content(
        final_prompt,
        image=image if isinstance(image, Image.Image) else None,
    )
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": user_content},
    ]

    try:
        response = pipe(
            messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            **generate_kwargs,
        )

        # Extract the answer text properly
        output["model_response"] = response[0]["generated_text"][-1]["content"] if response else "No answer."
        generated = response[0].get("generated_text")
        extracted_anw = extract_answer_from_response(response)
        output["answer"] = generated
        # Extract prediction
        output["predicted"] = _extract_predicted_disease(extracted_anw)


    except Exception as exc:
        output["error"] = f"Model inference error: {exc}"
    finally:
        output["latency"] = time.time() - start_time

    return output

def run_medgemma_rag(
    prompt: Optional[str] = None,
    *,
    prompt_template: Optional[str] = None,
    prompt_variables: Optional[Dict[str, Any]] = None,
    image: Optional[Image.Image] = None,
    rag_query: Optional[str] = None,
    rag_top_k: Optional[int] = None,
    rag_use_keyword_llm: bool = False,
    rag_keyword_model_id: Optional[str] = None,
    rag_keyword_max_items: int = 6,
    rag_keyword_prompt: Optional[str] = None,
    system_instruction: str = "You are an expert in tropical diseases.",
    system_instruction_template: Optional[str] = None,
    system_instruction_variables: Optional[Dict[str, Any]] = None,
    max_new_tokens: int = 2000,
    temperature: float = 0.0,
    **generate_kwargs: Any,
 ) -> Dict[str, Any]:
    """Generate an answer with MedGemma using RAG context when available."""
    start_time = time.time()
    output: Dict[str, Any] = {
        "answer": None,
        "citations": [],
        "context": "",
        "references": [],
        "latency": None,
        "error": None,
        "resolved_prompt": None,
        "resolved_system_instruction": None,
        "rag_query_used": None,
        "rag_keywords": [],
        "predicted": None,
    }

    # Render final prompt
    final_prompt = render_text_template(
        prompt,
        template=prompt_template,
        variables=prompt_variables,
        field_name="prompt",
)
    output["resolved_prompt"] = final_prompt


    # Extract keywords for RAG if enabled
    keyword_query: Optional[str] = None
    extracted_keywords: List[str] = []
    if rag_use_keyword_llm and final_prompt.strip():
        extracted_keywords = extract_keywords_with_llm(
            final_prompt,
            max_keywords=rag_keyword_max_items,
            model_id=rag_keyword_model_id,
            instruction_template=rag_keyword_prompt,
        )
        if extracted_keywords:
            keyword_query = ", ".join(extracted_keywords)
            output["rag_keywords"] = extracted_keywords

    # Gather RAG context
    context_text = ""
    references = []
    if rag_query is not False:
        if isinstance(rag_query, str) and rag_query.strip():
            query_source = rag_query.strip()
        elif keyword_query:
            query_source = keyword_query
        else:
            query_source = final_prompt

        context_text, references = gather_rag_context(query_source, top_k=rag_top_k)
        output["context"] = context_text
        output["references"] = references
        output["rag_query_used"] = query_source

    # Render system instruction
    try:
        final_system_instruction = render_text_template(
            system_instruction,
            template=system_instruction_template,
            variables=system_instruction_variables,
            field_name="system instruction",
            defaults={
                "final_prompt": final_prompt,
                "rag_has_matches": bool(references),
            },
        )
        output["resolved_system_instruction"] = final_system_instruction
    except Exception as e:
        output["error"] = f"System instruction rendering error: {e}"
        output["latency"] = time.time() - start_time
        return output

    # Build messages with RAG context
    system_content = build_system_content(final_system_instruction, context=context_text)
    user_content = build_user_content(
        final_prompt,
        image=image if isinstance(image, Image.Image) else None,
    )
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": user_content},
    ]

    try:
        response = pipe(
            messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            **generate_kwargs,
        )

        # Extract the answer
        generated = response[0].get("generated_text") if response else None
        output["model_response"] = response[0]["generated_text"][-1]["content"] if response else "No answer."
        output["answer"] = generated or "No answer."
        extracted_answer = extract_answer_from_response(response)
        output["predicted"] = _extract_predicted_disease(extracted_answer)

    except Exception as exc:
        output["error"] = f"Model inference error: {exc}"
    finally:
        output["latency"] = time.time() - start_time

    # Process citations
    if references:
        citations = []
        for ref in references:
            payload = ref.get("payload", {})
            citation_info = {
                "case_id": payload.get("case_id"),
                "meta_ref":payload.get("meta_ref"),
                "chunk_id": payload.get("chunk_id"),
                "source_pdf": f'https://{payload.get("source_pdf")}'
            }
            citations.append(citation_info)
        output["citations"] = citations

    return output

def _format_chat_like_answer(answer: Any) -> str:
    """Convert chat-style answers into a readable text block."""
    if answer is None:
        return ""
    if isinstance(answer, str):
        return answer
    if not isinstance(answer, list):
        return str(answer)

    formatted_lines: List[str] = []
    for message in answer:
        if not isinstance(message, dict):
            formatted_lines.append(str(message))
            continue
        role = str(message.get("role", "assistant")).strip() or "assistant"
        content = message.get("content")
        segments: List[str] = []
        if isinstance(content, list):
            for chunk in content:
                if not isinstance(chunk, dict):
                    segments.append(str(chunk))
                    continue
                chunk_type = chunk.get("type")
                if chunk_type == "text":
                    text = chunk.get("text")
                    if text:
                        segments.append(str(text))
                elif chunk_type == "image":
                    segments.append("[Image provided]")
        elif content is not None:
            segments.append(str(content))

        body = "\n".join(seg for seg in segments if seg).strip()
        if body:
            header = role.title()
            formatted_lines.append(f"{header}:\n{body}")
    return "\n\n".join(formatted_lines).strip()

def pretty_print_result(result) -> None:
    """Utility to display the result in a readable form."""
    if isinstance(result, list):
        for i, case_result in enumerate(result):
            print(f"\n{'='*50}")
            print(f"Case ID: {case_result.get('case_id', 'Unknown')}")
            print(f"{'='*50}")
            _print_single_result(case_result)
        return

    _print_single_result(result)

def _print_single_result(result) -> None:
    """Print a single case result."""
    if result.get("error"):
        print("Error:", result["error"])
        return
    if result.get("resolved_prompt"):
        print("Prompt (resolved):\n", result["resolved_prompt"])
    if result.get("rag_keywords"):
        print("\nKeywords for retrieval:")
        print(", ".join(result["rag_keywords"]))
    answer_text = _format_chat_like_answer(result.get("answer"))
    print("\nAnswer:\n", answer_text)
    if result.get("predicted"):
        print(f"\nPredicted disease: {result.get('predicted')}")
    if result.get("context"):
        print("\nContext (truncated):\n", result["context"][:500])
    if result.get("citations"):
        print("\nCitations:")
        for i, citation in enumerate(result.get("citations", []), 1):
            print(f"[{i}] {citation}")
    print(f"\nLatency: {result.get('latency', 0.0):.2f}s")

def derive_prompt_variables(
    user_prompt: str,
    symptom_limit: int = 12,
    age_model_id: Optional[str] = None,
    symptom_model_id: Optional[str] = None,
) -> Dict[str, Any]:
    """Build template variables via the keyword LLM to keep prompts DRY."""
    if not user_prompt or not user_prompt.strip():
        return {}
    variables: Dict[str, Any] = {}

    # Extract age using regex
    age_match = re.search(r'(\d{1,3})\s*[- ]?\s*(year|years|yr|yo|y/o)[- ]?(old)?', user_prompt.lower())
    if age_match:
        try:
            variables["age"] = int(age_match.group(1))
        except ValueError:
            variables["age"] = "unknown"
    else:
        # Fallback to LLM if regex fails
        age_instruction = (
            "Extract the patient's age from the statement. "
            "Respond with a single number if the age is given, otherwise respond with NONE.\n"
            "Patient statement: {prompt}"
        )
        age_candidates = extract_keywords_with_llm(
            user_prompt,
            max_keywords=1,
            model_id=age_model_id,
            instruction_template=age_instruction,
        )
        for token in age_candidates:
            digit_match = re.search(r"\d{1,3}", token)
            if digit_match:
                try:
                    variables["age"] = int(digit_match.group())
                    break
                except ValueError:
                    continue

    # Extract symptoms
    symptom_instruction = (
        "List up to {max_keywords} concise clinical symptoms or findings mentioned by the patient. "
        "Return a comma-separated list without numbering.\n"
        "Patient statement: {prompt}"
    )
    symptoms = extract_keywords_with_llm(
        user_prompt,
        max_keywords=symptom_limit,
        model_id=symptom_model_id,
        instruction_template=symptom_instruction,
    )
    if symptoms:
        variables["symptoms"] = ", ".join(symptoms)

    # Set defaults
    if "age" not in variables:
        variables["age"] = "unknown"
    if "symptoms" not in variables:
        variables["symptoms"] = "unspecified"

    return variables

def run_diagnosis_testcase(use_rag: bool = USE_RAG, max_cases: Optional[int] = MAX_CASES):
    """Run diagnosis test cases with optional RAG."""
    if not TEST_DATA_ROOT.exists():
        raise FileNotFoundError(f"Test data directory not found: {TEST_DATA_ROOT}")

    results = []
    case_counter = 0

    print(f"Scanning for cases in: {TEST_DATA_ROOT}")
    print(f"Using RAG: {use_rag}")

    for case_dir in sorted(TEST_DATA_ROOT.iterdir()):
        if max_cases is not None and case_counter >= max_cases:
            break
        if not case_dir.is_dir():
            continue

        query_path = case_dir / "query.txt"
        if not query_path.exists():
            continue

        # Read user prompt
        user_prompt = query_path.read_text(encoding="utf-8").strip()
        prompt_variables = derive_prompt_variables(
            user_prompt=user_prompt,
            symptom_limit=12,
        )

        # Load prompt template
        prompt_type = "clinical_qa"
        try:
            with open(f'/content/prompts/medgemma/{prompt_type}.yaml', 'r') as file:
                config = yaml.safe_load(file)
            prompt_template = config['prompt_template']
            system_instruction = config['system_prompt']
        except Exception as e:
            print(f"Warning: Could not load prompt template: {e}")
            # Fallback templates
            prompt_template = "Patient query: {prompt}\n\nPlease provide a diagnosis."
            system_instruction = "You are a medical expert specializing in tropical diseases."

        # Load images
        case_images = load_case_images(case_dir)
        image = case_images[0] if case_images else None

        # Run inference with or without RAG
        if use_rag:
            response = run_medgemma_rag(
                prompt=user_prompt,
                prompt_template=prompt_template,
                prompt_variables=prompt_variables,
                image=image,
                rag_query=None,
                rag_top_k=4,
                rag_use_keyword_llm=True,
                rag_keyword_max_items=5,
                system_instruction_template=system_instruction,
                system_instruction_variables={
                    "specialty": "tropical pulmonary diseases",
                    "audience": "clinicians",
                    "test_id": case_dir.name
                },
            )
        else:
            response = run_medgemma(
                prompt=user_prompt,
                prompt_template=prompt_template,
                prompt_variables=prompt_variables,
                image=image,
                system_instruction_template=system_instruction,
                system_instruction_variables={
                    "specialty": "tropical pulmonary diseases",
                    "audience": "clinicians",
                    "test_id": case_dir.name
                },
            )

        # Read ground truth
        ground_truth_path = case_dir / "ground_truth.txt"
        ground_truth = (
            ground_truth_path.read_text(encoding="utf-8").strip()
            if ground_truth_path.exists()
            else ""
        )

        # Store results
        result_data = {
            "case_id": case_dir.name,
            "predicted_disease": response.get("predicted"),
            "model_response":response.get("model_response"),
            "answer": response.get("answer"),
            #"latency": response.get("latency"),
            "ground_truth": ground_truth,
            #"resolved_prompt": response.get("resolved_prompt"),
            #"resolved_system_instruction": response.get("resolved_system_instruction"),
        }

        # Add RAG-specific fields if using RAG
        if use_rag:
            result_data.update({
                # "rag_keywords": response.get("rag_keywords", []),
                # "rag_query_used": response.get("rag_query_used"),
                "citations": response.get("citations", []),
                # "context_length": len(response.get("context", "")),
            })

        results.append(result_data)
        case_counter += 1

    print(f"\nProcessed {case_counter} cases.")
    return results

## Main execution
#medgemma without RAG
# print("=== Testing without RAG ===")
# testcase_results_no_rag = run_diagnosis_testcase(use_rag=False)
# pretty_print_result(testcase_results_no_rag)

# medgemma with RAG (if RAG system is available)
print("\n=== Testing with RAG ===")
testcase_results_rag = run_diagnosis_testcase(use_rag=True)
pretty_print_result(testcase_results_rag)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda



=== Testing with RAG ===
Scanning for cases in: /content/drive/MyDrive/Data/Test-Data/extracted-data
Using RAG: True


Device set to use cuda:0
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset



Processed 33 cases.

Case ID: 001

Answer:
 System:
You are an expert in tropical pulmonary diseases communicating with clinicians. Read the clinical case and provide the single most likely diagnosis.
### Format your response as:
**Predicted disease:**[Respond with most accurate diagnosis name]
**Reason:**[Short explanation]

Context: ## Case 51---A-34-Year-Old-HIV-Positive-Woman-from-Malawi-W_2022_Clinical-Cases-in-T (score=0.022)
**Patient Information**: A 34-year-old Malawian woman diagnosed with smear-positive pulmonary tuberculosis and HIV 5 months prior to presentation. At that time, her baseline CD4 count was 54/µL. She is currently taking antituberculous therapy, vitamin B6, antiretroviral therapy, and co-trimoxazole prophylaxis.
**Chief Complaint**: Slowly progressive weakness of the left arm and leg over a 3-month period.
**History Of Present Illness**: The patient's problems began approximately 3 months earlier with a limp in her left leg. The weakness progressed insidiousl

In [None]:
# SAVE TO EXCEL
import pandas as pd
# df1 = pd.DataFrame(testcase_results_no_rag)
# output_path_1 = "/content/testcase_medgemma_without_rag.xlsx"
# df1.to_excel(output_path_1, index=False)
# print(f"\nExcel saved to: {output_path_1}")

df2 = pd.DataFrame(testcase_results_rag)
output_path_2 = "/content/testcase_medgemma_with_rag.xlsx"
df2.to_excel(output_path_2, index=False)
print(f"\nExcel saved to: {output_path_2}")


Excel saved to: /content/testcase_medgemma_with_rag.xlsx


#Reset Memory

In [20]:
del pipe
torch.cuda.empty_cache()

In [21]:
import gc
gc.collect()


254

In [22]:
import torch
torch.cuda.empty_cache()

In [23]:
import pkg_resources
import google.protobuf
import transformers
import torch

print("qdrant-client:", pkg_resources.get_distribution("qdrant-client").version)
print("protobuf:", google.protobuf.__version__)
print("transformers:", transformers.__version__)
print("torch:", torch.__version__)


qdrant-client: 1.16.0
protobuf: 5.29.5
transformers: 4.57.1
torch: 2.8.0+cu126
