# Data

https://www.kaggle.com/datasets/longvil/cvc-colondb

# Import Libraries

In [None]:
import os
import gc
import time
import csv
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score
from skimage.segmentation import find_boundaries
from scipy.ndimage import distance_transform_edt as edt
import scipy.ndimage as ndi

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
def set_seed(seed=42):
    #Set seed for reproducibility across all libraries

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Make PyTorch deterministic (slower but more reproducible)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set environment variable for additional reproducibility
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

# Loading Train and Val File Names

In [None]:
# Define dataset root (modify if needed)
root_path = "/kaggle/input/cvc-colondb/CVC-ColonDB"
image_dir = os.path.join(root_path, "images")
mask_dir = os.path.join(root_path, "masks")

# Get sorted lists of image and mask files
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith('.png')])

# Verify that image and mask filenames match
assert len(image_files) == len(mask_files), "Mismatched number of images and masks!"
assert all(img == mask for img, mask in zip(image_files, mask_files)), "Filename mismatch!"

# Create DataFrame with relative paths
data = {
    "id": [i for i in range(len(image_files))],
    "image_path": [os.path.join("images", f) for f in image_files],
    "mask_path": [os.path.join("masks", f) for f in mask_files]
}
df = pd.DataFrame(data)

# Save to CSV
output_csv = "cvc_colondb_pairs.csv"
df.to_csv(output_csv, index=False)

In [None]:
data_metafile = "/kaggle/working/cvc_colondb_pairs.csv"

df = pd.read_csv(data_metafile)

In [None]:
# Get the first image and mask paths
image_path = "/kaggle/input/cvc-colondb/CVC-ColonDB/" + df["image_path"][1]
mask_path = "/kaggle/input/cvc-colondb/CVC-ColonDB/" + df["mask_path"][1]

# Open the images
img = Image.open(image_path).convert("RGB")  # Ensure RGB format
mask = Image.open(mask_path).convert("L")    # Ensure grayscale format

# Display them side by side
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(mask, cmap='gray')
plt.title("Mask")
plt.axis('off')

plt.tight_layout()
plt.show()

# UNet Architecture

## UNet Parts

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv_op(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int, stride_x=2):
        super().__init__()
        # Gating signal path (from decoder)
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # Skip connection path (from encoder)
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=stride_x, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # Attention map
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        psi_upsampled = F.interpolate(psi, size=x.size()[2:], mode='bilinear', align_corners=True)
        return x * psi_upsampled


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)

        return down, p


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)

## UNet Class

In [None]:
class AttentionUNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, 64)
        self.down_convolution_2 = DownSample(64, 128)
        self.down_convolution_3 = DownSample(128, 256)
        self.down_convolution_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.att4 = AttentionGate(F_g=1024, F_l=512, F_int=512, stride_x=2)
        self.att3 = AttentionGate(F_g=512, F_l=256, F_int=256, stride_x=2)
        self.att2 = AttentionGate(F_g=256, F_l=128, F_int=128, stride_x=2)
        self.att1 = AttentionGate(F_g=128, F_l=64, F_int=64, stride_x=2)

        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        down_1, p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)

        b = self.bottle_neck(p4)

        # Decoder with Attention
        att4 = self.att4(g=b, x=down_4)
        x4 = self.up_convolution_1(b, att4)

        att3 = self.att3(g=x4, x=down_3)
        x3 = self.up_convolution_2(x4, att3)

        att2 = self.att2(g=x3, x=down_2)
        x2 = self.up_convolution_3(x3, att2)

        att1 = self.att1(g=x2, x=down_1)
        x1 = self.up_convolution_4(x2, att1)

        # Final output
        out = self.out(x1)
        return out

        # up_1 = self.up_convolution_1(b, down_4)
        # up_2 = self.up_convolution_2(up_1, down_3)
        # up_3 = self.up_convolution_3(up_2, down_2)
        # up_4 = self.up_convolution_4(up_3, down_1)

        # out = self.out(up_4)
        # return out


if __name__ == "__main__":
    double_conv = DoubleConv(256, 256)
    # print(double_conv)

    input_image = torch.rand((1, 3, 512, 512))
    model = AttentionUNet(3, 10)
    # print(model)
    output = model(input_image)
    print(output.size()) # Expected size: [1, 10, 512, 512]

# Dataset Class


In [None]:
class KvasirDataset(Dataset):
    def __init__(self, root_path, filename_df, target_size=(320, 320), boundary_dist_max=50):
        self.root = root_path
        self.df = filename_df
        self.images = sorted([os.path.join(root_path, f) for f in self.df['image_path']])
        self.masks = sorted([os.path.join(root_path, f) for f in self.df['mask_path']])

        self.transform = A.Compose([
            A.Resize(height=320, width=320),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=30, p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])
        self.boundary_dist_max = boundary_dist_max

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert("RGB")
        m = Image.open(self.masks[idx]).convert("L")

        augmented = self.transform(image=np.array(img), mask=np.array(m))
        img_t = augmented['image']
        mask_t = augmented['mask'].unsqueeze(0).float()

        # Convert grayscale mask (0..1) into binary 0 or 1
        mask_bin = (mask_t > 0.5).float()

        dist_map = self._compute_signed_distance(mask_bin.numpy()[0, :, :])
        dist_map_t = torch.from_numpy(dist_map).unsqueeze(0).float()

        # return img_t, mask_bin, self.df['file_name'].iloc[idx]
        return img_t, mask_bin, dist_map_t, self.df['id'].iloc[idx]

    def _compute_signed_distance(self, bin_mask):
        posmask = bin_mask.astype(bool)
        negmask = ~posmask
        dist_out = edt(negmask)
        dist_in = edt(posmask)
        signed = dist_out - dist_in
        signed = np.clip(signed, -self.boundary_dist_max, self.boundary_dist_max)
        return signed / self.boundary_dist_max


# Loss

In [None]:
class StandardDiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(StandardDiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        intersection = (pred * target).sum()
        dice_score = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice_score
        
class AdaptiveTVMFDiceLoss:
    def __init__(self, num_classes=1, kappa_max=100.0, lambda_kappa=10.0):
        self.kappa = torch.zeros(num_classes, requires_grad=False)
        self.kappa_max = kappa_max
        self.lambda_kappa = lambda_kappa

    def update_kappa(self, dice_per_class):
        self.kappa = torch.min(self.lambda_kappa * dice_per_class,
                               torch.tensor(self.kappa_max))

    def __call__(self, preds, targets):
        B, C = preds.shape[:2]
        preds_flat = preds.view(B, C, -1)
        targets_flat = targets.view(B, C, -1)
        dot = (preds_flat * targets_flat).sum(-1)
        norm = torch.norm(preds_flat, dim=-1) * torch.norm(targets_flat, dim=-1) + 1e-6
        cos_sim = dot / norm
        k = self.kappa.view(1, C).to(preds.device)
        t = (1 + cos_sim) / (1 + k * (1 - cos_sim)) - 1
        return torch.mean((1 - t) ** 2)


class BoundaryLoss(nn.Module):
    def forward(self, preds, dist_map):
        return torch.mean(preds[:, 0, :, :] * dist_map)

class StandardDiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(StandardDiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        intersection = (pred * target).sum()
        dice_score = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice_score
        
class AdaptiveTVMFDiceLoss:
    def __init__(self, num_classes=1, kappa_max=100.0, lambda_kappa=10.0):
        self.kappa = torch.zeros(num_classes, requires_grad=False)
        self.kappa_max = kappa_max
        self.lambda_kappa = lambda_kappa

    def update_kappa(self, dice_per_class):
        self.kappa = torch.min(self.lambda_kappa * dice_per_class,
                               torch.tensor(self.kappa_max))

    def __call__(self, preds, targets):
        B, C = preds.shape[:2]
        preds_flat = preds.view(B, C, -1)
        targets_flat = targets.view(B, C, -1)
        dot = (preds_flat * targets_flat).sum(-1)
        norm = torch.norm(preds_flat, dim=-1) * torch.norm(targets_flat, dim=-1) + 1e-6
        cos_sim = dot / norm
        k = self.kappa.view(1, C).to(preds.device)
        t = (1 + cos_sim) / (1 + k * (1 - cos_sim)) - 1
        return torch.mean((1 - t) ** 2)


class BoundaryLoss(nn.Module):
    def forward(self, preds, dist_map):
        return torch.mean(preds[:, 0, :, :] * dist_map)


class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss


class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, smooth=1e-6):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        TP = (inputs * targets).sum()
        FP = ((1 - targets) * inputs).sum()
        FN = (targets * (1 - inputs)).sum()
        Tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return 1 - Tversky


class EdgeAwareLoss(nn.Module):
    def __init__(self, edge_weight=1e-4):
        super().__init__()
        self.edge_weight = edge_weight
        # Define sobel kernels as buffers without device yet
        sobel_x = torch.tensor([[[[-1, 0, 1],
                                  [-2, 0, 2],
                                  [-1, 0, 1]]]], dtype=torch.float32)
        sobel_y = torch.tensor([[[[-1, -2, -1],
                                  [ 0,  0,  0],
                                  [ 1,  2,  1]]]], dtype=torch.float32)
        self.register_buffer('sobel_x', sobel_x)
        self.register_buffer('sobel_y', sobel_y)

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        # Move kernels to input's device dynamically
        sobel_x = self.sobel_x.to(inputs.device)
        sobel_y = self.sobel_y.to(inputs.device)

        pred_edges_x = F.conv2d(inputs, sobel_x, padding=1)
        pred_edges_y = F.conv2d(inputs, sobel_y, padding=1)
        pred_edges = torch.sqrt(pred_edges_x ** 2 + pred_edges_y ** 2 + 1e-6)

        target_edges_x = F.conv2d(targets, sobel_x, padding=1)
        target_edges_y = F.conv2d(targets, sobel_y, padding=1)
        target_edges = torch.sqrt(target_edges_x ** 2 + target_edges_y ** 2 + 1e-6)

        loss = F.l1_loss(pred_edges, target_edges)
        return self.edge_weight * loss

class HybridFocalEdgeLoss(nn.Module):
    def __init__(self, focal_weight=0.5, edge_weight=0.3, dice_weight=0.2):
        super(HybridFocalEdgeLoss, self).__init__()
        self.focal_loss = FocalLoss(alpha=1, gamma=2)
        self.edge_loss = EdgeAwareLoss(edge_weight=1.0)
        self.dice_loss = StandardDiceLoss()
        self.focal_weight = focal_weight
        self.edge_weight = edge_weight
        self.dice_weight = dice_weight

    def forward(self, inputs, targets):
        focal = self.focal_loss(inputs, targets)
        edge = self.edge_loss(inputs, targets)
        dice = self.dice_loss(torch.sigmoid(inputs), targets)
        return self.focal_weight * focal + self.edge_weight * edge + self.dice_weight * dice

class HausdorffDistanceLoss(nn.Module):
    def __init__(self):
        super(HausdorffDistanceLoss, self).__init__()

    def forward(self, preds, targets):
        preds = preds.squeeze(1)  # [B,H,W]
        targets = targets.squeeze(1)

        preds_np = preds.detach().cpu().numpy()
        targets_np = targets.detach().cpu().numpy()

        losses = []
        for p, t in zip(preds_np, targets_np):
            p_bin = (p > 0.5).astype(np.uint8)
            t_bin = t.astype(np.uint8)

            dt_pred = ndi.distance_transform_edt(1 - p_bin)
            dt_gt = ndi.distance_transform_edt(1 - t_bin)

            hd_pred = np.percentile(dt_gt[p_bin == 1], 95) if np.any(p_bin == 1) else 0
            hd_gt = np.percentile(dt_pred[t_bin == 1], 95) if np.any(t_bin == 1) else 0

            hd = max(hd_pred, hd_gt)
            losses.append(hd)

        return torch.tensor(losses, device=preds.device).mean()

# Train Loop

In [None]:
def compute_metrics_per_image(pred_mask, gt_mask):
    pred = pred_mask.flatten()
    gt = gt_mask.flatten()

    dice = (2*np.sum(pred*gt) + 1e-6) / (np.sum(pred) + np.sum(gt) + 1e-6)
    iou = jaccard_score(gt, pred)
    accuracy = (np.sum(pred == gt) / len(gt))
    precision = precision_score(gt, pred, zero_division=0)
    recall = recall_score(gt, pred, zero_division=0)
    f1 = f1_score(gt, pred, zero_division=0)

    # Boundary metrics
    bgt = find_boundaries(gt_mask, mode='inner').astype(int)
    bpr = find_boundaries(pred_mask, mode='inner').astype(int)
    tp = np.sum((bpr == 1) & (bgt == 1))
    fp = np.sum((bpr == 1) & (bgt == 0))
    fn = np.sum((bgt == 1) & (bpr == 0))
    prec = tp / (tp + fp + 1e-6)
    rec = tp / (tp + fn + 1e-6)
    bf1 = 2 * prec * rec / (prec + rec + 1e-6)
    inter = np.sum((bpr & bgt))
    union = np.sum(((bpr + bgt) > 0))
    biou = inter / (union + 1e-6)

    # Per-class accuracy (binary class: bg and fg)
    total_pixels = len(gt)
    bg_acc = np.sum((pred == 0) & (gt == 0)) / (np.sum(gt == 0) + 1e-6)
    fg_acc = np.sum((pred == 1) & (gt == 1)) / (np.sum(gt == 1) + 1e-6)

    return dice, iou, bf1, biou, accuracy, precision, recall, f1, bg_acc, fg_acc

In [None]:
def train_branch(branch_name, config, train_df, val_df, out_dir, data_root,
                 BATCH=10, EPOCHS=20, LR=3e-4, seed=42):

    os.makedirs(out_dir, exist_ok=True)
    m_csv = os.path.join(out_dir, "metrics_per_image.csv")
    l_csv = os.path.join(out_dir, "loss_log.csv")

    with open(m_csv, 'w', newline='') as f:
        csv.writer(f).writerow(['image_id', 'epoch', 'branch', 'dice', 'iou', 'bf1', 'biou',
                                'accuracy', 'precision', 'recall', 'f1', 'bg_acc', 'fg_acc'])
    with open(l_csv, 'w', newline='') as f:
        csv.writer(f).writerow(['epoch', 'train_loss', 'val_loss', 'ce', 'dice_term', 'bnd_term', 'accuracy', 'bg_acc', 'fg_acc', 'time_s'])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Datasets & loaders
    train_ds = KvasirDataset(data_root, train_df)
    val_ds = KvasirDataset(data_root, val_df)

    tr_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                       generator=torch.Generator().manual_seed(seed))
    vl_dl = DataLoader(val_ds, batch_size=1, shuffle=False,
                       generator=torch.Generator().manual_seed(seed))

    # Model & optimizer
    model = AttentionUNet(in_channels=3, num_classes=1).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)

    # Loss function init based on config
    dice_loss_fn = StandardDiceLoss() if config.get('use_dice') else None
    adaptive_dice_fn = AdaptiveTVMFDiceLoss(num_classes=1) if config.get('use_tvmf') else None
    focal_loss_fn = FocalLoss() if config.get('use_focal') else None
    tversky_loss_fn = TverskyLoss() if config.get('use_tversky') else None
    bnd_loss_fn = BoundaryLoss() if config.get('use_boundary') else None

    if config.get('use_boundary'):
        btype = config.get('boundary_type', 'hausdorff')
        if btype == 'weighted':
            bnd_loss_fn = EdgeAwareLoss(edge_weight=config.get('w_bnd', 1e-4))
        elif btype == 'hausdorff':
            bnd_loss_fn = HausdorffDistanceLoss()

    w_ce = config.get('w_ce', 1.0)
    w_dice = config.get('w_dice', 0.0)
    w_bnd = config.get('w_bnd', 0.0)
    w_focal = config.get('w_focal', 0.0)
    w_tversky = config.get('w_tversky', 0.0)
    w_adaptive = config.get('w_adaptive', 0.0)
    w_bnd = config.get('w_bnd', 0.0)

    for epoch in range(1, EPOCHS + 1):
        t0 = time.time()
        model.train()
        tr_ce = tr_d = tr_b = tr_f = tr_t = tr_a = tot = 0.0

        for img, mask, dist, _ in tqdm(tr_dl, desc=f"[{branch_name}] Train"):
            img, mask, dist = img.to(device), mask.to(device), dist.to(device)
            logits = model(img)
            prob = torch.sigmoid(logits)

            ce_loss = F.binary_cross_entropy_with_logits(logits, mask)
            dice_loss = dice_loss_fn(prob, mask) if dice_loss_fn else 0
            bnd_loss = bnd_loss_fn(prob, dist) if bnd_loss_fn else 0
            focal_loss = focal_loss_fn(logits, mask) if focal_loss_fn else 0
            tversky_loss = tversky_loss_fn(logits, mask) if tversky_loss_fn else 0
            adaptive_loss = adaptive_dice_fn(prob, mask) if adaptive_dice_fn else 0

            loss = (w_ce * ce_loss + w_dice * dice_loss + w_bnd * bnd_loss +
                    w_focal * focal_loss + w_tversky * tversky_loss +
                    w_adaptive * adaptive_loss + w_bnd * bnd_loss)

            opt.zero_grad()
            loss.backward()
            opt.step()

            tr_ce += ce_loss.item()
            tr_d += dice_loss.item() if dice_loss_fn else 0
            tr_b += bnd_loss.item() if bnd_loss_fn else 0
            tr_f += focal_loss.item() if focal_loss_fn else 0
            tr_t += tversky_loss.item() if tversky_loss_fn else 0
            tr_a += adaptive_loss.item() if adaptive_dice_fn else 0
            tot += loss.item()

        train_loss = tot / len(tr_dl)

        model.eval()
        val_ce = val_d = val_b = val_f = val_t = val_a = tval = val_acc = 0.0
        val_bg_acc = val_fg_acc = 0.0

        with torch.no_grad(), open(m_csv, 'a', newline='') as mf:
            mw = csv.writer(mf)
            for img, mask, dist, img_id in tqdm(vl_dl, desc=f"[{branch_name}] Val"):
                img, mask, dist = img.to(device), mask.to(device), dist.to(device)
                logits = model(img)
                prob = torch.sigmoid(logits)

                ce_loss = F.binary_cross_entropy_with_logits(logits, mask)
                dice_loss = dice_loss_fn(prob, mask) if dice_loss_fn else 0
                bnd_loss = bnd_loss_fn(prob, dist) if bnd_loss_fn else 0
                focal_loss = focal_loss_fn(logits, mask) if focal_loss_fn else 0
                tversky_loss = tversky_loss_fn(logits, mask) if tversky_loss_fn else 0
                adaptive_loss = adaptive_dice_fn(prob, mask) if adaptive_dice_fn else 0

                loss = (w_ce * ce_loss + w_dice * dice_loss + w_bnd * bnd_loss +
                        w_focal * focal_loss + w_tversky * tversky_loss +
                        w_adaptive * adaptive_loss + w_bnd * bnd_loss)

                val_ce += ce_loss.item()
                val_d += dice_loss.item() if dice_loss_fn else 0
                val_b += bnd_loss.item() if bnd_loss_fn else 0
                val_f += focal_loss.item() if focal_loss_fn else 0
                val_t += tversky_loss.item() if tversky_loss_fn else 0
                val_a += adaptive_loss.item() if adaptive_dice_fn else 0
                tval += loss.item()

                pm = (prob.cpu().numpy() > 0.5).astype(np.uint8)[0, 0]
                gm = mask.cpu().numpy()[0, 0].astype(np.uint8)
                dice, iou, bf1, biou, acc, prec, rec, f1, bg_acc, fg_acc = compute_metrics_per_image(pm, gm)
                val_acc += acc
                val_bg_acc += bg_acc
                val_fg_acc += fg_acc

                mw.writerow([img_id, epoch, branch_name, dice, iou, bf1, biou,
                             acc, prec, rec, f1, bg_acc, fg_acc])

        val_loss = tval / len(vl_dl)
        avg_acc = val_acc / len(vl_dl)
        avg_bg_acc = val_bg_acc / len(vl_dl)
        avg_fg_acc = val_fg_acc / len(vl_dl)
        t_elapsed = time.time() - t0

        with open(l_csv, 'a', newline='') as lf:
            csv.writer(lf).writerow([
                epoch, train_loss, val_loss,
                tr_ce / len(tr_dl), tr_d / len(tr_dl), tr_b / len(tr_dl), avg_acc, avg_bg_acc, avg_fg_acc,
                t_elapsed
            ])

        print(f"{branch_name} epoch{epoch}: tr_loss {train_loss:.4f} val_loss {val_loss:.4f}, "
              f"avg_acc {avg_acc:.4f}, bg_acc {avg_bg_acc:.4f}, fg_acc {avg_fg_acc:.4f}, time {t_elapsed:.1f}s")

        torch.save(model.state_dict(), os.path.join(out_dir, f"{branch_name}_{epoch}.pth"))


In [None]:
branches = {
    'branch1_ce_only': {'w_ce': 1.0},
    'branch2_dice_only': {'use_dice': True, 'w_dice': 1.0},
    'branch3_ce_dice': {'use_dice': True, 'w_ce': 1.0, 'w_dice': 1.0},
    'branch4_ce_focal': {'use_focal': True, 'w_ce': 1.0, 'w_focal': 1.0},
    'branch5_ce_tversky': {'use_tversky': True, 'w_ce': 1.0, 'w_tversky': 1.0},
    'branch6_ce_adaptive_dice': {'use_tvmf': True, 'w_ce': 1.0, 'w_adaptive': 1.0},
    'branch7_ce_dice_focal': {'use_dice': True, 'use_focal': True, 'w_ce': 1.0, 'w_dice': 1.0, 'w_focal': 1.0},
    'branch8_ce_dice_tversky': {'use_dice': True, 'use_tversky': True, 'w_ce': 1.0, 'w_dice': 1.0, 'w_tversky': 1.0},
    'branch9_ce_adaptive_dice_focal': {'use_tvmf': True, 'use_focal': True, 'w_ce': 1.0, 'w_adaptive': 1.0, 'w_focal': 1.0},
    'branch10_ce_adaptive_dice_tversky': {'use_tvmf': True, 'use_tversky': True, 'w_ce': 1.0, 'w_adaptive': 1.0, 'w_tversky': 1.0},
    'branch11_ce_adaptive_dice_boundary': {'use_tvmf':True, 'use_boundary':True, 'w_ce': 1.0, 'w_dice':1.0, 'w_bnd':1.0},
}

train_df, val_df = train_test_split(
    df,
    test_size=0.2, 
    random_state=42,
    shuffle=True
)

for name, cfg in branches.items():
    torch.cuda.empty_cache()
    gc.collect()
    train_branch(name, cfg, train_df, val_df,
                 out_dir=f"/kaggle/working/experiments/{name}",
                 data_root=f"/kaggle/input/cvc-colondb/CVC-ColonDB/",
                 BATCH=10, EPOCHS=30, LR=3e-4)

# Test Model

In [None]:
def test_model(val_df, model_path, root_path="/kaggle/input/kvasirseg/Kvasir-SEG/Kvasir-SEG", seed=42, device="cuda"):
    # Load model
    model = AttentionUNet(in_channels=3, num_classes=1).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    val_dataset = KvasirDataset(root_path=root_path, filename_df=val_df)
    vl_dl = DataLoader(val_dataset, batch_size=2, shuffle=False,
                       generator=torch.Generator().manual_seed(seed))

    # Get 2 samples
    data_iter = iter(vl_dl)
    imgs, masks, _, filenames = next(data_iter)

    imgs = imgs.to(device)
    masks = masks.cpu().numpy()

    with torch.no_grad():
        outputs = model(imgs)
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float().cpu().numpy()

    # Show the first 2 images
    for i in range(2):
        img_np = imgs[i].cpu().permute(1, 2, 0).numpy()
        img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())  # normalize for display

        gt_mask = masks[i][0]  # [0] since it's single-channel
        pred_mask = preds[i][0]

        fig, ax = plt.subplots(1, 3, figsize=(12, 4))
        ax[0].imshow(img_np)
        ax[0].set_title(f"Original Image\n{filenames[i]}")
        ax[0].axis('off')

        ax[1].imshow(gt_mask, cmap='gray')
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis('off')

        ax[2].imshow(pred_mask, cmap='gray')
        ax[2].set_title("Predicted Mask")
        ax[2].axis('off')

        plt.show()

In [None]:
test_model(val_df, model_path="/kaggle/input/aunet-version-2/pytorch/default/1/branch10_ce_adaptive_dice_tversky_30.pth")