<a href="https://colab.research.google.com/github/Long2511/ai-project/blob/main/MedGemma_client_%26_runner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MedGemma client & runner

In [None]:
!pip install -U transformers
!pip install -U huggingface_hub transformers
!pip install transformers psutil torch datasets

Collecting huggingface_hub
  Downloading huggingface_hub-1.0.1-py3-none-any.whl.metadata (13 kB)
Collecting typer-slim (from huggingface_hub)
  Downloading typer_slim-0.20.0-py3-none-any.whl.metadata (16 kB)


In [None]:
import os
import sys

google_colab = "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT")

if google_colab:
    # Use secret if running in Google Colab
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
else:
    # Store Hugging Face data under `/content` if running in Colab Enterprise
    if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
        os.environ["HF_HOME"] = "/content/hf"
    # Authenticate with Hugging Face
    from huggingface_hub import get_token
    if get_token() is None:
        from huggingface_hub import notebook_login
        notebook_login()

In [None]:
!pip install --upgrade --quiet accelerate bitsandbytes transformers qdrant-client fastembed qdrant-client[fastembed]

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.3/337.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.3/105.3 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.3/103.3 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.8/324.8 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from transformers import BitsAndBytesConfig
import torch

model_variant = "4b-it"  # @param ["4b-it", "27b-it", "27b-text-it", "4b-pt"]
model_id = f"google/medgemma-{model_variant}"

use_quantization = True  # @param {type: "boolean"}

# @markdown Set `is_thinking` to `True` to turn on thinking mode. **Note:** Thinking is supported for the 27B variants only.
is_thinking = False  # @param {type: "boolean"}

# If running a 27B variant in Google Colab, check if the runtime satisfies
# memory requirements
if "27b" in model_variant and google_colab:
    if not ("A100" in torch.cuda.get_device_name(0) and use_quantization):
        raise ValueError(
            "Runtime has insufficient memory to run a 27B variant. "
            "Please select an A100 GPU and use 4-bit quantization."
        )

model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

if use_quantization:
    model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)

In [None]:
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List

from qdrant_client import QdrantClient, models

from fastembed import TextEmbedding


@dataclass
class RagSettings:
    qdrant_host: str = os.getenv("QDRANT_HOST", "165.22.56.15")
    qdrant_port: int = int(os.getenv("QDRANT_PORT", 6333))
    qdrant_api_key: str | None = os.getenv("QDRANT_API_KEY")
    qdrant_collection: str = os.getenv("QDRANT_COLLECTION", "tropical_cases_chunks")
    vector_name: str = os.getenv("QDRANT_VECTOR_NAME", "text")
    embed_model: str = os.getenv("EMBED_MODEL", "qdrant/fastembed-multilingual-v3")
    embed_fallbacks: str = os.getenv("EMBED_MODEL_FALLBACKS", "")
    rag_top_k: int = int(os.getenv("RAG_TOP_K", 4))
    context_prefix: str = os.getenv(
        "RAG_CONTEXT_PREFIX",
        "Use the following reference material when reasoning about the case:",
    )


rag_settings = RagSettings()


In [None]:
@lru_cache(maxsize=1)
def get_qdrant_client() -> QdrantClient:
    return QdrantClient(
        host=rag_settings.qdrant_host,
        port=rag_settings.qdrant_port,
        api_key=rag_settings.qdrant_api_key,
        prefer_grpc=False,
        timeout=120.0,
    )


def _candidate_embed_models() -> List[str]:
    env_fallbacks = []
    if rag_settings.embed_fallbacks:
        env_fallbacks = [model.strip() for model in rag_settings.embed_fallbacks.split(",") if model.strip()]
    hardcoded = [
        "qdrant/fastembed-multilingual-v1",
        "qdrant/fastembed-multilingual",
        "BAAI/bge-m3",
        "BAAI/bge-base-en-v1.5",
    ]
    candidates: List[str] = [rag_settings.embed_model, *env_fallbacks, *hardcoded]
    seen: set[str] = set()
    unique_candidates: List[str] = []
    for model_name in candidates:
        if not model_name or model_name in seen:
            continue
        seen.add(model_name)
        unique_candidates.append(model_name)
    return unique_candidates


@lru_cache(maxsize=1)
def get_embedder() -> TextEmbedding:
    errors: List[str] = []
    for model_name in _candidate_embed_models():
        try:
            return TextEmbedding(model_name=model_name)
        except ValueError as exc:
            errors.append(f"{model_name}: {exc}")
        except Exception as exc:
            errors.append(f"{model_name}: {exc}")
    error_msg = "Failed to initialise embedding model. "
    if errors:
        error_msg += "; ".join(errors)
    raise RuntimeError(error_msg)


def embed_query(query: str) -> List[float]:
    embedder = get_embedder()
    vectors = list(embedder.embed([query]))
    if not vectors:
        raise ValueError("Failed to compute embedding for query.")
    return vectors[0]


def retrieve_references(query: str, top_k: int | None = None) -> List[Dict[str, Any]]:
    top_k = top_k or rag_settings.rag_top_k
    vector = embed_query(query)
    client = get_qdrant_client()
    search_kwargs = dict(
        collection_name=rag_settings.qdrant_collection,
        limit=top_k,
        with_payload=True,
    )
    hits = []
    if rag_settings.vector_name:
        try:
            hits = client.search(
                query_vector=models.NamedVector(name=rag_settings.vector_name, vector=vector),
                **search_kwargs,
            )
        except Exception as exc:
            if "not configured" not in str(exc).lower():
                raise
            # Fall back to default unnamed vector when the requested name is absent.
            hits = client.search(query_vector=vector, **search_kwargs)
    else:
        hits = client.search(query_vector=vector, **search_kwargs)
    return [
        {
            "id": hit.id,
            "score": hit.score,
            "payload": hit.payload or {},
        }
        for hit in hits
    ]



def format_references(references: List[Dict[str, Any]]) -> str:
    if not references:
        return ""
    chunks: List[str] = []
    for idx, ref in enumerate(references, start=1):
        payload = ref.get("payload", {})
        title = payload.get("title") or payload.get("case_title") or f"Reference {idx}"
        summary = payload.get("summary") or payload.get("text") or payload.get("content") or ""
        summary = str(summary).strip()
        score = ref.get("score")
        prefix = f"## {title}\n"
        if score is not None:
            prefix = f"## {title} (score={score:.3f})\n"
        chunks.append(prefix + summary[:2500])
    return "\n\n".join(chunks)


def gather_rag_context(query: str | None, top_k: int | None = None) -> tuple[str, List[Dict[str, Any]]]:
    if not query:
        return "", []
    references = retrieve_references(query, top_k=top_k)
    return format_references(references), references


In [None]:
def build_system_content(base_instruction: str, *, context: str | None = None) -> List[Dict[str, str]]:
    content: List[Dict[str, str]] = [{"type": "text", "text": base_instruction}]
    if context:
        content.append(
            {
                "type": "text",
                "text": f"{rag_settings.context_prefix}\n\n{context}",
            }
        )
    return content


def build_user_content(prompt: str, *, image: Any | None = None) -> List[Dict[str, Any]]:
    content: List[Dict[str, Any]] = [{"type": "text", "text": prompt}]
    if image is not None:
        content.append({"type": "image", "image": image})
    return content


In [None]:
!pip install transformers psutil torch datasets



In [None]:
from google.colab import files
from PIL import Image
uploaded = files.upload()  # Upload the image file from local system

# Get the uploaded file name
image_path = next(iter(uploaded))
image = Image.open(image_path)


Saving ebola.jpg to ebola.jpg


# MedGemma with RAG

In [None]:
import time
from typing import Any, Dict, Optional
import requests
import torch
from transformers import pipeline
from PIL import Image
from io import BytesIO
from google.colab import files

# Initialize the MedGemma 4B model with optional quantization
pipe = pipeline(
    "image-text-to-text",
    model="google/medgemma-4b-it",
    torch_dtype=torch.bfloat16,
    device="cuda" if torch.cuda.is_available() else "cpu",
)


def load_image(image_url: Optional[str] = None, image_path: Optional[str] = None):
    """Load an image from a URL or disk, returning the PIL image or an error dict."""
    if image_url:
        try:
            response = requests.get(image_url)
            response.raise_for_status()
            image_obj = Image.open(BytesIO(response.content))
            image_obj.verify()  # quick integrity check
            image_obj = Image.open(BytesIO(response.content))  # reopen after verify
            return image_obj
        except Exception as exc:
            return {"error": "Failed to load image", "message": str(exc)}
    if image_path:
        try:
            image_obj = Image.open(image_path)
            image_obj.verify()
            image_obj = Image.open(image_path)
            return image_obj
        except Exception as exc:
            return {"error": "Failed to load image", "message": str(exc)}
    return None


def run_medgemma_rag(
    prompt: str,
    *,
    image: Optional[Image.Image] = None,
    rag_query: Optional[str] = None,
    rag_top_k: Optional[int] = None,
    system_instruction: str = "You are an expert in tropical diseases.",
    max_new_tokens: int = 200,
    temperature: float = 0.0,
    **generate_kwargs: Any,
 ) -> Dict[str, Any]:
    """Generate an answer with MedGemma using Qdrant-powered RAG context when available."""
    start_time = time.time()
    output: Dict[str, Any] = {
        "answer": None,
        "citations": [],
        "context": "",
        "references": [],
        "latency": None,
        "error": None,
    }

    context_text = ""
    references = []
    if rag_query is not False:
        query = rag_query or prompt
        context_text, references = gather_rag_context(query, top_k=rag_top_k)
        output["context"] = context_text
        output["references"] = references

    system_content = build_system_content(system_instruction, context=context_text)
    user_content = build_user_content(prompt, image=image if isinstance(image, Image.Image) else None)
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": user_content},
    ]

    try:
        response = pipe(
            text=messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            **generate_kwargs,
        )
        generated = response[0].get("generated_text") if response else None
        output["answer"] = generated or "No answer."
    except Exception as exc:
        output["error"] = f"Model inference error: {exc}"
    finally:
        output["latency"] = time.time() - start_time

    if references:
        citations = []
        for ref in references:
            payload = ref.get("payload", {})
            citations.append(
                payload.get("source")
                or payload.get("url")
                or payload.get("case_title")
                or payload.get("title")
                or str(ref.get("id"))
            )
        output["citations"] = [c for c in citations if c]
    return output


def pretty_print_rag_result(result: Dict[str, Any]) -> None:
    """Utility to display the RAG result in a readable form."""
    if result.get("error"):
        print("Error:", result["error"])
        return
    print("Answer:\n", result.get("answer", ""))
    if result.get("context"):
        print("\nContext (truncated):\n", result["context"][:500])
    if result.get("citations"):
        print("\nCitations:")
        for citation in result["citations"]:
            print("-", citation)
    print(f"\nLatency: {result.get('latency', 0.0):.2f}s")


# Example usage
image_url = None  # e.g. "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
image = globals().get("image", None)
if image is None and image_url:
    image = load_image(image_url=image_url)

rag_result = run_medgemma_rag(
    prompt="Describe this X-ray with likely diagnosis.",
    image=image if isinstance(image, Image.Image) else None,
    rag_query="pulmonary tuberculosis chest radiograph findings",
    rag_top_k=4,
)
pretty_print_rag_result(rag_result)

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Device set to use cuda


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/740 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

model_optimized.onnx:   0%|          | 0.00/218M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

  hits = client.search(
  hits = client.search(query_vector=vector, **search_kwargs)


Answer:
 [{'role': 'system', 'content': [{'type': 'text', 'text': 'You are an expert radiologist.'}, {'type': 'text', 'text': 'Use the following reference material when reasoning about the case:\n\n## A 24 Year Old Man of Turkish Origin With Jau_2022_Clinical Cases in Tro (score=0.087)\n\n\n## A 23 Year Old Farmer from Myanmar With Uni_2022_Clinical Cases in Tropi (score=0.086)\n\n\n## A 28 Year Old Male Fisherman from Malawi Wi_2022_Clinical Cases in Tropi (score=0.084)\n\n\n## 32 Year Old Woman from Nigeria With Jaund_2022_Clinical Cases in Tropic (score=0.083)\n'}]}, {'role': 'user', 'content': [{'type': 'text', 'text': 'Describe this X-ray with likely diagnosis.'}, {'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=660x426 at 0x7B5CF7E7B530>}]}, {'role': 'assistant', 'content': "Based on the X-ray image, the most likely diagnosis is **cutaneous anthrax**.\n\nHere's a breakdown of the findings and why they point to this diagnosis:\n\n*   **Multiple, ro