# Qwen2-VL + Qdrant RAG Evaluation (Text → case_id → Image Join)

This notebook evaluates your current **Qwen2-VL fine-tuned model** with a **RAG pipeline**:

1) Retrieve relevant text chunks from `case_text_bge_v1_5` (BGE embeddings)  
2) Rerank with `BAAI/bge-reranker-base`  
3) Extract top `case_id`(s) from those text hits  
4) **Join** into `case_image_clip_l14` by the same `case_id` to fetch matching `image_path`(s)  
5) Load **patient images** (from test folder) + **reference images** (from matched retrieved cases)  
6) Feed everything into **Qwen2-VL** and save results to Excel

You’ll also get debug columns showing whether reference images were loaded successfully.


In [None]:
# =========================
# 0) CONFIG
# =========================
QDRANT_URL       = "http://165.22.56.15:6333"
TEXT_COLLECTION  = "case_text_bge_v1_5"
IMAGE_COLLECTION = "case_image_clip_l14"

EVAL_ROOT = "/content/drive/MyDrive/Test-Data/extracted-data"
OUT_XLSX  = "/content/drive/MyDrive/eval_results.xlsx"

# When image payload is a *relative* path, we try: LOCAL_IMAGE_ROOT / relative_path
LOCAL_IMAGE_ROOT = "/content/drive/MyDrive/Project-AI/Data_for_RAG/diseases_extracted_images"

QWEN_MODEL_ID = "BaoNgoc29/qwen2-tropical-disease-train-structure-epoch-3"

# Retrieval / rerank params
VEC_TOPN       = 50   # initial vector search
RERANK_TOP     = 4    # keep after rerank for candidates + evidence
MAX_REF_CASES  = 3    # how many retrieved case_id(s) to join into image collection
MAX_REF_IMGS     = 3    # keep small to avoid CUDA OOM (increase later if VRAM allows)

# Patient images
MAX_PATIENT_IMGS = 3    # keep small to avoid CUDA OOM (increase later if VRAM allows)

# Generation
MAX_NEW_TOKENS = 512
MAX_INPUT_TOKENS   = 2048  # cap prompt length to prevent CUDA OOM
MAX_IMAGE_SIDE     = 768   # downscale images before feeding the VL model
MAX_META_FIELD_CHARS = 180 # truncate each meta field (no ellipsis)
DO_SAMPLE      = False

# Debug prints (Qdrant responses can be noisy)
QDRANT_DEBUG = False


In [None]:
!pip -q install "transformers>=4.44" accelerate sentence-transformers qwen-vl-utils pillow requests pandas openpyxl rapidfuzz


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.2/41.2 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25h

## 1) Mount Google Drive


In [None]:
from google.colab import drive
drive.mount("/content/drive")


Mounted at /content/drive


## 2) Imports + device


In [None]:
import os, re, json, io, glob
import gc
import requests
import pandas as pd
import torch
from PIL import Image
from rapidfuzz.fuzz import token_set_ratio

from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

device = "cuda" if torch.cuda.is_available() else "cpu"
AUX_DEVICE = "cpu" if device == "cuda" else device  # keep retrieval models on CPU to save VRAM
print("Device (Qwen):", device, "| Device (retrieval):", AUX_DEVICE)

assert os.path.isdir(EVAL_ROOT), f"Eval folder not found: {EVAL_ROOT}"
print("Eval root:", EVAL_ROOT)


Device (Qwen): cuda | Device (retrieval): cpu
Eval root: /content/drive/MyDrive/Test-Data/extracted-data


## 3) Qdrant REST helpers

- `/points/search` for vector search (text collection)  
- `/points/scroll` for:
  - sampling payload keys
  - **filtered scroll** to fetch all images that match a retrieved `case_id`


In [None]:
def qdrant_post(path, payload):
    url = f"{QDRANT_URL}{path}"
    r = requests.post(url, headers={"Content-Type":"application/json"}, data=json.dumps(payload))
    r.raise_for_status()
    return r.json()

def qdrant_scroll(collection, limit=8, with_payload=True, with_vectors=False):
    res = qdrant_post(f"/collections/{collection}/points/scroll", {
        "limit": limit,
        "with_payload": with_payload,
        "with_vectors": with_vectors
    })
    return res.get("result", {}).get("points", [])

def qdrant_search(collection, vector, limit=10, with_payload=True):
    res = qdrant_post(f"/collections/{collection}/points/search", {
        "vector": vector,
        "limit": limit,
        "with_payload": with_payload,
        "with_vectors": False
    })
    if QDRANT_DEBUG:
        print(res.get("result", [])[:2])
    return res.get("result", [])

def qdrant_scroll_filtered(collection, qfilter, limit=64, offset=None, with_payload=True, with_vectors=False):
    payload = {
        "limit": limit,
        "with_payload": with_payload,
        "with_vectors": with_vectors,
        "filter": qfilter,
    }
    if offset is not None:
        payload["offset"] = offset  # pass next_page_offset back into scroll
    return qdrant_post(f"/collections/{collection}/points/scroll", payload)


## 4) Infer payload keys (case_id / text / image_path)

Collections often use slightly different payload field names.  
We sample a few points and infer:
- text collection: `case_id` key and `text` key
- image collection: `case_id` key and `image_path` (or similar) key


In [None]:
def infer_keys(payloads):
    all_keys = set()
    for p in payloads:
        if isinstance(p, dict):
            all_keys |= set(p.keys())

    def pick_case_key():
        for k in ["case_id","caseId","case","caseid","id_case"]:
            if k in all_keys: return k
        for k in all_keys:
            lk = k.lower()
            if "case" in lk and "id" in lk: return k
        for k in all_keys:
            if "case" in k.lower(): return k
        return None

    def pick_text_key():
        for k in ["text","chunk","content","passage","context","document","doc"]:
            if k in all_keys: return k
        bestk, bestlen = None, 0
        for p in payloads:
            for k, v in (p or {}).items():
                if isinstance(v, str) and len(v) > bestlen:
                    bestk, bestlen = k, len(v)
        return bestk

    def pick_img_key():
        for k in ["image_path","image","path","url","file","filename","img","image_relpath"]:
            if k in all_keys: return k
        for k in all_keys:
            lk = k.lower()
            if any(x in lk for x in ["image","img","url","path","file","name"]):
                return k
        return None

    return pick_case_key(), pick_text_key(), pick_img_key(), sorted(list(all_keys))

text_pts = qdrant_scroll(TEXT_COLLECTION, limit=8, with_payload=True, with_vectors=False)
img_pts  = qdrant_scroll(IMAGE_COLLECTION, limit=8, with_payload=True, with_vectors=False)

TEXT_CASE_KEY, TEXT_TEXT_KEY, _, text_keys = infer_keys([p.get("payload", {}) for p in text_pts])
IMG_CASE_KEY,  _, IMG_REF_KEY, img_keys    = infer_keys([p.get("payload", {}) for p in img_pts])

print("[Inferred keys]")
print("TEXT case:", TEXT_CASE_KEY, "text:", TEXT_TEXT_KEY, "| example keys:", text_keys[:12])
print("IMG  case:", IMG_CASE_KEY,  "img_ref:", IMG_REF_KEY, "| example keys:", img_keys[:12])

assert TEXT_CASE_KEY is not None and TEXT_TEXT_KEY is not None, "Could not infer text collection keys"
assert IMG_CASE_KEY is not None and IMG_REF_KEY is not None, "Could not infer image collection keys"


[Inferred keys]
TEXT case: case_id text: text | example keys: ['case_id', 'chunk_id', 'disease', 'meta_ref', 'modality', 'source_embedding', 'source_pdf', 'split', 'text']
IMG  case: case_id img_ref: image_path | example keys: ['case_id', 'disease', 'figure_no', 'image_path', 'meta_ref', 'modality', 'source_pdf']


## 5) Embedders + reranker

- `BAAI/bge-base-en-v1.5`: text embeddings for initial retrieval  
- `BAAI/bge-reranker-base`: reranks the retrieved passages for quality


In [None]:
print("Loading BGE embedder...")
bge = SentenceTransformer("BAAI/bge-base-en-v1.5", device=AUX_DEVICE)

print("Loading reranker...")
reranker = CrossEncoder("BAAI/bge-reranker-base", device=AUX_DEVICE)

@torch.no_grad()
def embed_text_bge(text: str):
    v = bge.encode([text], normalize_embeddings=True)
    return v[0].tolist()


Loading BGE embedder...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Loading reranker...


config.json:   0%|          | 0.00/799 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/443 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/279 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

## 6) Image loading utilities

`load_image_any(ref)` tries:
1) a URL (`http://` / `https://`)  
2) an absolute path  
3) `LOCAL_IMAGE_ROOT / ref` (for relative refs stored in Qdrant payload)


In [None]:
def _downscale_image(im, max_side=MAX_IMAGE_SIDE):
    if im is None:
        return None
    try:
        im = im.convert("RGB")
        im.thumbnail((max_side, max_side))
    except Exception:
        pass
    return im

def load_image_any(ref):
    if ref is None:
        return None
    if isinstance(ref, list) and ref:
        ref = ref[0]
    ref = str(ref).strip()
    try:
        if ref.startswith("http://") or ref.startswith("https://"):
            r = requests.get(ref, timeout=20)
            r.raise_for_status()
            return _downscale_image(Image.open(io.BytesIO(r.content)).convert("RGB"))
        if os.path.exists(ref):
            return _downscale_image(Image.open(ref).convert("RGB"))
        cand = os.path.join(LOCAL_IMAGE_ROOT, ref)
        if os.path.exists(cand):
            return _downscale_image(Image.open(cand).convert("RGB"))
    except Exception:
        return None
    return None

def load_case_images_from_folder(cdir, max_images=2):
    # NEW: prefer "<case>/images/*", fallback to "<case>/*"
    img_dir = os.path.join(cdir, "images")
    search_roots = [img_dir] if os.path.isdir(img_dir) else []
    search_roots.append(cdir)  # fallback

    exts = ("*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp", "*.tif", "*.tiff")
    paths = []
    for root in search_roots:
        for e in exts:
            # direct files
            paths += glob.glob(os.path.join(root, e))
            # nested (just in case there are subfolders)
            paths += glob.glob(os.path.join(root, "**", e), recursive=True)

    # de-dup + stable sort
    seen = set()
    uniq = []
    for p in sorted(paths):
        if p not in seen:
            seen.add(p)
            uniq.append(p)

    if max_images is not None:
        uniq = uniq[:max_images]

    imgs = []
    for p in uniq:
        try:
            im = Image.open(p).convert("RGB")
            imgs.append(_downscale_image(im, max_side=MAX_IMAGE_SIDE))
        except:
            pass
    return imgs


## 7) Payload helpers


In [None]:
def get_case(payload, preferred_key):
    if not isinstance(payload, dict):
        return "unknown"
    if preferred_key and preferred_key in payload:
        return str(payload[preferred_key])
    for k, v in payload.items():
        if "case" in str(k).lower():
            return str(v)
    return "unknown"

def get_text(payload):
    if not isinstance(payload, dict):
        return None
    if TEXT_TEXT_KEY and TEXT_TEXT_KEY in payload:
        v = payload[TEXT_TEXT_KEY]
        if isinstance(v, list):
            v = "\n".join(map(str, v))
        return str(v)
    return None

def get_disease(payload):
    if not isinstance(payload, dict):
        return None
    v = payload.get("disease")
    return str(v) if v else None

def get_img_ref(payload):
    if not isinstance(payload, dict):
        return None
    if IMG_REF_KEY and IMG_REF_KEY in payload:
        return payload[IMG_REF_KEY]
    return None



# --------------------
# meta_ref -> structured case JSON (optional)
# --------------------
def get_meta_ref(payload):
    """Return the Drive path (or URL) to the structured case JSON if present."""
    if not isinstance(payload, dict):
        return None

    for k in ["meta_ref", "metaRef", "meta_path", "metaPath", "json_path", "jsonPath", "meta_json", "meta_json_path"]:
        v = payload.get(k)
        if v:
            return v

    # heuristic: any string value that looks like a JSON file path
    for k, v in payload.items():
        if isinstance(v, str) and v.strip().lower().endswith(".json") and ("meta" in str(k).lower() or "ref" in str(k).lower()):
            return v.strip()

    return None

_META_CACHE = {}

def load_meta_json(meta_ref):
    """Load structured JSON from meta_ref; cached for speed."""
    if meta_ref is None:
        return None
    if isinstance(meta_ref, list) and meta_ref:
        meta_ref = meta_ref[0]
    p = str(meta_ref).strip()
    if not p:
        return None

    if p in _META_CACHE:
        return _META_CACHE[p]

    data = None
    try:
        if p.startswith("http://") or p.startswith("https://"):
            # Avoid network dependency in eval; keep placeholder (you can enable requests.get if needed)
            data = None
        else:
            cand = p
            if not os.path.exists(cand):
                # common case: relative path stored in payload
                cand2 = os.path.join("/content/drive/MyDrive", p.lstrip("/"))
                cand = cand2 if os.path.exists(cand2) else cand
            if os.path.exists(cand):
                with open(cand, "r", encoding="utf-8") as f:
                    data = json.load(f)
    except Exception:
        data = None

    _META_CACHE[p] = data
    return data

def _clean_snip(x: str, max_len=MAX_META_FIELD_CHARS):
    x = re.sub(r"\s+", " ", str(x)).strip()
    if len(x) > max_len:
        x = x[:max_len].rstrip() + " [TRUNCATED]"
    return x

def format_meta_for_prompt(meta: dict):
    """Turn the structured JSON into a compact snippet for the model."""
    if not isinstance(meta, dict):
        return None

    keys_order = [
        "chief_complaint",
        "history_of_present_illness",
        "exposure_and_epidemiology",
        "vitals",
        "physical_exam",
        "labs_and_diagnostics",
    ]

    parts = []
    for k in keys_order:
        v = meta.get(k)
        if v:
            label = k.replace("_", " ").title()
            parts.append(f"{label}: {_clean_snip(v)}")

    return " | ".join(parts) if parts else None

def collect_meta_from_hits(reranked_hits, max_cases=2):
    """Collect structured meta snippets for the top retrieved cases."""
    out = []
    seen_case = set()

    for h in (reranked_hits or []):
        p = h.get("payload") or {}
        cid = get_case(p, TEXT_CASE_KEY)
        if cid in seen_case:
            continue
        seen_case.add(cid)

        meta_ref = get_meta_ref(p)
        meta = load_meta_json(meta_ref)
        snippet = format_meta_for_prompt(meta)
        if snippet:
            out.append(f"[case={cid}] {snippet}")

        if len(out) >= max_cases:
            break

    return "\n".join(out) if out else "(no structured meta found)"


## 8) Retrieve text → rerank → candidates


In [None]:
def retrieve_with_rerank(query_text: str):
    hits = qdrant_search(TEXT_COLLECTION, embed_text_bge(query_text), limit=VEC_TOPN, with_payload=True)

    passages, kept = [], []
    for h in hits:
        p = h.get("payload") or {}
        txt = get_text(p)
        if not txt:
            continue
        passages.append(txt)
        kept.append(h)

    if not kept:
        return [], [], {"text_hits": hits}

    pairs = [(query_text, t) for t in passages]
    rr_scores = reranker.predict(pairs)  # higher = better

    order = sorted(range(len(kept)), key=lambda i: rr_scores[i], reverse=True)[:RERANK_TOP]
    reranked_hits = []
    for i in order:
        h = kept[i]
        h["_rerank_score"] = float(rr_scores[i])
        reranked_hits.append(h)

    candidates, seen = [], set()
    for h in reranked_hits:
        d = get_disease(h.get("payload") or {})
        if d:
            key = d.strip().lower()
            if key not in seen:
                seen.add(key)
                candidates.append(d.strip())

    return reranked_hits, candidates, {"text_hits": reranked_hits}


## 9) JOIN: retrieved text case_id → image collection


In [None]:
def top_case_ids_from_text_hits(reranked_hits, max_cases=3):
    seen = set()
    out = []
    for h in (reranked_hits or []):
        cid = get_case(h.get("payload") or {}, TEXT_CASE_KEY)
        if not cid or cid == "unknown":
            continue
        if cid not in seen:
            seen.add(cid)
            out.append(cid)
        if len(out) >= max_cases:
            break
    return out

def build_case_id_filter(case_id, key_name):
    s = str(case_id).strip()
    should = [{"key": key_name, "match": {"value": s}}]
    if s.isdigit():
        should.append({"key": key_name, "match": {"value": int(s)}})
    return {"should": should}

def qdrant_fetch_all_by_case_id(collection, case_id, case_key, limit_per_page=64, max_points=200):
    qfilter = build_case_id_filter(case_id, case_key)
    out = []
    offset = None
    while True:
        res = qdrant_scroll_filtered(collection, qfilter, limit=limit_per_page, offset=offset,
                                    with_payload=True, with_vectors=False)
        points = res.get("result", {}).get("points", []) or []
        out.extend(points)

        offset = res.get("result", {}).get("next_page_offset", None)
        if offset is None or len(out) >= max_points:
            break
    return out[:max_points]

def fetch_case_images_from_qdrant(case_ids, max_images_per_case=2, max_total_images=4):
    imgs = []
    debug = []  # {case_id, ref, loaded}
    for cid in (case_ids or []):
        pts = qdrant_fetch_all_by_case_id(
            collection=IMAGE_COLLECTION,
            case_id=cid,
            case_key=IMG_CASE_KEY,
            limit_per_page=64,
            max_points=200
        )

        loaded_for_case = 0
        for pt in pts:
            p = pt.get("payload") or {}
            ref = get_img_ref(p)
            im = load_image_any(ref)
            ok = im is not None
            debug.append({"case_id": str(cid), "ref": str(ref), "loaded": ok})

            if ok:
                imgs.append(im)
                loaded_for_case += 1
                if loaded_for_case >= max_images_per_case:
                    break
                if len(imgs) >= max_total_images:
                    return imgs, debug
    return imgs, debug


## 10) Load Qwen2-VL model


In [None]:
print("Loading Qwen2-VL:", QWEN_MODEL_ID)
dtype = torch.float16 if device == "cuda" else torch.float32

qwen = Qwen2VLForConditionalGeneration.from_pretrained(
    QWEN_MODEL_ID,
    torch_dtype=dtype,
    device_map="auto"
)
proc = AutoProcessor.from_pretrained(QWEN_MODEL_ID)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading Qwen2-VL: BaoNgoc29/qwen2-tropical-disease-train-structure-epoch-3


adapter_config.json:   0%|          | 0.00/978 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/429M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/272 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/2.20M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/392 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

video_preprocessor_config.json:   0%|          | 0.00/910 [00:00<?, ?B/s]

## 11) Prompt + generation (patient images + reference images)


In [None]:
def build_prompt(query_text, evidence_lines, candidates, ref_case_ids=None, n_patient_images=0, n_ref_images=0, meta_block=None):
    cand_block = "\n".join([f"- {c}" for c in candidates[:10]]) if candidates else "- (no candidates)"
    evidence = "\n".join(evidence_lines).strip() if evidence_lines else "(no evidence)"

    ref_case_ids = ref_case_ids or []
    ref_case_line = ", ".join(map(str, ref_case_ids)) if ref_case_ids else "(none)"
    meta_block = meta_block or "(none)"

    return (
        "You are a Vision Language Model specialized in tropical disease recognition from clinical images. Your task is to analyze the provided clinical images and symptoms, then provide the most likely diagnosis and how to confirm it. Answer in ENGLISH ONLY. Keep the answer concise and medically grounded. Avoid adding unrelated information:\n"
        "Output EXACTLY 2 lines:\n"
        "Predicted disease: <one most likely diagnosis>\n"
        "You do not have to select from candidate diseases."
        "Reasoning: <1 full relevant sentence based on the image and symptoms>\n\n"
        f"Images note:\n"
        f"- First {n_patient_images} image(s) are PATIENT images.\n"
        f"- Next {n_ref_images} image(s) are REFERENCE images from retrieved cases.\n"
        f"- Reference case_id(s): {ref_case_line}\n\n"
        f"Retrieved structured case notes (from meta_ref):\n{meta_block}\n\n"
        f"Candidate diseases:\n{cand_block}\n\n"
        f"Question:\n{query_text}\n\n"
        f"Evidence:\n{evidence}"
    )


@torch.no_grad()
def qwen_predict(query_text: str, reranked_hits, candidates, patient_images=None, ref_images=None, ref_case_ids=None):
    evidence_lines = []
    for h in (reranked_hits or [])[:2]:
        p = h.get("payload") or {}
        cid = get_case(p, TEXT_CASE_KEY)
        txt = re.sub(r"\s+", " ", (get_text(p) or "")).strip()
        evidence_lines.append(f"[{cid}] {txt[:900]}")

    patient_images = patient_images or []
    ref_images = ref_images or []
    ref_case_ids = ref_case_ids or []

    meta_block = collect_meta_from_hits(reranked_hits, max_cases=2)

    user_text = build_prompt(
        query_text=query_text,
        evidence_lines=evidence_lines,
        candidates=candidates,
        ref_case_ids=ref_case_ids,
        n_patient_images=len(patient_images),
        n_ref_images=len(ref_images),
        meta_block=meta_block,
    )

    content = []
    for im in patient_images:
        content.append({"type":"image", "image": im})
    for im in ref_images:
        content.append({"type":"image", "image": im})
    content.append({"type":"text", "text": user_text})

    messages = [{"role":"user", "content": content}]

    chat = proc.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    if image_inputs is not None and len(image_inputs) == 0:
      image_inputs = None

    inputs = proc(
        text=[chat],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        truncation=False,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt"
    ).to(device)

    with torch.inference_mode():
        out_ids = qwen.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE
        )

    prompt_len = inputs["input_ids"].shape[1]
    gen_ids = out_ids[0][prompt_len:]
    return proc.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()


## 12) Parse output + snap diagnosis to candidate list


In [None]:
def norm(s: str) -> str:
    s = (s or "").lower()
    s = re.sub(r"[^a-z0-9\s\-\(\)]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def extract_diagnosis_line(ans: str) -> str:
    t = (ans or "").strip()
    m = re.search(r"(?im)^\s*Diagnosis\s*:\s*(.+)$", t)
    if m:
        return re.split(r"[\n\r]", m.group(1).strip(), maxsplit=1)[0].strip()
    for ln in t.splitlines():
        ln = ln.strip()
        if ln:
            return ln
    return ""

def snap_to_candidates(pred: str, candidates):
    if not candidates:
        return pred, "no_candidates"
    pred_n = norm(pred)

    for c in candidates:
        if norm(c) == pred_n:
            return c, "exact"

    best, best_score = None, -1
    for c in candidates:
        sc = token_set_ratio(norm(c), pred_n)
        if sc > best_score:
            best_score, best = sc, c
    return best, f"fuzzy({best_score})"


## 13) Quick sanity test (optional)


In [None]:
folders = sorted([d for d in os.listdir(EVAL_ROOT) if re.fullmatch(r"\d{3}", d)])
print("Found folders:", folders)

if folders:
    cdir = os.path.join(EVAL_ROOT, folders[0])
    query_path = os.path.join(cdir, "query.txt")
    if os.path.exists(query_path):
        query_text = open(query_path, "r", encoding="utf-8").read().strip()
        reranked_hits, candidates, _ = retrieve_with_rerank(query_text)

        patient_imgs = load_case_images_from_folder(cdir, max_images=MAX_PATIENT_IMGS)
        ref_case_ids = top_case_ids_from_text_hits(reranked_hits, max_cases=MAX_REF_CASES)
        ref_imgs, ref_dbg = fetch_case_images_from_qdrant(ref_case_ids, max_images_per_case=2, max_total_images=MAX_REF_IMGS)

        print("patient_imgs:", len(patient_imgs))
        print("ref_case_ids:", ref_case_ids)
        print("ref_imgs:", len(ref_imgs))
        print("ref_dbg sample:", ref_dbg[:3])


Found folders: ['001', '002', '003', '004', '005', '006', '007', '008', '009', '010', '011', '012', '013', '014', '015', '016', '017', '018', '019', '020', '021', '022', '023', '024', '025', '026', '027', '028', '029', '030', '031', '032', '033']
patient_imgs: 0
ref_case_ids: ['21---A-35-Year-Old-American-Man-With-Fatigue-_2022_Clinical-Cases-in-Tropica', '93---A-35-Year-Old-Male-Logger-from-Peru-With-Fe_2022_Clinical-Cases-in-Trop', '52---A-56-Year-Old-Man-from-Peru-With-Prolonged-_2022_Clinical-Cases-in-Trop']
ref_imgs: 4
ref_dbg sample: [{'case_id': '21---A-35-Year-Old-American-Man-With-Fatigue-_2022_Clinical-Cases-in-Tropica', 'ref': '/content/drive/MyDrive/Project-AI/Data_for_RAG/diseases_extracted_images/021/021_p1_fig_1.png', 'loaded': True}, {'case_id': '93---A-35-Year-Old-Male-Logger-from-Peru-With-Fe_2022_Clinical-Cases-in-Trop', 'ref': '/content/drive/MyDrive/Project-AI/Data_for_RAG/diseases_extracted_images/093/093_p2_fig_1.png', 'loaded': True}, {'case_id': '93---A-35-Year

## 14) Evaluation loop → Excel


In [None]:
case_dirs = sorted(
    os.path.join(EVAL_ROOT, d) for d in os.listdir(EVAL_ROOT)
    if os.path.isdir(os.path.join(EVAL_ROOT, d)) and re.fullmatch(r"\d{3}", d)
)
assert case_dirs, f"No numbered folders like 001 found under {EVAL_ROOT}"
rows = []
for cdir in case_dirs:
    folder = os.path.basename(cdir)
    q_path = os.path.join(cdir, "query.txt")
    gt_path = os.path.join(cdir, "ground_truth.txt")
    meta_path = os.path.join(cdir, "meta_eval.json")

    if not (os.path.exists(q_path) and os.path.exists(gt_path) and os.path.exists(meta_path)):
        rows.append({"folder": folder, "status": "skipped_missing_files"})
        continue

    query = open(q_path, "r", encoding="utf-8").read().strip()
    gt_answer = open(gt_path, "r", encoding="utf-8").read().strip()
    meta = json.load(open(meta_path, "r", encoding="utf-8"))

    test_id = meta.get("test_id", folder)
    gt_disease = meta.get("ground_truth_disease", "")

    # --- Retrieve text + rerank
    reranked_hits, candidates, _ = retrieve_with_rerank(query)

    # --- Patient images from folder
    patient_images = load_case_images_from_folder(cdir, max_images=MAX_PATIENT_IMGS)

    # --- JOIN reference images by case_id
    ref_case_ids = top_case_ids_from_text_hits(reranked_hits, max_cases=MAX_REF_CASES)
    ref_images, ref_debug = fetch_case_images_from_qdrant(
        ref_case_ids,
        max_images_per_case=2,
        max_total_images=MAX_REF_IMGS
    )

    print("Processing case:", folder)

    # --- Model prediction
    pred_answer = qwen_predict(
        query_text=query,
        reranked_hits=reranked_hits,
        candidates=candidates,
        patient_images=patient_images,
        ref_images=ref_images,
        ref_case_ids=ref_case_ids
    )

    raw_diag = extract_diagnosis_line(pred_answer)

    # --- memory hygiene (important on 16GB Colab GPUs)
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    pred_disease, snap_mode = snap_to_candidates(raw_diag, candidates)
    fuzzy = token_set_ratio(norm(pred_disease), norm(gt_disease)) if gt_disease else None

    # Trace from reranked top hits (for debugging)
    trace = []
    for h in reranked_hits[:5]:
        p = h.get("payload") or {}
        trace.append({
            "case_id": get_case(p, TEXT_CASE_KEY),
            "rerank_score": float(h.get("_rerank_score", 0.0)),
            "disease": get_disease(p)
        })

    rows.append({
        "test_id": test_id,
        "folder": folder,
        "gt_disease": gt_disease,
        "pred_disease": pred_disease,
        "disease_fuzzy": fuzzy,
        "snap_mode": snap_mode,
        "candidates": "; ".join(candidates[:12]),
        "ref_case_ids": "; ".join(map(str, ref_case_ids)),
        "n_patient_images": len(patient_images),
        "n_ref_images": len(ref_images),
        "ref_image_debug": json.dumps(ref_debug, ensure_ascii=False),
        "rag_trace": json.dumps(trace, ensure_ascii=False),
        "query": query,
        "gt_answer": gt_answer,
        "pred_answer": pred_answer,
        "raw_diagnosis_line": raw_diag,
    })



df = pd.DataFrame(rows)

with pd.ExcelWriter(OUT_XLSX, engine="openpyxl") as writer:
    df.to_excel(writer, index=False, sheet_name="eval")
    ws = writer.sheets["eval"]
    ws.freeze_panes = "A2"

print("Saved:", OUT_XLSX)
print(df[["test_id","folder","gt_disease","pred_disease","disease_fuzzy","snap_mode","n_patient_images","n_ref_images"]].head(10))


Processing case: 001


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processing case: 002
Processing case: 003
Processing case: 004
Processing case: 005
Processing case: 006
Processing case: 007
Processing case: 008
Processing case: 009
Processing case: 010
Processing case: 011
Processing case: 012
Processing case: 013
Processing case: 014
Processing case: 015
Processing case: 016
Processing case: 017
Processing case: 018
Processing case: 019
Processing case: 020
Processing case: 021
Processing case: 022
Processing case: 023
Processing case: 024
Processing case: 025
Processing case: 026
Processing case: 027
Processing case: 028
Processing case: 029
Processing case: 030
Processing case: 031
Processing case: 032
Processing case: 033
Saved: /content/drive/MyDrive/eval_results.xlsx
  test_id folder                               gt_disease  \
0       1    001                  Enteric (typhoid) fever   
1       2    002                                  Malaria   
2       3    003                   Visceral leishmaniasis   
3       4    004                    