# Vesuvius Challenge Surface Detection - Final Notebook

This notebook contains all code needed for training and generating a Kaggle submission.

In [None]:
# Install required packages (run this cell first on Kaggle)
!pip install -q monai tifffile imagecodecs scikit-learn

## 1. Setup & Configuration

In [None]:
# ===== CONFIGURE THESE PATHS FOR KAGGLE =====
DATA_DIR = '/kaggle/input/vesuvius-challenge-surface-detection'  # Change for Kaggle
OUTPUT_DIR = '/kaggle/working'  # Change for Kaggle

# For local testing:
# DATA_DIR = '../dataset/raw'
# OUTPUT_DIR = '.'

TRAIN_CSV = f'{DATA_DIR}/train.csv'
TEST_CSV = f'{DATA_DIR}/test.csv'
TRAIN_IMAGES = f'{DATA_DIR}/train_images'
TRAIN_LABELS = f'{DATA_DIR}/train_labels'
TEST_IMAGES = f'{DATA_DIR}/test_images'

# Hyperparameters
MODEL_NAME = 'segresnet'
BATCH_SIZE = 1
EPOCHS = 50
LEARNING_RATE = 1e-4
VAL_RATIO = 0.2
SEED = 42
TARGET_SIZE = 128
NUM_WORKERS = 0
IGNORE_LABEL = 2  # Unlabeled pixels

In [None]:
import os
import json
import random
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Union, Callable

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from scipy import ndimage

# Set seeds
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

## 2. Data Exploration (EDA)

In [None]:
# Load CSVs
train_df = pd.read_csv(TRAIN_CSV)
train_df['id'] = train_df['id'].astype(str)
test_df = pd.read_csv(TEST_CSV)
test_df['id'] = test_df['id'].astype(str)

print(f'Train samples: {len(train_df)}')
print(f'Test samples: {len(test_df)}')
print(f'\nTrain columns: {list(train_df.columns)}')
print(f'\nScroll ID distribution:')
print(train_df['scroll_id'].value_counts().sort_index())

In [None]:
# Check for missing files
train_images_dir = Path(TRAIN_IMAGES)
train_labels_dir = Path(TRAIN_LABELS)

valid_ids = []
for vol_id in train_df['id']:
    img_path = train_images_dir / f'{vol_id}.tif'
    lbl_path = train_labels_dir / f'{vol_id}.tif'
    if img_path.exists() and lbl_path.exists():
        valid_ids.append(vol_id)

print(f'Valid samples (with both image and label): {len(valid_ids)}')
print(f'Missing: {len(train_df) - len(valid_ids)}')

## 3. Data Visualization

In [None]:
# Visualize a sample volume
if len(valid_ids) > 0:
    sample_id = valid_ids[0]
    img = tifffile.imread(train_images_dir / f'{sample_id}.tif')
    lbl = tifffile.imread(train_labels_dir / f'{sample_id}.tif')
    
    print(f'Sample ID: {sample_id}')
    print(f'Image shape: {img.shape}, dtype: {img.dtype}')
    print(f'Label shape: {lbl.shape}, unique values: {np.unique(lbl)}')
    
    # Show middle slices
    mid_z = img.shape[0] // 2
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img[mid_z], cmap='gray'); axes[0].set_title('Image (Z-slice)')
    axes[1].imshow(lbl[mid_z], cmap='tab10', vmin=0, vmax=2); axes[1].set_title('Label (Z-slice)')
    axes[2].imshow(img[mid_z] * 0.5 + lbl[mid_z] * 50, cmap='gray'); axes[2].set_title('Overlay')
    plt.tight_layout(); plt.show()

## 4. Transforms & Dataset

In [None]:
class CenterCropOrPad:
    def __init__(self, target_size: int = 128):
        self.target_size = (target_size, target_size, target_size)
    
    def __call__(self, image, label=None):
        for dim in range(3):
            current = image.shape[dim]
            target = self.target_size[dim]
            if current > target:
                start = (current - target) // 2
                slices = [slice(None)] * 3
                slices[dim] = slice(start, start + target)
                image = image[tuple(slices)]
                if label is not None:
                    label = label[tuple(slices)]
            elif current < target:
                pad_before = (target - current) // 2
                pad_after = target - current - pad_before
                pad_width = [(0, 0)] * 3
                pad_width[dim] = (pad_before, pad_after)
                image = np.pad(image, pad_width, mode='constant', constant_values=0)
                if label is not None:
                    label = np.pad(label, pad_width, mode='constant', constant_values=2)
        return (image, label) if label is not None else image

class Normalize:
    def __init__(self, mean=127.5, std=127.5):
        self.mean, self.std = mean, std
    def __call__(self, image, label=None):
        image = (image.astype(np.float32) - self.mean) / self.std
        return (image, label) if label is not None else image

class ZJitter:
    def __init__(self, jitter_range=5):
        self.jitter_range = jitter_range
    def __call__(self, image, label=None):
        shift = np.random.randint(-self.jitter_range, self.jitter_range + 1)
        if shift == 0:
            return (image, label) if label is not None else image
        image = np.roll(image, shift, axis=0)
        if shift > 0: image[:shift] = 0
        else: image[shift:] = 0
        if label is not None:
            label = np.roll(label, shift, axis=0)
            if shift > 0: label[:shift] = 2
            else: label[shift:] = 2
            return image, label
        return image

class BasicAugs:
    def __call__(self, image, label=None):
        if np.random.random() < 0.5:  # H flip
            image = np.flip(image, axis=2).copy()
            if label is not None: label = np.flip(label, axis=2).copy()
        if np.random.random() < 0.5:  # V flip
            image = np.flip(image, axis=1).copy()
            if label is not None: label = np.flip(label, axis=1).copy()
        if np.random.random() < 0.5:  # Rot90
            k = np.random.randint(1, 4)
            image = np.rot90(image, k=k, axes=(1, 2)).copy()
            if label is not None: label = np.rot90(label, k=k, axes=(1, 2)).copy()
        return (image, label) if label is not None else image

In [None]:
class VesuviusDataset(Dataset):
    def __init__(self, vol_ids, images_dir, labels_dir=None, transforms=None, is_train=True):
        self.vol_ids = vol_ids
        self.images_dir = Path(images_dir)
        self.labels_dir = Path(labels_dir) if labels_dir else None
        self.transforms = transforms or []
        self.is_train = is_train
    
    def __len__(self):
        return len(self.vol_ids)
    
    def __getitem__(self, idx):
        vol_id = self.vol_ids[idx]
        image = tifffile.imread(self.images_dir / f'{vol_id}.tif')
        
        label = None
        if self.labels_dir:
            lbl_path = self.labels_dir / f'{vol_id}.tif'
            if lbl_path.exists():
                label = tifffile.imread(lbl_path)
        
        for t in self.transforms:
            if label is not None:
                image, label = t(image, label)
            else:
                result = t(image, None)
                image = result[0] if isinstance(result, tuple) else result
        
        image = torch.from_numpy(image).float().unsqueeze(0)  # [1, D, H, W]
        if label is not None:
            label = torch.from_numpy(label.astype(np.int64))
        else:
            label = torch.zeros(image.shape[1:], dtype=torch.long)
        
        return {'image': image, 'label': label, 'id': vol_id}

In [None]:
# Create train/val split
from sklearn.model_selection import train_test_split

train_ids, val_ids = train_test_split(valid_ids, test_size=VAL_RATIO, random_state=SEED)
print(f'Train: {len(train_ids)}, Val: {len(val_ids)}')

# Transforms
train_transforms = [CenterCropOrPad(TARGET_SIZE), ZJitter(5), BasicAugs(), Normalize()]
val_transforms = [CenterCropOrPad(TARGET_SIZE), Normalize()]

# Datasets
train_ds = VesuviusDataset(train_ids, TRAIN_IMAGES, TRAIN_LABELS, train_transforms, is_train=True)
val_ds = VesuviusDataset(val_ids, TRAIN_IMAGES, TRAIN_LABELS, val_transforms, is_train=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f'Batches - Train: {len(train_loader)}, Val: {len(val_loader)}')

## 5. Model & Metrics

In [None]:
def get_model(name='segresnet', in_channels=1, out_channels=2):
    if name == 'segresnet':
        from monai.networks.nets import SegResNet
        return SegResNet(spatial_dims=3, in_channels=in_channels, out_channels=out_channels,
                         init_filters=32, blocks_down=(1,2,2,4), blocks_up=(1,1,1))
    elif name == 'unet':
        from monai.networks.nets import UNet
        return UNet(spatial_dims=3, in_channels=in_channels, out_channels=out_channels,
                    channels=(32,64,128,256,512), strides=(2,2,2,2), num_res_units=2)
    else:
        raise ValueError(f'Unknown model: {name}')

model = get_model(MODEL_NAME).to(DEVICE)
print(f'Model: {MODEL_NAME}, Parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# Metrics
def dice_coefficient(pred, gt):
    intersection = np.sum(pred & gt)
    return 2 * intersection / (np.sum(pred) + np.sum(gt) + 1e-8)

def surface_dice(pred, gt, tau=2.0):
    from scipy.ndimage import distance_transform_edt, binary_erosion
    if pred.sum() == 0 and gt.sum() == 0: return 1.0
    if pred.sum() == 0 or gt.sum() == 0: return 0.0
    pred_surface = pred ^ binary_erosion(pred)
    gt_surface = gt ^ binary_erosion(gt)
    dist_pred = distance_transform_edt(~gt_surface)
    dist_gt = distance_transform_edt(~pred_surface)
    pred_close = pred_surface & (dist_pred <= tau)
    gt_close = gt_surface & (dist_gt <= tau)
    return (pred_close.sum() + gt_close.sum()) / (pred_surface.sum() + gt_surface.sum() + 1e-8)

def compute_metrics(pred, gt, ignore_label=2):
    # Mask out ignored pixels
    mask = gt != ignore_label
    pred_masked = pred[mask]
    gt_masked = gt[mask]
    pred_bin = (pred_masked == 1).astype(np.uint8)
    gt_bin = (gt_masked == 1).astype(np.uint8)
    return {'dice': dice_coefficient(pred_bin, gt_bin)}

## 6. Training Loop

In [None]:
# Custom loss that handles ignore_label
class MaskedDiceCELoss(nn.Module):
    def __init__(self, ignore_label=2):
        super().__init__()
        self.ignore_label = ignore_label
        self.ce = nn.CrossEntropyLoss(ignore_index=ignore_label)
    
    def dice_loss(self, pred, target):
        # pred: [B, C, D, H, W] softmax outputs
        # target: [B, D, H, W] class indices
        mask = target != self.ignore_label
        pred_soft = torch.softmax(pred, dim=1)
        
        # Only compute dice for foreground class (class 1)
        pred_fg = pred_soft[:, 1]  # [B, D, H, W]
        target_fg = (target == 1).float()
        
        # Apply mask
        pred_fg = pred_fg * mask
        target_fg = target_fg * mask
        
        intersection = (pred_fg * target_fg).sum()
        union = pred_fg.sum() + target_fg.sum()
        dice = (2 * intersection + 1e-5) / (union + 1e-5)
        return 1 - dice
    
    def forward(self, pred, target):
        ce_loss = self.ce(pred, target)
        dice_loss = self.dice_loss(pred, target)
        return ce_loss + dice_loss

criterion = MaskedDiceCELoss(ignore_label=IGNORE_LABEL)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-7)
scaler = GradScaler()

best_metric = 0.0
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}

In [None]:
for epoch in range(1, EPOCHS + 1):
    # Train
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f'Epoch {epoch} [Train]', leave=False):
        images = batch['image'].to(DEVICE)
        labels = batch['label'].to(DEVICE)  # [B, D, H, W] with values 0, 1, 2
        
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)  # [B, 2, D, H, W]
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    history['train_loss'].append(train_loss)
    
    # Validate
    model.eval()
    val_loss = 0.0
    all_dice = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f'Epoch {epoch} [Val]', leave=False):
            images = batch['image'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            preds = outputs.argmax(dim=1).cpu().numpy()
            gts = batch['label'].cpu().numpy()
            for p, g in zip(preds, gts):
                m = compute_metrics(p, g, ignore_label=IGNORE_LABEL)
                all_dice.append(m['dice'])
    
    val_loss /= len(val_loader)
    val_dice = np.mean(all_dice)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    
    scheduler.step()
    
    # Save best
    if val_dice > best_metric:
        best_metric = val_dice
        torch.save(model.state_dict(), f'{OUTPUT_DIR}/best_model.pth')
        print(f'Epoch {epoch}: New best model saved! Dice={val_dice:.4f}')
    
    print(f'Epoch {epoch}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}')

print(f'\nTraining complete! Best Dice: {best_metric:.4f}')

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history['train_loss'], label='Train'); axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss'); axes[0].legend(); axes[0].set_title('Loss')
axes[1].plot(history['val_dice'])
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Dice'); axes[1].set_title('Validation Dice')
plt.tight_layout(); plt.show()

## 7. Test Prediction & Submission

In [None]:
# Load best model
model.load_state_dict(torch.load(f'{OUTPUT_DIR}/best_model.pth', map_location=DEVICE))
model.eval()

# Test dataset - IMPORTANT: Use val_transforms (no augmentation)
test_ids = test_df['id'].tolist()
test_ds = VesuviusDataset(test_ids, TEST_IMAGES, labels_dir=None, transforms=val_transforms, is_train=False)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

print(f'Test samples: {len(test_ds)}')

In [None]:
# Generate predictions and create submission
submissions = []

for batch in tqdm(test_loader, desc='Generating predictions'):
    vol_id = batch['id'][0]
    image = batch['image'].to(DEVICE)
    
    with torch.no_grad(), autocast():
        output = model(image)
        pred = output.argmax(dim=1).squeeze().cpu().numpy()
    
    # RLE encode the prediction
    pred_flat = pred.flatten()
    pred_flat = np.concatenate([[0], pred_flat, [0]])
    runs = np.where(pred_flat[1:] != pred_flat[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    rle = ' '.join(str(x) for x in runs)
    
    submissions.append({'id': vol_id, 'rle': rle})

# Save submission
submission_df = pd.DataFrame(submissions)
submission_df.to_csv(f'{OUTPUT_DIR}/submission.csv', index=False)
print(f'Submission saved to {OUTPUT_DIR}/submission.csv')
print(submission_df.head())

In [None]:
print('\n=== NOTEBOOK COMPLETE ===')
print(f'Best validation Dice: {best_metric:.4f}')
print(f'Submission file: {OUTPUT_DIR}/submission.csv')