# 🚀 How to Run This Notebook in Google Colab

Follow these steps to run this notebook in Google Colab:

1. **Upload the notebook to Google Colab**:
   - Download this notebook file
   - Go to [Google Colab](https://colab.research.google.com/)
   - Click `File > Upload notebook` and select the downloaded file

2. **Mount Google Drive**:
   - Execute the first code cell to mount your Google Drive
   - Follow the authentication instructions to grant Colab access to your Drive

3. **Set up Kaggle API**:
   - Go to [your Kaggle account](https://www.kaggle.com/account)
   - Click on "Create New API Token" to download `kaggle.json`
   - In Colab, click the folder icon on the left sidebar, then upload the `kaggle.json` file
   - Run the second code cell to set up the Kaggle API and download the dataset

4. **Run the Code**:
   - Execute all cells in order
   - When ready to train, execute the `main_colab()` function instead of the original `main()` function
   
5. **Check Your Results**:
   - Once training is complete, your model checkpoints and outputs will be saved to:
     - Google Drive: `/content/drive/MyDrive/Contour_Mamba_Models/`
     - Local Colab: `/content/polygen_output/`

6. **Accessing Your Model After the Session Ends**:
   - Your model will be safe in Google Drive even after the Colab session expires
   - Find it at: `MyDrive/Contour_Mamba_Models/checkpoints/best_model_polygen.pth`

# Setup for Google Colab: Mount Drive and Connect to Kaggle Dataset

This notebook contains code to:
1. Mount Google Drive to save model checkpoints
2. Connect to and download the PolyGen2021 Kaggle dataset
3. Configure paths to save the best model to your Drive

In [1]:
# Step 1: Mount Google Drive to save your model
from google.colab import drive
drive.mount('/content/drive')

# Create a directory for saving models in Google Drive
import os
DRIVE_MODEL_DIR = "/content/drive/MyDrive/Contour_Mamba_Models"
os.makedirs(DRIVE_MODEL_DIR, exist_ok=True)
print(f"Model checkpoint directory created at: {DRIVE_MODEL_DIR}")

Mounted at /content/drive
Model checkpoint directory created at: /content/drive/MyDrive/Contour_Mamba_Models


In [2]:
# Step 2: Connect to Kaggle API and download the dataset
# You'll need to upload your Kaggle API token to Colab
# Follow these steps:
# 1. Go to kaggle.com → Account → Create API Token → download kaggle.json
# 2. Upload this file to Colab using the file browser

# Run this after uploading your kaggle.json
!mkdir -p ~/.kaggle
!cp /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download the PolyGen2021 dataset
# URL: https://www.kaggle.com/datasets/kokoroou/polypgen2021-segmentation-2
!kaggle datasets download -d kokoroou/polypgen2021-segmentation-2

# Extract the dataset
import zipfile
!mkdir -p /content/polypgen2021
!unzip -q polypgen2021-segmentation-2.zip -d /content/polypgen2021

# Verify the dataset structure
!ls -la /content/polypgen2021

Dataset URL: https://www.kaggle.com/datasets/kokoroou/polypgen2021-segmentation-2
License(s): unknown
Downloading polypgen2021-segmentation-2.zip to /content
 91% 999M/1.07G [00:00<00:00, 1.05GB/s]
100% 1.07G/1.07G [00:01<00:00, 1.10GB/s]
total 12
drwxr-xr-x 3 root root 4096 Jun 14 05:58 .
drwxr-xr-x 1 root root 4096 Jun 14 05:58 ..
drwxr-xr-x 4 root root 4096 Jun 14 05:58 PolypGen2021_segmentation


In [3]:
# Step 3: Modify the paths for Google Colab and add functionality to save to Drive
# This function updates the main() function to use Google Drive for checkpoints

def update_paths_for_colab():
    """
    Update the global paths to work with Google Colab and Google Drive
    Run this before executing the main() function
    """
    global OUTPUT_DIR, CHECKPOINT_DIR, POLYGEN_DATASET_PATH

    # Set paths for Google Colab
    OUTPUT_DIR = "/content/polygen_output"
    CHECKPOINT_DIR = "/content/drive/MyDrive/Contour_Mamba_Models/checkpoints"
    POLYGEN_DATASET_PATH = "/content/polypgen2021"

    # Create directories
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    print(f"Output directory set to: {OUTPUT_DIR}")
    print(f"Checkpoint directory set to: {CHECKPOINT_DIR}")
    print(f"Dataset path set to: {POLYGEN_DATASET_PATH}")

    return OUTPUT_DIR, CHECKPOINT_DIR, POLYGEN_DATASET_PATH

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from einops import rearrange
import math
import os
import zipfile
from tqdm import tqdm
import random
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torch.utils.checkpoint import checkpoint
from torch.amp import GradScaler, autocast

# -------------------------
# Multi-Kernel Positional Embedding Module
# -------------------------
class MultiKernelPositionalEmbedding(nn.Module):
    def __init__(self, in_channels, reduction=8):
        super(MultiKernelPositionalEmbedding, self).__init__()
        self.mid_channels = max(8, in_channels // reduction)

        self.conv3x3 = nn.Conv2d(in_channels, self.mid_channels, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(in_channels, self.mid_channels, kernel_size=5, padding=2)
        self.conv7x7 = nn.Conv2d(in_channels, self.mid_channels, kernel_size=7, padding=3)

        self.position_attention = nn.Sequential(
            nn.Conv2d(self.mid_channels * 3, in_channels, kernel_size=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat_3x3 = self.conv3x3(x)
        feat_5x5 = self.conv5x5(x)
        feat_7x7 = self.conv7x7(x)

        multi_scale_feat = torch.cat([feat_3x3, feat_5x5, feat_7x7], dim=1)
        attention_map = self.position_attention(multi_scale_feat)
        enhanced = x * attention_map
        return enhanced

# -------------------------
# Learnable Contour Extractor
# -------------------------
class LearnableContourExtractor(nn.Module):
    def __init__(self, in_channels):
        super(LearnableContourExtractor, self).__init__()

        self.sobel_x = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False, groups=in_channels)
        self.sobel_y = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False, groups=in_channels)

        sobel_x_kernel = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y_kernel = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)

        with torch.no_grad():
            for i in range(in_channels):
                self.sobel_x.weight[i, 0] = sobel_x_kernel
                self.sobel_y.weight[i, 0] = sobel_y_kernel

        self.edge_enhance = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels),
            nn.Sigmoid()
        )

        self.contour_refine = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)

    def forward(self, x):
        grad_x = self.sobel_x(x)
        grad_y = self.sobel_y(x)
        gradient_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)
        enhanced_edges = self.edge_enhance(gradient_magnitude)
        contour_features = self.contour_refine(enhanced_edges)
        return contour_features

# -------------------------
# Contour-Guided Selective Scan
# -------------------------
def contour_guided_selective_scan(u, delta, A, B, C, D, contour_features):
    try:
        batch_size, L, D_in = u.shape
        N = A.shape[-1]

        A = torch.clamp(A, min=-8.0, max=0.1)
        delta = torch.clamp(delta, min=1e-8, max=10.0)

        if contour_features.dim() == 4:
            b, c, h, w = contour_features.shape
            contour_features = contour_features.flatten(2).transpose(1, 2)
            if contour_features.shape[1] != L:
                contour_features = F.interpolate(
                    contour_features.transpose(1, 2).view(b, c, h, w),
                    size=(int(L**0.5), int(L**0.5)),
                    mode='bilinear',
                    align_corners=False
                ).flatten(2).transpose(1, 2)
            if contour_features.shape[2] != D_in:
                contour_features = F.adaptive_avg_pool1d(
                    contour_features.transpose(1, 2), D_in
                ).transpose(1, 2)

        contour_weight = torch.sigmoid(contour_features)
        delta_modulated = delta * (1.0 + 0.2 * contour_weight)

        chunk_size = min(64, L)
        outputs = []
        h_state = torch.zeros(batch_size, D_in, N, device=u.device, dtype=u.dtype)

        for i in range(0, L, chunk_size):
            end_idx = min(i + chunk_size, L)
            chunk_len = end_idx - i

            u_chunk = u[:, i:end_idx]
            delta_chunk = delta_modulated[:, i:end_idx]
            B_chunk = B[:, i:end_idx]
            C_chunk = C[:, i:end_idx]

            dA = torch.einsum('bld,dn->bldn', delta_chunk, A)
            dB_u = torch.einsum('bld,bld,bln->bldn', delta_chunk, u_chunk, B_chunk)

            dA_exp = torch.exp(torch.clamp(dA, min=-10.0, max=10.0))
            h_chunk = torch.zeros_like(dB_u)

            for t in range(chunk_len):
                h_state = h_state * dA_exp[:, t] + dB_u[:, t]
                h_chunk[:, t] = h_state

            y_chunk = torch.einsum('bldn,bln->bld', h_chunk, C_chunk)
            outputs.append(y_chunk)

        y = torch.cat(outputs, dim=1)
        return y + u * D.view(1, 1, -1).expand(batch_size, L, D_in)

    except Exception as e:
        print(f"Error in contour_guided_selective_scan: {e}")
        print(f"u shape: {u.shape}, delta shape: {delta.shape}")
        print(f"A shape: {A.shape}, B shape: {B.shape}, C shape: {C.shape}")
        print(f"contour_features shape: {contour_features.shape}")
        raise

# -------------------------
# Channel Attention Module
# -------------------------
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction),
            nn.ReLU(),
            nn.Linear(in_channels // reduction, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, l, c = x.shape
        y = self.avg_pool(x.transpose(1, 2)).view(b, c)
        y = self.fc(y).view(b, 1, c)
        return x * y.expand_as(x)

# -------------------------
# Contour-Aware Mamba Block
# -------------------------
class ContourAwareMambaBlock(nn.Module):
    def __init__(self, d_model, d_state=32, expand_factor=1.5):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = int(expand_factor * d_model)

        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3, padding=1, groups=self.d_inner)
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 16, bias=False)
        self.delta_proj = nn.Linear(16, self.d_inner, bias=True)

        self.contour_extractor = LearnableContourExtractor(d_model)
        self.contour_proj = nn.Linear(d_model, self.d_inner, bias=False)
        self.channel_attention = ChannelAttention(self.d_inner)

        A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))

        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
        self.dropout = nn.Dropout(0.1)
        self.norm_contour = nn.LayerNorm(self.d_inner)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5)
        nn.init.xavier_uniform_(self.x_proj.weight, gain=0.5)
        nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5)
        nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5)
        nn.init.xavier_uniform_(self.contour_proj.weight, gain=0.5)
        nn.init.constant_(self.delta_proj.bias, 1.0)

    def forward(self, x, spatial_dims=None):
        b, l, d = x.shape

        if spatial_dims is None:
            h = w = int(l ** 0.5)
            if h * w != l:
                h = int((l / 4) ** 0.5) * 2
                w = l // h
        else:
            h, w = spatial_dims

        x_spatial = x.transpose(1, 2).reshape(b, d, h, w)
        contour_spatial = self.contour_extractor(x_spatial)
        contour_flat = contour_spatial.flatten(2).transpose(1, 2)
        contour_features = self.contour_proj(contour_flat)
        contour_features = self.norm_contour(contour_features)

        x_and_res = self.in_proj(x)
        x_ssm, res = x_and_res.chunk(2, dim=-1)

        x_ssm = rearrange(x_ssm, 'b l d -> b d l')
        x_ssm = self.conv1d(x_ssm)
        x_ssm = rearrange(x_ssm, 'b d l -> b l d')
        x_ssm = F.silu(x_ssm)

        A = -torch.exp(self.A_log)
        x_dbl = self.x_proj(x_ssm)
        delta, B, C = torch.split(x_dbl, [16, self.d_state, self.d_state], dim=-1)
        delta = F.softplus(self.delta_proj(delta))

        y = contour_guided_selective_scan(x_ssm, delta, A, B, C, self.D, contour_features)
        y = self.channel_attention(y)
        y = y * F.silu(res)
        y = self.out_proj(y)
        return self.dropout(y)

# -------------------------
# Enhanced Residual Block
# -------------------------
class ContourAwareResidualBlock(nn.Module):
    def __init__(self, d_model, d_state=32, dropout_rate=0.2):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.mixer = ContourAwareMambaBlock(d_model, d_state)
        self.dropout = nn.Dropout(dropout_rate)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, spatial_dims=None):
        residual = x
        x = self.norm1(x)
        x = self.mixer(x, spatial_dims)
        x = self.dropout(x)
        x = residual + x
        return self.norm2(x)

# -------------------------
# Double Convolution with MKPE
# -------------------------
class DoubleConvWithMKPE(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.same_channels = in_channels == out_channels

        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.mkpe = MultiKernelPositionalEmbedding(out_channels)

        if not self.same_channels:
            self.project = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        identity = x if self.same_channels else self.project(x)
        x = self.double_conv(x)
        x = self.mkpe(x)
        x = x + identity
        return x

# -------------------------
# MKPE + Contour-Aware Mamba UNet
# -------------------------
class MKPEContourAwareMambaUNet(nn.Module):
    def __init__(self, num_classes=2, d_model=128, d_state=32, num_mamba_layers=6):
        super().__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # Encoder path
        self.encoder1 = DoubleConvWithMKPE(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = DoubleConvWithMKPE(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = DoubleConvWithMKPE(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        # Mamba blocks
        self.mamba_blocks = nn.ModuleList([
            ContourAwareResidualBlock(d_model, d_state)
            for _ in range(num_mamba_layers)
        ])

        # Bridge layers
        self.bridge_down = nn.Conv2d(256, d_model, kernel_size=1)
        self.bridge_up = nn.Conv2d(d_model, 256, kernel_size=1)
        self.bottleneck_mkpe = MultiKernelPositionalEmbedding(256)

        # Decoder path
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = DoubleConvWithMKPE(256, 128)
        self.deep_sup3 = nn.Conv2d(128, num_classes, kernel_size=1)

        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = DoubleConvWithMKPE(128, 64)
        self.deep_sup2 = nn.Conv2d(64, num_classes, kernel_size=1)

        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder1 = DoubleConvWithMKPE(35, 32)

        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
        self.output_mkpe = MultiKernelPositionalEmbedding(num_classes, reduction=2)

    def forward(self, x, return_deep=False):
        input_x = x

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pool = self.pool1(enc1)
        enc2 = self.encoder2(enc1_pool)
        enc2_pool = self.pool2(enc2)
        enc3 = self.encoder3(enc2_pool)
        enc3_pool = self.pool3(enc3)

        # Bridge with Mamba processing
        bridge_out = self.bridge_down(enc3_pool)
        b, c, h, w = bridge_out.size()
        mamba_input = bridge_out.permute(0, 2, 3, 1).reshape(b, h * w, c)

        mamba_output = mamba_input
        for mamba_block in self.mamba_blocks:
            mamba_output = mamba_block(mamba_output, spatial_dims=(h, w))

        mamba_output = mamba_output.reshape(b, h, w, c).permute(0, 3, 1, 2)
        mamba_output = self.bridge_up(mamba_output)
        mamba_output = self.bottleneck_mkpe(mamba_output)

        # Decoder path
        dec3 = self.upconv3(mamba_output)
        if dec3.shape[2:] != enc2.shape[2:]:
            dec3 = F.interpolate(dec3, size=enc2.shape[2:], mode='bilinear', align_corners=True)
        dec3 = torch.cat([dec3, enc2], dim=1)
        dec3 = self.decoder3(dec3)
        deep_out3 = self.deep_sup3(dec3)

        dec2 = self.upconv2(dec3)
        if dec2.shape[2:] != enc1.shape[2:]:
            dec2 = F.interpolate(dec2, size=enc1.shape[2:], mode='bilinear', align_corners=True)
        dec2 = torch.cat([dec2, enc1], dim=1)
        dec2 = self.decoder2(dec2)
        deep_out2 = self.deep_sup2(dec2)

        dec1 = self.upconv1(dec2)
        if dec1.shape[2:] != input_x.shape[2:]:
            dec1 = F.interpolate(dec1, size=input_x.shape[2:], mode='bilinear', align_corners=True)
        dec1 = torch.cat([dec1, input_x], dim=1)
        dec1 = self.decoder1(dec1)

        # Final output
        out = self.final_conv(dec1)
        out = self.output_mkpe(out)

        if return_deep:
            deep_out2 = F.interpolate(deep_out2, size=input_x.shape[2:], mode='bilinear', align_corners=True)
            deep_out3 = F.interpolate(deep_out3, size=input_x.shape[2:], mode='bilinear', align_corners=True)
            return out, deep_out2, deep_out3
        return out

# -------------------------
# Combined Loss Function
# -------------------------
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.bce_loss = nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        bce = self.bce_loss(inputs, targets)
        inputs_soft = F.softmax(inputs, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()
        intersection = (inputs_soft * targets_one_hot).sum(dim=(2, 3))
        cardinality = inputs_soft.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))
        dice = (2. * intersection / (cardinality + 1e-6)).mean()
        dice_loss = 1 - dice
        return self.bce_weight * bce + self.dice_weight * dice_loss

# -------------------------
# Deep Supervision Loss
# -------------------------
class DeepSupervisionLoss(nn.Module):
    def __init__(self, main_weight=0.8, deep2_weight=0.1, deep3_weight=0.1):
        super().__init__()
        self.main_weight = main_weight
        self.deep2_weight = deep2_weight
        self.deep3_weight = deep3_weight
        self.criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)

    def forward(self, outputs, target):
        main_out, deep2, deep3 = outputs
        loss_main = self.criterion(main_out, target)
        loss_deep2 = self.criterion(deep2, target)
        loss_deep3 = self.criterion(deep3, target)
        total_loss = (
            self.main_weight * loss_main +
            self.deep2_weight * loss_deep2 +
            self.deep3_weight * loss_deep3
        )
        return total_loss

# -------------------------
# Evaluation Metrics
# -------------------------
def calculate_iou(pred_mask, gt_mask):
    pred_mask = (pred_mask > 0).cpu().numpy().astype(bool)
    gt_mask = (gt_mask > 0).cpu().numpy().astype(bool)
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    return intersection / union if union != 0 else 1.0

def calculate_dice(pred_mask, gt_mask):
    pred_mask = (pred_mask > 0).cpu().numpy().astype(bool)
    gt_mask = (gt_mask > 0).cpu().numpy().astype(bool)
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    sum_areas = pred_mask.sum() + gt_mask.sum()
    return 2.0 * intersection / sum_areas if sum_areas != 0 else 1.0

# -------------------------
# Enhanced PolyGen2021 Dataset Class
# -------------------------
class PolyGen2021Dataset(Dataset):
    def __init__(self, images_dir, masks_dir, split='train', transform=None, augment=True, img_size=256,
                 split_ratio=(0.7, 0.15, 0.15), seed=42):
        """
        Enhanced dataset class for PolyGen2021

        Args:
            images_dir: Path to images directory
            masks_dir: Path to masks directory
            split: 'train', 'val', or 'test'
            transform: Image transformations
            augment: Whether to use data augmentation
            img_size: Target image size
            split_ratio: Train/val/test split ratio as tuple (train, val, test)
            seed: Random seed for reproducibility
        """
        self.transform = transform
        self.augment = augment and split == 'train'
        self.img_size = img_size
        self.images_dir = images_dir
        self.masks_dir = masks_dir

        if not os.path.exists(self.images_dir):
            raise ValueError(f"Images directory not found: {self.images_dir}")
        if not os.path.exists(self.masks_dir):
            raise ValueError(f"Masks directory not found: {self.masks_dir}")

        # Get all valid image files with verified mask pairs
        self.image_mask_pairs = []

        potential_images = sorted([f for f in os.listdir(self.images_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])

        print(f"Finding valid image-mask pairs...")
        for img_name in tqdm(potential_images, desc="Checking image-mask pairs"):
            base_name = os.path.splitext(img_name)[0]
            mask_found = False
            mask_path = None

            # Check for exact filename match with different extensions
            for ext in ['.jpg', '.png', '.jpeg', '.tif']:
                candidate = os.path.join(self.masks_dir, base_name + ext)
                if os.path.exists(candidate):
                    mask_found = True
                    mask_path = candidate
                    break

            # Try to find mask by substring matching if exact match fails
            if not mask_found:
                for mask_file in os.listdir(self.masks_dir):
                    # Check if base name is contained in mask filename or vice versa
                    mask_base = os.path.splitext(mask_file)[0]
                    if base_name in mask_base or mask_base in base_name:
                        mask_path = os.path.join(self.masks_dir, mask_file)
                        mask_found = True
                        break

            # If a matching mask is found, add to valid pairs
            if mask_found:
                self.image_mask_pairs.append((img_name, os.path.basename(mask_path)))

        if not self.image_mask_pairs:
            raise ValueError(f"No valid image-mask pairs found. Check that images and masks directories are correct.")

        print(f"Found {len(self.image_mask_pairs)} valid image-mask pairs")

        # Create train/val/test split based on the split_ratio
        np.random.seed(seed)
        indices = list(range(len(self.image_mask_pairs)))
        np.random.shuffle(indices)

        train_end = int(len(indices) * split_ratio[0])
        val_end = int(len(indices) * (split_ratio[0] + split_ratio[1]))

        if split == 'train':
            self.indices = indices[:train_end]
        elif split == 'val':
            self.indices = indices[train_end:val_end]
        else:  # test
            self.indices = indices[val_end:]

        print(f"Created {split} dataset with {len(self.indices)} image-mask pairs")

        # Initialize augmentations
        self.basic_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        img_name, mask_name = self.image_mask_pairs[self.indices[idx]]
        img_path = os.path.join(self.images_dir, img_name)
        mask_path = os.path.join(self.masks_dir, mask_name)

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply augmentations for training with controlled randomness
        if self.augment:
            # Spatial transforms
            if random.random() > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)

            if random.random() > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)

            # Rotation: More controlled rotation for medical images
            if random.random() > 0.5:
                angle = random.choice([0, 90, 180, 270])
                image = TF.rotate(image, angle, fill=0)
                mask = TF.rotate(mask, angle, fill=0)

            # Intensity transformations
            if random.random() > 0.5:
                brightness = random.uniform(0.85, 1.15)
                contrast = random.uniform(0.85, 1.15)
                saturation = random.uniform(0.85, 1.15)

                image = TF.adjust_brightness(image, brightness)
                image = TF.adjust_contrast(image, contrast)
                image = TF.adjust_saturation(image, saturation)

            # Color jitter
            if random.random() > 0.7:
                color_jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.05)
                image = color_jitter(image)

        # Apply basic transforms
        image = self.basic_transform(image)

        # Process mask
        mask = TF.resize(mask, (self.img_size, self.img_size), interpolation=TF.InterpolationMode.NEAREST)
        mask_array = np.array(mask)
        mask_binary = (mask_array > 127).astype(np.int64)  # Threshold for binary mask
        mask = torch.from_numpy(mask_binary).long()

        # Apply normalization only after all other transforms
        if self.transform:
            image = self.transform(image)

        return image, mask

# -------------------------
# Dataset Utility Functions
# -------------------------
def explore_dataset_structure(base_dir):
    """Explore the dataset structure to determine the actual paths"""
    print(f"Exploring dataset structure at {base_dir}...")

    if not os.path.exists(base_dir):
        print(f"Base directory {base_dir} does not exist")
        return None, None

    # Print the directory structure to understand what we're working with
    print("Directory structure:")

    found_images_dir = None
    found_masks_dir = None

    # Level 1 directories
    for root_item in os.listdir(base_dir):
        root_path = os.path.join(base_dir, root_item)
        if os.path.isdir(root_path):
            print(f"- {root_item}/")

            # Level 2 directories
            for sub_item in os.listdir(root_path):
                sub_path = os.path.join(root_path, sub_item)
                if os.path.isdir(sub_path):
                    print(f"  - {sub_item}/")

                    # Check for 'images' and 'masks' directories
                    if sub_item == 'images' and not found_images_dir:
                        found_images_dir = sub_path
                        print(f"    Found images directory: {found_images_dir}")

                    elif sub_item == 'masks' and not found_masks_dir:
                        found_masks_dir = sub_path
                        print(f"    Found masks directory: {found_masks_dir}")

                    # Level 3 directories (for nested structures)
                    else:
                        for third_item in os.listdir(sub_path):
                            third_path = os.path.join(sub_path, third_item)
                            if os.path.isdir(third_path):
                                print(f"    - {third_item}/")

                                # Check for deeply nested image/mask directories
                                if third_item == 'images' and not found_images_dir:
                                    found_images_dir = third_path
                                    print(f"      Found images directory: {found_images_dir}")

                                elif third_item == 'masks' and not found_masks_dir:
                                    found_masks_dir = third_path
                                    print(f"      Found masks directory: {found_masks_dir}")

    # Check if we found special folder structure from error message
    if not found_images_dir and os.path.exists(os.path.join(base_dir, "PolypGen2021_MultiCenterData_v3", "positive", "images")):
        found_images_dir = os.path.join(base_dir, "PolypGen2021_MultiCenterData_v3", "positive", "images")
        print(f"Found images directory: {found_images_dir}")

    if not found_masks_dir and os.path.exists(os.path.join(base_dir, "PolypGen2021_MultiCenterData_v3", "positive", "masks")):
        found_masks_dir = os.path.join(base_dir, "PolypGen2021_MultiCenterData_v3", "positive", "masks")
        print(f"Found masks directory: {found_masks_dir}")

    if found_images_dir:
        # Check image count
        image_files = [f for f in os.listdir(found_images_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(image_files)} images in {found_images_dir}")
        if image_files:
            print(f"Sample image names: {image_files[:3]}")

    if found_masks_dir:
        # Check mask count
        mask_files = [f for f in os.listdir(found_masks_dir) if f.endswith(('.jpg', '.png', '.jpeg', '.tif'))]
        print(f"Found {len(mask_files)} masks in {found_masks_dir}")
        if mask_files:
            print(f"Sample mask names: {mask_files[:3]}")

    return found_images_dir, found_masks_dir

def find_dataset_paths():
    """Find the correct matching image and mask paths for PolyGen2021"""
    kaggle_input_path = "/kaggle/input/polypgen2021"

    if os.path.exists(kaggle_input_path):
        print(f"Found PolyGen2021 dataset at {kaggle_input_path}")

        # Based on the error message, we know the structure of the dataset
        if os.path.exists("/kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3"):
            # For polyp segmentation, we need the positive directory (images with polyps and their masks)
            images_dir = "/kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/images"
            masks_dir = "/kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/masks"

            if os.path.exists(images_dir) and os.path.exists(masks_dir):
                print(f"Using positive images from: {images_dir}")
                print(f"Using masks from: {masks_dir}")

                # Verify some image-mask pairs exist
                image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
                mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])

                print(f"Found {len(image_files)} positive images")
                print(f"Found {len(mask_files)} masks")

                if image_files and mask_files:
                    # Show some sample names to help with debugging
                    print(f"Sample image names: {image_files[:3]}")
                    print(f"Sample mask names: {mask_files[:3]}")

                return images_dir, masks_dir
            else:
                if not os.path.exists(images_dir):
                    print(f"WARNING: Positive images directory not found at {images_dir}")
                if not os.path.exists(masks_dir):
                    print(f"WARNING: Masks directory not found at {masks_dir}")

        # If the specific structure wasn't found, explore the dataset
        print("Exploring dataset structure to find matching image and mask directories...")
        for root, dirs, files in os.walk(kaggle_input_path):
            if os.path.basename(root) == "images" and "positive" in root:
                potential_images_dir = root
                # Look for a parallel masks directory
                potential_masks_dir = os.path.join(os.path.dirname(root), "masks")
                if os.path.exists(potential_masks_dir):
                    print(f"Found matching directories:")
                    print(f"Images: {potential_images_dir}")
                    print(f"Masks: {potential_masks_dir}")
                    return potential_images_dir, potential_masks_dir

    # If automated finding fails, ask for manual input
    print("Could not find matching image and mask directories automatically.")
    print("Please enter the paths manually:")

    print("\nAvailable dataset contents:")
    for root, dirs, files in os.walk("/kaggle/input/polypgen2021", topdown=True, maxdepth=3):
        print(f"- {root}")
        if len(files) > 0:
            print(f"  ({len(files)} files, first few: {files[:3]})")

    images_dir = input("Enter the path to the POSITIVE images directory: ")
    masks_dir = input("Enter the path to the masks directory: ")

    if os.path.exists(images_dir) and os.path.exists(masks_dir):
        return images_dir, masks_dir
    else:
        raise ValueError("Invalid directory paths provided")

# -------------------------
# Enhanced Data Loaders
# -------------------------
def get_data_loaders(images_dir, masks_dir, batch_size=8, img_size=256,
                     split_ratio=(0.7, 0.15, 0.15), num_workers=4, seed=42):
    """
    Create data loaders for train, validation, and test sets

    Args:
        images_dir: Directory containing images
        masks_dir: Directory containing masks
        batch_size: Batch size for training
        img_size: Size of input images
        split_ratio: Tuple of (train, val, test) split ratios
        num_workers: Number of worker processes for data loading
        seed: Random seed for reproducibility
    """
    # Define normalization pipeline
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # Create datasets
    train_dataset = PolyGen2021Dataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        split='train',
        transform=normalize,
        augment=True,
        img_size=img_size,
        split_ratio=split_ratio,
        seed=seed
    )

    val_dataset = PolyGen2021Dataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        split='val',
        transform=normalize,
        augment=False,
        img_size=img_size,
        split_ratio=split_ratio,
        seed=seed
    )

    test_dataset = PolyGen2021Dataset(
        images_dir=images_dir,
        masks_dir=masks_dir,
        split='test',
        transform=normalize,
        augment=False,
        img_size=img_size,
        split_ratio=split_ratio,
        seed=seed
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True  # Prevent issues with batch norm on small last batch
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader, (len(train_dataset), len(val_dataset), len(test_dataset))

In [None]:
# -------------------------
# Enhanced Training Function
# -------------------------
def train_model(model, train_loader, val_loader, num_epochs=200,
                learning_rate=1.8e-4, device='cuda', checkpoint_dir='/kaggle/working/checkpoints',
                patience=15, scheduler_type='cosine'):
    """
    Enhanced training function with various optimizations

    Args:
        model: Model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Number of epochs to train for
        learning_rate: Initial learning rate
        device: Device to train on
        checkpoint_dir: Directory to save checkpoints
        patience: Early stopping patience
        scheduler_type: Type of learning rate scheduler to use
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    model = model.to(device)

    # Combined loss with dice and focal loss components
    criterion = CombinedLoss(bce_weight=0.3, dice_weight=0.7, focal_weight=0.3)
    deep_criterion = DeepSupervisionLoss(main_weight=0.8, deep2_weight=0.1, deep3_weight=0.1)

    # Optimizer with weight decay and grad clipping
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=1e-5,
        betas=(0.9, 0.95)
    )

    # Learning rate scheduler
    if scheduler_type == 'cosine':
        def lr_lambda(epoch):
            if epoch < 10:  # Warm-up period
                return 0.1 + 0.9 * (epoch / 10.0)
            else:  # Cosine decay
                return 0.5 * (1.0 + math.cos(math.pi * (epoch - 10) / (num_epochs - 10)))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    elif scheduler_type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6
        )
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )

    # Mixed precision training
    scaler = GradScaler('cuda')

    # Training tracking
    best_iou = 0.0
    best_dice = 0.0
    best_epoch = 0
    early_stop_counter = 0
    history = {
        'train_losses': [],
        'val_losses': [],
        'val_ious': [],
        'val_dices': [],
        'lr': []
    }

    # Training loop with progress bar
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_iou = 0.0
        train_dice = 0.0
        train_batches = 0

        # Progress bar for training
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")

        for batch_idx, (images, masks) in enumerate(progress_bar):
            images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()

            with autocast('cuda'):
                if epoch < 15:
                    outputs = model(images, return_deep=False)
                    loss = criterion(outputs, masks)
                else:
                    outputs = model(images, return_deep=True)
                    if epoch < 25:
                        deep_criterion.main_weight = 0.9
                        deep_criterion.deep2_weight = 0.05
                        deep_criterion.deep3_weight = 0.05
                    else:
                        deep_criterion.main_weight = 0.8
                        deep_criterion.deep2_weight = 0.1
                        deep_criterion.deep3_weight = 0.1
                    loss = deep_criterion(outputs, masks)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            # Update metrics
            train_loss += loss.item()

            # Calculate training metrics occasionally for progress monitoring
            if batch_idx % 20 == 0:
                with torch.no_grad():
                    main_output = outputs[0] if isinstance(outputs, tuple) else outputs
                    predictions = torch.argmax(main_output, dim=1)
                    batch_iou = 0.0
                    batch_dice = 0.0

                    for i in range(min(4, images.size(0))):  # Check just a few samples to save time
                        pred_mask = predictions[i]
                        gt_mask = masks[i]
                        batch_iou += calculate_iou(pred_mask, gt_mask)
                        batch_dice += calculate_dice(pred_mask, gt_mask)

                    batch_iou /= min(4, images.size(0))
                    batch_dice /= min(4, images.size(0))
                    train_iou += batch_iou
                    train_dice += batch_dice
                    train_batches += 1

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'lr': f"{optimizer.param_groups[0]['lr']:.6f}"
            })

            if batch_idx % 5 == 0:
                torch.cuda.empty_cache()

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_iou = 0.0
        val_dice = 0.0
        num_val_samples = 0

        # Progress bar for validation
        val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")

        with torch.no_grad():
            for images, masks in val_progress:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images, return_deep=(epoch >= 15))

                if epoch < 15:
                    loss = criterion(outputs, masks)
                else:
                    loss = deep_criterion(outputs, masks)

                val_loss += loss.item()
                main_output = outputs[0] if isinstance(outputs, tuple) else outputs
                predictions = torch.argmax(main_output, dim=1)

                for i in range(images.size(0)):
                    pred_mask = predictions[i]
                    gt_mask = masks[i]
                    iou = calculate_iou(pred_mask, gt_mask)
                    dice = calculate_dice(pred_mask, gt_mask)
                    val_iou += iou
                    val_dice += dice
                    num_val_samples += 1

                val_progress.set_postfix({'loss': f"{loss.item():.4f}"})

        # Calculate epoch metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        avg_val_iou = val_iou / num_val_samples
        avg_val_dice = val_dice / num_val_samples

        if train_batches > 0:
            avg_train_iou = train_iou / train_batches
            avg_train_dice = train_dice / train_batches
        else:
            avg_train_iou = 0
            avg_train_dice = 0

        # Update history
        history['train_losses'].append(avg_train_loss)
        history['val_losses'].append(avg_val_loss)
        history['val_ious'].append(avg_val_iou)
        history['val_dices'].append(avg_val_dice)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        # Update scheduler
        if scheduler_type == 'plateau':
            scheduler.step(avg_val_iou)  # Step with validation IoU for plateau scheduler
        else:
            scheduler.step()

        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, IoU: {avg_train_iou:.4f}, Dice: {avg_train_dice:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}, IoU: {avg_val_iou:.4f}, Dice: {avg_val_dice:.4f}')
        print(f'  Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        # Check for best model (using IoU)
        if avg_val_iou > best_iou:
            best_iou = avg_val_iou
            best_dice = avg_val_dice
            best_epoch = epoch
            early_stop_counter = 0

            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_iou': best_iou,
                'best_dice': best_dice,
                'history': history
            }, os.path.join(checkpoint_dir, 'best_model_polygen.pth'))
            print(f'  New best model saved with IoU: {best_iou:.4f}, Dice: {best_dice:.4f}')
        else:
            early_stop_counter += 1
            print(f'  No improvement for {early_stop_counter} epochs. Best IoU: {best_iou:.4f} at epoch {best_epoch+1}')

            # Save regular checkpoint every 10 epochs
            if epoch % 10 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_iou': best_iou,
                    'history': history
                }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))

        # Plot progress every 10 epochs
        if epoch % 10 == 9 or epoch == num_epochs - 1:
            plot_training_progress(history, save_path=os.path.join(checkpoint_dir, f'progress_epoch_{epoch+1}.png'))

        # Early stopping
        if early_stop_counter >= patience:
            print(f"Early stopping after {patience} epochs without improvement")
            break

        print('-' * 50)

    print(f"Training complete! Best validation IoU: {best_iou:.4f}, Dice: {best_dice:.4f} at epoch {best_epoch+1}")
    return history, best_iou, best_dice, best_epoch

# -------------------------
# Enhanced Visualization Functions
# -------------------------
def visualize_predictions(model, test_loader, device='cuda', num_samples=5, save_path=None):
    """Enhanced prediction visualization with overlay and error analysis"""
    model.eval()
    all_images = []
    all_masks = []
    all_preds = []

    # Get samples
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images, return_deep=False)
            predictions = torch.argmax(outputs, dim=1)

            # Add to collections
            all_images.extend(images.cpu())
            all_masks.extend(masks.cpu())
            all_preds.extend(predictions.cpu())

            if len(all_images) >= num_samples:
                break

    # Only take the requested number of samples
    all_images = all_images[:num_samples]
    all_masks = all_masks[:num_samples]
    all_preds = all_preds[:num_samples]

    # Denormalize images
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])

    # Create visualization grid
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))

    # If only one sample, wrap axes in a list to make indexing consistent
    if num_samples == 1:
        axes = np.array([axes])

    for i in range(num_samples):
        # Get image, mask, and prediction
        img = all_images[i] * std.view(3, 1, 1) + mean.view(3, 1, 1)
        img = torch.clamp(img, 0, 1).permute(1, 2, 0).numpy()

        mask = all_masks[i].numpy()
        pred = all_preds[i].numpy()

        # Calculate metrics
        iou_score = calculate_iou(all_preds[i], all_masks[i])
        dice_score = calculate_dice(all_preds[i], all_masks[i])

        # Create overlay for better visualization
        # Green: True Positive, Blue: False Negative, Red: False Positive
        overlay = np.zeros((*img.shape[:2], 4))

        # True positive (both pred and gt are 1) - green
        true_positive = np.logical_and(pred == 1, mask == 1)
        overlay[true_positive] = [0, 1, 0, 0.5]  # Semi-transparent green

        # False positive (pred is 1 but gt is 0) - red
        false_positive = np.logical_and(pred == 1, mask == 0)
        overlay[false_positive] = [1, 0, 0, 0.5]  # Semi-transparent red

        # False negative (pred is 0 but gt is 1) - blue
        false_negative = np.logical_and(pred == 0, mask == 1)
        overlay[false_negative] = [0, 0, 1, 0.5]  # Semi-transparent blue

        # Plot original image
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')

        # Plot ground truth
        axes[i, 1].imshow(mask, cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')

        # Plot prediction
        axes[i, 2].imshow(pred, cmap='gray')
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')

        # Plot overlay on original image
        axes[i, 3].imshow(img)
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title(f'Overlay (IoU: {iou_score:.4f}, Dice: {dice_score:.4f})')
        axes[i, 3].axis('off')

        # Add legend to the overlay
        legend_elements = [
            plt.Rectangle((0, 0), 1, 1, color=(0, 1, 0, 0.5), label='True Positive'),
            plt.Rectangle((0, 0), 1, 1, color=(1, 0, 0, 0.5), label='False Positive'),
            plt.Rectangle((0, 0), 1, 1, color=(0, 0, 1, 0.5), label='False Negative')
        ]

        # Place legend outside the last subplot
        if i == 0:  # Only add legend to the first row
            axes[i, 3].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

    return all_images, all_masks, all_preds

def visualize_contours(model, test_loader, device='cuda', num_samples=5, save_path=None):
    """Visualize contour extraction and feature maps"""
    model.eval()
    images, masks = next(iter(test_loader))
    images = images.to(device)[:num_samples]

    # Extract intermediate activations
    with torch.no_grad():
        # Get encoder features
        enc1 = model.encoder1(images)
        enc1_pool = model.pool1(enc1)
        enc2 = model.encoder2(enc1_pool)
        enc2_pool = model.pool2(enc2)
        enc3 = model.encoder3(enc2_pool)
        enc3_pool = model.pool3(enc3)

        # Get contour features
        bridge_out = model.bridge_down(enc3_pool)
        b, c, h, w = bridge_out.size()
        mamba_input = bridge_out.permute(0, 2, 3, 1).reshape(b, h * w, c)

        # Just visualize the first mamba block's contour extractor
        contour_spatial = model.mamba_blocks[0].contour_extractor(
            mamba_input.reshape(b, h, w, c).permute(0, 3, 1, 2)
        )

    # Prepare visualization
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))

    # If only one sample, wrap axes in a list to make indexing consistent
    if num_samples == 1:
        axes = np.array([axes])

    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    for i in range(num_samples):
        # Original image
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * std + mean
        img = np.clip(img, 0, 1)

        # Encoder features (high level)
        enc_features = enc3[i].mean(dim=0).cpu().numpy()
        enc_features = (enc_features - enc_features.min()) / (enc_features.max() - enc_features.min() + 1e-8)

        # Contour features
        contour = contour_spatial[i].mean(dim=0).cpu().numpy()
        contour = (contour - contour.min()) / (contour.max() - contour.min() + 1e-8)

        # Heatmap overlay
        heatmap = cv2.applyColorMap(np.uint8(255 * contour), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0

        # Resize heatmap to match original image size
        heatmap_resized = cv2.resize(heatmap, (img.shape[1], img.shape[0]))

        # Create blend of original image and heatmap
        blend = 0.7 * img + 0.3 * heatmap_resized

        # Plot each visualization
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(enc_features, cmap='viridis')
        axes[i, 1].set_title('Encoder Features')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(contour, cmap='magma')
        axes[i, 2].set_title('Contour Features')
        axes[i, 2].axis('off')

        axes[i, 3].imshow(blend)
        axes[i, 3].set_title('Contour Heatmap Overlay')
        axes[i, 3].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def plot_training_progress(history, save_path=None, figsize=(18, 6)):
    """Plot detailed training progress with multiple metrics"""
    plt.figure(figsize=figsize)

    # Plot training and validation loss
    plt.subplot(1, 3, 1)
    plt.plot(history['train_losses'], label='Train Loss')
    plt.plot(history['val_losses'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot validation IoU
    plt.subplot(1, 3, 2)
    plt.plot(history['val_ious'], label='IoU')
    if 'val_dices' in history:
        plt.plot(history['val_dices'], label='Dice')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title('Validation Metrics')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot learning rate
    plt.subplot(1, 3, 3)
    plt.plot(history['lr'])
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')

    plt.show()

# -------------------------
# Enhanced Evaluation Metrics
# -------------------------
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.3, dice_weight=0.5, focal_weight=0.2, gamma=2.0):
        """
        Combined loss function with BCE, Dice, and Focal loss components

        Args:
            bce_weight: Weight for binary cross-entropy loss
            dice_weight: Weight for dice loss
            focal_weight: Weight for focal loss
            gamma: Focal loss focusing parameter
        """
        super().__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.gamma = gamma
        self.bce_loss = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, inputs, targets):
        # BCE Loss
        bce = self.bce_loss(inputs, targets)

        # Dice Loss
        inputs_soft = F.softmax(inputs, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()

        intersection = (inputs_soft * targets_one_hot).sum(dim=(2, 3))
        cardinality = inputs_soft.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))
        dice = (2. * intersection / (cardinality + 1e-6)).mean()
        dice_loss = 1 - dice

        # Focal Loss (for hard examples)
        if self.focal_weight > 0:
            # Calculate focal factor
            pt = torch.exp(-bce)
            focal_factor = (1 - pt) ** self.gamma
            focal_loss = focal_factor * bce

            # Return combined loss
            return self.bce_weight * bce + self.dice_weight * dice_loss + self.focal_weight * focal_loss.mean()
        else:
            return self.bce_weight * bce + self.dice_weight * dice_loss

# Enhanced evaluation metrics and test function
def calculate_metrics(pred_mask, gt_mask):
    """Calculate multiple evaluation metrics"""
    pred_mask = pred_mask.cpu().numpy().astype(bool)
    gt_mask = gt_mask.cpu().numpy().astype(bool)

    # Intersection and union
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()

    # Pixel counts
    gt_pixels = gt_mask.sum()
    pred_pixels = pred_mask.sum()

    # Calculate metrics
    iou = intersection / union if union > 0 else 1.0
    dice = 2 * intersection / (gt_pixels + pred_pixels) if (gt_pixels + pred_pixels) > 0 else 1.0

    # Precision and recall
    precision = intersection / pred_pixels if pred_pixels > 0 else 1.0
    recall = intersection / gt_pixels if gt_pixels > 0 else 1.0

    # F1 score (equivalent to Dice but calculated differently)
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {
        'iou': float(iou),
        'dice': float(dice),
        'precision': float(precision),
        'recall': float(recall),
        'f1': float(f1)
    }

def test_model_comprehensive(model, test_loader, device='cuda', use_tta=True, use_crf=False):
    """Comprehensive model evaluation with TTA and detailed metrics"""
    model.eval()
    all_metrics = []
    all_metrics_by_size = {'small': [], 'medium': [], 'large': []}

    # For ROC curve
    all_probs = []
    all_gt = []

    progress_bar = tqdm(test_loader, desc="Testing")

    with torch.no_grad():
        for images, masks in progress_bar:
            images, masks = images.to(device), masks.to(device)

            # Standard prediction
            outputs = model(images, return_deep=False)

            # Test-time augmentation
            if use_tta:
                # Horizontal flip
                images_hflip = torch.flip(images, [3])
                outputs_hflip = model(images_hflip, return_deep=False)
                outputs_hflip = torch.flip(outputs_hflip, [3])

                # Vertical flip
                images_vflip = torch.flip(images, [2])
                outputs_vflip = model(images_vflip, return_deep=False)
                outputs_vflip = torch.flip(outputs_vflip, [2])

                # Combined prediction
                outputs = (outputs + outputs_hflip + outputs_vflip) / 3

            # Store probability maps for ROC analysis
            probs = F.softmax(outputs, dim=1)[:, 1].cpu().numpy().flatten()
            all_probs.extend(probs)

            # Get predictions
            predictions = torch.argmax(outputs, dim=1)

            for i in range(images.size(0)):
                pred_mask = predictions[i]
                gt_mask = masks[i]

                # Store ground truth for ROC analysis
                all_gt.extend(gt_mask.cpu().numpy().flatten())

                # Apply CRF post-processing if requested
                if use_crf:
                    try:
                        import pydensecrf.densecrf as dcrf
                        from pydensecrf.utils import unary_from_softmax

                        # Create CRF input
                        image = images[i].cpu().numpy().transpose(1, 2, 0)
                        image = (image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
                        image = (image * 255).astype(np.uint8)

                        # Get probability map
                        prob = F.softmax(outputs[i], dim=0).cpu().numpy()

                        # Create CRF model
                        d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], 2)
                        U = unary_from_softmax(prob)
                        d.setUnaryEnergy(U)

                        # Add pairwise terms
                        d.addPairwiseGaussian(sxy=3, compat=3)
                        d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=image, compat=10)

                        # Perform inference
                        Q = d.inference(5)
                        pred_mask = torch.from_numpy(np.argmax(Q, axis=0).reshape(gt_mask.shape)).to(device)
                    except ImportError:
                        print("Warning: pydensecrf not available. Skipping CRF post-processing.")

                # Calculate comprehensive metrics
                metrics = calculate_metrics(pred_mask, gt_mask)
                all_metrics.append(metrics)

                # Categorize by polyp size
                gt_size = gt_mask.sum().item() / (gt_mask.shape[0] * gt_mask.shape[1])

                if gt_size < 0.05:  # Small polyps (less than 5% of image)
                    all_metrics_by_size['small'].append(metrics)
                elif gt_size < 0.15:  # Medium polyps (5-15% of image)
                    all_metrics_by_size['medium'].append(metrics)
                else:  # Large polyps (>15% of image)
                    all_metrics_by_size['large'].append(metrics)

    # Calculate average metrics
    avg_metrics = {}
    for key in all_metrics[0].keys():
        avg_metrics[key] = np.mean([m[key] for m in all_metrics])
        avg_metrics[f"{key}_std"] = np.std([m[key] for m in all_metrics])

    # Calculate average metrics by size
    avg_metrics_by_size = {}
    for size, metrics_list in all_metrics_by_size.items():
        if metrics_list:  # Only calculate if there are samples in this category
            avg_metrics_by_size[size] = {}
            for key in metrics_list[0].keys():
                avg_metrics_by_size[size][key] = np.mean([m[key] for m in metrics_list])
                avg_metrics_by_size[size][f"{key}_std"] = np.std([m[key] for m in metrics_list])
            avg_metrics_by_size[size]['count'] = len(metrics_list)

    # Print summary
    print("\nTest Results:")
    print(f"Number of samples: {len(all_metrics)}")
    print(f"IoU: {avg_metrics['iou']:.4f} ± {avg_metrics['iou_std']:.4f}")
    print(f"Dice: {avg_metrics['dice']:.4f} ± {avg_metrics['dice_std']:.4f}")
    print(f"Precision: {avg_metrics['precision']:.4f} ± {avg_metrics['precision_std']:.4f}")
    print(f"Recall: {avg_metrics['recall']:.4f} ± {avg_metrics['recall_std']:.4f}")
    print(f"F1 Score: {avg_metrics['f1']:.4f} ± {avg_metrics['f1_std']:.4f}")

    print("\nResults by polyp size:")
    for size, metrics in avg_metrics_by_size.items():
        if metrics:
            print(f"\n{size.capitalize()} polyps ({metrics['count']} samples):")
            print(f"  IoU: {metrics['iou']:.4f} ± {metrics['iou_std']:.4f}")
            print(f"  Dice: {metrics['dice']:.4f} ± {metrics['dice_std']:.4f}")

    # Calculate high IoU percentage
    high_iou_count = sum(1 for m in all_metrics if m['iou'] > 0.9)
    high_iou_percent = high_iou_count / len(all_metrics) * 100 if all_metrics else 0
    print(f"\nSamples with IoU > 0.9: {high_iou_count} ({high_iou_percent:.1f}%)")

    # Return all collected data for further analysis
    return {
        'avg_metrics': avg_metrics,
        'metrics_by_size': avg_metrics_by_size,
        'all_metrics': all_metrics,
        'all_probs': np.array(all_probs),
        'all_gt': np.array(all_gt)
    }

def plot_roc_curve(all_probs, all_gt, save_path=None):
    """Plot ROC curve from classification results"""
    from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

    # Calculate ROC curve and area
    fpr, tpr, _ = roc_curve(all_gt, all_probs)
    roc_auc = auc(fpr, tpr)

    # Calculate PR curve and area
    precision, recall, _ = precision_recall_curve(all_gt, all_probs)
    pr_auc = average_precision_score(all_gt, all_probs)

    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot ROC curve
    ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.3f})')
    ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('False Positive Rate')
    ax1.set_ylabel('True Positive Rate')
    ax1.set_title('Receiver Operating Characteristic (ROC)')
    ax1.legend(loc="lower right")
    ax1.grid(True, alpha=0.3)

    # Plot PR curve
    ax2.plot(recall, precision, color='green', lw=2, label=f'PR curve (area = {pr_auc:.3f})')
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('Recall')
    ax2.set_ylabel('Precision')
    ax2.set_title('Precision-Recall Curve')
    ax2.legend(loc="lower left")
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

    return roc_auc, pr_auc

In [None]:
# # -------------------------
# # Main Function
# # -------------------------
# def main():
#     # Setup directories
#     OUTPUT_DIR = "/kaggle/working/polygen_output"
#     CHECKPOINT_DIR = "/kaggle/working/polygen_checkpoints"
#     os.makedirs(OUTPUT_DIR, exist_ok=True)
#     os.makedirs(CHECKPOINT_DIR, exist_ok=True)

#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     print(f"Using device: {device}")

#     # Print GPU info if available
#     if torch.cuda.is_available():
#         print(f"GPU: {torch.cuda.get_device_name(0)}")
#         print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

#     # Set seeds for reproducibility
#     torch.manual_seed(42)
#     np.random.seed(42)
#     random.seed(42)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(42)
#         torch.backends.cudnn.deterministic = True
#         torch.backends.cudnn.benchmark = False

#     print("Finding PolyGen2021 dataset in Kaggle environment...")
#     images_dir, masks_dir = find_dataset_paths()

#     if not images_dir or not masks_dir:
#         # Try with the known structure based on the error message
#         images_dir = "/kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/images"
#         masks_dir = "/kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/masks"

#     print(f"Using images from: {images_dir}")
#     print(f"Using masks from: {masks_dir}")

#     print("Creating data loaders...")
#     # Using batch size of 8 as requested with seed for reproducibility
#     train_loader, val_loader, test_loader, split_sizes = get_data_loaders(
#         images_dir, masks_dir, batch_size=8, img_size=256,
#         split_ratio=(0.7, 0.15, 0.15), num_workers=4, seed=42
#     )

#     print(f"Dataset split: Train={split_sizes[0]}, Val={split_sizes[1]}, Test={split_sizes[2]}")

#     print("Creating model...")
#     model = MKPEContourAwareMambaUNet(num_classes=2, d_model=128, d_state=32, num_mamba_layers=6)
#     model.gradient_checkpointing = True

#     total_params = sum(p.numel() for p in model.parameters())
#     trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     print(f"Total parameters: {total_params:,}")
#     print(f"Trainable parameters: {trainable_params:,}")

#     print("Starting training...")
#     # Using 200 epochs with learning rate from GOOD_mkpe_x
#     history, best_iou, best_dice, best_epoch = train_model(
#         model, train_loader, val_loader,
#         num_epochs=200,
#         learning_rate=1.8e-4,  # Matching the learning rate from GOOD_mkpe_x
#         device=device,
#         checkpoint_dir=CHECKPOINT_DIR,
#         patience=15,  # Early stopping patience
#         scheduler_type='cosine'  # Using cosine scheduler
#     )

#     # Load best model for testing
#     print("\nLoading best model for testing...")
#     checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, 'best_model_polygen.pth'))
#     model.load_state_dict(checkpoint['model_state_dict'])

#     print("Performing comprehensive model testing...")
#     test_results = test_model_comprehensive(model, test_loader, device, use_tta=True)

#     # Plot ROC curve
#     print("\nGenerating ROC and PR curves...")
#     roc_auc, pr_auc = plot_roc_curve(
#         test_results['all_probs'],
#         test_results['all_gt'],
#         save_path=os.path.join(OUTPUT_DIR, 'roc_pr_curves.png')
#     )

#     print("\nGenerating visualizations...")
#     visualize_predictions(
#         model, test_loader, device, num_samples=8,
#         save_path=os.path.join(OUTPUT_DIR, 'prediction_visualization.png')
#     )

#     visualize_contours(
#         model, test_loader, device, num_samples=5,
#         save_path=os.path.join(OUTPUT_DIR, 'contour_visualization.png')
#     )

#     # Save detailed test metrics
#     metrics_summary = {
#         'overall': test_results['avg_metrics'],
#         'by_size': test_results['metrics_by_size'],
#         'roc_auc': roc_auc,
#         'pr_auc': pr_auc,
#         'best_model': {
#             'epoch': best_epoch,
#             'iou': best_iou,
#             'dice': best_dice
#         }
#     }

#     # Save metrics summary as JSON
#     import json
#     with open(os.path.join(OUTPUT_DIR, 'test_metrics.json'), 'w') as f:
#         # Convert numpy values to Python native types for JSON serialization
#         def convert_to_serializable(obj):
#             if isinstance(obj, (np.integer, np.floating, np.bool_)):
#                 return obj.item()
#             elif isinstance(obj, np.ndarray):
#                 return obj.tolist()
#             return obj

#         # Process the metrics dictionary to make it JSON serializable
#         serializable_metrics = json.loads(
#             json.dumps(metrics_summary, default=convert_to_serializable)
#         )
#         json.dump(serializable_metrics, f, indent=4)

#     # Final summary
#     print("\nTraining and evaluation complete!")
#     print(f"Best validation IoU: {best_iou:.4f} (Epoch {best_epoch+1})")
#     print(f"Best validation Dice: {best_dice:.4f}")
#     print(f"Final test IoU: {test_results['avg_metrics']['iou']:.4f}")
#     print(f"Final test Dice: {test_results['avg_metrics']['dice']:.4f}")
#     print(f"ROC AUC: {roc_auc:.4f}")
#     print(f"All results saved to {OUTPUT_DIR}")

# if __name__ == "__main__":
#     try:
#         main()
#     except KeyboardInterrupt:
#         print("Training interrupted by user")
#     except Exception as e:
#         import traceback
#         print(f"Error during training: {e}")
#         traceback.print_exc()

#         # Try to save a backup of the model if an error occurs
#         try:
#             if 'model' in locals() and 'optimizer' in locals():
#                 torch.save({
#                     'model_state_dict': model.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),
#                 }, '/kaggle/working/emergency_backup.pth')
#                 print("Emergency backup saved to /kaggle/working/emergency_backup.pth")
#         except:
#             print("Could not save emergency backup")

Using device: cuda
GPU: Tesla P100-PCIE-16GB
Memory: 17.06 GB
Finding PolyGen2021 dataset in Kaggle environment...
Found PolyGen2021 dataset at /kaggle/input/polypgen2021
Using positive images from: /kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/images
Using masks from: /kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/masks
Found 3762 positive images
Found 3762 masks
Sample image names: ['C1_100H0050.jpg', 'C1_100S0001.jpg', 'C1_100S0003.jpg']
Sample mask names: ['C1_100H0050.jpg', 'C1_100S0001.jpg', 'C1_100S0003.jpg']
Using images from: /kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/images
Using masks from: /kaggle/input/polypgen2021/PolypGen2021_MultiCenterData_v3/positive/masks
Creating data loaders...
Finding valid image-mask pairs...


Checking image-mask pairs: 100%|██████████| 3762/3762 [00:07<00:00, 478.84it/s]


Found 3762 valid image-mask pairs
Created train dataset with 2633 image-mask pairs
Finding valid image-mask pairs...


Checking image-mask pairs: 100%|██████████| 3762/3762 [00:01<00:00, 2477.17it/s]


Found 3762 valid image-mask pairs
Created val dataset with 564 image-mask pairs
Finding valid image-mask pairs...


Checking image-mask pairs: 100%|██████████| 3762/3762 [00:00<00:00, 72604.08it/s]


Found 3762 valid image-mask pairs
Created test dataset with 565 image-mask pairs
Dataset split: Train=2633, Val=564, Test=565
Creating model...
Total parameters: 4,998,682
Trainable parameters: 4,998,682
Starting training...


Epoch 1/100 [Train]:  31%|███▏      | 103/329 [05:58<13:04,  3.47s/it, loss=0.6777, lr=0.000010]