In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_PT-14-21998AB-200-014.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_F-14-23222AB-200-007.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_A-14-29960CD-200-012.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_TA-14-16184-200-002.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_PT-14-21998AB-200-021.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_TA-14-16184-200-018.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_F-14-21998CD-200-019.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_TA-14-16184-200-006.png
/kaggle/input/breast-cancer-dataset-from-breakhis/fold1/fold1/test/200X/B_200X/SOB_B_PT-14-21998AB-200-023.png
/kaggle/in

In [15]:
# ==========================================================
# Cell 0 — Install dependencies (run once)
# ==========================================================
try:
    import open_clip
except Exception:
    import sys
    # Install specific compatible versions
    !pip install -q --upgrade open-clip-torch==2.23.0 faiss-cpu transformers sentence-transformers tqdm matplotlib scikit-learn


In [16]:
# Cell 1 — Imports, basic setup, and utilities
import os
import json
import time
import math
import random
from pathlib import Path
from typing import List


from PIL import Image
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt


import torch
import open_clip
from open_clip import tokenize


# small utilities
def safe_makedir(p):
    Path(p).mkdir(parents=True, exist_ok=True)


# seed for reproducibility (best-effort)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7c5020195010>

In [17]:

CONFIG = {
    # Set your BreakHis dataset root here (Kaggle input path). Example: "/kaggle/input/breakhis".
    "BREAKHIS_DIR": "/kaggle/input/breast-cancer-dataset-from-breakhis/",
        
    # Optional: path to a corpus (txt file) or a folder with .txt files; if None, a small sample corpus is used.
    "CORPUS_PATH": None, # e.g., "/kaggle/input/biomed-corpus/corpus.txt"
    
    # Output directory (local to notebook workspace)
    "OUT_DIR": "./breakhis_rag_outputs",
    
    # Model choices
    "BIOMEDCLIP_HF_ID": "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
    # Lightweight LLM for synthesis 
    "LLM_ID": "google/flan-t5-small", # small & practical for Kaggle; swap if you have more GPU
    
    # Runtime & behavior
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_IMAGE": 8,
    "BATCH_TEXT": 128,
    "TOP_K": 5,
    }

# Ensure output dir exists
OUT_DIR = Path(CONFIG["OUT_DIR"])
safe_makedir(OUT_DIR)

print("CONFIG summary:")
for k,v in CONFIG.items():
    print(f" {k}: {v}")

CONFIG summary:
 BREAKHIS_DIR: /kaggle/input/breast-cancer-dataset-from-breakhis/
 CORPUS_PATH: None
 OUT_DIR: ./breakhis_rag_outputs
 BIOMEDCLIP_HF_ID: hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
 LLM_ID: google/flan-t5-small
 DEVICE: cuda
 BATCH_IMAGE: 8
 BATCH_TEXT: 128
 TOP_K: 5


In [18]:
# Cell 3 — Load BiomedCLIP (OpenCLIP) with safe fallbacks
device = CONFIG["DEVICE"]
print("Using device:", device)


# Helper to load model and preprocess
def load_biomedclip(hf_id: str, device: str = "cpu"):
    """Try HF-aware loader first; fallback to built-in open_clip if needed.
    Returns: model, preprocess_callable
    """
    try:
        # HF-aware convenience function
        model, preprocess = open_clip.create_model_from_pretrained(hf_id, device=device)
        # create_model_from_pretrained may return model on CPU by default; move to device if GPU requested
        model = model.to(device).eval()
        print("Loaded BiomedCLIP via create_model_from_pretrained()")
        return model, preprocess
    except Exception as e:
        print("create_model_from_pretrained failed:", e)
        print("Attempting fallback: create_model_and_transforms with a built-in config (weights may not be BiomedCLIP)")
        try:
            # Fallback loads a generic ViT model (weights from laion or similar) — not ideal but keeps pipeline runnable
            model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms("ViT-B-16", pretrained="openai")
            model = model.to(device).eval()
            print("Loaded fallback ViT-B-16 (not BiomedCLIP weights). Consider fixing open-clip version or network access.")
            return model, preprocess_val
        except Exception as e2:
            raise RuntimeError(f"Failed to load model via open_clip: primary error: {e}; fallback error: {e2}")


# Load model
try:
    model, preprocess = load_biomedclip(CONFIG["BIOMEDCLIP_HF_ID"], device=device)
except Exception as e:
    # If model load fails fatally, provide a helpful message and abort pipeline
    raise RuntimeError(f"BiomedCLIP load failed: {e}. Restart kernel and ensure open-clip-torch >= 2.23.0 is installed and internet access is available.")


# Ensure model has encode_image and encode_text
if not hasattr(model, "encode_image") or not hasattr(model, "encode_text"):
    raise RuntimeError("Loaded model does not expose encode_image/encode_text APIs expected by the pipeline.")

print("Model loaded; preprocess callable available:", callable(preprocess))

Using device: cuda
Loaded BiomedCLIP via create_model_from_pretrained()
Model loaded; preprocess callable available: True


In [21]:
# Cell 4 — Collect BreakHis image paths (flexible)
BREAKHIS_DIR = Path(CONFIG["BREAKHIS_DIR"]).expanduser()
if not BREAKHIS_DIR.exists():
    raise FileNotFoundError(f"BREAKHIS_DIR not found at {BREAKHIS_DIR}. Update CONFIG['BREAKHIS_DIR'] to the correct mount path.")


# gather images recursively
def collect_images(root: Path) -> List[str]:
    exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
    files = [str(p) for p in root.rglob("*") if p.suffix.lower() in exts]
    files = sorted(files)
    return files


image_paths = collect_images(BREAKHIS_DIR)
if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {BREAKHIS_DIR}. Check dataset structure.")


print(f"Found {len(image_paths)} images under {BREAKHIS_DIR}")

Found 15818 images under /kaggle/input/breast-cancer-dataset-from-breakhis


In [22]:
# Cell 5 — Compute (or load cached) image embeddings
IMAGE_EMB_FILE = OUT_DIR / "image_embeddings.npy"
IMAGE_PATHS_FILE = OUT_DIR / "image_paths.json"


if IMAGE_EMB_FILE.exists() and IMAGE_PATHS_FILE.exists():
    print("Loading cached image embeddings...")
    image_embeddings = np.load(str(IMAGE_EMB_FILE))
    with open(IMAGE_PATHS_FILE, "r") as f:
        saved_paths = json.load(f)
    # Basic sanity check
    if len(saved_paths) != image_embeddings.shape[0]:
        print("Warning: cached embeddings length mismatch; recomputing embeddings")
        compute_images = True
    else:
        compute_images = False
else:
    compute_images = True


if compute_images:
    print("(Re)computing image embeddings. This may take time depending on dataset size and device.")
    batch = CONFIG["BATCH_IMAGE"] if device=="cuda" else 1 # smaller CPU batch
    emb_list = []
    paths_for_cache = []
    model.eval()
    from torch.utils.data import DataLoader

    for i in tqdm(range(0, len(image_paths), batch), desc="embed-images"):
        batch_paths = image_paths[i:i+batch]
        imgs = []
        for p in batch_paths:
            try:
                im = Image.open(p).convert("RGB")
                img_t = preprocess(im)
                imgs.append(img_t)
                paths_for_cache.append(p)
            except Exception as e:
                print(f"Warning: failed to open image {p}: {e}")
        if len(imgs)==0:
            continue
        tensor = torch.stack(imgs).to(device)
        with torch.no_grad():
            feats = model.encode_image(tensor)
            feats = feats / feats.norm(dim=-1, keepdim=True)
        emb_list.append(feats.cpu().numpy())
    if len(emb_list)==0:
        raise RuntimeError("No embeddings computed; check image reading step.")
    image_embeddings = np.concatenate(emb_list, axis=0).astype("float32")
    np.save(IMAGE_EMB_FILE, image_embeddings)
    with open(IMAGE_PATHS_FILE, "w") as f:
        json.dump(paths_for_cache, f)
    print("Saved image embeddings to", IMAGE_EMB_FILE)


# sanity
print("Image embeddings shape:", image_embeddings.shape)

Loading cached image embeddings...
Image embeddings shape: (15818, 512)


In [23]:
# Cell 6 — Prepare text corpus: read user corpus or use small builtin sample
CORPUS_PATH = CONFIG["CORPUS_PATH"]

def load_corpus(corpus_path):
    corpus = []
    p = Path(corpus_path)
    if p.is_file():
        # assume one sentence per line
        with open(p, "r", encoding="utf-8") as f:
            for line in f:
                t = line.strip()
                if t:
                    corpus.append(t)
    elif p.is_dir():
        # read all txt inside
        for txt in p.rglob("*.txt"):
            with open(txt, "r", encoding="utf-8") as f:
                for line in f:
                    t = line.strip()
                    if t:
                        corpus.append(t)
    else:
        raise FileNotFoundError(f"CORPUS_PATH {corpus_path} not found")
    # deduplicate
    seen = set(); out = []
    for s in corpus:
        if s not in seen:
            seen.add(s); out.append(s)
    return out

if CORPUS_PATH is None:
    # small built-in biomedical-style corpus (placeholder). Replace with a real corpus for production.
    prompt_corpus = [
        "H&E-stained breast tissue with tumor islands and pleomorphism",
        "benign breast histology with normal ducts",
        "histopathology slide showing mitotic figures and irregular nuclei",
        "melanocytic lesion with irregular nests",
        "necrosis and apoptotic bodies",
        "stromal fibrosis and inflammation",
        "high mitotic index and pleomorphism",
        "artifact / folding / staining artifact",
        "normal adipose tissue and connective stroma",
        "scattered inflammatory infiltrate"
    ]
    print("Using small built-in prompt corpus (set CONFIG['CORPUS_PATH'] to use a real corpus).")
else:
    print("Loading corpus from", CORPUS_PATH)
    prompt_corpus = load_corpus(CORPUS_PATH)

print("Corpus size:", len(prompt_corpus))


Using small built-in prompt corpus (set CONFIG['CORPUS_PATH'] to use a real corpus).
Corpus size: 10


In [24]:

# Cell 8 — Build FAISS index over text embeddings (inner product / cosine)
import faiss

# ensure contiguous
text_embeddings = np.ascontiguousarray(text_embeddings.astype('float32'))
D = text_embeddings.shape[1]
print("Building FAISS index (IndexFlatIP) with dim", D)
index = faiss.IndexFlatIP(D)
index.add(text_embeddings)
print("Index built. nlist:=", index.ntotal)

NameError: name 'text_embeddings' is not defined

In [None]:

# Cell 9 — Retrieval: query top-k prompts for every image (batched)
TOP_K = CONFIG["TOP_K"]
print("Retrieving top-{} prompts for each image...".format(TOP_K))

# prepare arrays
N = image_embeddings.shape[0]
all_top_idx = np.zeros((N, TOP_K), dtype=np.int32)
all_top_scores = np.zeros((N, TOP_K), dtype=np.float32)

# faiss search requires float32 contiguous
img_embs = np.ascontiguousarray(image_embeddings.astype('float32'))
B = 256
for i in tqdm(range(0, N, B), desc="faiss-query-batches"):
    q = img_embs[i:i+B]
    Dscores, I = index.search(q, TOP_K)
    all_top_idx[i:i+B] = I
    all_top_scores[i:i+B] = Dscores

# Build results objects and save to JSONL + CSV
print("Building result objects and saving outputs...")
RESULTS_JSON = OUT_DIR / "retrieval_results.jsonl"
rows = []
with open(RESULTS_JSON, "w") as outf:
    for i, img_path in enumerate(tqdm(image_paths, desc="write-results")):
        idxs = all_top_idx[i].tolist()
        scores = all_top_scores[i].tolist()
        prompts = [prompt_corpus[j] for j in idxs]
        obj = {
            "image_path": img_path,
            "top_k_prompts": prompts,
            "top_k_scores": scores
        }
        outf.write(json.dumps(obj) + "\n")
        rows.append({
            "image_path": img_path,
            "top1_prompt": prompts[0],
            "top1_score": scores[0]
        })
# Save CSV summary
pd.DataFrame(rows).to_csv(OUT_DIR/"retrieval_summary.csv", index=False)
print("Saved", RESULTS_JSON, "and retrieval_summary.csv")


In [None]:
# Cell 10 — LLM Synthesis: generate grounded captions from retrieved context
# Two modes:
#  - LOCAL LLM via transformers (default; uses flan-t5-small)
#  - EXTERNAL LLM via OpenAI API (if you set OPENAI_API_KEY env var and prefer that)

USE_OPENAI = os.environ.get("OPENAI_API_KEY") is not None

if USE_OPENAI:
    print("OpenAI API key detected; will use OpenAI for synthesis. Make sure 'openai' library is installed and API key set in env.")
else:
    print("No OpenAI key detected — using local transformers LLM (flan-t5-small) for synthesis by default.")

# load local LLM model (transformers pipeline)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
LLM_ID = CONFIG["LLM_ID"]
try:
    print("Loading LLM for local synthesis:", LLM_ID)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_ID)
    llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_ID).to(device)
    llm_pipe = pipeline("text2text-generation", model=llm_model, tokenizer=llm_tokenizer, device=0 if device=="cuda" else -1)
    print("Local LLM loaded.")
except Exception as e:
    print("Failed to load local LLM (transformers). Error:", e)
    print("Falling back to a very small local generator (may be slow). Will try CPU inference.")
    try:
        llm_tokenizer = AutoTokenizer.from_pretrained(LLM_ID)
        llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_ID)
        llm_model = llm_model.to("cpu")
        llm_pipe = pipeline("text2text-generation", model=llm_model, tokenizer=llm_tokenizer, device=-1)
    except Exception as e2:
        raise RuntimeError(f"Failed to initialize LLM for synthesis: {e}; fallback error: {e2}")

# Synthesis function: create a prompt with retrieved docs and ask LLM to summarize salient observable features
def synthesize_caption(retrieved_texts: List[str], img_path: str, llm_pipeline, max_length=80):
    # Compose a concise instruction that discourages hallucination and diagnostics
    context = "\n---\n".join(retrieved_texts)
    prompt = (
        "You are a clinical note assistant. Given the following biomedical text excerpts (references) and an image filepath,\n"
        "produce a concise, factual description (1-2 sentences) of observable histopathological features present in the image.\n"
        "Do NOT provide a diagnosis; only describe observable features. If uncertain, state uncertainty.\n\n"
        f"Image path: {img_path}\n\n"
        "References:\n" + context + "\n\nOutput:"
    )
    try:
        out = llm_pipeline(prompt, max_length=max_length, do_sample=False)[0]["generated_text"]
        return out.strip()
    except Exception as e:
        print("LLM generation error for image", img_path, ":", e)
        return ""

# Run synthesis for a sample subset or full dataset depending on compute
SYNTH_JSON = OUT_DIR / "synthesized_captions.jsonl"
use_sample = False
max_images = 500  # safe default; you can increase this if you have long runtime & GPU
num_images = min(len(image_paths), max_images) if use_sample else len(image_paths)
print(f"Will synthesize captions for {num_images} images (use_sample={use_sample}).")

with open(SYNTH_JSON, "w") as outf:
    for i in tqdm(range(num_images), desc="synthesize-captions"):
        idxs = all_top_idx[i].tolist()
        retrieved = [prompt_corpus[j] for j in idxs]
        imgp = image_paths[i]
        caption = synthesize_caption(retrieved, imgp, llm_pipe)
        obj = {
            "image_path": imgp,
            "generated_caption": caption,
            "retrieved_prompts": retrieved,
            "retrieved_scores": all_top_scores[i].tolist()
        }
        outf.write(json.dumps(obj) + "\n")

print("Saved synthesized captions to", SYNTH_JSON)


In [None]:

# Cell 11 — (Optional) Re-score generated caption against image using CLIP similarity
# This helps verify how well the generated caption aligns with image embedding
print("Computing image-to-generated-caption similarity for synthesized captions...")
res_rows = []
with open(SYNTH_JSON, "r") as f:
    for line in tqdm(f, desc="rescore-captions"):
        obj = json.loads(line)
        cap = obj.get("generated_caption", "")
        if not cap:
            sim = float('nan')
        else:
            try:
                tok = tokenize([cap]).to(device)
                with torch.no_grad():
                    tf = model.encode_text(tok)
                    tf = tf / tf.norm(dim=-1, keepdim=True)
                    # get image embedding (reload from cache per image index)
                    # find image index
                    img_idx = image_paths.index(obj["image_path"])  # O(N) but acceptable for moderate sizes
                    img_vec = torch.from_numpy(image_embeddings[img_idx]).unsqueeze(0).to(device)
                    sim_t = (img_vec @ tf.T).item()
                    sim = float(sim_t)
            except Exception as e:
                print("Rescore failed for", obj["image_path"], e)
                sim = float('nan')
        res_rows.append({
            "image_path": obj["image_path"],
            "generated_caption": cap,
            "gen_clip_similarity": sim
        })

pd.DataFrame(res_rows).to_csv(OUT_DIR/"synthesized_captions_rescored.csv", index=False)
print("Saved rescoring CSV.")


In [None]:

# Cell 12 — Visualizations: histogram, sample gallery, heatmap, TSNE
print("Creating visualizations (sampled).")

# Load synthesized results into memory (first N images for display)
with open(SYNTH_JSON, "r") as f:
    synth_all = [json.loads(l) for l in f]

# Histogram of top1 retrieval scores
max_scores = np.array([r['retrieved_scores'][0] for r in synth_all], dtype=np.float32)
plt.figure(figsize=(8,3))
plt.hist(max_scores, bins=40)
plt.title('Histogram of top-1 retrieval cosine similarity (image -> best prompt)')
plt.xlabel('cosine similarity')
plt.ylabel('count')
plt.show()

# Sample gallery with generated captions
sample_n = min(24, len(synth_all))
idxs = list(range(sample_n))
cols = 6
rows = math.ceil(len(idxs)/cols)
plt.figure(figsize=(cols*2.2, rows*2.2))
for i, j in enumerate(idxs):
    r = synth_all[j]
    try:
        img = Image.open(r['image_path']).convert('RGB').resize((224,224))
    except Exception:
        continue
    ax = plt.subplot(rows, cols, i+1)
    ax.imshow(img)
    ax.axis('off')
    title = (r['generated_caption'][:80] + '...') if len(r['generated_caption'])>80 else r['generated_caption']
    ax.set_title(title, fontsize=7)
plt.suptitle('Sample images with generated captions')
plt.tight_layout()
plt.show()

# Similarity heatmap for first 32 images vs prompts
subset = min(32, image_embeddings.shape[0])
sub_sims = np.matmul(image_embeddings[:subset], text_embeddings.T)
plt.figure(figsize=(10,6))
plt.imshow(sub_sims, aspect='auto')
plt.colorbar()
plt.title('Similarity heatmap (first 32 images x prompts)')
plt.xlabel('prompt index')
plt.ylabel('image index')
plt.show()

# TSNE of image embeddings colored by top-1 prompt index (sampled)
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
sample_n = min(200, image_embeddings.shape[0])
idx_sample = np.random.RandomState(0).choice(image_embeddings.shape[0], size=sample_n, replace=False)
emb_sub = image_embeddings[idx_sample]
pca = PCA(n_components=50, random_state=0).fit_transform(emb_sub)
tsne = TSNE(n_components=2, perplexity=30, random_state=0, init='pca', n_iter=800)
proj = tsne.fit_transform(pca)

# color by top1 idx
top1_idx = all_top_idx[idx_sample, 0]
plt.figure(figsize=(7,6))
for lab in np.unique(top1_idx):
    mask = top1_idx==lab
    pts = proj[mask]
    plt.scatter(pts[:,0], pts[:,1], s=10, label=str(lab))
plt.legend(title='top1 prompt idx', bbox_to_anchor=(1.05,1))
plt.title('TSNE of image embeddings colored by top-1 prompt index')
plt.show()


In [None]:

# Cell 13 — Save run metadata and finish
meta = {
    "date": time.asctime(),
    "breakhis_dir": str(BREAKHIS_DIR),
    "n_images": image_embeddings.shape[0],
    "n_prompts": len(prompt_corpus),
    "biomedclip_hf_id": CONFIG['BIOMEDCLIP_HF_ID'],
    "llm_id": CONFIG['LLM_ID'],
}
with open(OUT_DIR/"run_metadata.json", "w") as f:
    json.dump(meta, f, indent=2)

print("Pipeline completed. Outputs saved to:", OUT_DIR)
print("Files: image_embeddings.npy, text_embeddings.npy, retrieval_results.jsonl, synthesized_captions.jsonl, synthesized_captions_rescored.csv")

# End of pipeline
# -----------------------------------------------------------------------------
# Notes / next steps:
# - For production or large corpora, use FAISS HNSW/IVF indexes, store doc metadata, and consider vector DBs.
# - For better grounded generation, use a stronger biomedical LLM or OpenAI with prompt templates + chain-of-thought filtering.
# - Always validate generated text with domain experts before clinical use.
