<a href="https://colab.research.google.com/github/AkhilaNacham/MedICat/blob/main/Testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch torchvision pillow tqdm gradio




In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os, json, torch
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from transformers import AutoTokenizer, AutoModel
import gradio as gr

# --- DATA LAYER ---
IMAGE_DIR = "/content/drive/MyDrive/MediCat/figures"
ANNOT_PATH = "/content/drive/MyDrive/MediCat/subcaptions_public.jsonl"

In [None]:
def handle_request(query):
    """User Layer: Request Handler"""
    if not query.strip():
        return None, "‚ùå Empty query. Please enter a valid prompt."
    return query.lower(), None


In [None]:
MEDICAL_KEYWORDS = [
    # Imaging Modalities
    "mri", "ct", "x-ray", "xray", "ultrasound", "radiograph", "radiography", "fluoroscopy",
    "angiography", "pet", "spect", "echocardiogram", "ecg", "eeg", "emg", "endoscopy",
    "colonoscopy", "laparoscopy", "biopsy", "mammogram", "microscopy", "histopathology",
    "radiology", "tomography", "scan", "imaging", "radiation",

    # Organs / Body Parts
    "brain", "heart", "lung", "liver", "kidney", "pancreas", "spleen", "stomach", "intestine",
    "colon", "esophagus", "bladder", "uterus", "ovary", "testis", "spine", "spinal", "eye",
    "retina", "cornea", "ear", "nose", "throat", "thyroid", "bone", "joint", "muscle", "skin",
    "nerve", "blood", "vessel", "artery", "vein", "gland", "cartilage", "ligament", "pelvis",
    "hip", "knee", "shoulder", "hand", "wrist", "foot", "ankle", "neck", "chest", "abdomen",

    # Diseases / Conditions
    "tumor", "cancer", "carcinoma", "sarcoma", "melanoma", "leukemia", "lymphoma", "adenoma",
    "metastasis", "infection", "pneumonia", "tuberculosis", "asthma", "covid", "influenza",
    "diabetes", "hypertension", "stroke", "aneurysm", "infarction", "thrombosis", "embolism",
    "arthritis", "osteoporosis", "fracture", "scoliosis", "meningitis", "encephalitis",
    "hepatitis", "cirrhosis", "renal", "nephritis", "gastritis", "ulcer", "appendicitis",
    "colitis", "bronchitis", "dermatitis", "eczema", "psoriasis", "glaucoma", "cataract",
    "anemia", "obesity", "sepsis", "trauma", "injury", "lesion", "inflammation", "edema",
    "necrosis", "fibrosis", "degeneration", "atrophy", "infection", "tumour", "swelling",

    # Procedures / Treatments
    "surgery", "operation", "transplant", "resection", "therapy", "chemotherapy", "radiation",
    "dialysis", "immunotherapy", "gene therapy", "stem cell", "vaccination", "anesthesia",
    "biopsy", "stent", "catheter", "implant", "prosthesis", "rehabilitation", "screening",
    "monitoring", "examination", "diagnosis", "treatment", "therapy", "management",

    # General Medical / Research Terms
    "clinical", "pathology", "histology", "physiology", "anatomy", "genetic", "genome",
    "mutation", "biomarker", "epidemiology", "toxicology", "oncology", "cardiology",
    "neurology", "neuroscience", "gastroenterology", "urology", "orthopedic", "dermatology",
    "ophthalmology", "otolaryngology", "pulmonology", "endocrinology", "immunology",
    "hematology", "radiotherapy", "psychology", "psychiatry", "pharmacology", "vaccine",
    "virus", "bacteria", "fungus", "parasite", "microbiology", "virology", "bioinformatics",
    "biomedical", "medical", "clinical study", "symptom", "diagnostic", "scan report",
    "laboratory", "tissue", "cell", "organism", "disease", "syndrome", "disorder", "case study"
]

def is_medical_term(query):
    """Check if the user query likely relates to medical content."""
    return any(word in query.lower() for word in MEDICAL_KEYWORDS)

In [None]:
def load_jsonl(path):
    data = []
    with open(path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

annotations = load_jsonl(ANNOT_PATH)
print(f"‚úÖ Loaded {len(annotations)} dataset entries.")


‚úÖ Loaded 2118 dataset entries.


In [None]:
# üß† Full Fine-Tuning of CLIP on MedICaT (High-Accuracy Version)
!pip install open_clip_torch tqdm pillow --quiet

import torch, json, os, random
from PIL import Image
from tqdm import tqdm
import open_clip

# ---------------- Paths ----------------
IMAGE_DIR = "/content/drive/MyDrive/MediCat/figures"
ANNOT_PATH = "/content/drive/MyDrive/MediCat/subcaptions_public.jsonl"
SAVE_DIR  = "/content/drive/MyDrive/MediCat/finetuned_medclip_full"

# ---------------- Load dataset ----------------
def load_jsonl(path):
    data = []
    with open(path) as f:
        for line in f:
            data.append(json.loads(line))
    return data

anns = load_jsonl(ANNOT_PATH)
print(f"Loaded {len(anns)} annotations")

pairs = []
for ann in anns:
    fig_uri, pdf_hash = ann.get("fig_uri"), ann.get("pdf_hash")
    if not (fig_uri and pdf_hash):
        continue
    filename = f"{pdf_hash}_{fig_uri}"
    img_path = os.path.join(IMAGE_DIR, filename)
    if os.path.exists(img_path):
        caption = ann.get("text", "").strip()
        if caption:
            pairs.append((img_path, caption))

print(f"Prepared {len(pairs)} image‚Äìtext pairs for training")

# ---------------- Initialize CLIP ----------------
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()  # fine-tune both image & text towers

# ---------------- Training setup ----------------
batch_size = 16
epochs = 2
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6, weight_decay=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

def get_batch():
    """Yield random mini-batches"""
    batch = random.sample(pairs, batch_size)
    imgs = [preprocess(Image.open(p[0]).convert("RGB")) for p in batch]
    caps = tokenizer([p[1] for p in batch])
    return torch.stack(imgs).to(device), caps.to(device)

steps_per_epoch = len(pairs) // batch_size
print(f"Training for {epochs} epochs √ó {steps_per_epoch} steps / epoch")

# ---------------- Training loop ----------------
for epoch in range(epochs):
    running = 0.0
    for step in tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{epochs}"):
        imgs, caps = get_batch()
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            img_feat = model.encode_image(imgs)
            txt_feat = model.encode_text(caps)
            img_feat = img_feat / img_feat.norm(dim=1, keepdim=True)
            txt_feat = txt_feat / txt_feat.norm(dim=1, keepdim=True)
            logits_i = img_feat @ txt_feat.T
            logits_t = logits_i.T
            labels = torch.arange(len(imgs), device=device)
            loss_i = loss_fn(logits_i, labels)
            loss_t = loss_fn(logits_t, labels)
            loss = (loss_i + loss_t) / 2

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running += loss.item()

    print(f"Epoch {epoch+1} avg loss = {running/steps_per_epoch:.4f}")

# ---------------- Save model ----------------
os.makedirs(SAVE_DIR, exist_ok=True)
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "medclip_finetuned_full.pt"))
print(f"‚úÖ Saved fine-tuned CLIP model to: {SAVE_DIR}")


[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.5/1.5 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.8/44.8 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hLoaded 2118 annotations
Prepared 69 image‚Äìtext pairs for training


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler()


Training for 2 epochs √ó 4 steps / epoch


  with torch.cuda.amp.autocast():
Epoch 1/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:44<00:00, 11.19s/it]


Epoch 1 avg loss = 2.7146


Epoch 2/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:20<00:00,  5.19s/it]


Epoch 2 avg loss = 2.6301
‚úÖ Saved fine-tuned CLIP model to: /content/drive/MyDrive/MediCat/finetuned_medclip_full


In [None]:
# ---------- Enhanced CLIP-based indexing & retrieval (auto-detect fine-tuned model) ----------
from sentence_transformers import util
from PIL import Image
import os, torch
from tqdm import tqdm
import numpy as np
import open_clip

# --- Auto-detect which fine-tuned model exists ---
fast_model_path = "/content/drive/MyDrive/MediCat/finetuned_medclip/medclip_finetuned.pt"
full_model_path = "/content/drive/MyDrive/MediCat/finetuned_medclip_full/medclip_finetuned_full.pt"

if os.path.exists(full_model_path):
    model_path = full_model_path
elif os.path.exists(fast_model_path):
    model_path = fast_model_path
else:
    raise FileNotFoundError("‚ùå No fine-tuned CLIP model found. Please run the fine-tuning cell first.")

print("‚úÖ Loading fine-tuned CLIP model from:", model_path)
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')  # ‚úÖ Important: keep pretrained='openai'
model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
model.eval()
clip_model = model

print("‚úÖ CLIP model loaded and ready.\n")

# --- Helper: build subcaption text ---
def extract_subcaptions_from_ann(ann):
    tokens = ann.get("tokens", [])
    token_texts = [t.get("text", "") for t in tokens]
    subcaptions = {}
    sc_raw = ann.get("subcaptions", {}) or {}
    for label, idxs in sc_raw.items():
        try:
            words = [token_texts[i] for i in idxs if i < len(token_texts)]
            text = " ".join(words).replace(" .", ".").replace(" ,", ",").strip()
            subcaptions[label] = text
        except Exception:
            subcaptions[label] = ""
    return subcaptions

# --- Helper: get image embedding ---
def get_image_embedding_once(img_path):
    try:
        img = Image.open(img_path).convert("RGB")
        img_tensor = preprocess(img).unsqueeze(0)
        with torch.no_grad():
            emb = clip_model.encode_image(img_tensor)
        return emb
    except Exception as e:
        print("‚ö†Ô∏è Image load error:", img_path, e)
        return None

# --- Build improved subfigure index ---
image_index = []
fig_image_emb_cache = {}
INDEX_LIMIT = 400   # can increase later

print("‚öôÔ∏è Building CLIP subfigure index...")
for ann in tqdm(annotations[:INDEX_LIMIT], desc="Indexing Figures"):
    fig_uri = ann.get("fig_uri")
    pdf_hash = ann.get("pdf_hash")
    if not fig_uri or not pdf_hash:
        continue
    filename = f"{pdf_hash}_{fig_uri}"
    img_path = os.path.join(IMAGE_DIR, filename)
    if not os.path.exists(img_path):
        alt_path = os.path.splitext(img_path)[0] + ".jpg"
        if os.path.exists(alt_path):
            img_path = alt_path
        else:
            continue

    # cache whole-figure embedding
    if img_path in fig_image_emb_cache:
        fig_img_emb = fig_image_emb_cache[img_path]
    else:
        fig_img_emb = get_image_embedding_once(img_path)
        if fig_img_emb is None:
            continue
        fig_image_emb_cache[img_path] = fig_img_emb

    fig_caption = ann.get("text", "").strip()
    subcaptions = extract_subcaptions_from_ann(ann)
    subfigs = ann.get("subfigures", []) or []

    if subfigs:
        for s in subfigs:
            label = s.get("label", "")
            subcap_text = subcaptions.get(label, "").strip()
            if not subcap_text:
                subcap_text = (fig_caption.split(".")[0] + ".").strip() if fig_caption else ""
            combined_text = (subcap_text + " " + (fig_caption if fig_caption and fig_caption not in subcap_text else "")).strip()

            txt_emb = clip_model.encode_text(open_clip.get_tokenizer('ViT-B-32')([combined_text]))
            image_index.append({
                "id": ann.get("fig_key", ""),
                "fig_path": img_path,
                "subfig_label": label,
                "caption": fig_caption,
                "inline_ref": subcap_text,
                "combined_text": combined_text,
                "img_emb": fig_img_emb,
                "txt_emb": txt_emb
            })
    else:
        combined_text = fig_caption
        if not combined_text:
            continue
        txt_emb = clip_model.encode_text(open_clip.get_tokenizer('ViT-B-32')([combined_text]))
        image_index.append({
            "id": ann.get("fig_key", ""),
            "fig_path": img_path,
            "subfig_label": "",
            "caption": fig_caption,
            "inline_ref": fig_caption.split(".")[0] if fig_caption else "",
            "combined_text": combined_text,
            "img_emb": fig_img_emb,
            "txt_emb": txt_emb
        })

print(f"‚úÖ Indexed {len(image_index)} figure/subfigure entries (limit {INDEX_LIMIT}).")

# --- Retrieval: weighted text + image similarity ---
def search_images(query, top_k=5, text_weight=0.75, image_weight=0.25):
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    q_tokens = tokenizer([query])
    with torch.no_grad():
        q_emb = clip_model.encode_text(q_tokens)
    q_emb = q_emb / q_emb.norm(dim=1, keepdim=True)

    scores = []
    for entry in image_index:
        t_emb = entry["txt_emb"] / entry["txt_emb"].norm(dim=1, keepdim=True)
        i_emb = entry["img_emb"] / entry["img_emb"].norm(dim=1, keepdim=True)
        t_sim = torch.matmul(q_emb, t_emb.T).item()
        i_sim = torch.matmul(q_emb, i_emb.T).item()
        combined = text_weight * t_sim + image_weight * i_sim
        scores.append((combined, t_sim, i_sim, entry))

    top = sorted(scores, key=lambda x: x[0], reverse=True)[:top_k]
    results = []
    for combined, t_sim, i_sim, e in top:
        results.append({
            "path": e["fig_path"],
            "subfig_label": e["subfig_label"],
            "caption": e["caption"],
            "inline_ref": e["inline_ref"],
            "score_combined": round(float(combined), 4),
            "score_text": round(float(t_sim), 4),
            "score_image": round(float(i_sim), 4)
        })
    return results

print("üéØ Retrieval system ready (subfigure-based, fine-tuned CLIP).")

‚úÖ Loading fine-tuned CLIP model from: /content/drive/MyDrive/MediCat/finetuned_medclip_full/medclip_finetuned_full.pt




‚úÖ CLIP model loaded and ready.

‚öôÔ∏è Building CLIP subfigure index...


Indexing Figures: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 400/400 [00:12<00:00, 30.96it/s]

‚úÖ Indexed 37 figure/subfigure entries (limit 400).
üéØ Retrieval system ready (subfigure-based, fine-tuned CLIP).





In [None]:
from PIL import Image, ImageDraw, ImageFont

# üîß Create and cache a fallback image
FALLBACK_PATH = "/content/fallback_placeholder.png"
if not os.path.exists(FALLBACK_PATH):
    img = Image.new("RGB", (400, 300), color=(220, 220, 220))
    draw = ImageDraw.Draw(img)
    draw.text((100, 140), "No Image", fill=(0, 0, 0))
    img.save(FALLBACK_PATH)

def medicat_search(query):
    """
    Safely handles user query and returns valid image-caption pairs for Gradio.
    Always returns actual images ‚Äî never None.
    """
    query, error = handle_request(query)
    if error:
        return [(FALLBACK_PATH, error)]

    if not is_medical_term(query):
        return [(FALLBACK_PATH, "‚ö†Ô∏è Not a medical term or context. Try another query.")]

    try:
        query_emb = get_text_embedding(query)
        results = retrieve_references(query_emb)
    except Exception as e:
        return [(FALLBACK_PATH, f"‚ùå Retrieval error: {str(e)}")]

    output = []

    for res in results:
        # Handle tuple or dict
        if isinstance(res, tuple) and len(res) == 2:
            score, item = res
        elif isinstance(res, dict):
            score, item = res.get("score_combined", 0.0), res
        else:
            continue

        # Safe path extraction
        img_path = None
        for key in ("path", "fig_path", "image_path"):
            val = item.get(key)
            if val and os.path.exists(val):
                img_path = val
                break

        # Skip or fallback to placeholder
        if not img_path:
            img_path = FALLBACK_PATH

        caption = item.get("caption", "No caption available")
        inline_ref = item.get("inline_ref", "")
        similarity = round(float(score) * 100, 2) if isinstance(score, (float, int)) else 0.0

        caption_text = (
            f"ü©ª Caption: {caption}\n"
            f"üìñ Inline Ref: {inline_ref or 'None'}\n"
            f"üéØ Match Score: {similarity}%"
        )

        output.append((img_path, caption_text))

    # Always return at least one valid image
    if not output:
        output = [(FALLBACK_PATH, f"‚ö†Ô∏è No valid image matches found for query '{query}'.")]

    return output

In [None]:
iface = gr.Interface(
    fn=medicat_search,
    inputs=gr.Textbox(label="üîé Enter Medical Query", placeholder="e.g., MRI images related to brain tumor"),
    outputs=gr.Gallery(label="Retrieved Medical Images with Captions"),
    title="üß† MedICaT Image‚ÄìText Retrieval System",
    description="Architecture Flow: User Layer ‚Üí Application Layer ‚Üí Data Layer ‚Üí Output Layer"
)

iface.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://0e454484130b896e72.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://0e454484130b896e72.gradio.live


