# Vesuvius Challenge - Surface Detection Submission

This notebook performs surface segmentation on 3D CT scans and generates submission.zip.

## Pipeline
1. Load pre-trained ResidualUNet3D model
2. Load test volumes (3D CT scans)
3. Run sliding window inference
4. Apply post-processing
5. Save 3D binary masks as .tif files
6. Create submission.zip

**Output:** submission.zip with [image_id].tif files (3D uint8 binary masks)

## 1. Setup and Imports

In [None]:
import sys
import time
import zipfile
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import tifffile

# Add src to path
if Path('src').exists():
    sys.path.insert(0, str(Path.cwd()))

from src.vesuvius.models import build_model
from src.vesuvius.infer import sliding_window_predict
from src.vesuvius.postprocess import apply_postprocessing

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Paths (update for Kaggle environment)
DATA_ROOT = Path('/kaggle/input/vesuvius-challenge-surface-detection')
# DATA_ROOT = Path('vesuvius_kaggle_data')  # Local testing

MODEL_PATH = Path('/kaggle/input/vesuvius-model-weights/last_exp001.pt')
# MODEL_PATH = Path('checkpoints/last_exp001.pt')  # Local testing

OUTPUT_DIR = Path('.')  # Kaggle working directory

# Model config
MODEL_CONFIG = {
    'type': 'unet3d_residual',
    'in_channels': 1,
    'out_channels': 1,
    'base_channels': 40,
    'channel_multipliers': [1, 2, 2, 4, 4],
    'blocks_per_stage': 3,
    'deep_supervision': True,
    'dropout': 0.15,
    'activation': 'mish',
    'norm': 'instance'
}

# Inference config
INFERENCE_CONFIG = {
    'patch_size': [64, 128, 128],
    'overlap': [32, 96, 96],
    'gaussian_blend_sigma': 0.125,
    'tta': 'none',  # Set to 'flips' for better accuracy (slower)
    'threshold': 0.42,
}

# Post-processing config
POSTPROCESS_CONFIG = {
    'remove_small_components_voxels': 600,
    'fill_holes': False,
    'closing_radius': 3
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Data root: {DATA_ROOT}")
print(f"Model: {MODEL_PATH}")
print(f"Device: {DEVICE}")

## 3. Load Model

In [None]:
print("Loading model...")
model = build_model(MODEL_CONFIG)

checkpoint = torch.load(MODEL_PATH, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.to(DEVICE)
model.eval()

num_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model loaded: {num_params/1e6:.1f}M parameters")

# Test forward pass
with torch.no_grad():
    test_input = torch.randn(1, 1, *INFERENCE_CONFIG['patch_size']).to(DEVICE)
    test_output = model(test_input)
    print(f"✓ Test pass: {test_input.shape} -> {test_output['logits'].shape}")

## 4. Load Test Data

In [None]:
# Read test.csv
test_csv = DATA_ROOT / 'test.csv'
test_df = pd.read_csv(test_csv)

# Get test image IDs
test_ids = test_df['Id'].astype(str).tolist()

print(f"Found {len(test_ids)} test volumes:")
for test_id in test_ids:
    print(f"  - {test_id}")

## 5. Inference Functions

In [None]:
def load_test_volume(image_id: str) -> np.ndarray:
    """Load 3D test volume."""
    volume_path = DATA_ROOT / 'test_images' / f"{image_id}.tif"
    
    if not volume_path.exists():
        raise FileNotFoundError(f"Volume not found: {volume_path}")
    
    volume = tifffile.imread(str(volume_path))
    return volume


def normalize_volume(volume: np.ndarray) -> np.ndarray:
    """Normalize to [0, 1]."""
    vmin, vmax = volume.min(), volume.max()
    if vmax > vmin:
        volume = (volume - vmin) / (vmax - vmin)
    return volume.astype(np.float32)


def predict_volume(image_id: str) -> np.ndarray:
    """Run inference on single volume."""
    print(f"\nProcessing {image_id}...")
    start = time.time()
    
    # Load and normalize
    volume = load_test_volume(image_id)
    print(f"  Volume: {volume.shape}, {volume.dtype}")
    
    volume = normalize_volume(volume)
    volume_tensor = torch.from_numpy(volume[None, None, ...])
    
    # Inference
    print(f"  Running inference...")
    with torch.no_grad():
        pred = sliding_window_predict(model, volume_tensor, INFERENCE_CONFIG, DEVICE)
    
    # Binarize
    threshold = INFERENCE_CONFIG['threshold']
    binary_mask = (pred >= threshold).astype(np.uint8)
    
    # Post-process
    binary_mask = apply_postprocessing(binary_mask, POSTPROCESS_CONFIG)
    
    elapsed = time.time() - start
    coverage = binary_mask.sum() / binary_mask.size * 100
    print(f"  ✓ Done in {elapsed:.1f}s | Coverage: {coverage:.2f}%")
    
    return binary_mask

## 6. Run Inference on All Test Volumes

In [None]:
print("=" * 80)
print("Running inference on test volumes")
print("=" * 80)

predictions = {}
total_start = time.time()

for image_id in test_ids:
    try:
        mask_3d = predict_volume(image_id)
        predictions[image_id] = mask_3d
    except Exception as e:
        print(f"  ✗ Error: {e}")
        raise

total_elapsed = time.time() - total_start
print("\n" + "=" * 80)
print(f"✓ Inference complete: {total_elapsed/60:.1f} minutes")
print(f"✓ Processed {len(predictions)} volumes")
print("=" * 80)

## 7. Create Submission ZIP

In [None]:
print("\nCreating submission.zip...")

submission_path = OUTPUT_DIR / 'submission.zip'

with zipfile.ZipFile(submission_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    for image_id, mask_3d in predictions.items():
        # Save as temporary TIFF
        temp_path = f'{image_id}.tif'
        tifffile.imwrite(temp_path, mask_3d.astype(np.uint8))
        
        # Add to ZIP
        zf.write(temp_path, arcname=f'{image_id}.tif')
        
        # Clean up
        Path(temp_path).unlink()
        
        print(f"  ✓ Added {image_id}.tif to ZIP")

zip_size = submission_path.stat().st_size / (1024 * 1024)
print(f"\n✓ Created submission.zip ({zip_size:.1f} MB)")

## 8. Validate Submission

In [None]:
print("\nValidating submission...")

with zipfile.ZipFile(submission_path, 'r') as zf:
    files = zf.namelist()
    print(f"Files in ZIP: {len(files)}")
    
    for filename in files:
        # Read and validate
        with zf.open(filename) as f:
            import io
            mask = tifffile.imread(io.BytesIO(f.read()))
            
            print(f"  ✓ {filename}: {mask.shape}, {mask.dtype}, "
                  f"{mask.sum()/mask.size*100:.1f}% coverage")
            
            # Validate format
            assert mask.dtype == np.uint8, f"Wrong dtype: {mask.dtype}"
            assert mask.ndim == 3, f"Must be 3D, got {mask.ndim}D"
            assert np.all((mask == 0) | (mask == 1)), "Must be binary (0 or 1)"

print("\n✓ Validation passed!")
print("=" * 80)
print("✓ SUBMISSION READY")
print(f"✓ File: {submission_path}")
print(f"✓ Size: {zip_size:.1f} MB")
print("=" * 80)