In [None]:
from google.colab import files
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!pip install -q gdown

In [None]:
!kaggle datasets download -d debeshjha1/kvasirseg

In [None]:
!unzip -q kvasirseg.zip -d kvasir_raw

In [None]:
!mkdir -p polyp_data/images polyp_data/masks
!cp -r kvasir_raw/Kvasir-SEG//Kvasir-SEG/images/* polyp_data/images/
!cp -r kvasir_raw/Kvasir-SEG//Kvasir-SEG/masks/* polyp_data/masks/

In [None]:
import os

print("Images:", len(os.listdir("polyp_data/images")))
print("Masks:", len(os.listdir("polyp_data/masks")))

In [None]:
!pip install -q transformers==4.40.0 pillow tqdm matplotlib scikit-image kaggle

In [None]:
import os, random
import numpy as np
import torch
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import pandas as pd
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

In [None]:
from transformers import CLIPModel, CLIPProcessor

model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
print("Loaded:", model_name)

In [None]:
if not hasattr(model, "visual_projection"):
    raise RuntimeError("Strict CLIP (HF) requires model.visual_projection to exist.")

if not hasattr(model, "text_projection"):
    raise RuntimeError("Strict CLIP (HF) requires model.text_projection to exist.")

if not hasattr(model, "logit_scale"):
    raise RuntimeError("Strict CLIP requires logit_scale parameter.")
model.logit_scale = torch.nn.Parameter(model.logit_scale)

print("Strict HuggingFace CLIP enforced successfully.")

In [None]:
print(model)

In [None]:
images_dir = Path("polyp_data/images")
masks_dir = Path("polyp_data/masks")
img_files = sorted([p for p in images_dir.glob("*") if p.suffix.lower() in [".jpg",".png"]])
mask_files = sorted([p for p in masks_dir.glob("*") if p.suffix.lower() in [".jpg",".png"]])
mask_map = {m.name: m for m in mask_files}
pairs = [(img, mask_map.get(img.name)) for img in img_files if img.name in mask_map]
pairs = [(i,m) for i,m in pairs if m is not None]
print("Total matched pairs:", len(pairs))

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from PIL import Image
import numpy as np
import random
import matplotlib.pyplot as plt

class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=8):
        super().__init__()
        self.clip_model = clip_model
        self.classnames = classnames
        d_text = clip_model.text_model.config.hidden_size
        self.n_ctx = n_ctx
        self.ctx = nn.Parameter(torch.randn(n_ctx, d_text) * 0.02)

    def forward(self):
        embs = []
        for cname in self.classnames:
            tok = processor.tokenizer(cname, return_tensors="pt").to(device)
            with torch.no_grad():
                name_embs = self.clip_model.text_model.embeddings(tok.input_ids).mean(dim=1)  # (1, D)
            ctx_mean = self.ctx.mean(dim=0, keepdim=True)  # (1, D)
            embs.append((ctx_mean + name_embs).squeeze(0))
        return torch.stack(embs, dim=0)  # (n_classes, D)

class PrototypeRefiner(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.lin = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, patch_tokens, mask_small):
        msum = mask_small.sum().clamp(min=1e-6)
        proto = (patch_tokens * mask_small.unsqueeze(-1)).sum(dim=0) / msum  # (D,)
        q = self.lin(proto).unsqueeze(0)  # (1,D)
        logits = (patch_tokens @ q.t()).squeeze(-1)  # (N,)
        w = torch.softmax(logits * mask_small, dim=0)
        refined = (w.unsqueeze(-1) * patch_tokens).sum(dim=0)
        proto = self.norm(proto + refined)
        return proto

class CrossAttentionFusion(nn.Module):
    def __init__(self, dim, n_heads=8):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, query_tokens, support_protos):
        B, N, D = query_tokens.shape
        if support_protos.dim() == 2:
            key = support_protos.unsqueeze(0).repeat(B, 1, 1)  # (B, K, D)
        else:
            key = support_protos  # assume (B, K, D)

        attn_out, _ = self.mha(query_tokens, key, key)  # query, key, value
        x = self.norm1(query_tokens + attn_out)
        x = self.norm2(x + self.ff(x))
        return x

class UNetDecoderFromPatches(nn.Module):
    def __init__(self, in_dim, mid_ch=256):
        super().__init__()
        self.proj1 = nn.Linear(in_dim, mid_ch)
        self.proj2 = nn.Linear(in_dim, mid_ch)
        self.proj3 = nn.Linear(in_dim, mid_ch)

        self.conv1 = nn.Sequential(
            nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1),
            nn.ReLU()
        )

        self.final = nn.Conv2d(mid_ch, 1, 1)

    def forward(self, feats, out_size):
        B, N, D = feats[0].shape
        sz = int(N ** 0.5)

        f1 = self.proj1(feats[0])  # coarsest
        f1 = rearrange(f1, 'b (h w) c -> b c h w', h=sz)
        f1 = F.interpolate(f1, scale_factor=4, mode='bilinear', align_corners=False)

        f2 = self.proj2(feats[1])  # mid
        f2 = rearrange(f2, 'b (h w) c -> b c h w', h=sz)
        f2 = F.interpolate(f2, scale_factor=2, mode='bilinear', align_corners=False)

        f3 = self.proj3(feats[2])  # finest
        f3 = rearrange(f3, 'b (h w) c -> b c h w', h=sz)

        p1 = self.conv1(f3)  # base fine feature map
        p1_up = F.interpolate(p1, size=f2.shape[2:], mode='bilinear', align_corners=False)
        p2 = self.conv2(torch.cat([p1_up, f2], dim=1))

        p2_up = F.interpolate(p2, size=f1.shape[2:], mode='bilinear', align_corners=False)
        p3 = self.conv3(torch.cat([p2_up, f1], dim=1))

        out = self.final(p3)  # (B,1,H',W')
        out = F.interpolate(out, size=out_size, mode='bilinear', align_corners=False)
        return out

In [None]:
classnames = ["polyp"]
prompt_learner = PromptLearner(model, classnames).to(device)
refiner = PrototypeRefiner(dim=model.vision_model.config.hidden_size).to(device)
fusion = CrossAttentionFusion(dim=model.vision_model.config.hidden_size, n_heads=8).to(device)
decoder = UNetDecoderFromPatches(in_dim=model.vision_model.config.hidden_size, mid_ch=256).to(device)

for p in model.parameters():
    p.requires_grad = False

trainable_params = list(prompt_learner.parameters()) + list(refiner.parameters()) + list(fusion.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(trainable_params, lr=3e-4)

def load_image_for_clip(path):
    pil = Image.open(path).convert("RGB")
    orig_size = pil.size[::-1]  # (H,W)
    pix = processor(images=pil, return_tensors="pt")["pixel_values"].to(device)
    return pix, orig_size, pil

def load_mask_orig(path):
    m = Image.open(path).convert("L")
    arr = (np.array(m) > 127).astype(np.float32)
    t = torch.from_numpy(arr).to(device)  # (H,W)
    return t

def sample_episode(k=3):
    idxs = random.sample(range(len(pairs)), k+1)
    q_idx = idxs[0]
    sup_idxs = idxs[1:]
    q_img_path, q_mask_path = pairs[q_idx]
    supports = [pairs[i] for i in sup_idxs]
    print(f"Episode generated: Query = {q_img_path.name}, Support count = {len(supports)}")
    return q_img_path, q_mask_path, supports

In [None]:
def get_multilayer_patch_tokens(pixel_values):
    """
    Forward CLIP vision_model asking for hidden states.
    Returns a list of three layers of patch tokens (coarse->fine) each (B, N, D)
    """
    outputs = model.vision_model(pixel_values=pixel_values, output_hidden_states=True)
    hidden = outputs.hidden_states  # tuple length L+1
    idxs = [len(hidden)//3, 2*len(hidden)//3, len(hidden)-1]
    feats = []
    for i in idxs:
        h = hidden[i]
        h = h[:, 1:, :].contiguous()
        feats.append(h)
    return feats

def compute_support_proto(sup_pixel_values, sup_mask_tensor):
    """
    sup_pixel_values: CLIP-preprocessed pixel_values (1,3,224,224)
    sup_mask_tensor: original mask (H_orig, W_orig) tensor (on device)
    """
    feats = get_multilayer_patch_tokens(sup_pixel_values)  # list
    finest = feats[-1].squeeze(0)  # (N, D)
    N = finest.shape[0]
    sz = int(N ** 0.5)

    mask_small = F.interpolate(sup_mask_tensor.unsqueeze(0).unsqueeze(0),
                               size=(sz, sz), mode='nearest').view(-1)  # (N,)
    mask_small = mask_small.to(finest.dtype)

    proto = refiner(finest, mask_small)  # (D,)
    return proto, feats

In [None]:
def forward_episode(query_img_path, query_mask_path, supports):
    q_pix, q_size, q_pil = load_image_for_clip(query_img_path)
    q_mask_orig = load_mask_orig(query_mask_path)  # (H_orig, W_orig)

    q_feats = get_multilayer_patch_tokens(q_pix)  # list of 3 (B=1, N, D)
    q_patch_tokens = q_feats[-1]  # finest (1, N, D)

    protos = []
    for s_img_path, s_mask_path in supports:
        s_pix, s_size, _ = load_image_for_clip(s_img_path)
        s_mask_orig = load_mask_orig(s_mask_path)
        proto, _ = compute_support_proto(s_pix, s_mask_orig)
        protos.append(proto)
    protos = torch.stack(protos, dim=0)  # (K, D)
    proto_mean = protos.mean(dim=0)  # (D,)

    text_embs = prompt_learner()  # (n_class, D)

    fused = fusion(q_patch_tokens, proto_mean.unsqueeze(0))  # (1, N, D)

    decoder_feats = [q_feats[0], q_feats[1], fused]
    out_size = q_mask_orig.shape  # (H_orig, W_orig)

    mask_logits = decoder(decoder_feats, out_size=out_size)  # (1,1,H,W)
    return mask_logits, q_mask_orig, q_pil

In [None]:
num_epochs = 10
episodes_per_epoch = 30

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for it in range(episodes_per_epoch):
        q_img, q_mask, supports = sample_episode(k=3)
        mask_logits, gt_mask, _ = forward_episode(q_img, q_mask, supports)
        gt = gt_mask.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

        # BCE + Dice
        loss_bce = F.binary_cross_entropy_with_logits(mask_logits, gt)
        pred_prob = torch.sigmoid(mask_logits)
        intersection = (pred_prob * gt).sum()
        loss_dice = 1 - (2*intersection + 1) / (pred_prob.sum() + gt.sum() + 1)
        loss = loss_bce + loss_dice

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} | Avg Loss: {total_loss/episodes_per_epoch:.4f}")

In [None]:
q_img, q_mask, supports = sample_episode(k=3)
mask_logits, gt_mask, q_pil = forward_episode(q_img, q_mask, supports)
pred = (torch.sigmoid(mask_logits) > 0.5).cpu().numpy().squeeze()

plt.figure(figsize=(14,6))
plt.subplot(1,3,1)
plt.imshow(q_pil); plt.title("Query Image"); plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(gt_mask.cpu(), cmap='gray'); plt.title("GT Mask"); plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(pred, cmap='gray'); plt.title("Pred Mask"); plt.axis('off')
plt.show()

In [None]:
# Show 5 random prediction samples after training
num_samples_to_show = 5
plt.figure(figsize=(15, num_samples_to_show * 3))

for idx in range(num_samples_to_show):
    q_img, q_mask, supports = sample_episode(k=3)
    mask_logits, gt_mask, q_pil = forward_episode(q_img, q_mask, supports)
    pred = (torch.sigmoid(mask_logits) > 0.5).cpu().numpy().squeeze()

    # Query Image
    plt.subplot(num_samples_to_show, 3, idx*3 + 1)
    plt.imshow(q_pil)
    plt.title("Query Image")
    plt.axis('off')

    # Ground Truth Mask
    plt.subplot(num_samples_to_show, 3, idx*3 + 2)
    plt.imshow(gt_mask.cpu(), cmap='gray')
    plt.title("GT Mask")
    plt.axis('off')

    # Predicted Mask
    plt.subplot(num_samples_to_show, 3, idx*3 + 3)
    plt.imshow(pred, cmap='gray')
    plt.title("Pred Mask")
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

def dice_coeff(pred, target):
    intersection = (pred * target).sum()
    return (2. * intersection + 1e-6) / (pred.sum() + target.sum() + 1e-6)

def iou_score(pred, target):
    intersection = (pred & target).sum()
    union = (pred | target).sum()
    return (intersection + 1e-6) / (union + 1e-6)

# Evaluate on random episodes
num_eval_samples = 20
accs, dices, ious, precisions, recalls, f1s = [], [], [], [], [], []

for _ in range(num_eval_samples):
    q_img, q_mask, supports = sample_episode(k=3)
    mask_logits, gt_mask, _ = forward_episode(q_img, q_mask, supports)
    pred_bin = (torch.sigmoid(mask_logits) > 0.5).cpu().numpy().squeeze().astype(np.uint8)
    gt_bin = gt_mask.cpu().numpy().astype(np.uint8)

    pred_flat = pred_bin.flatten()
    gt_flat = gt_bin.flatten()

    accs.append((pred_bin == gt_bin).mean())
    dices.append(dice_coeff(pred_bin, gt_bin))
    ious.append(iou_score(pred_bin, gt_bin))
    precisions.append(precision_score(gt_flat, pred_flat, zero_division=0))
    recalls.append(recall_score(gt_flat, pred_flat, zero_division=0))
    f1s.append(f1_score(gt_flat, pred_flat, zero_division=0))

print(f"Evaluation over {num_eval_samples} episodes:")
print(f"Accuracy : {np.mean(accs):.4f}")
print(f"Dice     : {np.mean(dices):.4f}")
print(f"IoU      : {np.mean(ious):.4f}")
print(f"Precision: {np.mean(precisions):.4f}")
print(f"Recall   : {np.mean(recalls):.4f}")
print(f"F1-score : {np.mean(f1s):.4f}")

In [None]:
import numpy as np
import torch
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_masks(pred_mask, true_mask, threshold=0.5, eps=1e-6):
    """
    pred_mask: torch.Tensor or np.ndarray (H, W) or (1, H, W) or (1, 1, H, W)
    true_mask: torch.Tensor or np.ndarray same shape as pred_mask
    threshold: probability threshold for binarizing
    eps: small value to avoid division by zero
    """
    # Convert to numpy arrays
    if torch.is_tensor(pred_mask):
        pred_mask = pred_mask.detach().cpu().numpy()
    if torch.is_tensor(true_mask):
        true_mask = true_mask.detach().cpu().numpy()

    # Remove extra dims
    pred_mask = np.squeeze(pred_mask)
    true_mask = np.squeeze(true_mask)

    # Ensure float
    pred_mask = pred_mask.astype(np.float32)
    true_mask = true_mask.astype(np.float32)

    # Soft Dice (no thresholding)
    pred_flat_f = pred_mask.flatten()
    true_flat_f = true_mask.flatten()
    intersection_f = np.sum(pred_flat_f * true_flat_f)
    soft_dice = (2. * intersection_f + eps) / (np.sum(pred_flat_f) + np.sum(true_flat_f) + eps)

    # Binary masks
    pred_bin = (pred_mask > threshold).astype(np.uint8)
    true_bin = (true_mask > 0.5).astype(np.uint8)

    pred_flat = pred_bin.flatten()
    true_flat = true_bin.flatten()

    # Binary Dice
    intersection_b = np.sum(pred_flat * true_flat)
    bin_dice = (2. * intersection_b + eps) / (np.sum(pred_flat) + np.sum(true_flat) + eps)

    # IoU
    union = np.sum(pred_flat) + np.sum(true_flat) - intersection_b
    iou = (intersection_b + eps) / (union + eps)

    # Accuracy
    acc = np.mean(pred_bin == true_bin)

    # Precision, Recall, F1
    precision = precision_score(true_flat, pred_flat, zero_division=0)
    recall = recall_score(true_flat, pred_flat, zero_division=0)
    f1 = f1_score(true_flat, pred_flat, zero_division=0)

    return {
        "Accuracy": acc,
        "Soft Dice": soft_dice,
        "Binary Dice": bin_dice,
        "IoU": iou,
        "Precision": precision,
        "Recall": recall,
        "F1-score": f1
    }

In [None]:
num_eval_samples = 20
metrics_accum = {k: [] for k in evaluate_masks(torch.zeros(1,1,10,10), torch.zeros(1,1,10,10)).keys()}

for _ in range(num_eval_samples):
    q_img, q_mask, supports = sample_episode(k=3)
    mask_logits, gt_mask, _ = forward_episode(q_img, q_mask, supports)
    pred_probs = torch.sigmoid(mask_logits)

    res = evaluate_masks(pred_probs, gt_mask.unsqueeze(0))
    for k, v in res.items():
        metrics_accum[k].append(v)

print(f"Evaluation over {num_eval_samples} episodes:")
for k, v in metrics_accum.items():
    print(f"{k:12}: {np.mean(v):.4f}")