In [None]:
from google.colab import files
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

#!pip install -q gdown

!kaggle datasets download -d debeshjha1/kvasircapsuleseg

!unzip -q kvasircapsuleseg.zip -d kvasir_raw

!mkdir -p polyp_data/images polyp_data/masks
!cp -r kvasir_raw/Kvasir-Capsule/images/* polyp_data/images/
!cp -r kvasir_raw/Kvasir-Capsule/masks/* polyp_data/masks/

import os

print("Images:", len(os.listdir("polyp_data/images")))
print("Masks:", len(os.listdir("polyp_data/masks")))

#!pip install -q transformers==4.40.0 pillow tqdm matplotlib scikit-image kaggle

import os, random
import numpy as np
import torch
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import pandas as pd
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

#!pip install git+https://github.com/openai/CLIP.git

#!pip install -q albumentations==1.4.6

import os
import random
import math
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import trange, tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from transformers import CLIPModel, CLIPProcessor
from einops import rearrange
import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 256
CLIP_IMG_SIZE = 224
SUPPORT_SHOTS = 3

In [None]:
class SoftDiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2 * (probs * targets).sum(dim=(1,2,3)) + self.eps
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1 - num / den).mean()

class FocalLossBinary(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha, self.gamma, self.reduction = alpha, gamma, reduction
    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        p = torch.sigmoid(logits)
        pt = targets * p + (1 - targets) * (1 - p)
        loss = (self.alpha * (1 - pt) ** self.gamma) * bce
        return loss.mean() if self.reduction=="mean" else (loss.sum() if self.reduction=="sum" else loss)

soft_dice_loss = SoftDiceLoss()
focal_loss = FocalLossBinary()

def improved_seg_loss(logits, gt):
    return (0.6 * soft_dice_loss(logits, gt)
          + 0.2 * F.binary_cross_entropy_with_logits(logits, gt)
          + 0.2 * focal_loss(logits, gt))

In [None]:
def load_image_for_clip(path, processor):
    pil = Image.open(path).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
    orig_size = pil.size[::-1]
    pix = processor(images=pil, return_tensors="pt")["pixel_values"].to(DEVICE)
    return pix, orig_size, pil

def load_mask_orig(path):
    m = Image.open(path).convert("L").resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
    arr = (np.array(m) > 127).astype(np.float32)
    return torch.from_numpy(arr).to(DEVICE)

def load_image_for_clip_from_pil(pil, processor):
    return processor(images=pil, return_tensors="pt")["pixel_values"].to(DEVICE)

def pil_mask_to_tensor01(pil_mask):
    arr = (np.array(pil_mask.convert("L")) > 127).astype(np.float32)
    return torch.from_numpy(arr).to(DEVICE)

In [None]:
class PairedImageDataset:
    def __init__(self, image_dir, mask_dir):
        img_files = sorted([p for p in Path(image_dir).glob("*") if p.suffix.lower() in [".jpg",".png"]])
        mask_files = sorted([p for p in Path(mask_dir).glob("*") if p.suffix.lower() in [".jpg",".png"]])
        mask_map = {m.name: m for m in mask_files}
        self.pairs = [(i, mask_map.get(i.name)) for i in img_files if i.name in mask_map]
        self.pairs = [(i,m) for i,m in self.pairs if m is not None]
        print("Total matched pairs:", len(self.pairs))
    def sample_episode(self, k=3):
        idxs = random.sample(range(len(self.pairs)), k+1)
        q_idx, sup_idxs = idxs[0], idxs[1:]
        q_img_path, q_mask_path = self.pairs[q_idx]
        supports = [self.pairs[i] for i in sup_idxs]
        return q_img_path, q_mask_path, supports

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, p_drop=0.1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p_drop)
        )
    def forward(self, x):
        return self.block(x)

class UNetProto(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base_ch*4, base_ch*8)
        self.up2 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, stride=2)
        self.dec2 = ConvBlock(base_ch*4, base_ch*4)
        self.up1 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, stride=2)
        self.dec1 = ConvBlock(base_ch*2, base_ch*2)
        self.final_up = nn.ConvTranspose2d(base_ch*2, base_ch, 2, stride=2)
        self.final = nn.Conv2d(base_ch, 1, 1)

    def forward_encoder(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        return {
            'bottleneck': b,
            'e3': e3,
            'e2': e2,
            'e1': e1,
        }

    def forward_decoder(self, feat):
        x = self.up2(feat)
        x = self.dec2(x)
        x = self.up1(x)
        x = self.dec1(x)
        x = self.final_up(x)
        out = self.final(x)
        return out

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, ctx_len=8):
        super().__init__()
        self.clip = clip_model
        self.classnames = classnames
        self.ctx_len = ctx_len
        self.ctx = nn.Parameter(torch.randn(ctx_len, clip_model.text_model.config.hidden_size))

        self.text_proj = nn.Linear(
            clip_model.text_model.config.hidden_size,
            clip_model.vision_model.config.hidden_size
        )

    def forward(self, processor):
        tokens = [f"a photo of a {c}" for c in self.classnames]
        inputs = processor(text=tokens, padding=True, return_tensors="pt").to(DEVICE)

        outputs = self.clip.text_model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            output_hidden_states=True
        )
        txt_embeds = outputs.hidden_states[-1]

        ctx_exp = self.ctx.unsqueeze(0).expand(txt_embeds.size(0), -1, -1)
        txt_full = torch.cat([ctx_exp, txt_embeds], dim=1)
        txt_proj = self.text_proj(txt_full)
        return txt_proj

class PrototypeRefiner(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.lin = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, patch_tokens, mask_small):
        msum = mask_small.sum().clamp(min=1e-6)
        proto = (patch_tokens * mask_small.unsqueeze(-1)).sum(dim=0) / msum
        q = self.lin(proto).unsqueeze(0)
        logits = (patch_tokens @ q.t()).squeeze(-1)
        w = torch.softmax(logits * mask_small, dim=0)
        refined = (w.unsqueeze(-1) * patch_tokens).sum(dim=0)
        proto = self.norm(proto + refined)
        return proto

class CrossAttentionFusion(nn.Module):
    def __init__(self, dim, n_heads=8):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, n_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, query_tokens, support_protos):
        B, N, D = query_tokens.shape
        if support_protos.dim() == 2:
            key = support_protos.unsqueeze(0).repeat(B, 1, 1)
        else:
            key = support_protos

        attn_out, _ = self.mha(query_tokens, key, key)
        x = self.norm1(query_tokens + attn_out)
        x = self.norm2(x + self.ff(x))
        return x

class UNetDecoderFromPatches(nn.Module):
    def __init__(self, in_dim, mid_ch=256):
        super().__init__()
        self.proj1 = nn.Linear(in_dim, mid_ch)
        self.proj2 = nn.Linear(in_dim, mid_ch)
        self.proj3 = nn.Linear(in_dim, mid_ch)

        self.conv1 = nn.Sequential(
            nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(mid_ch, mid_ch, 3, padding=1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(mid_ch*2, mid_ch, 3, padding=1),
            nn.ReLU()
        )

        self.final = nn.Conv2d(mid_ch, 1, 1)

    def forward(self, feats, out_size):
        B, N, D = feats[0].shape
        sz = int(N ** 0.5)

        f1 = self.proj1(feats[0])
        f1 = rearrange(f1, 'b (h w) c -> b c h w', h=sz)
        f1 = F.interpolate(f1, scale_factor=4, mode='bilinear', align_corners=False)

        f2 = self.proj2(feats[1])
        f2 = rearrange(f2, 'b (h w) c -> b c h w', h=sz)
        f2 = F.interpolate(f2, scale_factor=2, mode='bilinear', align_corners=False)

        f3 = self.proj3(feats[2])
        f3 = rearrange(f3, 'b (h w) c -> b c h w', h=sz)

        p1 = self.conv1(f3)
        p1_up = F.interpolate(p1, size=f2.shape[2:], mode='bilinear', align_corners=False)
        p2 = self.conv2(torch.cat([p1_up, f2], dim=1))

        p2_up = F.interpolate(p2, size=f1.shape[2:], mode='bilinear', align_corners=False)
        p3 = self.conv3(torch.cat([p2_up, f1], dim=1))

        out = self.final(p3)
        out = F.interpolate(out, size=out_size, mode='bilinear', align_corners=False)
        return out

In [None]:
class IntegratedFewShotModel(nn.Module):
    def __init__(self, clip_model, processor, classnames=["polyp"], base_ch=32):
        super().__init__()
        self.unet = UNetProto(in_ch=3, base_ch=base_ch).to(DEVICE)

        self.clip = clip_model.to(DEVICE)
        self.processor = processor

        self.prompt_learner = PromptLearner(self.clip, classnames).to(DEVICE)

        self.refiner = PrototypeRefiner(dim=self.clip.vision_model.config.hidden_size).to(DEVICE)
        self.fusion = CrossAttentionFusion(dim=self.clip.vision_model.config.hidden_size, n_heads=8).to(DEVICE)
        self.clip_decoder = UNetDecoderFromPatches(in_dim=self.clip.vision_model.config.hidden_size, mid_ch=256).to(DEVICE)

        self.text_proj = nn.Linear(self.clip.text_model.config.hidden_size,
                           self.clip.vision_model.config.hidden_size).to(DEVICE)

        self.unet_to_clip = nn.Linear(8*base_ch, self.clip.vision_model.config.hidden_size)

        for p in self.clip.parameters():
            p.requires_grad = False

    def compute_clilp_patch_tokens(self, pixel_values):
        outputs = self.clip.vision_model(pixel_values=pixel_values, output_hidden_states=True)
        hidden = outputs.hidden_states
        idxs = [len(hidden)//3, 2*len(hidden)//3, len(hidden)-1]
        feats = []
        for i in idxs:
            h = hidden[i]
            h = h[:, 1:, :].contiguous()
            feats.append(h)
        return feats

    def compute_support_proto(self, sup_pixel_values, sup_mask_tensor):
        feats = self.compute_clilp_patch_tokens(sup_pixel_values)
        finest = feats[-1].squeeze(0)
        N = finest.shape[0]
        sz = int(N ** 0.5)

        mask_small = F.interpolate(sup_mask_tensor.unsqueeze(0).unsqueeze(0),
                                   size=(sz, sz), mode='nearest').view(-1)
        mask_small = mask_small.to(finest.dtype)

        proto = self.refiner(finest, mask_small)
        return proto, feats

    def forward_episode(self, q_img_path, q_mask_path, supports):
        q_pix, q_size, q_pil = load_image_for_clip(q_img_path, self.processor)
        q_mask_orig = load_mask_orig(q_mask_path)
        q_feats = self.compute_clilp_patch_tokens(q_pix)
        q_patch_tokens = q_feats[-1]

        protos = []
        for s_img_path, s_mask_path in supports:
            s_pix, s_size, _ = load_image_for_clip(s_img_path, self.processor)
            s_mask_orig = load_mask_orig(s_mask_path)
            proto, _ = self.compute_support_proto(s_pix, s_mask_orig)
            protos.append(proto)
        protos = torch.stack(protos, dim=0)
        proto_mean = protos.mean(dim=0)
        proto_tokens = proto_mean.unsqueeze(0).unsqueeze(1)

        text_tokens = self.prompt_learner(self.processor)
        text_tokens = text_tokens.mean(dim=0, keepdim=True)

        all_proto = torch.cat([proto_tokens, text_tokens], dim=1)
        fused = self.fusion(q_patch_tokens, all_proto)

        decoder_feats = [q_feats[0], q_feats[1], fused]
        mask_logits_clip = self.clip_decoder(decoder_feats, out_size=q_mask_orig.shape)

        q_pil_unet = Image.open(q_img_path).convert('RGB').resize((IMG_SIZE, IMG_SIZE))
        q_tensor_unet = torch.from_numpy(np.array(q_pil_unet).transpose(2,0,1)).float().unsqueeze(0) / 255.0
        q_tensor_unet = q_tensor_unet.to(DEVICE)

        unet_feats = self.unet.forward_encoder(q_tensor_unet)
        unet_bottleneck = unet_feats['bottleneck']

        B, C, Hb, Wb = unet_bottleneck.shape
        unet_vec = unet_bottleneck.view(B, C, -1).mean(dim=-1)
        clip_aligned = self.unet_to_clip(unet_vec)

        fused_proto = fused.mean(dim=1)
        proto = fused_proto.view(1, -1)

        sim = F.cosine_similarity(clip_aligned.unsqueeze(-1), proto.unsqueeze(-1), dim=1)
        sim_map = sim.view(B,1,1,1).expand(B,1,Hb,Wb)

        fused_bottleneck = unet_bottleneck * sim_map
        mask_logits_unet = self.unet.forward_decoder(fused_bottleneck)
        mask_logits_unet = F.interpolate(mask_logits_unet, size=q_mask_orig.shape, mode='bilinear', align_corners=False)

        if mask_logits_clip.shape != mask_logits_unet.shape:
            mask_logits_unet = F.interpolate(mask_logits_unet, size=mask_logits_clip.shape[2:], mode='bilinear', align_corners=False)

        mask_logits = 0.5 * mask_logits_clip + 0.5 * mask_logits_unet

        return mask_logits, q_mask_orig, q_pil

In [None]:
def combined_bce_dice_loss(logits, gt):
    loss_bce = F.binary_cross_entropy_with_logits(logits, gt)
    pred = torch.sigmoid(logits)
    intersection = (pred * gt).sum()
    loss_dice = 1 - (2 * intersection + 1) / (pred.sum() + gt.sum() + 1)
    return loss_bce + loss_dice

def build_and_train(num_epochs=10, episodes_per_epoch=30, lr=3e-4, wd=1e-4):
    clip_name = "openai/clip-vit-base-patch32"
    clip_model = CLIPModel.from_pretrained(clip_name).to(DEVICE)
    processor = CLIPProcessor.from_pretrained(clip_name)
    print('Loaded CLIP', clip_name)

    integrated = IntegratedFewShotModel(clip_model, processor).to(DEVICE)

    trainable = list(integrated.prompt_learner.parameters()) + \
                list(integrated.refiner.parameters()) + \
                list(integrated.fusion.parameters()) + \
                list(integrated.clip_decoder.parameters()) + \
                list(integrated.unet_to_clip.parameters()) + \
                list(integrated.unet.parameters())

    optimizer = optim.AdamW(trainable, lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=='cuda'))

    dataset = PairedImageDataset('polyp_data/images', 'polyp_data/masks')

    for epoch in trange(num_epochs, desc="Training Epochs"):
        integrated.train()
        total_loss = 0.0
        for _ in trange(episodes_per_epoch, leave=False, desc=f"Epoch {epoch+1}"):
            q_img_path, q_mask_path, supports = dataset.sample_episode(k=SUPPORT_SHOTS)

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
                mask_logits, gt_mask, _ = integrated.forward_episode(
                    q_img_path, q_mask_path, supports
                )
                gt = gt_mask.unsqueeze(0).unsqueeze(0).to(DEVICE)
                loss = improved_seg_loss(mask_logits, gt)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(trainable, max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()

        scheduler.step(epoch)
        print(f"Epoch {epoch+1}/{num_epochs} | Avg Loss: {total_loss/episodes_per_epoch:.4f}")

    return integrated, dataset

In [None]:
def dice_score(pred, target, eps=1e-6):
    pred = pred.float()
    target = target.float()
    inter = (pred * target).sum()
    return (2 * inter + eps) / (pred.sum() + target.sum() + eps)

def iou_score(pred, target, eps=1e-6):
    pred = pred.float()
    target = target.float()
    inter = (pred * target).sum()
    union = pred.sum() + target.sum() - inter
    return (inter + eps) / (union + eps)

def visualize_inference(model_inst, dataset_inst, n=3, thr=0.5):
    model_inst.eval()
    dice_scores, iou_scores = [], []

    plt.figure(figsize=(12, 4 * n))
    for i in range(n):
        q_img_path, q_mask_path, supports = dataset_inst.sample_episode(k=SUPPORT_SHOTS)

        with torch.no_grad(), torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
            logits, gt_mask, q_pil = model_inst.forward_episode(
                q_img_path, q_mask_path, supports
            )
            probs = torch.sigmoid(logits).squeeze().float().cpu()
            pp = postprocess_prob(probs, thr=thr, min_size=64)
            pred = pp.float()

        gt = gt_mask.unsqueeze(0).unsqueeze(0).cpu()
        dice = dice_score(pred, gt)
        iou = iou_score(pred, gt)
        dice_scores.append(dice.item()); iou_scores.append(iou.item())

        plt.subplot(n, 3, i * 3 + 1); plt.imshow(q_pil); plt.title("Query"); plt.axis("off")
        plt.subplot(n, 3, i * 3 + 2); plt.imshow(gt.squeeze(), cmap="gray"); plt.title("GT"); plt.axis("off")
        plt.subplot(n, 3, i * 3 + 3); plt.imshow(pred.squeeze(), cmap="gray")
        plt.title(f"Pred\nDice:{dice:.3f}, IoU:{iou:.3f}"); plt.axis("off")

    plt.tight_layout(); plt.show()
    print(f"Mean Dice: {np.mean(dice_scores):.4f} | Mean IoU: {np.mean(iou_scores):.4f}")
    return dice_scores, iou_scores

def evaluate_model(model_inst, dataset_inst, num_samples=50, thr=0.5):
    model_inst.eval()
    dice_scores, iou_scores = [], []
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
        for _ in range(num_samples):
            q_img_path, q_mask_path, supports = dataset_inst.sample_episode(k=SUPPORT_SHOTS)
            logits, gt_mask, _ = model_inst.forward_episode(
                q_img_path, q_mask_path, supports
            )
            probs = torch.sigmoid(logits).squeeze().float().cpu()
            pred = postprocess_prob(probs, thr=thr, min_size=64).float()
            gt = gt_mask.unsqueeze(0).unsqueeze(0).cpu()

            dice = dice_score(pred, gt); iou = iou_score(pred, gt)
            dice_scores.append(dice.item()); iou_scores.append(iou.item())

    print(f"Evaluation on {num_samples} episodes:")
    print(f"Mean Dice: {np.mean(dice_scores):.4f} ± {np.std(dice_scores):.4f}")
    print(f"Mean IoU: {np.mean(iou_scores):.4f} ± {np.std(iou_scores):.4f}")
    return dice_scores, iou_scores

In [None]:
def find_global_threshold(model_inst, dataset_inst, episodes=30, thr_grid=None):
    if thr_grid is None:
        thr_grid = np.linspace(0.35, 0.65, 7)
    best_thr, best_dice = 0.5, -1
    for thr in thr_grid:
        ds, _ = evaluate_model(model_inst, dataset_inst, num_samples=episodes, thr=thr)
        m = np.mean(ds)
        print(f"thr={thr:.2f} -> mean Dice {m:.4f}")
        if m > best_dice:
            best_dice, best_thr = m, thr
    print(f"[Threshold search] Best thr={best_thr:.2f} with mean Dice={best_dice:.4f}")
    return best_thr

In [None]:
import numpy as np
from skimage.morphology import remove_small_holes, remove_small_objects, closing, disk
from skimage.measure import label

def postprocess_prob(prob_2d, thr=0.5, min_size=64, pad=False):
    if torch.is_tensor(prob_2d):
        prob = prob_2d.detach().cpu().numpy()
    else:
        prob = prob_2d
    mask = (prob >= thr).astype(np.uint8)
    lab = label(mask)
    mask = lab > 0
    mask = remove_small_objects(mask, min_size=min_size)
    mask = remove_small_holes(mask, area_threshold=min_size)
    mask = closing(mask, disk(3))
    mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0).unsqueeze(0)
    return mask

In [None]:
if __name__ == '__main__':
    trained_model, dataset = build_and_train(num_epochs=10, episodes_per_epoch=60)

    print("\nSearching best inference threshold (optional):")
    best_thr = find_global_threshold(trained_model, dataset, episodes=15)

    print("\nVisualizing some inference results:")
    visualize_inference(trained_model, dataset, n=3, thr=best_thr)

    print("\nEvaluating model performance:")
    evaluate_model(trained_model, dataset, num_samples=20, thr=best_thr)