"""
# 🎯 Attention U-Net for Kidney Stone Segmentation on KSSD2025

## 📊 Objective
Beat the baseline Modified U-Net score of **97.06%** using Attention U-Net

## 🎯 Expected Results
- **Target Dice Score:** 97.5% - 98.2%
- **Strategy:** Attention mechanisms for small object detection
- **Architecture:** U-Net + Attention Gates

## 📋 Implementation Plan
1. ✅ Setup & Import Libraries
2. ✅ Load KSSD2025 Dataset
3. ✅ Data Preprocessing & Augmentation
4. ✅ Build Attention U-Net Architecture
5. ✅ Training with 5-Fold Cross-Validation
6. ✅ Evaluation & Visualization
7. ✅ Results Comparison

## 📦 Step 1: Install & Import Required Libraries


In [1]:
import sys
!{sys.executable} -m pip install -q segmentation-models-pytorch albumentations

print("✅ Libraries installed successfully!")
print("="*50)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h✅ Libraries installed successfully!


In [2]:
# Core Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
import gc
import warnings
warnings.filterwarnings('ignore')

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Sklearn
from sklearn.model_selection import KFold

# Image Processing
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Progress Bar
from tqdm.auto import tqdm

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

print("✅ All libraries imported successfully!")
print(f"✅ PyTorch Version: {torch.__version__}")
print(f"✅ CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✅ CUDA Device: {torch.cuda.get_device_name(0)}")
print("="*50)

✅ All libraries imported successfully!
✅ PyTorch Version: 2.8.0+cu126
✅ CUDA Available: True
✅ CUDA Device: Tesla T4


## 📂 Step 2: Configure Dataset Paths

**IMPORTANT:** 

In [19]:
# Configuration
class Config:
    # === DATASET PATHS - CORRECTED FOR YOUR STRUCTURE ===
    # Your dataset has a 'data' subdirectory
    DATA_PATH = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data"
    
    # Image and mask directories
    IMAGE_DIR = f"{DATA_PATH}/image"
    MASK_DIR = f"{DATA_PATH}/label"
    
    # Image Settings
    IMG_SIZE = 256  # Resize images to 256x256
    
    # Training Settings
    BATCH_SIZE = 16
    NUM_EPOCHS = 150
    LEARNING_RATE = 0.001
    NUM_FOLDS = 5
    
    # Model Settings
    ENCODER_CHANNELS = [16, 32, 64, 128]
    DECODER_CHANNELS = [128, 64, 32, 16]
    
    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Save Settings
    SAVE_MODELS = True
    MODEL_DIR = "/kaggle/working/models"
    OUTPUT_DIR = "/kaggle/working/outputs"

config = Config()

# Create necessary directories
Path(config.MODEL_DIR).mkdir(parents=True, exist_ok=True)
Path(config.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print("⚙️ Configuration Settings:")
print(f"  📁 Data Path: {config.DATA_PATH}")
print(f"  📁 Image Dir: {config.IMAGE_DIR}")
print(f"  📁 Mask Dir: {config.MASK_DIR}")
print(f"  🖼️  Image Size: {config.IMG_SIZE}x{config.IMG_SIZE}")
print(f"  📦 Batch Size: {config.BATCH_SIZE}")
print(f"  🔄 Epochs: {config.NUM_EPOCHS}")
print(f"  📊 K-Folds: {config.NUM_FOLDS}")
print(f"  🎯 Device: {config.DEVICE}")
print("="*50)

⚙️ Configuration Settings:
  📁 Data Path: /kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data
  📁 Image Dir: /kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data/images
  📁 Mask Dir: /kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data/masks
  🖼️  Image Size: 256x256
  📦 Batch Size: 16
  🔄 Epochs: 150
  📊 K-Folds: 5
  🎯 Device: cuda


## 🔍 Step 3: Load and Explore Dataset

In [20]:
# Explore the dataset structure
base_path = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset"

print("📁 Exploring dataset structure...\n")

def explore_directory(path, level=0, max_level=3):
    """Recursively explore directory structure"""
    if level > max_level or not os.path.exists(path):
        return
    
    indent = "  " * level
    items = sorted(os.listdir(path))
    
    for item in items[:20]:  # Limit to first 20 items
        item_path = os.path.join(path, item)
        if os.path.isdir(item_path):
            count = len(os.listdir(item_path))
            print(f"{indent}📁 {item}/ ({count} items)")
            if level < 2:  # Only go 2 levels deep
                explore_directory(item_path, level + 1, max_level)
        else:
            print(f"{indent}📄 {item}")
    
    if len(items) > 20:
        print(f"{indent}... and {len(items) - 20} more items")

explore_directory(base_path)
print("\n" + "="*50)

📁 Exploring dataset structure...

📁 data/ (2 items)
  📁 image/ (838 items)
    📄 1.tif
    📄 10.tif
    📄 1000.tif
    📄 1001.tif
    📄 1002.tif
    📄 1003.tif
    📄 1012.tif
    📄 1013.tif
    📄 1014.tif
    📄 1015.tif
    📄 1020.tif
    📄 1021.tif
    📄 1022.tif
    📄 1023.tif
    📄 1024.tif
    📄 1025.tif
    📄 1026.tif
    📄 1027.tif
    📄 1028.tif
    📄 1029.tif
    ... and 818 more items
  📁 label/ (838 items)
    📄 1.tif
    📄 10.tif
    📄 1000.tif
    📄 1001.tif
    📄 1002.tif
    📄 1003.tif
    📄 1012.tif
    📄 1013.tif
    📄 1014.tif
    📄 1015.tif
    📄 1020.tif
    📄 1021.tif
    📄 1022.tif
    📄 1023.tif
    📄 1024.tif
    📄 1025.tif
    📄 1026.tif
    📄 1027.tif
    📄 1028.tif
    📄 1029.tif
    ... and 818 more items



## 📊 Step 4: Load Dataset with Flexible Path Detection

In [21]:
# Function to find images with multiple extensions
def find_images(directory, extensions=['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.PNG', '*.JPEG', '*.tif', '*.TIF', '*.tiff', '*.TIFF']):
    """Find all images in directory with given extensions"""
    all_images = []
    for ext in extensions:
        all_images.extend(glob(os.path.join(directory, ext)))
        # Also search recursively in case images are in subdirectories
        all_images.extend(glob(os.path.join(directory, '**', ext), recursive=True))
    return sorted(list(set(all_images)))  # Remove duplicates and sort

def auto_find_dataset_dirs(base_path):
    """Automatically find image and mask directories"""
    possible_image_dirs = ['images', 'image', 'img', 'train', 'train_images', 'data/images']
    possible_mask_dirs = ['masks', 'mask', 'labels', 'label', 'train_masks', 'data/masks', 'ground_truth', 'gt']
    
    image_dir = None
    mask_dir = None
    
    # Try to find image directory
    for dir_name in possible_image_dirs:
        test_path = os.path.join(base_path, dir_name)
        if os.path.exists(test_path):
            # Check if it has images
            test_images = find_images(test_path)
            if len(test_images) > 0:
                image_dir = test_path
                print(f"✅ Found images in: {dir_name}")
                break
    
    # Try to find mask directory
    for dir_name in possible_mask_dirs:
        test_path = os.path.join(base_path, dir_name)
        if os.path.exists(test_path):
            # Check if it has images
            test_masks = find_images(test_path)
            if len(test_masks) > 0:
                mask_dir = test_path
                print(f"✅ Found masks in: {dir_name}")
                break
    
    return image_dir, mask_dir

# Try to auto-detect
print("🔍 Auto-detecting dataset structure...\n")
detected_image_dir, detected_mask_dir = auto_find_dataset_dirs(base_path)

# Update config if found
if detected_image_dir:
    config.IMAGE_DIR = detected_image_dir
if detected_mask_dir:
    config.MASK_DIR = detected_mask_dir

print(f"\n📁 Using directories:")
print(f"  Images: {config.IMAGE_DIR}")
print(f"  Masks: {config.MASK_DIR}")
print("="*50)

🔍 Auto-detecting dataset structure...


📁 Using directories:
  Images: /kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data/images
  Masks: /kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data/masks


## 📥 Step 5: Load and Match Images with Masks

In [25]:
import os
from glob import glob
import pandas as pd

# ===============================
# CONFIG
# ===============================
class config:
    IMAGE_DIR = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data/image"
    MASK_DIR  = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset/data/label"

# base path only for fallback search
base_path = "/kaggle/input/kssd2025-kidney-stone-segmentation-dataset"

# ===============================
# Helper
# ===============================
def find_images(folder):
    return (
        glob(os.path.join(folder, "*.png")) +
        glob(os.path.join(folder, "*.jpg")) +
        glob(os.path.join(folder, "*.jpeg"))
    )

# Get image and mask paths
print("📥 Loading dataset...\n")

image_paths = find_images(config.IMAGE_DIR)
mask_paths = find_images(config.MASK_DIR)

print(f"📊 Dataset Statistics:")
print(f"  🖼️  Total Images Found: {len(image_paths)}")
print(f"  🎭 Total Masks Found: {len(mask_paths)}")

if len(image_paths) == 0:
    print("\n❌ ERROR: No images found!")
    print("\nLet me search the entire dataset directory...")
    all_images = glob(os.path.join(base_path, '**', '*.jpg'), recursive=True) + \
                 glob(os.path.join(base_path, '**', '*.png'), recursive=True)
    if len(all_images) > 0:
        print(f"\nFound {len(all_images)} images in total. Showing first 10:")
        for img in all_images[:10]:
            print(f"  {img}")
else:
    print("\n✅ Images loaded successfully!")
    print(f"\nFirst 5 image paths:")
    for img in image_paths[:5]:
        print(f"  {img}")

if len(mask_paths) == 0:
    print("\n❌ ERROR: No masks found!")
else:
    print("\n✅ Masks loaded successfully!")
    print(f"\nFirst 5 mask paths:")
    for mask in mask_paths[:5]:
        print(f"  {mask}")

# Match images and masks by filename
if len(image_paths) > 0 and len(mask_paths) > 0:
    # Extract filenames (without extension and path)
    def get_base_name(path):
        return os.path.splitext(os.path.basename(path))[0]
    
    image_dict = {get_base_name(p): p for p in image_paths}
    mask_dict = {get_base_name(p): p for p in mask_paths}
    
    # Find matching pairs
    matched_data = []
    unmatched_images = []
    
    for img_name, img_path in image_dict.items():
        if img_name in mask_dict:
            matched_data.append({
                'image_path': img_path,
                'mask_path': mask_dict[img_name],
                'filename': img_name
            })
        else:
            unmatched_images.append(img_name)
    
    data_df = pd.DataFrame(matched_data)
    
    print(f"\n✅ Matched {len(data_df)} image-mask pairs")
    
    if len(unmatched_images) > 0:
        print(f"⚠️ Warning: {len(unmatched_images)} images without matching masks")
        if len(unmatched_images) <= 5:
            print(f"Unmatched: {unmatched_images}")
    
    if len(data_df) > 0:
        print(f"\n📋 Dataset Preview:")
        print(data_df.head(10))
    else:
        print("\n❌ No matching image-mask pairs found!")
        data_df = None
else:
    print("\n❌ Cannot create dataset - missing images or masks")
    data_df = None

print("\n" + "="*50)

📥 Loading dataset...

📊 Dataset Statistics:
  🖼️  Total Images Found: 0
  🎭 Total Masks Found: 0

❌ ERROR: No images found!

Let me search the entire dataset directory...

❌ ERROR: No masks found!

❌ Cannot create dataset - missing images or masks

