# Medical Imaging & Vesuvius Challenge Training Notebook

**Notebook Version** - Self-contained notebook with duplicated data folder.

This notebook is adapted from `train.py` to run on Kaggle. It supports:

## Medical Image Classification
- **Classification** across multiple medical imaging domains (brain, lungs, skin, breast, bone)
- **Segmentation** with a UNet model

## Vesuvius Challenge (3D CT Volume Segmentation)
- **3D volume segmentation** for ancient scroll CT scans
- **Competition metrics**: Surface Dice, TopoScore, VOI
- **.tif volume I/O** with variable dimension handling
- **2.5D UNet** with multi-slice context

All code is self-contained - no external Python files needed!

## Required Dependencies

**Automatic Installation**: Run cell 0 (Install Required Packages) to automatically install all dependencies.

**Manual Installation** (if needed):
```bash
pip install torch torchvision
pip install tifffile scipy scikit-image
pip install pillow numpy
```

**Note**: The notebook includes fallback handling if some dependencies are missing, but full functionality requires all packages. After installing packages, restart the kernel and run all cells.

## Data Directory Setup

The notebook uses a configurable `DATA_ROOT` directory (set in cell 2) that works across different environments:

- **Local PC**: `DATA_ROOT = 'data'` (default)
- **Kaggle**: Set `DATA_ROOT = '/kaggle/input/vesuvius-challenge-surface-detection'` or use environment variable
- **Google Colab**: Set `DATA_ROOT = '/content/data'` or mount your drive
- **Cloud/Remote**: Set via environment variable: `export DATA_ROOT=/path/to/data`

All data paths in the notebook are relative to `DATA_ROOT`, making it easy to switch between environments.

---

## Citation

If you use this notebook for the Vesuvius Challenge, please cite:

```bibtex
@misc{vesuvius-challenge-surface-detection,
    author = {Sean Johnson and David Josey and Elian Rafael Dal Pr√† and Hendrik Schilling and Youssef Nader and Johannes Rudolph and Forrest McDonald and Paul Henderson and Giorgio Angelotti and Sohier Dane and Mar√≠a Cruz},
    title = {Vesuvius Challenge - Surface Detection},
    year = {2025},
    howpublished = {\url{https://kaggle.com/competitions/vesuvius-challenge-surface-detection}},
    note = {Kaggle}
}
```

**Competition**: [Vesuvius Challenge - Surface Detection](https://kaggle.com/competitions/vesuvius-challenge-surface-detection)


## 0. Install Required Packages

**Run this cell first** to automatically install all required dependencies.

- ‚úÖ **Kaggle**: Most packages are pre-installed, but this will install any missing ones
- ‚úÖ **Google Colab**: Will install all required packages
- ‚úÖ **Local Jupyter**: Will install packages in your current environment
- ‚ö†Ô∏è  **After installation**: Restart the kernel and run all cells for imports to work

You can skip this cell if all packages are already installed.


In [20]:
# Install required packages
# This cell will install all dependencies needed for the notebook
# You can skip this if packages are already installed

import subprocess
import sys

def install_package(pip_name, import_name=None):
    """
    Install a package using pip, with error handling.
    
    Args:
        pip_name: Package name for pip install (e.g., 'scikit-image')
        import_name: Import name if different (e.g., 'skimage'). If None, uses pip_name
    """
    if import_name is None:
        import_name = pip_name
    
    # Check if already installed
    try:
        __import__(import_name)
        print(f"‚úÖ {pip_name} is already installed")
        return True
    except ImportError:
        print(f"üì¶ Installing {pip_name}...")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name, "--quiet"])
            # Try importing again to verify
            try:
                __import__(import_name)
                print(f"‚úÖ {pip_name} installed successfully")
                return True
            except ImportError:
                print(f"‚ö†Ô∏è  {pip_name} installed but import '{import_name}' failed - may need kernel restart")
                return True  # Still return True as package was installed
        except subprocess.CalledProcessError:
            print(f"‚ùå Failed to install {pip_name}")
            return False

# List of required packages: (pip_name, import_name)
# Some packages have different pip names vs import names
required_packages = [
    ("torch", "torch"),
    ("torchvision", "torchvision"),
    ("tifffile", "tifffile"),
    ("scipy", "scipy"),
    ("scikit-image", "skimage"),  # pip: scikit-image, import: skimage
    ("pillow", "PIL"),  # pip: pillow, import: PIL
    ("numpy", "numpy"),
]

print("="*60)
print("Checking and installing required packages...")
print("="*60)

installed_count = 0
for pip_name, import_name in required_packages:
    if install_package(pip_name, import_name):
        installed_count += 1

print("="*60)
print(f"Package installation complete: {installed_count}/{len(required_packages)} packages ready")
print("="*60)
print("\n‚ö†Ô∏è  Note: If packages were just installed, you may need to restart the kernel")
print("   and run the cells again for imports to work properly.")
print("   In Jupyter: Kernel -> Restart & Run All")
print("   In Kaggle/Colab: Runtime -> Restart runtime")


Checking and installing required packages...
‚úÖ torch is already installed
‚úÖ torchvision is already installed
‚úÖ tifffile is already installed
‚úÖ scipy is already installed
üì¶ Installing scikit-image...


‚ùå Failed to install scikit-image
‚úÖ pillow is already installed
‚úÖ numpy is already installed
Package installation complete: 6/7 packages ready

‚ö†Ô∏è  Note: If packages were just installed, you may need to restart the kernel
   and run the cells again for imports to work properly.
   In Jupyter: Kernel -> Restart & Run All
   In Kaggle/Colab: Runtime -> Restart runtime


## 1. Imports and Setup


In [21]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from PIL import ImageFile, Image
import warnings
import random
import numpy as np
from collections import defaultdict
from pathlib import Path
from typing import Tuple, Optional, List

# ============================================================================
# Data Directory Configuration
# ============================================================================
# Set base data directory - can be overridden via environment variable
# This allows the notebook to work on multiple PCs and cloud environments
DATA_ROOT = os.environ.get('DATA_ROOT', 'data')  # Default: 'data' in current directory
# Alternative: Set manually for your environment:
# DATA_ROOT = '/kaggle/input/vesuvius-challenge-surface-detection'  # Kaggle
# DATA_ROOT = '/content/data'  # Google Colab
# DATA_ROOT = './data'  # Local development

print(f"Data root directory: {DATA_ROOT}")
print(f"(Set DATA_ROOT environment variable to change, or edit DATA_ROOT in this cell)")

# Scientific computing imports (for metrics)
try:
    from scipy import ndimage
    from scipy.ndimage import distance_transform_edt, binary_erosion, zoom
    SCIPY_AVAILABLE = True
except ImportError:
    print("Warning: scipy not installed. Some metrics may not work. Install with: pip install scipy")
    SCIPY_AVAILABLE = False

try:
    from skimage import measure, morphology
    SKIMAGE_AVAILABLE = True
except ImportError:
    print("Warning: scikit-image not installed. Some metrics may not work. Install with: pip install scikit-image")
    SKIMAGE_AVAILABLE = False

# Vesuvius Challenge specific imports
try:
    import tifffile
    TIF_FILE_AVAILABLE = True
except ImportError:
    print("Warning: tifffile not installed. Install with: pip install tifffile")
    TIF_FILE_AVAILABLE = False

# Allow loading truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Suppress PIL warnings
warnings.filterwarnings('ignore', category=UserWarning, module='PIL')

# Kaggle/Linux compatibility - use 2 workers on Linux, 0 on Windows
NUM_WORKERS = 0 if sys.platform == 'win32' else 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print dependency status
print("="*60)
print("Dependency Status:")
print("="*60)
print(f"Device: {DEVICE}")
print(f"tifffile: {'‚úÖ Available' if TIF_FILE_AVAILABLE else '‚ùå Not installed (required for Vesuvius Challenge)'}")
print(f"scipy: {'‚úÖ Available' if SCIPY_AVAILABLE else '‚ùå Not installed (required for metrics and resizing)'}")
print(f"scikit-image: {'‚úÖ Available' if SKIMAGE_AVAILABLE else '‚ùå Not installed (required for topology metrics)'}")
print("="*60)

# Warn if critical dependencies missing
missing_deps = []
if not TIF_FILE_AVAILABLE:
    missing_deps.append("tifffile (for Vesuvius Challenge)")
if not SCIPY_AVAILABLE:
    missing_deps.append("scipy (for metrics and resizing)")
if not SKIMAGE_AVAILABLE:
    missing_deps.append("scikit-image (for topology metrics)")

if missing_deps:
    print(f"\n‚ö†Ô∏è  Warning: Missing dependencies: {', '.join(missing_deps)}")
    print("   Install with: pip install " + " ".join(missing_deps))
    print("   Some features may not work correctly.\n")
else:
    print("\n‚úÖ All dependencies available!\n")


Data root directory: data
(Set DATA_ROOT environment variable to change, or edit DATA_ROOT in this cell)
Dependency Status:
Device: cpu
tifffile: ‚úÖ Available
scipy: ‚úÖ Available
scikit-image: ‚ùå Not installed (required for topology metrics)

   Install with: pip install scikit-image (for topology metrics)
   Some features may not work correctly.



## 2. Utility Functions


In [22]:
def set_seed(seed=42):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed for reproducibility
set_seed(42)


## 3. Model Definitions


In [23]:
# ResNet Classifier
class ResNetClassifier(nn.Module):
    """
    ResNet-based classifier for medical image classification.
    Uses a pretrained ResNet18 as backbone with a custom classification head.
    """
    
    def __init__(self, num_classes=4, pretrained=True):
        """
        Args:
            num_classes: Number of output classes
            pretrained: Whether to use pretrained ResNet weights
        """
        super(ResNetClassifier, self).__init__()
        
        # Load pretrained ResNet18
        if pretrained:
            resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        else:
            resnet = models.resnet18(weights=None)
        
        # Remove the final fully connected layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        
        # Get the number of features from ResNet18
        num_features = resnet.fc.in_features
        
        # Custom classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # Extract features
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        # Classification
        x = self.classifier(x)
        return x


In [24]:
# UNet Segmentation Model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels),
        )

    def forward(self, x):
        return self.net(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Pad if needed
        diff_y = x2.size(2) - x1.size(2)
        diff_x = x2.size(3) - x1.size(3)
        x1 = F.pad(
            x1,
            [diff_x // 2, diff_x - diff_x // 2,
             diff_y // 2, diff_y - diff_y // 2],
        )
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class UNet2_5D(nn.Module):
    """2.5D UNet for Vesuvius Challenge - processes multi-slice context"""
    def __init__(self, in_channels=3, out_channels=1):
        """
        Args:
            in_channels: Number of input channels (slice_window for 2.5D)
            out_channels: Number of output channels (1 for binary segmentation)
        """
        super().__init__()
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
        
        # Optional: Add attention mechanism for slice importance
        self.slice_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: [B, C, H, W] where C = slice_window
        # Apply slice attention if multi-channel
        if x.shape[1] > 1:
            attn = self.slice_attention(x)
            x = x * attn
        
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [25]:
class SegmentationDataset(Dataset):
    """Simple image/mask dataset for segmentation.
    Expects matching filenames in images_dir and masks_dir.
    """

    def __init__(self, images_dir, masks_dir, input_size=(256, 256)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.input_size = input_size
        self.image_files = sorted([
            f for f in os.listdir(images_dir)
            if os.path.isfile(os.path.join(images_dir, f))
        ])
        self.img_transform = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        fname = self.image_files[idx]
        img_path = os.path.join(self.images_dir, fname)
        mask_path = os.path.join(self.masks_dir, fname)

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.img_transform(image)
        mask = self.mask_transform(mask)
        # binarize mask
        mask = (mask > 0.5).float()

        return image, mask


class VesuviusVolumeDataset(Dataset):
    """Dataset for 3D CT volumes from Vesuvius Challenge.
    Handles variable dimensions and .tif volume files.
    """
    
    def __init__(
        self,
        volume_dir: str,
        mask_dir: Optional[str] = None,
        mode: str = "train",
        slice_window: int = 3,  # 2.5D: use N slices as context
        target_size: Optional[Tuple[int, int]] = None,
        augment: bool = False
    ):
        """
        Args:
            volume_dir: Directory containing .tif volume files
            mask_dir: Directory containing .tif mask files (optional for test)
            mode: 'train', 'val', or 'test'
            slice_window: Number of slices to use as context (for 2.5D)
            target_size: Optional (H, W) to resize to, None keeps original size
            augment: Whether to apply data augmentation
        """
        self.volume_dir = Path(volume_dir)
        self.mask_dir = Path(mask_dir) if mask_dir else None
        self.mode = mode
        self.slice_window = slice_window
        self.target_size = target_size
        self.augment = augment
        
        # Find all volume files
        self.volume_files = sorted(self.volume_dir.glob("*.tif"))
        if not self.volume_files:
            self.volume_files = sorted(self.volume_dir.glob("*.tiff"))
        
        if not self.volume_files:
            raise ValueError(f"No .tif/.tiff files found in {volume_dir}")
        
        print(f"Found {len(self.volume_files)} volume files")
        
        # Pre-load volume metadata
        self.volumes_info = []
        for vol_path in self.volume_files:
            if TIF_FILE_AVAILABLE:
                try:
                    # Just read shape without loading full volume
                    with tifffile.TiffFile(vol_path) as tif:
                        shape = tif.series[0].shape
                        self.volumes_info.append({
                            'path': vol_path,
                            'shape': shape,
                            'num_slices': shape[0] if len(shape) == 3 else 1
                        })
                except Exception as e:
                    print(f"Warning: Could not read {vol_path}: {e}")
            else:
                # Fallback: assume standard shape
                self.volumes_info.append({
                    'path': vol_path,
                    'shape': None,
                    'num_slices': 100  # Default estimate
                })
    
    def __len__(self):
        """Total number of slices across all volumes"""
        return sum(info['num_slices'] for info in self.volumes_info)
    
    def _get_volume_and_slice_idx(self, idx):
        """Convert global index to (volume_idx, slice_idx)"""
        current = 0
        for vol_idx, info in enumerate(self.volumes_info):
            if idx < current + info['num_slices']:
                slice_idx = idx - current
                return vol_idx, slice_idx
            current += info['num_slices']
        # Fallback to last volume
        return len(self.volumes_info) - 1, self.volumes_info[-1]['num_slices'] - 1
    
    def _load_slice_with_context(self, volume_path: Path, slice_idx: int) -> np.ndarray:
        """Load a slice with surrounding context slices (2.5D)"""
        if not TIF_FILE_AVAILABLE:
            raise ImportError("tifffile required for volume loading")
        
        volume = tifffile.imread(str(volume_path))
        num_slices = volume.shape[0]
        
        # Get slice range with padding
        half_window = self.slice_window // 2
        start_idx = max(0, slice_idx - half_window)
        end_idx = min(num_slices, slice_idx + half_window + 1)
        
        # Extract slices
        slices = volume[start_idx:end_idx]
        
        # Pad if needed
        if len(slices) < self.slice_window:
            padding = self.slice_window - len(slices)
            if start_idx == 0:
                # Pad at beginning
                slices = np.concatenate([slices[:1].repeat(padding, axis=0), slices], axis=0)
            else:
                # Pad at end
                slices = np.concatenate([slices, slices[-1:].repeat(padding, axis=0)], axis=0)
        
        # Stack slices: [C, H, W] where C = slice_window
        if len(slices.shape) == 2:
            slices = slices[np.newaxis, :]
        
        # Take center slice if we have more than needed
        if slices.shape[0] > self.slice_window:
            center = slices.shape[0] // 2
            start = center - half_window
            slices = slices[start:start + self.slice_window]
        
        # Combine slices into channels: [H, W, C] -> [C, H, W]
        if slices.shape[0] == 1:
            slice_data = slices[0]
        else:
            slice_data = np.transpose(slices, (1, 2, 0))  # [H, W, C]
            slice_data = np.transpose(slice_data, (2, 0, 1))  # [C, H, W]
        
        return slice_data.astype(np.float32) / 255.0
    
    def __getitem__(self, idx):
        vol_idx, slice_idx = self._get_volume_and_slice_idx(idx)
        vol_info = self.volumes_info[vol_idx]
        
        # Load slice with context
        slice_data = self._load_slice_with_context(vol_info['path'], slice_idx)
        
        # Load corresponding mask if available
        mask = None
        if self.mask_dir and self.mode != "test":
            mask_path = self.mask_dir / vol_info['path'].name
            if mask_path.exists():
                try:
                    mask_volume = tifffile.imread(str(mask_path))
                    mask = mask_volume[slice_idx].astype(np.float32)
                    mask = (mask > 0.5).astype(np.float32)
                except Exception as e:
                    print(f"Warning: Could not load mask {mask_path}: {e}")
                    mask = np.zeros_like(slice_data[0] if len(slice_data.shape) == 3 else slice_data)
        
        # Resize if target_size specified
        if self.target_size:
            if not SCIPY_AVAILABLE:
                raise ImportError("scipy required for resizing. Install with: pip install scipy")
            h, w = self.target_size
            if len(slice_data.shape) == 3:
                # Multi-channel
                current_h, current_w = slice_data.shape[1], slice_data.shape[2]
                zoom_factors = (1, h / current_h, w / current_w)
                slice_data = zoom(slice_data, zoom_factors, order=1)
                if mask is not None:
                    mask = zoom(mask, (h / current_h, w / current_w), order=0)
            else:
                # Single channel
                current_h, current_w = slice_data.shape
                slice_data = zoom(slice_data, (h / current_h, w / current_w), order=1)
                if mask is not None:
                    mask = zoom(mask, (h / current_h, w / current_w), order=0)
        
        # Convert to torch tensors
        slice_tensor = torch.from_numpy(slice_data).float()
        if mask is not None:
            mask_tensor = torch.from_numpy(mask).float().unsqueeze(0)
        else:
            mask_tensor = torch.zeros(slice_tensor.shape[-2:]).float().unsqueeze(0)
        
        return {
            'image': slice_tensor,
            'mask': mask_tensor,
            'volume_id': vol_info['path'].stem,
            'slice_idx': slice_idx
        }


In [26]:
def dice_loss_logits(logits, targets, eps=1e-6):
    """Dice loss for binary segmentation given logits."""
    probs = torch.sigmoid(logits)
    num = 2 * (probs * targets).sum(dim=(2, 3))
    den = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + eps
    return 1 - (num / den).mean()


# ============================================================================
# Vesuvius Challenge Competition Metrics
# ============================================================================

def surface_dice_score(pred: np.ndarray, target: np.ndarray, tolerance: float = 1.0) -> float:
    """
    Surface Dice Score - measures distance between surfaces.
    
    Args:
        pred: Binary prediction mask [H, W] or [Z, H, W]
        target: Binary target mask [H, W] or [Z, H, W]
        tolerance: Distance tolerance in pixels
    
    Returns:
        Surface Dice score (0-1, higher is better)
    """
    pred = (pred > 0.5).astype(np.uint8)
    target = (target > 0.5).astype(np.uint8)
    
    if pred.ndim == 3:
        # 3D volume: compute per-slice and average
        scores = []
        for z in range(pred.shape[0]):
            if target[z].sum() > 0 or pred[z].sum() > 0:
                score = surface_dice_score(pred[z], target[z], tolerance)
                scores.append(score)
        return np.mean(scores) if scores else 0.0
    
    # 2D: compute surface distances
    if not SCIPY_AVAILABLE:
        # Fallback to simple Dice if scipy not available
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()
        return float(2 * intersection / union) if union > 0 else 0.0
    
    # Get surfaces (boundaries)
    pred_surface = pred - binary_erosion(pred)
    target_surface = target - binary_erosion(target)
    
    if pred_surface.sum() == 0 and target_surface.sum() == 0:
        return 1.0
    
    # Distance from pred surface to target surface
    dist_pred_to_target = distance_transform_edt(~target_surface)
    dist_target_to_pred = distance_transform_edt(~pred_surface)
    
    # Count surface points within tolerance
    pred_in_tolerance = (dist_pred_to_target[pred_surface > 0] <= tolerance).sum()
    target_in_tolerance = (dist_target_to_pred[target_surface > 0] <= tolerance).sum()
    
    total_pred_surface = pred_surface.sum()
    total_target_surface = target_surface.sum()
    
    if total_pred_surface == 0 or total_target_surface == 0:
        return 0.0
    
    surface_dice = (pred_in_tolerance + target_in_tolerance) / (total_pred_surface + total_target_surface)
    return float(surface_dice)


def topo_score(pred: np.ndarray, target: np.ndarray) -> float:
    """
    TopoScore - measures topological similarity.
    Penalizes artificial mergers and splits.
    
    Args:
        pred: Binary prediction mask [H, W] or [Z, H, W]
        target: Binary target mask [H, W] or [Z, H, W]
    
    Returns:
        TopoScore (0-1, higher is better)
    """
    pred = (pred > 0.5).astype(np.uint8)
    target = (target > 0.5).astype(np.uint8)
    
    if pred.ndim == 3:
        # 3D: compute per-slice and average
        scores = []
        for z in range(pred.shape[0]):
            if target[z].sum() > 0 or pred[z].sum() > 0:
                score = topo_score(pred[z], target[z])
                scores.append(score)
        return np.mean(scores) if scores else 0.0
    
    # 2D: count connected components
    if not SKIMAGE_AVAILABLE:
        # Fallback: simple component counting
        return 1.0 if (pred == target).all() else 0.5
    
    pred_labels = measure.label(pred)
    target_labels = measure.label(target)
    
    num_pred_components = pred_labels.max()
    num_target_components = target_labels.max()
    
    if num_target_components == 0:
        return 1.0 if num_pred_components == 0 else 0.0
    
    # Penalize difference in number of components
    component_ratio = min(num_pred_components, num_target_components) / max(num_pred_components, num_target_components)
    
    return float(component_ratio)


def voi_score(pred: np.ndarray, target: np.ndarray) -> float:
    """
    Variation of Information (VOI) - measures information distance.
    
    Args:
        pred: Binary prediction mask [H, W] or [Z, H, W]
        target: Binary target mask [H, W] or [Z, H, W]
    
    Returns:
        Normalized VOI score (0-1, higher is better, lower VOI is better)
    """
    pred = (pred > 0.5).astype(np.uint8)
    target = (target > 0.5).astype(np.uint8)
    
    if pred.ndim == 3:
        # 3D: compute per-slice and average
        scores = []
        for z in range(pred.shape[0]):
            if target[z].sum() > 0 or pred[z].sum() > 0:
                score = voi_score(pred[z], target[z])
                scores.append(score)
        return np.mean(scores) if scores else 0.0
    
    # 2D: compute VOI
    if not SKIMAGE_AVAILABLE:
        # Fallback: simple similarity
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()
        return float(intersection / union) if union > 0 else 0.0
    
    pred_labels = measure.label(pred)
    target_labels = measure.label(target)
    
    # Get unique labels
    pred_unique = np.unique(pred_labels)
    target_unique = np.unique(target_labels)
    
    # Remove background (0)
    pred_unique = pred_unique[pred_unique > 0]
    target_unique = target_unique[target_unique > 0]
    
    if len(pred_unique) == 0 and len(target_unique) == 0:
        return 1.0
    
    # Compute joint histogram
    h, w = pred.shape
    total_pixels = h * w
    
    # Compute entropy
    pred_entropy = 0.0
    for label in pred_unique:
        p = (pred_labels == label).sum() / total_pixels
        if p > 0:
            pred_entropy -= p * np.log2(p)
    
    target_entropy = 0.0
    for label in target_unique:
        p = (target_labels == label).sum() / total_pixels
        if p > 0:
            target_entropy -= p * np.log2(p)
    
    # Compute mutual information
    mi = 0.0
    for p_label in pred_unique:
        for t_label in target_unique:
            joint = ((pred_labels == p_label) & (target_labels == t_label)).sum() / total_pixels
            p_pred = (pred_labels == p_label).sum() / total_pixels
            p_target = (target_labels == t_label).sum() / total_pixels
            if joint > 0 and p_pred > 0 and p_target > 0:
                mi += joint * np.log2(joint / (p_pred * p_target))
    
    # VOI = H(pred) + H(target) - 2*MI
    voi = pred_entropy + target_entropy - 2 * mi
    
    # Normalize to [0, 1] (higher is better)
    max_entropy = max(pred_entropy, target_entropy) * 2
    normalized_voi = 1.0 - (voi / max_entropy) if max_entropy > 0 else 1.0
    
    return float(np.clip(normalized_voi, 0.0, 1.0))


def compute_competition_metrics(pred: np.ndarray, target: np.ndarray) -> dict:
    """
    Compute all Vesuvius Challenge competition metrics.
    
    Args:
        pred: Prediction mask [H, W] or [Z, H, W]
        target: Target mask [H, W] or [Z, H, W]
    
    Returns:
        Dictionary with metrics: surface_dice, topo_score, voi_score, combined_score
    """
    surface_dice = surface_dice_score(pred, target)
    topo = topo_score(pred, target)
    voi = voi_score(pred, target)
    
    # Combined score (weighted average as per competition)
    # Weights may vary - using equal weights as default
    combined = (surface_dice * 0.4 + topo * 0.3 + voi * 0.3)
    
    return {
        'surface_dice': surface_dice,
        'topo_score': topo,
        'voi_score': voi,
        'combined_score': combined
    }


## 6. Configuration


In [27]:
# Classification configuration per domain
# All paths are relative to DATA_ROOT (set above)
CONFIG = {
    "brain": {
        "train_path": os.path.join(DATA_ROOT, "brain/brain/Training"),
        "val_path": os.path.join(DATA_ROOT, "brain/brain/Testing"),
        "test_path": os.path.join(DATA_ROOT, "brain/brain/Testing"),  # Using Testing as test set
        "model": ResNetClassifier(num_classes=4),
        "input_size": (224, 224),
    },
    "lungs": {
        "train_path": os.path.join(DATA_ROOT, "lungs/train"),
        "val_path": os.path.join(DATA_ROOT, "lungs/val"),
        "test_path": os.path.join(DATA_ROOT, "lungs/val"),  # Using val as test if no separate test set
        "model": ResNetClassifier(num_classes=5),
        "input_size": (224, 224),
    },
    "skin": {
        "train_path": os.path.join(DATA_ROOT, "skin/train"),
        "val_path": os.path.join(DATA_ROOT, "skin/val"),
        "test_path": os.path.join(DATA_ROOT, "skin/val"),  # Using val as test if no separate test set
        # Skin dataset has 9 classes (including Vascular lesion)
        "model": ResNetClassifier(num_classes=9),
        "input_size": (128, 128),
    },
    "breast": {
        "train_path": os.path.join(DATA_ROOT, "breast_split/train") if os.path.exists(os.path.join(DATA_ROOT, "breast_split/train")) else os.path.join(DATA_ROOT, "breast"),
        "val_path": os.path.join(DATA_ROOT, "breast_split/val") if os.path.exists(os.path.join(DATA_ROOT, "breast_split/val")) else os.path.join(DATA_ROOT, "breast"),
        "test_path": os.path.join(DATA_ROOT, "breast_split/test") if os.path.exists(os.path.join(DATA_ROOT, "breast_split/test")) else os.path.join(DATA_ROOT, "breast"),
        "model": ResNetClassifier(num_classes=3),
        "input_size": (128, 128),
    },
    "bone": {
        "train_path": os.path.join(DATA_ROOT, "Bone/train"),
        "val_path": os.path.join(DATA_ROOT, "Bone/val"),
        "test_path": os.path.join(DATA_ROOT, "Bone/test"),
        "model": ResNetClassifier(num_classes=2),
        "input_size": (224, 224),
    },
}

print("Configuration loaded for domains:", list(CONFIG.keys()))
print(f"Data root: {DATA_ROOT}")


Configuration loaded for domains: ['brain', 'lungs', 'skin', 'breast', 'bone']
Data root: data


## 7. Training Functions


In [28]:
def train_model(name, config):
    print(f"\n{'='*60}")
    print(f"Training on: {name.upper()}")
    print(f"{'='*60}")

    # Use simple torchvision transforms for training (with augmentation)
    train_transform = transforms.Compose([
        transforms.Resize(config["input_size"]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    
    # Use simple transforms for validation and testing (no augmentation)
    eval_transform = transforms.Compose([
        transforms.Resize(config["input_size"]),
        transforms.ToTensor(),
    ])

    # Load datasets
    train_dir = config["train_path"]
    val_dir = config["val_path"]
    test_dir = config.get("test_path", val_dir)  # Use test_path if available, else val

    # Check if directories exist
    if not os.path.exists(train_dir):
        print(f"[ERROR] Training directory not found: {train_dir}")
        return
    
    if not os.path.exists(val_dir):
        print(f"[WARNING] Validation directory not found: {val_dir}, using train directory")
        val_dir = train_dir
    
    if not os.path.exists(test_dir):
        print(f"[WARNING] Test directory not found: {test_dir}, using validation directory")
        test_dir = val_dir

    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    val_dataset = datasets.ImageFolder(val_dir, transform=eval_transform)
    test_dataset = datasets.ImageFolder(test_dir, transform=eval_transform)
    
    print(f"\nDataset Information:")
    print(f"   Training samples: {len(train_dataset)}")
    print(f"   Validation samples: {len(val_dataset)}")
    print(f"   Testing samples: {len(test_dataset)}")
    print(f"   Classes: {train_dataset.classes}")
    print(f"   Number of classes: {len(train_dataset.classes)}")

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                              num_workers=NUM_WORKERS,
                              pin_memory=True if torch.cuda.is_available() else False)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                            num_workers=NUM_WORKERS,
                            pin_memory=True if torch.cuda.is_available() else False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                             num_workers=NUM_WORKERS,
                             pin_memory=True if torch.cuda.is_available() else False)

    # Model setup
    model = config["model"].to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    num_epochs = 5
    print(f"\nStarting training for {num_epochs} epochs...")
    print(f"Device: {DEVICE}")
    print(f"Batch size: 32\n")
    
    for epoch in range(1, num_epochs + 1):
        # Training phase
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        num_batches = 0

        for batch_idx, (imgs, labels) in enumerate(train_loader):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            num_batches += 1
            
            # Print progress every 50 batches
            if (batch_idx + 1) % 50 == 0:
                current_acc = 100. * correct / total
                print(f"  Batch {batch_idx + 1}/{len(train_loader)}: Loss={running_loss/num_batches:.4f}, Acc={current_acc:.2f}%")

        train_acc = 100. * correct / total
        
        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        val_acc = 100. * val_correct / val_total if val_total > 0 else 0.0
        print(f"Epoch {epoch}: Train Loss={running_loss:.4f} | Train Acc={train_acc:.2f}% | Val Loss={val_loss:.4f} | Val Acc={val_acc:.2f}%")

    # Final test evaluation
    print(f"\n{'='*60}")
    print(f"Final Test Evaluation")
    print(f"{'='*60}")
    model.eval()
    test_loss, test_correct, test_total = 0.0, 0, 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            test_correct += (preds == labels).sum().item()
            test_total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_acc = 100. * test_correct / test_total if test_total > 0 else 0.0
    print(f"Test Results:")
    print(f"   Test Loss: {test_loss:.4f}")
    print(f"   Test Accuracy: {test_acc:.2f}%")
    print(f"   Correct: {test_correct}/{test_total}")
    
    # Calculate per-class accuracy
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    for pred, label in zip(all_preds, all_labels):
        class_total[label] += 1
        if pred == label:
            class_correct[label] += 1
    
    print(f"\nPer-Class Accuracy:")
    for i, class_name in enumerate(test_dataset.classes):
        if class_total[i] > 0:
            acc = 100. * class_correct[i] / class_total[i]
            print(f"   {class_name}: {acc:.2f}% ({class_correct[i]}/{class_total[i]})")

    # Save model weights into central 'trained' folder
    os.makedirs("trained", exist_ok=True)
    save_path = os.path.join("trained", f"{name.lower()}_model.pth")
    torch.save(model.state_dict(), save_path)
    print(f"\n[SUCCESS] Model saved: {save_path}")
    print(f"{'='*60}\n")


## 9. Vesuvius Challenge: Volume Output Generation


In [29]:
def predict_volume(model, volume_path: str, output_path: str, 
                   slice_window: int = 3, device: str = "cuda",
                   batch_size: int = 8) -> np.ndarray:
    """
    Generate predictions for an entire 3D volume and save as .tif.
    
    Args:
        model: Trained segmentation model
        volume_path: Path to input .tif volume
        output_path: Path to save output .tif mask
        slice_window: Number of slices for 2.5D context
        device: Device to run inference on
        batch_size: Batch size for inference
    
    Returns:
        Predicted volume mask as numpy array
    """
    if not TIF_FILE_AVAILABLE:
        raise ImportError("tifffile required for volume I/O")
    
    model.eval()
    model = model.to(device)
    
    # Load volume
    print(f"Loading volume: {volume_path}")
    volume = tifffile.imread(volume_path)
    print(f"Volume shape: {volume.shape}")
    
    num_slices, height, width = volume.shape
    predictions = np.zeros((num_slices, height, width), dtype=np.float32)
    
    # Process slices in batches
    half_window = slice_window // 2
    
    with torch.no_grad():
        for slice_idx in range(num_slices):
            # Get slice range with context
            start_idx = max(0, slice_idx - half_window)
            end_idx = min(num_slices, slice_idx + half_window + 1)
            
            # Extract slices
            slices = volume[start_idx:end_idx].astype(np.float32) / 255.0
            
            # Pad if needed
            if len(slices) < slice_window:
                padding = slice_window - len(slices)
                if start_idx == 0:
                    slices = np.concatenate([slices[:1].repeat(padding, axis=0), slices], axis=0)
                else:
                    slices = np.concatenate([slices, slices[-1:].repeat(padding, axis=0)], axis=0)
            
            # Take center slice if we have more than needed
            if len(slices) > slice_window:
                center = len(slices) // 2
                start = center - half_window
                slices = slices[start:start + slice_window]
            
            # Convert to tensor: [C, H, W]
            if len(slices.shape) == 2:
                slice_tensor = torch.from_numpy(slices).float().unsqueeze(0)
            else:
                slice_tensor = torch.from_numpy(slices).float()
                if slice_tensor.shape[0] != slice_window:
                    # Reshape if needed
                    slice_tensor = slice_tensor.transpose(1, 2).transpose(0, 1)
            
            slice_tensor = slice_tensor.unsqueeze(0).to(device)  # [1, C, H, W]
            
            # Predict
            logits = model(slice_tensor)
            pred = torch.sigmoid(logits).cpu().numpy()[0, 0]  # [H, W]
            
            predictions[slice_idx] = pred
            
            if (slice_idx + 1) % 10 == 0:
                print(f"Processed {slice_idx + 1}/{num_slices} slices")
    
    # Binarize predictions
    predictions_binary = (predictions > 0.5).astype(np.uint8) * 255
    
    # Save as .tif
    print(f"Saving predictions to: {output_path}")
    tifffile.imwrite(output_path, predictions_binary)
    
    return predictions_binary


def generate_submission(model, test_volume_dir: str, output_dir: str,
                       slice_window: int = 3, device: str = "cuda"):
    """
    Generate submission files for all test volumes.
    
    Args:
        model: Trained segmentation model
        test_volume_dir: Directory containing test .tif volumes
        output_dir: Directory to save submission .tif files
        slice_window: Number of slices for 2.5D context
        device: Device to run inference on
    """
    os.makedirs(output_dir, exist_ok=True)
    
    volume_dir = Path(test_volume_dir)
    volume_files = sorted(volume_dir.glob("*.tif"))
    if not volume_files:
        volume_files = sorted(volume_dir.glob("*.tiff"))
    
    print(f"Found {len(volume_files)} test volumes")
    
    for vol_path in volume_files:
        image_id = vol_path.stem
        output_path = Path(output_dir) / f"{image_id}.tif"
        
        print(f"\nProcessing {image_id}...")
        try:
            predict_volume(model, str(vol_path), str(output_path), 
                          slice_window=slice_window, device=device)
            print(f"‚úÖ Saved: {output_path}")
        except Exception as e:
            print(f"‚ùå Error processing {image_id}: {e}")
    
    print(f"\n‚úÖ Submission files saved to: {output_dir}")


## 10. Vesuvius Challenge Training Function


In [30]:
def train_vesuvius_challenge(
    volume_dir: str,
    mask_dir: str,
    output_dir: str = "trained",
    slice_window: int = 3,
    epochs: int = 20,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    target_size: Optional[Tuple[int, int]] = None
):
    """
    Train model for Vesuvius Challenge 3D volume segmentation.
    
    Args:
        volume_dir: Directory containing training .tif volumes
        mask_dir: Directory containing training .tif masks
        output_dir: Directory to save trained model
        slice_window: Number of slices for 2.5D context (3, 5, or 7)
        epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        target_size: Optional (H, W) to resize to, None keeps original size
    """
    if not TIF_FILE_AVAILABLE:
        raise ImportError("tifffile required. Install with: pip install tifffile")
    
    print(f"\n{'='*60}")
    print("Vesuvius Challenge Training")
    print(f"{'='*60}")
    print(f"Volume directory: {volume_dir}")
    print(f"Mask directory: {mask_dir}")
    print(f"Slice window: {slice_window}")
    print(f"Target size: {target_size}")
    print(f"Epochs: {epochs}")
    print(f"Batch size: {batch_size}")
    
    # Create dataset
    dataset = VesuviusVolumeDataset(
        volume_dir=volume_dir,
        mask_dir=mask_dir,
        mode="train",
        slice_window=slice_window,
        target_size=target_size,
        augment=True
    )
    
    if len(dataset) == 0:
        print("[ERROR] No samples found in dataset")
        return
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Create model
    model = UNet2_5D(in_channels=slice_window, out_channels=1).to(DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    print(f"\nModel: UNet2.5D (in_channels={slice_window})")
    print(f"Dataset size: {len(dataset)} samples")
    print(f"Device: {DEVICE}")
    
    os.makedirs(output_dir, exist_ok=True)
    best_score = 0.0
    
    for epoch in range(1, epochs + 1):
        # Training phase
        model.train()
        epoch_loss = 0.0
        all_preds = []
        all_targets = []
        
        for batch_idx, batch in enumerate(loader):
            images = batch['image'].to(DEVICE)  # [B, C, H, W]
            masks = batch['mask'].to(DEVICE)     # [B, 1, H, W]
            
            optimizer.zero_grad()
            logits = model(images)
            
            # Combined loss: BCE + Dice
            bce_loss = criterion(logits, masks)
            dice_loss = dice_loss_logits(logits, masks)
            loss = 0.5 * bce_loss + 0.5 * dice_loss
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # Store predictions for metrics
            preds = torch.sigmoid(logits).cpu().numpy()
            targets = masks.cpu().numpy()
            all_preds.append(preds)
            all_targets.append(targets)
            
            if (batch_idx + 1) % 10 == 0:
                print(f"  Batch {batch_idx + 1}/{len(loader)}: Loss={loss.item():.4f}")
        
        avg_loss = epoch_loss / len(loader)
        
        # Compute competition metrics
        all_preds = np.concatenate(all_preds, axis=0)
        all_targets = np.concatenate(all_targets, axis=0)
        
        # Average metrics across batch
        metrics_list = []
        for i in range(len(all_preds)):
            pred = all_preds[i, 0]  # [H, W]
            target = all_targets[i, 0]  # [H, W]
            metrics = compute_competition_metrics(pred, target)
            metrics_list.append(metrics)
        
        avg_metrics = {
            'surface_dice': np.mean([m['surface_dice'] for m in metrics_list]),
            'topo_score': np.mean([m['topo_score'] for m in metrics_list]),
            'voi_score': np.mean([m['voi_score'] for m in metrics_list]),
            'combined_score': np.mean([m['combined_score'] for m in metrics_list])
        }
        
        scheduler.step(avg_loss)
        
        print(f"\nEpoch {epoch}/{epochs}:")
        print(f"  Loss: {avg_loss:.4f}")
        print(f"  Surface Dice: {avg_metrics['surface_dice']:.4f}")
        print(f"  TopoScore: {avg_metrics['topo_score']:.4f}")
        print(f"  VOI Score: {avg_metrics['voi_score']:.4f}")
        print(f"  Combined Score: {avg_metrics['combined_score']:.4f}")
        
        # Save best model
        if avg_metrics['combined_score'] > best_score:
            best_score = avg_metrics['combined_score']
            model_path = os.path.join(output_dir, "vesuvius_unet2.5d_best.pth")
            torch.save(model.state_dict(), model_path)
            print(f"  [BEST] Saved model to {model_path}")
    
    print(f"\n{'='*60}")
    print(f"Training finished. Best combined score: {best_score:.4f}")
    print(f"{'='*60}\n")
    
    return model


## 11. Run Vesuvius Challenge Training

Configure paths and run training:


In [31]:
# ============================================================================
# Vesuvius Challenge Configuration
# ============================================================================
# Update these paths to match your data directory structure
# All paths are relative to DATA_ROOT (set in cell 2)

VESUVIUS_VOLUME_DIR = os.path.join(DATA_ROOT, "vesuvius/train/volumes")      # Directory with .tif volumes
VESUVIUS_MASK_DIR = os.path.join(DATA_ROOT, "vesuvius/train/masks")         # Directory with .tif masks
VESUVIUS_SLICE_WINDOW = 3                                # 2.5D: use 3 slices (center + 1 on each side)
                                                          # Options: 3, 5, or 7 (more slices = more context but slower)
VESUVIUS_EPOCHS = 20
VESUVIUS_BATCH_SIZE = 4                                  # Reduce if OOM errors occur
VESUVIUS_TARGET_SIZE = None                              # None = keep original size, or (256, 256) to resize
VESUVIUS_LEARNING_RATE = 1e-4
VESUVIUS_OUTPUT_DIR = "trained"                          # Directory to save trained models

# Check if dependencies are available
if not TIF_FILE_AVAILABLE:
    print("‚ùå ERROR: tifffile is required for Vesuvius Challenge training!")
    print("   Install with: pip install tifffile")
elif not SCIPY_AVAILABLE:
    print("‚ö†Ô∏è  WARNING: scipy is recommended for resizing and metrics")
    print("   Install with: pip install scipy")
elif not SKIMAGE_AVAILABLE:
    print("‚ö†Ô∏è  WARNING: scikit-image is recommended for topology metrics")
    print("   Install with: pip install scikit-image")
else:
    print("‚úÖ All dependencies available for Vesuvius Challenge training")
    print(f"\nConfiguration:")
    print(f"  Volume directory: {VESUVIUS_VOLUME_DIR}")
    print(f"  Mask directory: {VESUVIUS_MASK_DIR}")
    print(f"  Slice window: {VESUVIUS_SLICE_WINDOW}")
    print(f"  Epochs: {VESUVIUS_EPOCHS}")
    print(f"  Batch size: {VESUVIUS_BATCH_SIZE}")
    print(f"  Target size: {VESUVIUS_TARGET_SIZE}")
    print(f"\nTo start training, uncomment the code below:\n")

# Uncomment to run Vesuvius Challenge training:
# try:
#     model = train_vesuvius_challenge(
#         volume_dir=VESUVIUS_VOLUME_DIR,
#         mask_dir=VESUVIUS_MASK_DIR,
#         output_dir=VESUVIUS_OUTPUT_DIR,
#         slice_window=VESUVIUS_SLICE_WINDOW,
#         epochs=VESUVIUS_EPOCHS,
#         batch_size=VESUVIUS_BATCH_SIZE,
#         learning_rate=VESUVIUS_LEARNING_RATE,
#         target_size=VESUVIUS_TARGET_SIZE
#     )
#     print("\n‚úÖ Training completed successfully!")
# except FileNotFoundError as e:
#     print(f"\n‚ùå Error: Data directory not found: {e}")
#     print("   Please update VESUVIUS_VOLUME_DIR and VESUVIUS_MASK_DIR above")
# except ImportError as e:
#     print(f"\n‚ùå Error: Missing dependency: {e}")
#     print("   Install required packages: pip install tifffile scipy scikit-image")
# except Exception as e:
#     print(f"\n‚ùå Error during training: {e}")
#     import traceback
#     traceback.print_exc()


   Install with: pip install scikit-image


## 12. Generate Submission Files

After training, use this to generate .tif submission files:


In [32]:
# ============================================================================
# Generate Submission Files for Vesuvius Challenge
# ============================================================================
# After training, use this cell to generate .tif submission files

# Configuration
# Paths are relative to current directory (for outputs) or DATA_ROOT (for inputs)
MODEL_PATH = "trained/vesuvius_unet2.5d_best.pth"        # Path to trained model (relative to notebook)
TEST_VOLUME_DIR = os.path.join(DATA_ROOT, "vesuvius/test/volumes")           # Directory with test .tif volumes
SUBMISSION_DIR = "submission"                            # Directory to save submission files (relative to notebook)

# Check dependencies
if not TIF_FILE_AVAILABLE:
    print("‚ùå ERROR: tifffile is required for submission generation!")
    print("   Install with: pip install tifffile")
elif not os.path.exists(MODEL_PATH):
    print(f"‚ùå ERROR: Model file not found: {MODEL_PATH}")
    print("   Please train a model first or update MODEL_PATH above")
elif not os.path.exists(TEST_VOLUME_DIR):
    print(f"‚ö†Ô∏è  WARNING: Test volume directory not found: {TEST_VOLUME_DIR}")
    print("   Please update TEST_VOLUME_DIR above or create the directory")
else:
    print("‚úÖ Ready to generate submission files")
    print(f"\nConfiguration:")
    print(f"  Model: {MODEL_PATH}")
    print(f"  Test volumes: {TEST_VOLUME_DIR}")
    print(f"  Output directory: {SUBMISSION_DIR}")
    print(f"  Slice window: {VESUVIUS_SLICE_WINDOW}")
    print(f"\nTo generate submission, uncomment the code below:\n")

# Uncomment to generate submission files:
# try:
#     # Load model
#     print(f"Loading model from {MODEL_PATH}...")
#     model = UNet2_5D(in_channels=VESUVIUS_SLICE_WINDOW, out_channels=1)
#     model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
#     model = model.to(DEVICE)
#     model.eval()
#     print("‚úÖ Model loaded successfully")
#     
#     # Generate submission
#     print(f"\nGenerating submission files...")
#     generate_submission(
#         model=model,
#         test_volume_dir=TEST_VOLUME_DIR,
#         output_dir=SUBMISSION_DIR,
#         slice_window=VESUVIUS_SLICE_WINDOW,
#         device=str(DEVICE)
#     )
#     
#     print(f"\n‚úÖ Submission files saved to: {SUBMISSION_DIR}")
#     print("   Zip the directory and submit to Kaggle!")
#     
# except FileNotFoundError as e:
#     print(f"\n‚ùå Error: File not found: {e}")
#     print("   Please check MODEL_PATH and TEST_VOLUME_DIR")
# except RuntimeError as e:
#     print(f"\n‚ùå Error: Model loading failed: {e}")
#     print("   Make sure MODEL_PATH points to a valid trained model")
#     print("   Check that VESUVIUS_SLICE_WINDOW matches training configuration")
# except Exception as e:
#     print(f"\n‚ùå Error during submission generation: {e}")
#     import traceback
#     traceback.print_exc()


‚ùå ERROR: Model file not found: trained/vesuvius_unet2.5d_best.pth
   Please train a model first or update MODEL_PATH above


## Quick Start Guide

### For Medical Image Classification:
1. Run cells 1-6 (imports, models, config)
2. Set `TASK = "classification"` and `DOMAIN` in cell 8
3. Run cell 8 to train

### For Vesuvius Challenge:
1. **Install dependencies**: `pip install tifffile scipy scikit-image`
2. Run cells 1-6 (imports, models)
3. Update paths in cell 11 (Vesuvius Challenge Configuration)
4. Uncomment and run training code in cell 11
5. After training, update paths in cell 12 (Submission Generation)
6. Uncomment and run submission code in cell 12
7. Zip the submission directory and upload to Kaggle

### Troubleshooting:
- **Missing dependencies**: Check cell 2 for dependency status
- **OOM errors**: Reduce `VESUVIUS_BATCH_SIZE` or `target_size`
- **File not found**: Update directory paths in configuration cells
- **Model loading errors**: Ensure `VESUVIUS_SLICE_WINDOW` matches training config

---

## References

**Vesuvius Challenge Competition:**
- [Kaggle Competition Page](https://kaggle.com/competitions/vesuvius-challenge-surface-detection)
- [Official Website](https://scrollprize.org)
- [Data Portal](https://dl.ash2txt.org)

**Citation:**
```bibtex
@misc{vesuvius-challenge-surface-detection,
    author = {Sean Johnson and David Josey and Elian Rafael Dal Pr√† and Hendrik Schilling and Youssef Nader and Johannes Rudolph and Forrest McDonald and Paul Henderson and Giorgio Angelotti and Sohier Dane and Mar√≠a Cruz},
    title = {Vesuvius Challenge - Surface Detection},
    year = {2025},
    howpublished = {\url{https://kaggle.com/competitions/vesuvius-challenge-surface-detection}},
    note = {Kaggle}
}
```


In [33]:
def train_segmentation(images_dir, masks_dir, output_path,
                       input_size=(256, 256), epochs=20, batch_size=4):
    print(f"\n{'='*60}")
    print("Training segmentation UNet")
    print(f"{'='*60}")
    print(f"Images dir: {images_dir}")
    print(f"Masks dir:  {masks_dir}")

    if not os.path.isdir(images_dir):
        print(f"[ERROR] Images directory not found: {images_dir}")
        return
    if not os.path.isdir(masks_dir):
        print(f"[ERROR] Masks directory not found: {masks_dir}")
        return

    dataset = SegmentationDataset(images_dir, masks_dir, input_size=input_size)
    if len(dataset) == 0:
        print("[ERROR] No segmentation samples found.")
        return

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                        num_workers=NUM_WORKERS)

    model = UNet(in_channels=1, out_channels=1).to(DEVICE)
    bce = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print(f"Dataset size: {len(dataset)} samples")
    print(f"Device: {DEVICE}")

    best_loss = float("inf")
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0.0
        for imgs, masks in loader:
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)

            optimizer.zero_grad()
            logits = model(imgs)
            loss = 0.5 * bce(logits, masks) + 0.5 * dice_loss_logits(logits, masks)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(loader)
        print(f"Epoch {epoch}/{epochs}: Segmentation loss={avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), output_path)
            print(f"  [BEST] Saved UNet to {output_path}")

    print(f"\nSegmentation training finished. Best loss: {best_loss:.4f}")
    print(f"{'='*60}\n")


## 8. Run Training

Configure the task and domain below, then run this cell:


In [34]:
# Configuration: choose task and domain
# Options:
#   TASK: "classification" or "segmentation"
#   DOMAIN (for classification): one of list(CONFIG.keys()) or "all"

TASK = "classification"   # @param ["classification", "segmentation"]
DOMAIN = "all"            # @param ["all", "brain", "lungs", "skin", "breast", "bone"]

# Segmentation paths (used only if TASK == "segmentation")
SEG_IMAGES = os.path.join(DATA_ROOT, "brain_seg/train/images")
SEG_MASKS = os.path.join(DATA_ROOT, "brain_seg/train/masks")
SEG_OUTPUT = "trained/unet_brain_tumor.pth"

print(f"Task: {TASK}")
if TASK == "classification":
    print(f"Domain: {DOMAIN}")
    if DOMAIN == "all":
        for name, cfg in CONFIG.items():
            train_model(name, cfg)
    else:
        if DOMAIN in CONFIG:
            train_model(DOMAIN, CONFIG[DOMAIN])
        else:
            print(f"[ERROR] Unknown domain: {DOMAIN}")
            print(f"Available domains: {', '.join(CONFIG.keys())}")
else:
    train_segmentation(
        images_dir=SEG_IMAGES,
        masks_dir=SEG_MASKS,
        output_path=SEG_OUTPUT,
        input_size=(256, 256),
        epochs=20,
        batch_size=4,
    )


Task: classification
Domain: all

Training on: BRAIN
[ERROR] Training directory not found: data\brain/brain/Training

Training on: LUNGS
[ERROR] Training directory not found: data\lungs/train

Training on: SKIN
[ERROR] Training directory not found: data\skin/train

Training on: BREAST
[ERROR] Training directory not found: data\breast

Training on: BONE
[ERROR] Training directory not found: data\Bone/train
