# 03 - Prepare Tiles Dataset

Creates the train/val/test split and computes normalization statistics.

- **Train/Val:** Syria + South Sudan (80/20 split)
- **Test:** Chad + Ethiopia + Yemen (entirely held out)

**Input:** `data/sentinel2/*.npy`, `data/labels/all_locations.csv`  
**Output:** `data/tiles/manifest.csv`, `data/tiles/norm_stats.npz`

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
from pathlib import Path

from src.utils import load_config
from src.data import create_manifest, compute_norm_stats

In [None]:
config = load_config('../configs/default.yaml')

labels_df = pd.read_csv('../data/labels/all_locations.csv')
tiles_dir = Path('../data/sentinel2')
output_dir = Path('../data/tiles')
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Total locations: {len(labels_df)}")
print(f"Available tiles: {len(list(tiles_dir.glob('*.npy')))}")

## 1. Create manifest with train/val/test split

In [None]:
manifest = create_manifest(
    tiles_dir=tiles_dir,
    labels_df=labels_df,
    output_path=output_dir / 'manifest.csv',
    train_countries=config['train_countries'],
    test_countries=config['test_countries'],
    val_fraction=0.2,
)

## 2. Compute normalization statistics from training set

In [None]:
norm_stats = compute_norm_stats(
    manifest_path=output_dir / 'manifest.csv',
    split='train',
    low_pct=config['normalization']['low'],
    high_pct=config['normalization']['high'],
)

# Save normalization stats
np.savez(
    output_dir / 'norm_stats.npz',
    low=norm_stats['low'],
    high=norm_stats['high'],
)

print("Normalization stats per band:")
for i, band in enumerate(config['bands']):
    print(f"  {band}: low={norm_stats['low'][i]:.1f}, high={norm_stats['high'][i]:.1f}")

## 3. Verify dataset

In [None]:
from src.data import CampTileDataset, SatelliteAugmentation

# Test loading
train_dataset = CampTileDataset(
    manifest_path=output_dir / 'manifest.csv',
    split='train',
    transform=SatelliteAugmentation(),
    normalize=True,
    norm_stats=norm_stats,
)

print(f"Train dataset size: {len(train_dataset)}")

# Load one sample
image, label = train_dataset[0]
print(f"Image shape: {image.shape}")
print(f"Image range: [{image.min():.3f}, {image.max():.3f}]")
print(f"Label: {label}")

In [None]:
import matplotlib.pyplot as plt

# Visualize augmented samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i, ax in enumerate(axes.flat):
    idx = i % len(train_dataset)
    image, label = train_dataset[idx]
    rgb = image[[2, 1, 0]].numpy().transpose(1, 2, 0)
    rgb = np.clip(rgb, 0, 1)
    ax.imshow(rgb)
    ax.set_title(f"Label: {'camp' if label == 1 else 'non-camp'}")
    ax.axis('off')

plt.suptitle('Augmented Training Samples', fontsize=14)
plt.tight_layout()
plt.show()