In [None]:
from google.colab import files
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!pip install -q gdown

In [None]:
!kaggle datasets download -d debeshjha1/kvasirseg

In [None]:
!unzip -q kvasirseg.zip -d kvasir_raw

In [None]:
!mkdir -p polyp_data/images polyp_data/masks
!cp -r kvasir_raw/Kvasir-SEG//Kvasir-SEG/images/* polyp_data/images/
!cp -r kvasir_raw/Kvasir-SEG//Kvasir-SEG/masks/* polyp_data/masks/

In [None]:
import os

print("Images:", len(os.listdir("polyp_data/images")))
print("Masks:", len(os.listdir("polyp_data/masks")))

In [None]:
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import random
from torchvision import transforms

class FewShotPolypDataset(Dataset):
    def __init__(self, image_dir, mask_dir, support_shots=3, image_size=256):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.support_shots = support_shots
        self.image_size = image_size

        self.image_list = sorted(os.listdir(image_dir))  # ensure consistent ordering
        self.mask_list = sorted(os.listdir(mask_dir))

        assert len(self.image_list) == len(self.mask_list), "Mismatch in image and mask count"

        self.transform_image = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])

        self.transform_mask = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),  # returns (1, H, W)
        ])

        # Random 3-shot support set
        self.support_indices = random.sample(range(len(self.image_list)), support_shots)
        self.query_indices = [i for i in range(len(self.image_list)) if i not in self.support_indices]

    def __len__(self):
        return len(self.query_indices)

    def __getitem__(self, idx):
        # Query sample
        query_idx = self.query_indices[idx]
        query_image = Image.open(os.path.join(self.image_dir, self.image_list[query_idx])).convert('RGB')
        query_mask = Image.open(os.path.join(self.mask_dir, self.mask_list[query_idx])).convert('L')

        query_image = self.transform_image(query_image)
        query_mask = self.transform_mask(query_mask)
        query_mask = (query_mask > 0.5).float()  # binarize

        # Support set
        support_images = []
        support_masks = []
        for s_idx in self.support_indices:
            s_image = Image.open(os.path.join(self.image_dir, self.image_list[s_idx])).convert('RGB')
            s_mask = Image.open(os.path.join(self.mask_dir, self.mask_list[s_idx])).convert('L')

            s_image = self.transform_image(s_image)
            s_mask = self.transform_mask(s_mask)
            s_mask = (s_mask > 0.5).float()

            support_images.append(s_image)
            support_masks.append(s_mask)

        support_images = torch.stack(support_images)  # (K, 3, H, W)
        support_masks = torch.stack(support_masks)    # (K, 1, H, W)

        return {
            'query_image': query_image,
            'query_mask': query_mask,
            'support_images': support_images,
            'support_masks': support_masks
        }

In [None]:
dataset = FewShotPolypDataset("polyp_data/images", "polyp_data/masks", support_shots=3)
sample = dataset[0]

print("Query image:", sample['query_image'].shape)
print("Query mask :", sample['query_mask'].shape)
print("Support images:", sample['support_images'].shape)
print("Support masks :", sample['support_masks'].shape)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UNetProto(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base_ch, base_ch * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(base_ch * 2, base_ch * 4)
        self.pool3 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = ConvBlock(base_ch * 4, base_ch * 8)

        # Decoder (NO skip connections)
        self.up2 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, 2, stride=2)
        self.dec2 = ConvBlock(base_ch * 4, base_ch * 4)
        self.up1 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, 2, stride=2)
        self.dec1 = ConvBlock(base_ch * 2, base_ch * 2)
        self.final_up = nn.ConvTranspose2d(base_ch * 2, base_ch, 2, stride=2)
        self.final = nn.Conv2d(base_ch, 1, 1)

    def forward_encoder(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        return b

    def forward_decoder(self, x):
        x = self.up2(x)
        x = self.dec2(x)
        x = self.up1(x)
        x = self.dec1(x)
        x = self.final_up(x)
        out = self.final(x)
        return out

In [None]:
def compute_prototype(support_feats, support_masks):
    # support_feats: (K, C, H, W)
    # support_masks: (K, 1, H, W)
    K, C, H, W = support_feats.shape
    support_masks = F.interpolate(support_masks, size=(H, W), mode='nearest')

    masked_feats = support_feats * support_masks  # apply mask
    proto = masked_feats.sum(dim=(0, 2, 3)) / (support_masks.sum(dim=(0, 2, 3)) + 1e-5)
    return proto  # (C,)

In [None]:
class FewShotSegModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNetProto()

    def forward(self, query_img, support_imgs, support_masks):
        # Support encoding
        support_feats = self.unet.forward_encoder(support_imgs)  # (K, C, H, W)
        proto = compute_prototype(support_feats, support_masks)  # (C,)

        # Query encoding
        query_feat = self.unet.forward_encoder(query_img.unsqueeze(0))  # (1, C, H, W)

        # Expand proto and fuse via cosine similarity
        proto = proto.view(1, -1, 1, 1)  # (1, C, 1, 1)
        sim_map = F.cosine_similarity(query_feat, proto, dim=1).unsqueeze(1)  # (1, 1, H, W)
        fused_feat = query_feat * sim_map  # element-wise weighting

        # Decode the fused feature map
        out = self.unet.forward_decoder(fused_feat)
        return torch.sigmoid(out)

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

In [None]:
import torch.nn.functional as F

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=(2, 3))
    dice = (2. * intersection + smooth) / (pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3)) + smooth)
    return 1 - dice.mean()

def combined_loss(pred, target):
    bce = F.binary_cross_entropy(pred, target)
    dice = dice_loss(pred, target)
    return bce + dice

In [None]:
import matplotlib.pyplot as plt

def visualize_sample(query_img, pred_mask, gt_mask):
    pred_mask = (pred_mask.squeeze().cpu().detach().numpy() > 0.5).astype(float)
    gt_mask = gt_mask.squeeze().cpu().numpy()
    query_img = query_img.permute(1, 2, 0).cpu().numpy()

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(query_img)
    axs[0].set_title("Query Image")
    axs[1].imshow(gt_mask, cmap='gray')
    axs[1].set_title("Ground Truth")
    axs[2].imshow(pred_mask, cmap='gray')
    axs[2].set_title("Prediction")
    for ax in axs:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

# Reload your dataset with 3-shot
dataset = FewShotPolypDataset("polyp_data/images", "polyp_data/masks", support_shots=3)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Model and optimizer
model = FewShotSegModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training
for epoch in range(5):  # increase to 20+ for real training
    model.train()
    epoch_loss = 0.0

    for batch in dataloader:
        query_img = batch['query_image'].to(device)
        query_mask = batch['query_mask'].to(device)
        support_imgs = batch['support_images'].to(device)
        support_masks = batch['support_masks'].to(device)

        optimizer.zero_grad()
        pred = model(
            query_img.squeeze(0),
            support_imgs.squeeze(0),
            support_masks.squeeze(0)
        )
        loss = combined_loss(pred, query_mask)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader):.4f}")

In [None]:
model.eval()
with torch.no_grad():
    sample = dataset[0]
    q_img = sample['query_image'].to(device)
    q_mask = sample['query_mask'].to(device)
    s_imgs = sample['support_images'].to(device)
    s_masks = sample['support_masks'].to(device)

    pred_mask = model(q_img, s_imgs, s_masks)
    visualize_sample(q_img.cpu(), pred_mask.cpu(), q_mask.cpu())

In [None]:
def compute_all_metrics(pred_mask, true_mask, threshold=0.5):
    pred_bin = (pred_mask > threshold).float()
    true_bin = (true_mask > 0.5).float()

    TP = (pred_bin * true_bin).sum()
    FP = (pred_bin * (1 - true_bin)).sum()
    FN = ((1 - pred_bin) * true_bin).sum()
    TN = ((1 - pred_bin) * (1 - true_bin)).sum()

    epsilon = 1e-8
    dice = (2 * TP) / (2 * TP + FP + FN + epsilon)
    iou = TP / (TP + FP + FN + epsilon)
    precision = TP / (TP + FP + epsilon)
    recall = TP / (TP + FN + epsilon)
    f1 = (2 * precision * recall) / (precision + recall + epsilon)
    accuracy = (TP + TN) / (TP + FP + FN + TN + epsilon)

    return {
        'Dice': dice.item(),
        'IoU': iou.item(),
        'Precision': precision.item(),
        'Recall': recall.item(),
        'F1 Score': f1.item(),
        'Accuracy': accuracy.item()
    }

In [None]:
metrics = compute_all_metrics(pred_mask.cpu(), q_mask.cpu())
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")