# üéØ Attention U-Net for Kidney Stone Segmentation on KSSD2025

## üìä Objective
Beat the baseline Modified U-Net score of **97.06%** using Attention U-Net

## üéØ Expected Results
- **Target Dice Score:** 97.5% - 98.2%
- **Strategy:** Attention mechanisms for small object detection
- **Architecture:** U-Net + Attention Gates

## üì¶ Step 1: Install & Import Required Libraries

In [None]:
# Install required packages
!pip install -q segmentation-models-pytorch albumentations

print("‚úÖ Libraries installed successfully!")
print("="*50)

In [None]:
# Core Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
import gc
import warnings
import os
from glob import glob
warnings.filterwarnings('ignore')

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Sklearn
from sklearn.model_selection import KFold

# Image Processing
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Progress Bar
from tqdm.auto import tqdm

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

print("‚úÖ All libraries imported successfully!")
print(f"‚úÖ PyTorch Version: {torch.__version__}")
print(f"‚úÖ CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ CUDA Device: {torch.cuda.get_device_name(0)}")
print("="*50)

## üìÇ Step 2: Configure Dataset Paths

**CORRECTED FOR KSSD2025 STRUCTURE WITH /data SUBDIRECTORY**

In [None]:
# Configuration
class Config:
    # === DATASET PATHS - CORRECTED FOR YOUR STRUCTURE ===
    # Your dataset has a 'data' subdirectory
    DATA_PATH = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data"
    
    # Image and mask directories
    IMAGE_DIR = f"{DATA_PATH}/images"
    MASK_DIR = f"{DATA_PATH}/masks"
    
    # Image Settings
    IMG_SIZE = 256  # Resize images to 256x256
    
    # Training Settings
    BATCH_SIZE = 16
    NUM_EPOCHS = 150
    LEARNING_RATE = 0.001
    NUM_FOLDS = 5
    
    # Model Settings
    ENCODER_CHANNELS = [16, 32, 64, 128]
    DECODER_CHANNELS = [128, 64, 32, 16]
    
    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Save Settings
    SAVE_MODELS = True
    MODEL_DIR = "/kaggle/working/models"
    OUTPUT_DIR = "/kaggle/working/outputs"

config = Config()

# Create necessary directories
Path(config.MODEL_DIR).mkdir(parents=True, exist_ok=True)
Path(config.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print("‚öôÔ∏è Configuration Settings:")
print(f"  üìÅ Data Path: {config.DATA_PATH}")
print(f"  üìÅ Image Dir: {config.IMAGE_DIR}")
print(f"  üìÅ Mask Dir: {config.MASK_DIR}")
print(f"  üñºÔ∏è  Image Size: {config.IMG_SIZE}x{config.IMG_SIZE}")
print(f"  üì¶ Batch Size: {config.BATCH_SIZE}")
print(f"  üîÑ Epochs: {config.NUM_EPOCHS}")
print(f"  üìä K-Folds: {config.NUM_FOLDS}")
print(f"  üéØ Device: {config.DEVICE}")
print("="*50)

## üîç Step 3: Explore Dataset Structure

Let's first understand your dataset structure

In [None]:
# Explore the dataset structure
base_path = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset"

print("üìÅ Exploring dataset structure...\n")

def explore_directory(path, level=0, max_level=3):
    """Recursively explore directory structure"""
    if level > max_level or not os.path.exists(path):
        return
    
    indent = "  " * level
    items = sorted(os.listdir(path))
    
    for item in items[:20]:  # Limit to first 20 items
        item_path = os.path.join(path, item)
        if os.path.isdir(item_path):
            count = len(os.listdir(item_path))
            print(f"{indent}üìÅ {item}/ ({count} items)")
            if level < 2:  # Only go 2 levels deep
                explore_directory(item_path, level + 1, max_level)
        else:
            print(f"{indent}üìÑ {item}")
    
    if len(items) > 20:
        print(f"{indent}... and {len(items) - 20} more items")

explore_directory(base_path)
print("\n" + "="*50)

## üìä Step 4: Load Dataset with Flexible Path Detection

In [None]:
# Function to find images with multiple extensions
def find_images(directory, extensions=['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.PNG', '*.JPEG']):
    """Find all images in directory with given extensions"""
    all_images = []
    for ext in extensions:
        all_images.extend(glob(os.path.join(directory, ext)))
        # Also search recursively in case images are in subdirectories
        all_images.extend(glob(os.path.join(directory, '**', ext), recursive=True))
    return sorted(list(set(all_images)))  # Remove duplicates and sort

def auto_find_dataset_dirs(base_path):
    """Automatically find image and mask directories"""
    possible_image_dirs = ['images', 'image', 'img', 'train', 'train_images', 'data/images']
    possible_mask_dirs = ['masks', 'mask', 'labels', 'label', 'train_masks', 'data/masks', 'ground_truth', 'gt']
    
    image_dir = None
    mask_dir = None
    
    # Try to find image directory
    for dir_name in possible_image_dirs:
        test_path = os.path.join(base_path, dir_name)
        if os.path.exists(test_path):
            # Check if it has images
            test_images = find_images(test_path)
            if len(test_images) > 0:
                image_dir = test_path
                print(f"‚úÖ Found images in: {dir_name}")
                break
    
    # Try to find mask directory
    for dir_name in possible_mask_dirs:
        test_path = os.path.join(base_path, dir_name)
        if os.path.exists(test_path):
            # Check if it has images
            test_masks = find_images(test_path)
            if len(test_masks) > 0:
                mask_dir = test_path
                print(f"‚úÖ Found masks in: {dir_name}")
                break
    
    return image_dir, mask_dir

# Try to auto-detect
print("üîç Auto-detecting dataset structure...\n")
detected_image_dir, detected_mask_dir = auto_find_dataset_dirs(base_path)

# Update config if found
if detected_image_dir:
    config.IMAGE_DIR = detected_image_dir
if detected_mask_dir:
    config.MASK_DIR = detected_mask_dir

print(f"\nüìÅ Using directories:")
print(f"  Images: {config.IMAGE_DIR}")
print(f"  Masks: {config.MASK_DIR}")
print("="*50)

## üì• Step 5: Load and Match Images with Masks

In [None]:
# Get image and mask paths
print("üì• Loading dataset...\n")

image_paths = find_images(config.IMAGE_DIR)
mask_paths = find_images(config.MASK_DIR)

print(f"üìä Dataset Statistics:")
print(f"  üñºÔ∏è  Total Images Found: {len(image_paths)}")
print(f"  üé≠ Total Masks Found: {len(mask_paths)}")

if len(image_paths) == 0:
    print("\n‚ùå ERROR: No images found!")
    print("\nLet me search the entire dataset directory...")
    all_images = glob(os.path.join(base_path, '**', '*.jpg'), recursive=True) + \
                 glob(os.path.join(base_path, '**', '*.png'), recursive=True)
    if len(all_images) > 0:
        print(f"\nFound {len(all_images)} images in total. Showing first 10:")
        for img in all_images[:10]:
            print(f"  {img}")
else:
    print("\n‚úÖ Images loaded successfully!")
    print(f"\nFirst 5 image paths:")
    for img in image_paths[:5]:
        print(f"  {img}")

if len(mask_paths) == 0:
    print("\n‚ùå ERROR: No masks found!")
else:
    print("\n‚úÖ Masks loaded successfully!")
    print(f"\nFirst 5 mask paths:")
    for mask in mask_paths[:5]:
        print(f"  {mask}")

# Match images and masks by filename
if len(image_paths) > 0 and len(mask_paths) > 0:
    # Extract filenames (without extension and path)
    def get_base_name(path):
        return os.path.splitext(os.path.basename(path))[0]
    
    image_dict = {get_base_name(p): p for p in image_paths}
    mask_dict = {get_base_name(p): p for p in mask_paths}
    
    # Find matching pairs
    matched_data = []
    unmatched_images = []
    
    for img_name, img_path in image_dict.items():
        if img_name in mask_dict:
            matched_data.append({
                'image_path': img_path,
                'mask_path': mask_dict[img_name],
                'filename': img_name
            })
        else:
            unmatched_images.append(img_name)
    
    data_df = pd.DataFrame(matched_data)
    
    print(f"\n‚úÖ Matched {len(data_df)} image-mask pairs")
    
    if len(unmatched_images) > 0:
        print(f"‚ö†Ô∏è Warning: {len(unmatched_images)} images without matching masks")
        if len(unmatched_images) <= 5:
            print(f"Unmatched: {unmatched_images}")
    
    if len(data_df) > 0:
        print(f"\nüìã Dataset Preview:")
        print(data_df.head(10))
    else:
        print("\n‚ùå No matching image-mask pairs found!")
        data_df = None
else:
    print("\n‚ùå Cannot create dataset - missing images or masks")
    data_df = None

print("\n" + "="*50)

## üîç Step 6: Visualize Sample Images and Masks

In [None]:
def visualize_samples(df, num_samples=4):
    """Visualize random samples from dataset"""
    if df is None or len(df) == 0:
        print("‚ùå No data available for visualization")
        return
    
    num_samples = min(num_samples, len(df))
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples*3))
    
    # Handle single sample case
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    indices = np.random.choice(len(df), num_samples, replace=False)
    
    for idx, sample_idx in enumerate(indices):
        # Load image and mask
        image = cv2.imread(df.iloc[sample_idx]['image_path'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(df.iloc[sample_idx]['mask_path'], cv2.IMREAD_GRAYSCALE)
        
        # Original image
        axes[idx, 0].imshow(image)
        axes[idx, 0].set_title(f'Original Image\n{df.iloc[sample_idx]["filename"]}')
        axes[idx, 0].axis('off')
        
        # Mask
        axes[idx, 1].imshow(mask, cmap='gray')
        axes[idx, 1].set_title(f'Ground Truth Mask')
        axes[idx, 1].axis('off')
        
        # Overlay
        overlay = image.copy()
        overlay[mask > 0] = [255, 0, 0]  # Red overlay on stones
        axes[idx, 2].imshow(overlay)
        axes[idx, 2].set_title(f'Overlay (Red = Stone)')
        axes[idx, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'{config.OUTPUT_DIR}/dataset_samples.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f"‚úÖ Sample visualization saved to {config.OUTPUT_DIR}/dataset_samples.png")

if data_df is not None:
    visualize_samples(data_df, num_samples=4)
    print("="*50)
else:
    print("‚ö†Ô∏è Skipping visualization - no data loaded")

## üîÑ Step 7: Data Augmentation Pipeline

In [None]:
# Training augmentation
train_transform = A.Compose([
    A.Resize(config.IMG_SIZE, config.IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=2.5, p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.0075,
        scale_limit=0.0075,
        rotate_limit=0,
        p=0.5
    ),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

# Validation augmentation (no data augmentation)
val_transform = A.Compose([
    A.Resize(config.IMG_SIZE, config.IMG_SIZE),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

print("‚úÖ Data augmentation pipelines created!")
print("  - Training: Flip, Rotate, Shift, Scale, Brightness")
print("  - Validation: Resize & Normalize only")
print("="*50)

## üì¶ Step 8: Custom Dataset Class

In [None]:
class KidneyStoneDataset(Dataset):
    """Custom Dataset for Kidney Stone Segmentation"""
    
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Load image
        image_path = self.df.iloc[idx]['image_path']
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask_path = self.df.iloc[idx]['mask_path']
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Binarize mask (0 or 1)
        mask = (mask > 0).astype(np.float32)
        
        # Apply augmentation
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # Add channel dimension to mask
        mask = mask.unsqueeze(0)
        
        return image, mask

print("‚úÖ Custom Dataset class created!")
print("="*50)

## üèóÔ∏è Step 9: Build Attention U-Net Architecture

In [None]:
class ConvBlock(nn.Module):
    """Convolutional Block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU"""
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class AttentionGate(nn.Module):
    """Attention Gate Module"""
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class AttentionUNet(nn.Module):
    """Attention U-Net Architecture"""
    def __init__(self, in_channels=3, out_channels=1):
        super(AttentionUNet, self).__init__()
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, 16)
        self.pool1 = nn.MaxPool2d(2)
        
        self.enc2 = ConvBlock(16, 32)
        self.pool2 = nn.MaxPool2d(2)
        
        self.enc3 = ConvBlock(32, 64)
        self.pool3 = nn.MaxPool2d(2)
        
        self.enc4 = ConvBlock(64, 128)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(128, 256)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att4 = AttentionGate(F_g=128, F_l=128, F_int=64)
        self.dec4 = ConvBlock(256, 128)
        
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att3 = AttentionGate(F_g=64, F_l=64, F_int=32)
        self.dec3 = ConvBlock(128, 64)
        
        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=32, F_l=32, F_int=16)
        self.dec2 = ConvBlock(64, 32)
        
        self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.att1 = AttentionGate(F_g=16, F_l=16, F_int=8)
        self.dec1 = ConvBlock(32, 16)
        
        # Output
        self.out = nn.Conv2d(16, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)
        
        # Bottleneck
        b = self.bottleneck(p4)
        
        # Decoder with Attention
        d4 = self.up4(b)
        e4 = self.att4(d4, e4)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        e3 = self.att3(d3, e3)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        e2 = self.att2(d2, e2)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        e1 = self.att1(d1, e1)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        # Output
        out = self.out(d1)
        return torch.sigmoid(out)


def count_parameters(model):
    """Count trainable parameters in model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Test model
model = AttentionUNet().to(config.DEVICE)
print("‚úÖ Attention U-Net model created!")
print(f"  üìä Total Parameters: {count_parameters(model):,}")
print(f"  üéØ Input Size: {config.IMG_SIZE}x{config.IMG_SIZE}x3")
print(f"  üéØ Output Size: {config.IMG_SIZE}x{config.IMG_SIZE}x1")
print("="*50)

## üìä Step 10: Define Loss Function and Metrics

In [None]:
class DiceLoss(nn.Module):
    """Dice Loss for segmentation"""
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets):
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / \
               (predictions.sum() + targets.sum() + self.smooth)
        
        return 1 - dice


def dice_coefficient(predictions, targets, threshold=0.5, smooth=1.0):
    """Calculate Dice coefficient"""
    predictions = (predictions > threshold).float()
    predictions = predictions.view(-1)
    targets = targets.view(-1)
    
    intersection = (predictions * targets).sum()
    dice = (2. * intersection + smooth) / \
           (predictions.sum() + targets.sum() + smooth)
    
    return dice.item()


def iou_score(predictions, targets, threshold=0.5, smooth=1.0):
    """Calculate IoU (Intersection over Union)"""
    predictions = (predictions > threshold).float()
    predictions = predictions.view(-1)
    targets = targets.view(-1)
    
    intersection = (predictions * targets).sum()
    union = predictions.sum() + targets.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    
    return iou.item()


def precision_score(predictions, targets, threshold=0.5, smooth=1.0):
    """Calculate Precision"""
    predictions = (predictions > threshold).float()
    predictions = predictions.view(-1)
    targets = targets.view(-1)
    
    true_positive = (predictions * targets).sum()
    predicted_positive = predictions.sum()
    precision = (true_positive + smooth) / (predicted_positive + smooth)
    
    return precision.item()


def recall_score(predictions, targets, threshold=0.5, smooth=1.0):
    """Calculate Recall (Sensitivity)"""
    predictions = (predictions > threshold).float()
    predictions = predictions.view(-1)
    targets = targets.view(-1)
    
    true_positive = (predictions * targets).sum()
    actual_positive = targets.sum()
    recall = (true_positive + smooth) / (actual_positive + smooth)
    
    return recall.item()


print("‚úÖ Loss function and metrics defined!")
print("  - Loss: Dice Loss")
print("  - Metrics: Dice, IoU, Precision, Recall")
print("="*50)

## üèãÔ∏è Step 11: Training and Validation Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        dice = dice_coefficient(outputs, masks)
        iou = iou_score(outputs, masks)
        
        # Update running metrics
        running_loss += loss.item()
        running_dice += dice
        running_iou += iou
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'dice': f'{dice:.4f}',
            'iou': f'{iou:.4f}'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    
    return epoch_loss, epoch_dice, epoch_iou


def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    running_precision = 0.0
    running_recall = 0.0
    
    pbar = tqdm(dataloader, desc='Validation')
    
    with torch.no_grad():
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Calculate metrics
            dice = dice_coefficient(outputs, masks)
            iou = iou_score(outputs, masks)
            precision = precision_score(outputs, masks)
            recall = recall_score(outputs, masks)
            
            # Update running metrics
            running_loss += loss.item()
            running_dice += dice
            running_iou += iou
            running_precision += precision
            running_recall += recall
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{dice:.4f}'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_dice = running_dice / len(dataloader)
    epoch_iou = running_iou / len(dataloader)
    epoch_precision = running_precision / len(dataloader)
    epoch_recall = running_recall / len(dataloader)
    
    return epoch_loss, epoch_dice, epoch_iou, epoch_precision, epoch_recall


print("‚úÖ Training and validation functions defined!")
print("="*50)

## üîÑ Step 12: 5-Fold Cross-Validation Training

**This will take 2-4 hours!** ‚òï

In [None]:
# Check if data is loaded
if data_df is None or len(data_df) == 0:
    print("‚ùå Cannot start training - no data loaded!")
    print("Please fix the dataset path issues in the earlier cells.")
else:
    # Initialize K-Fold
    kfold = KFold(n_splits=config.NUM_FOLDS, shuffle=True, random_state=42)
    
    # Storage for results
    fold_results = []
    all_histories = []
    
    print("\n" + "="*70)
    print(" " * 15 + "üöÄ STARTING 5-FOLD CROSS-VALIDATION")
    print("="*70 + "\n")
    
    # K-Fold Cross-Validation
    for fold, (train_idx, val_idx) in enumerate(kfold.split(data_df), 1):
        print(f"\n{'='*70}")
        print(f" " * 25 + f"üìä FOLD {fold}/{config.NUM_FOLDS}")
        print(f"{'='*70}\n")
        
        # Split data
        train_df = data_df.iloc[train_idx]
        val_df = data_df.iloc[val_idx]
        
        print(f"  üì¶ Training samples: {len(train_df)}")
        print(f"  üì¶ Validation samples: {len(val_df)}")
        
        # Create datasets
        train_dataset = KidneyStoneDataset(train_df, transform=train_transform)
        val_dataset = KidneyStoneDataset(val_df, transform=val_transform)
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        # Initialize model
        model = AttentionUNet().to(config.DEVICE)
        criterion = DiceLoss()
        optimizer = Adam(model.parameters(), lr=config.LEARNING_RATE)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=10, verbose=True)
        
        # Training history
        history = {
            'train_loss': [], 'train_dice': [], 'train_iou': [],
            'val_loss': [], 'val_dice': [], 'val_iou': [],
            'val_precision': [], 'val_recall': []
        }
        
        best_dice = 0.0
        patience_counter = 0
        max_patience = 20
        
        print(f"\n  üèãÔ∏è Training started...\n")
        
        # Training loop
        for epoch in range(1, config.NUM_EPOCHS + 1):
            print(f"\nEpoch {epoch}/{config.NUM_EPOCHS}")
            print("-" * 70)
            
            # Train
            train_loss, train_dice, train_iou = train_one_epoch(
                model, train_loader, criterion, optimizer, config.DEVICE
            )
            
            # Validate
            val_loss, val_dice, val_iou, val_precision, val_recall = validate(
                model, val_loader, criterion, config.DEVICE
            )
            
            # Update scheduler
            scheduler.step(val_loss)
            
            # Save history
            history['train_loss'].append(train_loss)
            history['train_dice'].append(train_dice)
            history['train_iou'].append(train_iou)
            history['val_loss'].append(val_loss)
            history['val_dice'].append(val_dice)
            history['val_iou'].append(val_iou)
            history['val_precision'].append(val_precision)
            history['val_recall'].append(val_recall)
            
            # Print metrics
            print(f"\nTrain Loss: {train_loss:.4f} | Train Dice: {train_dice:.4f} | Train IoU: {train_iou:.4f}")
            print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f} | Val IoU: {val_iou:.4f}")
            print(f"Val Precision: {val_precision:.4f} | Val Recall: {val_recall:.4f}")
            
            # Save best model
            if val_dice > best_dice:
                best_dice = val_dice
                patience_counter = 0
                
                if config.SAVE_MODELS:
                    model_path = f"{config.MODEL_DIR}/attention_unet_fold{fold}_best.pth"
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'dice': val_dice,
                        'iou': val_iou,
                    }, model_path)
                    print(f"\n‚úÖ Best model saved! Dice: {best_dice:.4f}")
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= max_patience:
                print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch}")
                break
        
        # Store fold results
        fold_results.append({
            'fold': fold,
            'best_dice': best_dice,
            'final_val_loss': history['val_loss'][-1],
            'final_val_iou': history['val_iou'][-1],
            'final_val_precision': history['val_precision'][-1],
            'final_val_recall': history['val_recall'][-1]
        })
        
        all_histories.append(history)
        
        print(f"\n‚úÖ Fold {fold} completed! Best Dice: {best_dice:.4f}")
        print(f"{'='*70}\n")
        
        # Clear memory
        del model, optimizer, scheduler
        gc.collect()
        torch.cuda.empty_cache()
    
    print("\n" + "="*70)
    print(" " * 20 + "üéâ ALL FOLDS COMPLETED!")
    print("="*70 + "\n")
    
    # Create results dataframe
    results_df = pd.DataFrame(fold_results)
    print("\nüìä Results Summary:")
    print(results_df.to_string(index=False))
    print(f"\nüìà Mean Dice Score: {results_df['best_dice'].mean():.4f} ¬± {results_df['best_dice'].std():.4f}")
    print(f"üìà Mean IoU Score: {results_df['final_val_iou'].mean():.4f} ¬± {results_df['final_val_iou'].std():.4f}")
    print("="*50)

## üìä Step 13: Visualize Training Results

In [None]:
# Plot training curves
if 'all_histories' in locals() and len(all_histories) > 0:
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Training Curves - All Folds', fontsize=16, fontweight='bold')
    
    for fold_idx, history in enumerate(all_histories, 1):
        # Plot Loss
        axes[0, 0].plot(history['train_loss'], label=f'Fold {fold_idx} Train', alpha=0.7)
        axes[0, 0].plot(history['val_loss'], label=f'Fold {fold_idx} Val', alpha=0.7)
        axes[0, 0].set_title('Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend(fontsize=8)
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot Dice
        axes[0, 1].plot(history['train_dice'], label=f'Fold {fold_idx} Train', alpha=0.7)
        axes[0, 1].plot(history['val_dice'], label=f'Fold {fold_idx} Val', alpha=0.7)
        axes[0, 1].set_title('Dice Score')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Dice')
        axes[0, 1].legend(fontsize=8)
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot IoU
        axes[0, 2].plot(history['train_iou'], label=f'Fold {fold_idx} Train', alpha=0.7)
        axes[0, 2].plot(history['val_iou'], label=f'Fold {fold_idx} Val', alpha=0.7)
        axes[0, 2].set_title('IoU Score')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('IoU')
        axes[0, 2].legend(fontsize=8)
        axes[0, 2].grid(True, alpha=0.3)
        
        # Plot Precision
        axes[1, 0].plot(history['val_precision'], label=f'Fold {fold_idx}', alpha=0.7)
        axes[1, 0].set_title('Validation Precision')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Precision')
        axes[1, 0].legend(fontsize=8)
        axes[1, 0].grid(True, alpha=0.3)
        
        # Plot Recall
        axes[1, 1].plot(history['val_recall'], label=f'Fold {fold_idx}', alpha=0.7)
        axes[1, 1].set_title('Validation Recall')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Recall')
        axes[1, 1].legend(fontsize=8)
        axes[1, 1].grid(True, alpha=0.3)
    
    # Bar plot of best dice scores
    axes[1, 2].bar(range(1, config.NUM_FOLDS + 1), results_df['best_dice'], 
                   color='skyblue', edgecolor='navy')
    axes[1, 2].axhline(y=0.9706, color='r', linestyle='--', label='Baseline (97.06%)')
    axes[1, 2].set_title('Best Dice Score per Fold')
    axes[1, 2].set_xlabel('Fold')
    axes[1, 2].set_ylabel('Dice Score')
    axes[1, 2].set_ylim([0.90, 1.0])
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(f'{config.OUTPUT_DIR}/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"‚úÖ Training curves saved!")
else:
    print("‚ö†Ô∏è No training history available")

## üîç Step 14: Visualize Predictions

In [None]:
def visualize_predictions(model_path, df, num_samples=6):
    """Visualize model predictions"""
    if df is None or len(df) == 0:
        print("‚ùå No data available")
        return
    
    # Load model
    model = AttentionUNet().to(config.DEVICE)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Random samples
    num_samples = min(num_samples, len(df))
    indices = np.random.choice(len(df), num_samples, replace=False)
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, num_samples*3))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for idx, sample_idx in enumerate(indices):
            # Load image
            image_path = df.iloc[sample_idx]['image_path']
            mask_path = df.iloc[sample_idx]['mask_path']
            
            image = cv2.imread(image_path)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask = (mask > 0).astype(np.uint8)
            
            # Preprocess
            transformed = val_transform(image=image_rgb, mask=mask)
            image_tensor = transformed['image'].unsqueeze(0).to(config.DEVICE)
            
            # Predict
            pred = model(image_tensor)
            pred = pred.squeeze().cpu().numpy()
            pred_binary = (pred > 0.5).astype(np.uint8)
            
            # Plot
            axes[idx, 0].imshow(image_rgb)
            axes[idx, 0].set_title('Original')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(mask, cmap='gray')
            axes[idx, 1].set_title('Ground Truth')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(pred, cmap='hot')
            axes[idx, 2].set_title('Prediction (Prob)')
            axes[idx, 2].axis('off')
            
            axes[idx, 3].imshow(pred_binary, cmap='gray')
            axes[idx, 3].set_title('Binary Prediction')
            axes[idx, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig(f'{config.OUTPUT_DIR}/predictions.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("‚úÖ Predictions saved!")

if 'results_df' in locals() and len(results_df) > 0:
    best_fold = results_df.loc[results_df['best_dice'].idxmax(), 'fold']
    model_path = f"{config.MODEL_DIR}/attention_unet_fold{int(best_fold)}_best.pth"
    
    if os.path.exists(model_path):
        print(f"\nüîç Visualizing from best fold ({int(best_fold)})...\n")
        visualize_predictions(model_path, data_df, num_samples=6)
    else:
        print("‚ö†Ô∏è Model not found")
else:
    print("‚ö†Ô∏è No results available")

## üìù Step 15: Final Summary

In [None]:
def generate_final_report(fold_results, baseline_dice=0.9706):
    """Generate final report"""
    
    print("\n" + "="*70)
    print(" " * 20 + "üèÜ FINAL RESULTS REPORT")
    print("="*70)
    
    results_df = pd.DataFrame(fold_results)
    
    print("\nüìä PER-FOLD RESULTS:")
    print("-"*70)
    for _, row in results_df.iterrows():
        print(f"  Fold {int(row['fold'])}: Dice = {row['best_dice']:.4f} | "
              f"IoU = {row['final_val_iou']:.4f} | "
              f"Precision = {row['final_val_precision']:.4f} | "
              f"Recall = {row['final_val_recall']:.4f}")
    
    mean_dice = results_df['best_dice'].mean()
    std_dice = results_df['best_dice'].std()
    
    print("\nüìà SUMMARY:")
    print("-"*70)
    print(f"  Mean Dice:   {mean_dice:.4f} ¬± {std_dice:.4f}")
    print(f"  Min Dice:    {results_df['best_dice'].min():.4f}")
    print(f"  Max Dice:    {results_df['best_dice'].max():.4f}")
    
    improvement = (mean_dice - baseline_dice) * 100
    
    print("\nüéØ COMPARISON:")
    print("-"*70)
    print(f"  Baseline:     {baseline_dice:.4f} (97.06%)")
    print(f"  Our Model:    {mean_dice:.4f} ({mean_dice*100:.2f}%)")
    print(f"  Improvement:  {improvement:+.2f}%")
    
    if mean_dice > baseline_dice:
        print(f"\n  ‚úÖ SUCCESS! Beat the baseline! üéâ")
    else:
        print(f"\n  ‚ö†Ô∏è Below baseline")
    
    print("\n" + "="*70 + "\n")

if 'fold_results' in locals():
    generate_final_report(fold_results)
else:
    print("‚ö†Ô∏è No results available")

## üíæ Step 16: Save Results

In [None]:
import pickle

if 'all_histories' in locals() and 'fold_results' in locals():
    # Save histories
    with open(f'{config.OUTPUT_DIR}/training_histories.pkl', 'wb') as f:
        pickle.dump(all_histories, f)
    
    # Save results
    with open(f'{config.OUTPUT_DIR}/results_summary.pkl', 'wb') as f:
        pickle.dump({
            'fold_results': fold_results,
            'mean_dice': results_df['best_dice'].mean(),
            'std_dice': results_df['best_dice'].std()
        }, f)
    
    # Save text file
    with open(f'{config.OUTPUT_DIR}/RESULTS.txt', 'w') as f:
        f.write("ATTENTION U-NET RESULTS\n")
        f.write("="*50 + "\n\n")
        for _, row in results_df.iterrows():
            f.write(f"Fold {int(row['fold'])}: {row['best_dice']:.4f}\n")
        f.write(f"\nMean: {results_df['best_dice'].mean():.4f} ¬± {results_df['best_dice'].std():.4f}\n")
    
    print("‚úÖ All results saved!")
    print(f"  üìÅ Models: {config.MODEL_DIR}/")
    print(f"  üìÅ Outputs: {config.OUTPUT_DIR}/")
else:
    print("‚ö†Ô∏è No results to save")

## üéâ Complete!

Training complete! Check:
- Models in `/kaggle/working/models/`
- Visualizations in `/kaggle/working/outputs/`

**Good luck with your research! üöÄ**