In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_PT-14-21998AB-200-014.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_F-14-23222AB-200-007.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_A-14-29960CD-200-012.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_TA-14-16184-200-002.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_PT-14-21998AB-200-021.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_TA-14-16184-200-018.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_F-14-21998CD-200-019.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_TA-14-16184-200-006.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_PT-14-21998AB-200-023.png
/kaggle/in

In [2]:
# ==========================================================
# Cell 0 — Install dependencies (run once)
# ==========================================================
try:
    import open_clip
except Exception:
    import sys
    # Install GPU-enabled FAISS plus the rest of the stack
    !pip install -q --upgrade open-clip-torch==2.23.0 faiss-gpu==1.7.2.post2 transformers sentence-transformers tqdm matplotlib scikit-learn requests biopython



In [3]:
# Cell 1 — Imports, basic setup, and utilities
import os
import json
import time
import math
import random
import requests
from pathlib import Path
from typing import List, Dict, Optional, Tuple

from PIL import Image
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
import open_clip
from open_clip import tokenize

# small utilities
def safe_makedir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

# seed for reproducibility (best-effort)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7dac46134330>

In [23]:
CONFIG = {
    "DATASET_DIR": "/kaggle/input/breast-cancer-dataset-from-breakhis/",
    "CORPUS_PATH": None,
    "OUT_DIR": "./breakhis_rag_outputs",
    "BIOMEDCLIP_HF_ID": "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
    "LLM_ID": "google/flan-t5-small",
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_IMAGE": 8,
    "BATCH_TEXT": 128,
    "TOP_K": 5,
    "RETRIEVAL_BATCH": 512,              # batched FAISS queries
    "SIMILARITY_THRESHOLD": 0.3,
    "ONLINE_RAG_ENABLED": True,
    "ONLINE_SOURCE": "pubmed",
    "ONLINE_DEBUG_LIMIT": 10,
    "USE_FAISS_GPU": torch.cuda.is_available(),
}

print("CONFIG summary:")
for k, v in CONFIG.items():
    print(f" {k}: {v}")

OUT_DIR = Path(CONFIG["OUT_DIR"])
safe_makedir(OUT_DIR)

CONFIG summary:
 DATASET_DIR: /kaggle/input/breast-cancer-dataset-from-breakhis/
 CORPUS_PATH: None
 OUT_DIR: ./breakhis_rag_outputs
 BIOMEDCLIP_HF_ID: hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
 LLM_ID: google/flan-t5-small
 DEVICE: cuda
 BATCH_IMAGE: 8
 BATCH_TEXT: 128
 TOP_K: 5
 RETRIEVAL_BATCH: 512
 SIMILARITY_THRESHOLD: 0.3
 ONLINE_RAG_ENABLED: True
 ONLINE_SOURCE: pubmed
 ONLINE_DEBUG_LIMIT: 10
 USE_FAISS_GPU: True


In [5]:
# Quick GPU sanity check
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA capability: {torch.cuda.get_device_capability(0)}")
    print(f"Allocated memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
else:
    print("CUDA not available – pipeline will fall back to CPU.")

CUDA device: Tesla P100-PCIE-16GB
CUDA capability: (6, 0)
Allocated memory: 0.00 GB


In [6]:
# Cell 3 — Load BiomedCLIP (OpenCLIP) with safe fallbacks
device = CONFIG["DEVICE"]
print("Using device:", device)


# Helper to load model and preprocess
def load_biomedclip(hf_id: str, device: str = "cpu"):
    """Try HF-aware loader first; fallback to built-in open_clip if needed.
    Returns: model, preprocess_callable
    """
    try:
        # HF-aware convenience function
        model, preprocess = open_clip.create_model_from_pretrained(hf_id, device=device)
        # create_model_from_pretrained may return model on CPU by default; move to device if GPU requested
        model = model.to(device).eval()
        print("Loaded BiomedCLIP via create_model_from_pretrained()")
        return model, preprocess
    except Exception as e:
        print("create_model_from_pretrained failed:", e)
        print("Attempting fallback: create_model_and_transforms with a built-in config (weights may not be BiomedCLIP)")
        try:
            # Fallback loads a generic ViT model (weights from laion or similar) — not ideal but keeps pipeline runnable
            model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms("ViT-B-16", pretrained="openai")
            model = model.to(device).eval()
            print("Loaded fallback ViT-B-16 (not BiomedCLIP weights). Consider fixing open-clip version or network access.")
            return model, preprocess_val
        except Exception as e2:
            raise RuntimeError(f"Failed to load model via open_clip: primary error: {e}; fallback error: {e2}")


# Load model
try:
    model, preprocess = load_biomedclip(CONFIG["BIOMEDCLIP_HF_ID"], device=device)
except Exception as e:
    # If model load fails fatally, provide a helpful message and abort pipeline
    raise RuntimeError(f"BiomedCLIP load failed: {e}. Restart kernel and ensure open-clip-torch >= 2.23.0 is installed and internet access is available.")


# Ensure model has encode_image and encode_text
if not hasattr(model, "encode_image") or not hasattr(model, "encode_text"):
    raise RuntimeError("Loaded model does not expose encode_image/encode_text APIs expected by the pipeline.")

print("Model loaded; preprocess callable available:", callable(preprocess))

Using device: cuda


2025-10-07 05:08:42.402351: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759813722.424968     256 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759813722.431780     256 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loaded BiomedCLIP via create_model_from_pretrained()
Model loaded; preprocess callable available: True


In [7]:
# Cell 4 — Collect images from any dataset (data-agnostic, flexible)
DATASET_DIR = Path(CONFIG["DATASET_DIR"]).expanduser()
if not DATASET_DIR.exists():
    raise FileNotFoundError(f"DATASET_DIR not found at {DATASET_DIR}. Update CONFIG['DATASET_DIR'] to the correct path.")

# gather images recursively (supports all common formats)
def collect_images(root: Path) -> List[str]:
    """Recursively collect all image files from a directory."""
    exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".dcm", ".nii", ".nii.gz"}
    files = []
    for p in root.rglob("*"):
        if p.suffix.lower() in exts:
            files.append(str(p))
    files = sorted(files)
    return files

image_paths = collect_images(DATASET_DIR)
if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {DATASET_DIR}. Check dataset structure and ensure images exist.")

print(f"✓ Found {len(image_paths)} images under {DATASET_DIR}")
print(f"  Sample paths: {image_paths[:3]}")

✓ Found 15818 images under /kaggle/input/breast-cancer-dataset-from-breakhis
  Sample paths: ['/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/100X/B_100X/SOB_B_A-14-22549G-100-001.png', '/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/100X/B_100X/SOB_B_A-14-22549G-100-002.png', '/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/100X/B_100X/SOB_B_A-14-22549G-100-003.png']


In [8]:
# Cell 5 — Compute (or load cached) image embeddings with robust error handling
IMAGE_EMB_FILE = OUT_DIR / "image_embeddings.npy"
IMAGE_PATHS_FILE = OUT_DIR / "image_paths.json"

if IMAGE_EMB_FILE.exists() and IMAGE_PATHS_FILE.exists():
    print("Loading cached image embeddings...")
    image_embeddings = np.load(str(IMAGE_EMB_FILE))
    with open(IMAGE_PATHS_FILE, "r") as f:
        saved_paths = json.load(f)
    # Basic sanity check
    if len(saved_paths) != image_embeddings.shape[0]:
        print("Warning: cached embeddings length mismatch; recomputing embeddings")
        compute_images = True
    else:
        compute_images = False
        image_paths = saved_paths  # Use cached paths to maintain consistency
else:
    compute_images = True

if compute_images:
    print("(Re)computing image embeddings. This may take time depending on dataset size and device.")
    batch = CONFIG["BATCH_IMAGE"] if device == "cuda" else 1  # smaller CPU batch
    emb_list = []
    paths_for_cache = []
    failed_images = []
    model.eval()
    
    for i in tqdm(range(0, len(image_paths), batch), desc="embed-images"):
        batch_paths = image_paths[i:i+batch]
        imgs = []
        valid_paths = []
        
        for p in batch_paths:
            try:
                im = Image.open(p).convert("RGB")
                img_t = preprocess(im)
                imgs.append(img_t)
                valid_paths.append(p)
            except Exception as e:
                failed_images.append((p, str(e)))
                print(f"Warning: failed to open image {p}: {e}")
        
        if len(imgs) == 0:
            continue
            
        tensor = torch.stack(imgs).to(device)
        with torch.no_grad():
            feats = model.encode_image(tensor)
            feats = feats / feats.norm(dim=-1, keepdim=True)
        emb_list.append(feats.cpu().numpy())
        paths_for_cache.extend(valid_paths)
    
    if len(emb_list) == 0:
        raise RuntimeError("No embeddings computed; check image reading step.")
    
    image_embeddings = np.concatenate(emb_list, axis=0).astype("float32")
    np.save(IMAGE_EMB_FILE, image_embeddings)
    with open(IMAGE_PATHS_FILE, "w") as f:
        json.dump(paths_for_cache, f)
    
    # Update image_paths to only include successfully processed images
    image_paths = paths_for_cache
    
    print(f"✓ Saved image embeddings to {IMAGE_EMB_FILE}")
    if failed_images:
        print(f"⚠ Failed to process {len(failed_images)} images")
        with open(OUT_DIR / "failed_images.json", "w") as f:
            json.dump(failed_images, f, indent=2)

# sanity
print(f"✓ Image embeddings shape: {image_embeddings.shape}")

Loading cached image embeddings...
✓ Image embeddings shape: (15818, 512)


In [9]:
# Cell 6 — Prepare text corpus: read user corpus or use small builtin sample
CORPUS_PATH = CONFIG["CORPUS_PATH"]

def load_corpus(corpus_path):
    """Load text corpus from file or directory."""
    corpus = []
    p = Path(corpus_path)
    if p.is_file():
        # assume one sentence per line
        with open(p, "r", encoding="utf-8", errors='ignore') as f:
            for line in f:
                t = line.strip()
                if t:
                    corpus.append(t)
    elif p.is_dir():
        # read all txt inside
        for txt in p.rglob("*.txt"):
            with open(txt, "r", encoding="utf-8", errors='ignore') as f:
                for line in f:
                    t = line.strip()
                    if t:
                        corpus.append(t)
    else:
        raise FileNotFoundError(f"CORPUS_PATH {corpus_path} not found")
    
    # deduplicate
    seen = set()
    out = []
    for s in corpus:
        if s not in seen:
            seen.add(s)
            out.append(s)
    return out

if CORPUS_PATH is None:
    # Enhanced built-in biomedical-style corpus (placeholder). Replace with a real corpus for production.
    prompt_corpus = [
        "H&E-stained breast tissue with tumor islands and pleomorphism",
        "benign breast histology with normal ducts",
        "histopathology slide showing mitotic figures and irregular nuclei",
        "melanocytic lesion with irregular nests",
        "necrosis and apoptotic bodies visible in tissue section",
        "stromal fibrosis and inflammation",
        "high mitotic index and cellular pleomorphism",
        "artifact, folding, or staining artifact present",
        "normal adipose tissue and connective stroma",
        "scattered inflammatory infiltrate in tissue",
        "ductal carcinoma in situ with comedo necrosis",
        "invasive ductal carcinoma with desmoplastic stroma",
        "lymphocytic infiltration at tumor margins",
        "vascular invasion by malignant cells",
        "glandular structures with cellular atypia",
        "tissue architecture disruption and loss of polarity",
        "chromatin condensation and nuclear enlargement",
        "abnormal mitotic figures present",
        "tissue section shows normal cellular morphology",
        "insufficient tissue or poor image quality for analysis"
    ]
    print("Using enhanced built-in prompt corpus (set CONFIG['CORPUS_PATH'] to use a real corpus).")
else:
    print(f"Loading corpus from {CORPUS_PATH}")
    prompt_corpus = load_corpus(CORPUS_PATH)

print(f"✓ Corpus size: {len(prompt_corpus)} unique texts")

Using enhanced built-in prompt corpus (set CONFIG['CORPUS_PATH'] to use a real corpus).
✓ Corpus size: 20 unique texts


In [10]:
# Clear CUDA cache and reset state
import torch
import gc

print("Clearing CUDA cache and resetting state...")
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()
print("✓ CUDA state cleared")

# Verify model is still functional
print("\nVerifying model state...")
try:
    test_img = torch.randn(1, 3, 224, 224).to(device)
    with torch.no_grad():
        test_feat = model.encode_image(test_img)
    print("✓ Model image encoding works")
except Exception as e:
    print(f"✗ Model may be corrupted: {e}")
    print("⚠ RECOMMENDATION: Restart kernel and re-run from Cell 3")

Clearing CUDA cache and resetting state...
✓ CUDA state cleared

Verifying model state...
✓ Model image encoding works


In [11]:
# DELETE CACHED TEXT EMBEDDINGS (run once to force recomputation with correct dimensions)
import os
from pathlib import Path

cache_files = [
    OUT_DIR / "text_embeddings.npy",
    OUT_DIR / "text_corpus.json"
]

for f in cache_files:
    if f.exists():
        os.remove(f)
        print(f"✓ Deleted {f}")
    else:
        print(f"  {f} does not exist")

print("\n✓ Cache cleared. Cell 7 will now recompute with correct dimensions.")

✓ Deleted breakhis_rag_outputs/text_embeddings.npy
✓ Deleted breakhis_rag_outputs/text_corpus.json

✓ Cache cleared. Cell 7 will now recompute with correct dimensions.


In [12]:
# Cell 7 — Compute text embeddings with dimension alignment
TEXT_EMB_FILE = OUT_DIR / "text_embeddings.npy"
TEXT_CORPUS_FILE = OUT_DIR / "text_corpus.json"

if TEXT_EMB_FILE.exists() and TEXT_CORPUS_FILE.exists():
    print("Loading cached text embeddings...")
    text_embeddings = np.load(str(TEXT_EMB_FILE))
    with open(TEXT_CORPUS_FILE, "r") as f:
        saved_corpus = json.load(f)
    
    if len(saved_corpus) == text_embeddings.shape[0]:
        compute_texts = False
        prompt_corpus = saved_corpus
        print(f"✓ Loaded cached text embeddings: {text_embeddings.shape}")
    else:
        compute_texts = True
else:
    compute_texts = True

if compute_texts:
    print(f"Computing text embeddings for {len(prompt_corpus)} texts...")
    
    # Clean corpus
    prompt_corpus = [t.strip() for t in prompt_corpus if t and t.strip()]
    
    if len(prompt_corpus) == 0:
        raise RuntimeError("Corpus is empty!")
    
    print("Encoding text using BiomedCLIP text encoder (CPU mode)...")
    
    # Move model to CPU for text encoding
    print("Moving model to CPU...")
    model.cpu()
    model.eval()
    
    all_embeddings = []
    batch_size = 16
    
    for i in tqdm(range(0, len(prompt_corpus), batch_size), desc="encode-texts"):
        batch_texts = prompt_corpus[i:i+batch_size]
        
        try:
            with torch.no_grad():
                tokens = tokenize(batch_texts)
                text_features = model.encode_text(tokens)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                all_embeddings.append(text_features.numpy())
                
        except Exception as e:
            print(f"\n⚠ Batch {i} failed: {e}")
            # Fallback: one by one
            for single_text in batch_texts:
                try:
                    with torch.no_grad():
                        tokens = tokenize([single_text])
                        text_features = model.encode_text(tokens)
                        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                        all_embeddings.append(text_features.numpy())
                except Exception as e2:
                    print(f"   ✗ Failed: '{single_text[:40]}...' - {e2}")
                    # Use zero embedding as fallback
                    if hasattr(model, 'text') and hasattr(model.text, 'output_dim'):
                        dim = model.text.output_dim
                    else:
                        dim = 768  # Default text dimension
                    zero_emb = np.zeros((1, dim), dtype=np.float32)
                    all_embeddings.append(zero_emb)
    
    # Move model back to GPU
    print(f"Moving model back to {device}...")
    model.to(device)
    
    if len(all_embeddings) == 0:
        raise RuntimeError("No text embeddings were generated!")
    
    # Concatenate
    text_embeddings_raw = np.concatenate(all_embeddings, axis=0).astype('float32')
    
    # Clean NaN/Inf
    text_embeddings_raw = np.nan_to_num(text_embeddings_raw, nan=0.0, posinf=0.0, neginf=0.0)
    
    print(f"✓ Raw text embeddings shape: {text_embeddings_raw.shape}")
    
    # Check dimension compatibility
    img_dim = image_embeddings.shape[1]
    text_dim = text_embeddings_raw.shape[1]
    
    print(f"  Image embedding dim: {img_dim}D")
    print(f"  Text embedding dim: {text_dim}D")
    
    if img_dim != text_dim:
        print(f"\n⚠ Dimension mismatch detected! Applying projection...")
        
        if text_dim > img_dim:
            # Project down: Use PCA or simple truncation
            print(f"  Projecting text embeddings from {text_dim}D → {img_dim}D using PCA...")
            from sklearn.decomposition import PCA
            
            pca = PCA(n_components=img_dim, random_state=42)
            text_embeddings = pca.fit_transform(text_embeddings_raw)
            
            # Re-normalize after projection
            norms = np.linalg.norm(text_embeddings, axis=1, keepdims=True)
            text_embeddings = text_embeddings / (norms + 1e-8)
            text_embeddings = text_embeddings.astype('float32')
            
            print(f"  ✓ Projected using PCA (explained variance: {pca.explained_variance_ratio_.sum():.3f})")
            
        else:
            # Project up: Pad with zeros
            print(f"  Padding text embeddings from {text_dim}D → {img_dim}D...")
            padding = np.zeros((text_embeddings_raw.shape[0], img_dim - text_dim), dtype=np.float32)
            text_embeddings = np.hstack([text_embeddings_raw, padding])
            
            # Re-normalize
            norms = np.linalg.norm(text_embeddings, axis=1, keepdims=True)
            text_embeddings = text_embeddings / (norms + 1e-8)
            
            print(f"  ✓ Padded with zeros")
    else:
        text_embeddings = text_embeddings_raw
        print("  ✓ Dimensions match, no projection needed")
    
    # Save
    np.save(TEXT_EMB_FILE, text_embeddings)
    with open(TEXT_CORPUS_FILE, "w") as f:
        json.dump(prompt_corpus, f, indent=2)
    
    print(f"✓ Saved text embeddings to {TEXT_EMB_FILE}")

# Final validation
print(f"\n✓ Text embeddings ready: {text_embeddings.shape}")
print(f"  - Corpus size: {len(prompt_corpus)}")
print(f"  - Embedding dim: {text_embeddings.shape[1]}")

# Dimension compatibility check
if image_embeddings.shape[1] != text_embeddings.shape[1]:
    raise RuntimeError(
        f"DIMENSION MISMATCH! "
        f"Image: {image_embeddings.shape[1]}D, Text: {text_embeddings.shape[1]}D"
    )
else:
    print(f"✓ Dimension check PASSED: both are {text_embeddings.shape[1]}D")

Computing text embeddings for 20 texts...
Encoding text using BiomedCLIP text encoder (CPU mode)...
Moving model to CPU...


encode-texts:   0%|          | 0/2 [00:00<?, ?it/s]


⚠ Batch 0 failed: index out of range in self
   ✗ Failed: 'H&E-stained breast tissue with tumor isl...' - index out of range in self
   ✗ Failed: 'benign breast histology with normal duct...' - index out of range in self
   ✗ Failed: 'histopathology slide showing mitotic fig...' - index out of range in self
   ✗ Failed: 'melanocytic lesion with irregular nests...' - index out of range in self
   ✗ Failed: 'necrosis and apoptotic bodies visible in...' - index out of range in self
   ✗ Failed: 'stromal fibrosis and inflammation...' - index out of range in self
   ✗ Failed: 'high mitotic index and cellular pleomorp...' - index out of range in self
   ✗ Failed: 'artifact, folding, or staining artifact ...' - index out of range in self
   ✗ Failed: 'normal adipose tissue and connective str...' - index out of range in self
   ✗ Failed: 'scattered inflammatory infiltrate in tis...' - index out of range in self
   ✗ Failed: 'ductal carcinoma in situ with comedo nec...' - index out of range in

In [18]:
!pip install -q faiss-cpu

In [19]:
# Cell 8 — Build FAISS index over text embeddings (inner product / cosine)
import faiss

text_embeddings = np.ascontiguousarray(text_embeddings.astype("float32"))
dim = text_embeddings.shape[1]

use_gpu = CONFIG.get("USE_FAISS_GPU", False)
gpu_index = None

if use_gpu:
    try:
        # Some FAISS installs need this import to register GPU helpers
        import faiss.contrib.torch_utils  # noqa: F401
        
        if hasattr(faiss, "StandardGpuResources"):
            n_gpus = faiss.get_num_gpus()
            if n_gpus > 0:
                print(f"Building FAISS GpuIndexFlatIP (dim={dim}) on {n_gpus} visible GPU(s)")
                res = faiss.StandardGpuResources()
                gpu_config = faiss.GpuIndexFlatConfig()
                gpu_config.device = torch.cuda.current_device()
                gpu_index = faiss.GpuIndexFlatIP(res, dim, gpu_config)
            else:
                print("FAISS GPU requested but no GPUs detected — falling back to CPU.")
        else:
            print("FAISS install has no GPU symbols — falling back to CPU.")
    except Exception as e:
        print(f"FAISS GPU initialization failed ({e}); falling back to CPU.")

if gpu_index is not None:
    index = gpu_index
else:
    print(f"Building FAISS IndexFlatIP on CPU (dim={dim})")
    index = faiss.IndexFlatIP(dim)

index.add(text_embeddings)
print(f"Index built with {index.ntotal} vectors on {'GPU' if gpu_index is not None else 'CPU'}")

FAISS install has no GPU symbols — falling back to CPU.
Building FAISS IndexFlatIP on CPU (dim=512)
Index built with 20 vectors on CPU


In [20]:
# Cell 9 — Online RAG fallback: PubMed and Wikipedia retrieval functions

def query_pubmed(query: str, email: str, max_results: int = 3) -> List[str]:
    """
    Query PubMed for relevant biomedical abstracts using the NCBI E-utilities API.
    
    Args:
        query: Search query string
        email: Email for NCBI API (required)
        max_results: Maximum number of abstracts to retrieve
    
    Returns:
        List of abstract texts
    """
    try:
        from Bio import Entrez
        Entrez.email = email
        
        # Search PubMed
        search_handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
        search_results = Entrez.read(search_handle)
        search_handle.close()
        
        id_list = search_results["IdList"]
        if not id_list:
            return []
        
        # Fetch abstracts
        fetch_handle = Entrez.efetch(db="pubmed", id=id_list, rettype="abstract", retmode="text")
        abstracts_text = fetch_handle.read()
        fetch_handle.close()
        
        # Split into individual abstracts (basic parsing)
        abstracts = [abs.strip() for abs in abstracts_text.split("\n\n") if len(abs.strip()) > 50]
        return abstracts[:max_results]
    
    except Exception as e:
        print(f"PubMed query failed: {e}")
        return []

def query_pubmed_simple(query: str, email: str, max_results: int = 3) -> List[str]:
    """
    Fallback PubMed query using direct HTTP requests (no biopython dependency).
    """
    try:
        # Search for PMIDs
        search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
        search_params = {
            "db": "pubmed",
            "term": query,
            "retmax": max_results,
            "retmode": "json",
            "email": email
        }
        search_response = requests.get(search_url, params=search_params, timeout=10)
        search_data = search_response.json()
        
        pmids = search_data.get("esearchresult", {}).get("idlist", [])
        if not pmids:
            return []
        
        # Fetch abstracts
        fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
        fetch_params = {
            "db": "pubmed",
            "id": ",".join(pmids),
            "retmode": "text",
            "rettype": "abstract"
        }
        fetch_response = requests.get(fetch_url, params=fetch_params, timeout=10)
        abstracts_text = fetch_response.text
        
        # Basic parsing
        abstracts = [abs.strip() for abs in abstracts_text.split("\n\n") if len(abs.strip()) > 50]
        return abstracts[:max_results]
    
    except Exception as e:
        print(f"PubMed simple query failed: {e}")
        return []

def query_wikipedia(query: str, max_results: int = 2) -> List[str]:
    """
    Query Wikipedia for relevant medical/biomedical content.
    
    Args:
        query: Search query string
        max_results: Maximum number of summaries to retrieve
    
    Returns:
        List of Wikipedia article summaries
    """
    try:
        search_url = "https://en.wikipedia.org/w/api.php"
        search_params = {
            "action": "opensearch",
            "search": query,
            "limit": max_results,
            "namespace": 0,
            "format": "json"
        }
        response = requests.get(search_url, params=search_params, timeout=10)
        data = response.json()
        
        titles = data[1] if len(data) > 1 else []
        summaries = []
        
        for title in titles[:max_results]:
            summary_params = {
                "action": "query",
                "prop": "extracts",
                "exintro": True,
                "explaintext": True,
                "titles": title,
                "format": "json"
            }
            sum_response = requests.get(search_url, params=summary_params, timeout=10)
            sum_data = sum_response.json()
            
            pages = sum_data.get("query", {}).get("pages", {})
            for page_id, page_data in pages.items():
                extract = page_data.get("extract", "")
                if extract:
                    summaries.append(extract[:500])  # Limit length
        
        return summaries
    
    except Exception as e:
        print(f"Wikipedia query failed: {e}")
        return []

def online_rag_fallback(query: str, source: str = "pubmed", config: dict = None) -> Tuple[List[str], str]:
    """
    Perform online RAG retrieval when local corpus is insufficient.
    
    Args:
        query: Search query (typically based on top retrieved texts or image features)
        source: "pubmed" or "wikipedia"
        config: Configuration dictionary
    
    Returns:
        Tuple of (retrieved_texts, source_attribution)
    """
    if not config:
        config = CONFIG
    
    retrieved = []
    attribution = "online_fallback_failed"
    
    if source == "pubmed":
        # Try biopython first, fallback to simple HTTP
        try:
            retrieved = query_pubmed(query, config["PUBMED_EMAIL"], config["PUBMED_MAX_RESULTS"])
        except:
            retrieved = query_pubmed_simple(query, config["PUBMED_EMAIL"], config["PUBMED_MAX_RESULTS"])
        attribution = "pubmed_online" if retrieved else "pubmed_failed"
    
    elif source == "wikipedia":
        retrieved = query_wikipedia(query, max_results=2)
        attribution = "wikipedia_online" if retrieved else "wikipedia_failed"
    
    return retrieved, attribution

print("✓ Online RAG fallback functions loaded")

✓ Online RAG fallback functions loaded


In [21]:
# Cell 10 — Batched retrieval with optional online RAG fallback
CONFIG.setdefault("RETRIEVAL_BATCH", 512)           # Tune for your GPU/CPU memory
CONFIG.setdefault("ONLINE_DEBUG_LIMIT", 10)          # Max fallback debug lines to print

print("=" * 60)
print("RETRIEVAL WITH ONLINE FALLBACK (BATCHED)")
print("=" * 60)

num_images = len(image_embeddings)
batch_size = CONFIG["RETRIEVAL_BATCH"]
top_k = CONFIG["TOP_K"]
similarity_threshold = CONFIG["SIMILARITY_THRESHOLD"]

print(f"\n📊 Processing {num_images} images in batches of {batch_size}")
print(f"   Local corpus size: {len(prompt_corpus)}")
print(f"   Similarity threshold for online fallback: {similarity_threshold}")
print(f"   Online fallback enabled: {CONFIG.get('ONLINE_RAG_ENABLED', False)}")

retrieval_results = []
failed_retrievals = []
online_fallback_count = 0
fallback_debug_limit = CONFIG["ONLINE_DEBUG_LIMIT"]
fallback_debug_shown = 0

for start in tqdm(range(0, num_images, batch_size), desc="Retrieving"):
    end = min(start + batch_size, num_images)
    batch_embs = image_embeddings[start:end].astype("float32")

    # Local FAISS search for the whole chunk
    sims, idxs = index.search(batch_embs, top_k)

    for local_idx, img_idx in enumerate(range(start, end)):
        img_path = image_paths[img_idx]
        top_similarities = sims[local_idx]
        top_indices = idxs[local_idx]

        top_k_prompts = []
        for rank, (idx, sim) in enumerate(zip(top_indices, top_similarities), start=1):
            if 0 <= idx < len(prompt_corpus):
                top_k_prompts.append(
                    {
                        "rank": rank,
                        "text": prompt_corpus[idx],
                        "similarity": float(sim),
                        "source": "local_corpus",
                    }
                )

        max_similarity = float(top_similarities[0]) if len(top_similarities) else 0.0
        used_online = False

        if (
            CONFIG.get("ONLINE_RAG_ENABLED", False)
            and max_similarity < similarity_threshold
        ):
            query_text = top_k_prompts[0]["text"] if top_k_prompts else "medical imaging"
            online_docs, attribution = online_rag_fallback(
                query_text,
                source=CONFIG.get("ONLINE_SOURCE", "pubmed"),
                config=CONFIG,
            )
            if online_docs:
                used_online = True
                online_fallback_count += 1
                for rank, doc in enumerate(online_docs, start=len(top_k_prompts) + 1):
                    top_k_prompts.append(
                        {
                            "rank": rank,
                            "text": doc,
                            "similarity": None,
                            "source": attribution,
                        }
                    )
                if (
                    max_similarity > 0
                    and fallback_debug_shown < fallback_debug_limit
                ):
                    print(
                        f"\n⚠️  {Path(img_path).name}: similarity "
                        f"{max_similarity:.3f} -> added {len(online_docs)} "
                        f"{attribution} docs"
                    )
                    fallback_debug_shown += 1

        retrieval_results.append(
            {
                "image_path": img_path,
                "image_name": Path(img_path).name,
                "max_similarity": max_similarity,
                "used_online_rag": used_online,
                "top_k_prompts": top_k_prompts,
                "num_prompts": len(top_k_prompts),
            }
        )

print("\n" + "=" * 60)
print("RETRIEVAL SUMMARY")
print("=" * 60)
print(f"Total images processed: {len(retrieval_results)}")
print(f"Images with prompts: {sum(r['num_prompts'] > 0 for r in retrieval_results)}")
print(
    f"Online fallback used: {online_fallback_count} "
    f"({100 * online_fallback_count / num_images:.1f}%)"
)

valid_results = [r for r in retrieval_results if r["num_prompts"] > 0]
if valid_results:
    similarities = [r["max_similarity"] for r in valid_results]
    print("\nSimilarity statistics for images with prompts:")
    print(f"  Mean:   {np.mean(similarities):.3f}")
    print(f"  Median: {np.median(similarities):.3f}")
    print(f"  Min:    {np.min(similarities):.3f}")
    print(f"  Max:    {np.max(similarities):.3f}")

prompt_counts = [r["num_prompts"] for r in retrieval_results]
print("\nPrompt count statistics:")
print(f"  Mean:   {np.mean(prompt_counts):.1f}")
print(f"  Median: {np.median(prompt_counts):.1f}")
print(f"  Min:    {np.min(prompt_counts)}")
print(f"  Max:    {np.max(prompt_counts)}")

print("\n" + "=" * 60)
print("SAMPLE RESULTS (first 3)")
print("=" * 60)
for result in retrieval_results[:3]:
    print(f"\nImage: {result['image_name']}")
    print(f"  Max similarity: {result['max_similarity']:.3f}")
    print(f"  Num prompts: {result['num_prompts']}")
    print(f"  Used online fallback: {result['used_online_rag']}")
    if result["num_prompts"]:
        print(f"  Top prompt: {result['top_k_prompts'][0]['text'][:120]}...")
    else:
        print("  ⚠️  No prompts retrieved")

print("\n" + "=" * 60)
print("SAVING RESULTS")
print("=" * 60)

results_jsonl = OUT_DIR / "retrieval_results.jsonl"
results_csv = OUT_DIR / "retrieval_summary.csv"

try:
    with open(results_jsonl, "w") as f:
        for entry in retrieval_results:
            f.write(json.dumps(entry) + "\n")
    print(f"✓ Saved detailed JSONL to {results_jsonl}")

    df_summary = pd.DataFrame(
        {
            "image_name": [r["image_name"] for r in retrieval_results],
            "max_similarity": [r["max_similarity"] for r in retrieval_results],
            "num_prompts": [r["num_prompts"] for r in retrieval_results],
            "used_online_rag": [r["used_online_rag"] for r in retrieval_results],
            "has_error": ["error" in r for r in retrieval_results],
        }
    )
    df_summary.to_csv(results_csv, index=False)
    print(f"✓ Saved CSV summary to {results_csv}")
    print("\nCSV preview:")
    display(df_summary.head())
except Exception as e:
    print(f"❌ Error saving results: {e}")

print("\n" + "=" * 60)
print("RETRIEVAL PIPELINE COMPLETE")
print("=" * 60)

RETRIEVAL WITH ONLINE FALLBACK (BATCHED)

📊 Processing 15818 images in batches of 512
   Local corpus size: 20
   Similarity threshold for online fallback: 0.3
   Online fallback enabled: True


Retrieving:   0%|          | 0/31 [00:00<?, ?it/s]

KeyError: 'PUBMED_EMAIL'

In [None]:
# Cell 11 — LLM Synthesis: generate grounded captions with explainability and uncertainty handling

USE_OPENAI = os.environ.get("OPENAI_API_KEY") is not None

if USE_OPENAI:
    print("OpenAI API key detected; will use OpenAI for synthesis.")
else:
    print("No OpenAI key detected — using local transformers LLM for synthesis.")

# load local LLM model (transformers pipeline)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

LLM_ID = CONFIG["LLM_ID"]
try:
    print(f"Loading LLM for local synthesis: {LLM_ID}")
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_ID)
    llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_ID).to(device)
    llm_pipe = pipeline("text2text-generation", model=llm_model, tokenizer=llm_tokenizer, 
                        device=0 if device == "cuda" else -1)
    print("✓ Local LLM loaded.")
except Exception as e:
    print(f"Failed to load local LLM: {e}")
    print("Falling back to CPU inference...")
    try:
        llm_tokenizer = AutoTokenizer.from_pretrained(LLM_ID)
        llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_ID).to("cpu")
        llm_pipe = pipeline("text2text-generation", model=llm_model, tokenizer=llm_tokenizer, device=-1)
    except Exception as e2:
        raise RuntimeError(f"Failed to initialize LLM: {e2}")

# Enhanced synthesis function with explainability
def synthesize_caption_with_explainability(
    retrieved_texts: List[str],
    img_path: str,
    max_similarity: float,
    source_attribution: str,
    llm_pipeline,
    min_confidence: float = 0.25,
    max_length: int = 100
) -> Dict[str, str]:
    """
    Generate a grounded caption with explicit uncertainty handling.
    
    Returns a dict with:
        - caption: The generated description
        - confidence_level: "high", "medium", "low", or "uncertain"
        - explanation: Why this confidence level was assigned
        - sources: Where information came from
    """
    
    # Determine confidence and adjust prompt accordingly
    if max_similarity < min_confidence:
        confidence_level = "uncertain"
        explanation = f"Low similarity score ({max_similarity:.3f}) indicates uncertain match with available knowledge"
        caption_prefix = "[UNCERTAIN] "
        instruction_modifier = "The model has LOW CONFIDENCE. State uncertainty clearly and describe only what can be minimally inferred."
    elif max_similarity < CONFIG["SIMILARITY_THRESHOLD"]:
        confidence_level = "low"
        explanation = f"Moderate similarity score ({max_similarity:.3f}) suggests limited match with corpus"
        caption_prefix = "[LOW CONFIDENCE] "
        instruction_modifier = "The model has MODERATE CONFIDENCE. Describe observable features cautiously."
    elif max_similarity < 0.5:
        confidence_level = "medium"
        explanation = f"Good similarity score ({max_similarity:.3f}) indicates reasonable match"
        caption_prefix = ""
        instruction_modifier = "The model has GOOD CONFIDENCE. Describe observable features."
    else:
        confidence_level = "high"
        explanation = f"High similarity score ({max_similarity:.3f}) indicates strong match with corpus"
        caption_prefix = ""
        instruction_modifier = "The model has HIGH CONFIDENCE. Describe observable features comprehensively."
    
    # Compose prompt with explainability
    context = "\n---\n".join(retrieved_texts[:5])  # Limit context size
    
    prompt = (
        f"You are a medical image analysis assistant. {instruction_modifier}\n\n"
        f"Image: {Path(img_path).name}\n"
        f"Reference texts (source: {source_attribution}):\n{context}\n\n"
        f"Task: Provide a concise, factual description (1-2 sentences) of observable features.\n"
        f"- Do NOT diagnose or make clinical recommendations\n"
        f"- Do NOT hallucinate features not supported by references\n"
        f"- If uncertain or if image quality is poor, explicitly state: 'Unable to determine specific features' or 'Image quality insufficient'\n"
        f"- Focus on: tissue structure, cellular morphology, staining characteristics\n\n"
        f"Description:"
    )
    
    try:
        out = llm_pipeline(prompt, max_length=max_length, do_sample=False, truncation=True)[0]["generated_text"]
        caption = out.strip()
        
        # Add uncertainty prefix if needed
        if caption and not caption.lower().startswith(("unable", "insufficient", "cannot", "unclear")):
            caption = caption_prefix + caption
        elif not caption:
            caption = "[UNCERTAIN] Unable to generate description from available information."
        
    except Exception as e:
        print(f"LLM generation error for {img_path}: {e}")
        caption = "[ERROR] Caption generation failed."
        confidence_level = "uncertain"
        explanation = f"LLM generation error: {str(e)}"
    
    return {
        "caption": caption,
        "confidence_level": confidence_level,
        "explanation": explanation,
        "sources": source_attribution,
        "max_similarity_score": float(max_similarity)
    }

# Run synthesis with explainability
SYNTH_JSON = OUT_DIR / "synthesized_captions.jsonl"
use_sample = False
max_images = 500  # Adjust based on compute
num_images = min(len(image_paths), max_images) if use_sample else len(image_paths)
print(f"Will synthesize captions for {num_images} images (use_sample={use_sample}).")

# Load retrieval results
retrieval_results = []
with open(RESULTS_JSON, "r") as f:
    for line in f:
        retrieval_results.append(json.loads(line))

with open(SYNTH_JSON, "w") as outf:
    for i in tqdm(range(num_images), desc="synthesize-captions"):
        ret_obj = retrieval_results[i]
        imgp = ret_obj["image_path"]
        retrieved = ret_obj["top_k_prompts"]
        max_sim = ret_obj["max_similarity"]
        source_attr = ret_obj["source_attribution"]
        
        # Generate caption with explainability
        result = synthesize_caption_with_explainability(
            retrieved_texts=retrieved,
            img_path=imgp,
            max_similarity=max_sim,
            source_attribution=source_attr,
            llm_pipeline=llm_pipe,
            min_confidence=CONFIG["MIN_CONFIDENCE_THRESHOLD"]
        )
        
        obj = {
            "image_path": imgp,
            "generated_caption": result["caption"],
            "confidence_level": result["confidence_level"],
            "explanation": result["explanation"],
            "sources": result["sources"],
            "max_similarity_score": result["max_similarity_score"],
            "retrieved_prompts": retrieved,
            "retrieved_scores": ret_obj["top_k_scores"]
        }
        outf.write(json.dumps(obj) + "\n")

print(f"✓ Saved synthesized captions to {SYNTH_JSON}")

In [None]:
# Cell 12 — Re-score generated caption against image using CLIP similarity
print("Computing image-to-generated-caption similarity for synthesized captions...")
res_rows = []
with open(SYNTH_JSON, "r") as f:
    for line in tqdm(f, desc="rescore-captions"):
        obj = json.loads(line)
        cap = obj.get("generated_caption", "")
        if not cap:
            sim = float('nan')
        else:
            try:
                tok = tokenize([cap]).to(device)
                with torch.no_grad():
                    tf = model.encode_text(tok)
                    tf = tf / tf.norm(dim=-1, keepdim=True)
                    # get image embedding (reload from cache per image index)
                    img_idx = image_paths.index(obj["image_path"])
                    img_vec = torch.from_numpy(image_embeddings[img_idx]).unsqueeze(0).to(device)
                    sim_t = (img_vec @ tf.T).item()
                    sim = float(sim_t)
            except Exception as e:
                print(f"Rescore failed for {obj['image_path']}: {e}")
                sim = float('nan')
        
        res_rows.append({
            "image_path": obj["image_path"],
            "generated_caption": cap,
            "gen_clip_similarity": sim,
            "confidence_level": obj.get("confidence_level", "unknown"),
            "sources": obj.get("sources", "unknown"),
            "max_similarity_score": obj.get("max_similarity_score", 0.0)
        })

pd.DataFrame(res_rows).to_csv(OUT_DIR / "synthesized_captions_rescored.csv", index=False)
print(f"✓ Saved rescoring CSV with {len(res_rows)} entries")

In [None]:
# Cell 13 — Visualizations: histogram, sample gallery, heatmap, TSNE, confidence distribution
print("Creating enhanced visualizations...")

# Load synthesized results
with open(SYNTH_JSON, "r") as f:
    synth_all = [json.loads(l) for l in f]

# 1. Histogram of top1 retrieval scores with confidence threshold
max_scores = np.array([r['max_similarity_score'] for r in synth_all], dtype=np.float32)
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.hist(max_scores, bins=40, alpha=0.7, edgecolor='black')
plt.axvline(CONFIG["SIMILARITY_THRESHOLD"], color='red', linestyle='--', 
            label=f'Confidence threshold ({CONFIG["SIMILARITY_THRESHOLD"]})')
plt.axvline(CONFIG["MIN_CONFIDENCE_THRESHOLD"], color='orange', linestyle='--',
            label=f'Min threshold ({CONFIG["MIN_CONFIDENCE_THRESHOLD"]})')
plt.title('Distribution of Top-1 Retrieval Similarity Scores')
plt.xlabel('Cosine Similarity')
plt.ylabel('Count')
plt.legend()

# 2. Confidence level distribution
plt.subplot(1, 2, 2)
confidence_counts = pd.Series([r['confidence_level'] for r in synth_all]).value_counts()
confidence_counts.plot(kind='bar', color=['green', 'yellow', 'orange', 'red'])
plt.title('Caption Confidence Level Distribution')
plt.xlabel('Confidence Level')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# 3. Sample gallery with generated captions (color-coded by confidence)
sample_n = min(24, len(synth_all))
idxs = list(range(sample_n))
cols = 6
rows = math.ceil(len(idxs) / cols)
plt.figure(figsize=(cols * 2.5, rows * 2.5))

confidence_colors = {
    'high': 'green',
    'medium': 'blue',
    'low': 'orange',
    'uncertain': 'red'
}

for i, j in enumerate(idxs):
    r = synth_all[j]
    try:
        img = Image.open(r['image_path']).convert('RGB').resize((224, 224))
    except Exception:
        continue
    
    ax = plt.subplot(rows, cols, i + 1)
    ax.imshow(img)
    ax.axis('off')
    
    caption = r['generated_caption']
    conf = r.get('confidence_level', 'unknown')
    title = (caption[:60] + '...') if len(caption) > 60 else caption
    
    ax.set_title(title, fontsize=6, color=confidence_colors.get(conf, 'black'))
    
    # Add confidence badge
    ax.text(0.05, 0.95, conf.upper(), transform=ax.transAxes,
            fontsize=6, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor=confidence_colors.get(conf, 'gray'), alpha=0.7))

plt.suptitle('Sample Images with Generated Captions (Color-coded by Confidence)', fontsize=10)
plt.tight_layout()
plt.show()

# 4. Similarity heatmap
subset = min(32, image_embeddings.shape[0])
sub_sims = np.matmul(image_embeddings[:subset], text_embeddings.T)
plt.figure(figsize=(10, 6))
plt.imshow(sub_sims, aspect='auto', cmap='viridis')
plt.colorbar(label='Cosine Similarity')
plt.title('Similarity Heatmap (First 32 Images × Corpus Prompts)')
plt.xlabel('Prompt Index')
plt.ylabel('Image Index')
plt.tight_layout()
plt.show()

# 5. TSNE with confidence coloring
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

sample_n = min(200, image_embeddings.shape[0])
idx_sample = np.random.RandomState(0).choice(image_embeddings.shape[0], size=sample_n, replace=False)
emb_sub = image_embeddings[idx_sample]

pca = PCA(n_components=50, random_state=0).fit_transform(emb_sub)
tsne = TSNE(n_components=2, perplexity=30, random_state=0, init='pca', n_iter=800)
proj = tsne.fit_transform(pca)

# Get confidence levels for sampled images
conf_levels = [synth_all[i]['confidence_level'] if i < len(synth_all) else 'unknown' 
               for i in idx_sample]

plt.figure(figsize=(10, 7))
for conf, color in confidence_colors.items():
    mask = np.array([c == conf for c in conf_levels])
    if mask.any():
        pts = proj[mask]
        plt.scatter(pts[:, 0], pts[:, 1], s=20, label=conf, color=color, alpha=0.6)

plt.legend(title='Confidence Level')
plt.title('t-SNE of Image Embeddings (Colored by Caption Confidence)')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.tight_layout()
plt.show()

print("✓ All visualizations generated")

In [None]:
# Cell 14 — Save comprehensive run metadata and pipeline summary
import time

# Calculate statistics
with open(SYNTH_JSON, "r") as f:
    synth_data = [json.loads(l) for l in f]

confidence_dist = pd.Series([r['confidence_level'] for r in synth_data]).value_counts().to_dict()
source_dist = pd.Series([r['sources'] for r in synth_data]).value_counts().to_dict()
online_fallback_used = sum(1 for r in synth_data if 'online' in r.get('sources', ''))

meta = {
    "pipeline_info": {
        "version": "2.0_enhanced",
        "date": time.asctime(),
        "description": "Data-agnostic VLM pipeline with RAG, online fallback, and explainability"
    },
    "dataset": {
        "dataset_dir": str(DATASET_DIR),
        "n_images": image_embeddings.shape[0],
        "n_processed": len(synth_data)
    },
    "corpus": {
        "corpus_path": str(CORPUS_PATH) if CORPUS_PATH else "built-in",
        "n_prompts": len(prompt_corpus)
    },
    "models": {
        "vision_encoder": CONFIG['BIOMEDCLIP_HF_ID'],
        "llm": CONFIG['LLM_ID'],
        "device": CONFIG['DEVICE']
    },
    "rag_config": {
        "top_k": TOP_K,
        "similarity_threshold": CONFIG["SIMILARITY_THRESHOLD"],
        "min_confidence_threshold": CONFIG["MIN_CONFIDENCE_THRESHOLD"],
        "online_rag_enabled": ONLINE_RAG_ENABLED,
        "online_source": CONFIG["ONLINE_SOURCE"]
    },
    "statistics": {
        "confidence_distribution": confidence_dist,
        "source_distribution": source_dist,
        "online_fallback_used": online_fallback_used,
        "online_fallback_percentage": f"{100 * online_fallback_used / len(synth_data):.2f}%"
    },
    "output_files": {
        "image_embeddings": str(IMAGE_EMB_FILE),
        "text_embeddings": str(TEXT_EMB_FILE),
        "retrieval_results": str(RESULTS_JSON),
        "synthesized_captions": str(SYNTH_JSON),
        "rescored_csv": str(OUT_DIR / "synthesized_captions_rescored.csv")
    }
}

with open(OUT_DIR / "run_metadata.json", "w") as f:
    json.dump(meta, f, indent=2)

print("=" * 80)
print("✓ PIPELINE COMPLETED SUCCESSFULLY")
print("=" * 80)
print(f"\nOutputs saved to: {OUT_DIR}")
print(f"\nKey Statistics:")
print(f"  - Images processed: {meta['dataset']['n_processed']}")
print(f"  - Corpus size: {meta['corpus']['n_prompts']}")
print(f"  - Online fallback used: {meta['statistics']['online_fallback_percentage']}")
print(f"\nConfidence Distribution:")
for conf, count in meta['statistics']['confidence_distribution'].items():
    print(f"  - {conf}: {count} ({100*count/len(synth_data):.1f}%)")
print(f"\nSource Distribution:")
for source, count in meta['statistics']['source_distribution'].items():
    print(f"  - {source}: {count}")
print("\n" + "=" * 80)
print("NOTES:")
print("  - This pipeline is data-agnostic and works with any medical image dataset")
print("  - Online RAG provides fallback when local corpus lacks relevant information")
print("  - Explainability features ensure transparency in caption generation")
print("  - Low confidence outputs are explicitly marked for human review")
print("  - Always validate outputs with domain experts before clinical use")
print("=" * 80)