# 04 - Train Camp Classifier

Trains a ResNet-18 binary classifier (camp vs non-camp) and two baselines.

**Run this notebook on Google Colab** with GPU runtime.

**Input:** `data/tiles/manifest.csv`, `data/tiles/norm_stats.npz`  
**Output:** `checkpoints/best_model.pth`, training metrics

In [None]:
# --- Colab setup (uncomment if running on Colab) ---
# !pip install pyyaml
# from google.colab import drive
# drive.mount('/content/drive')
# PROJECT_DIR = '/content/drive/MyDrive/sentinel-refugee-detection'

# --- Local setup ---
PROJECT_DIR = '..'

In [None]:
import sys
sys.path.insert(0, PROJECT_DIR)

import numpy as np
import torch
from pathlib import Path

from src.utils import load_config
from src.data import CampTileDataset, SatelliteAugmentation
from src.model import create_camp_classifier
from src.train import train_model, evaluate

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
config = load_config(f'{PROJECT_DIR}/configs/default.yaml')
tiles_dir = Path(f'{PROJECT_DIR}/data/tiles')

# Load normalization stats
stats_file = np.load(tiles_dir / 'norm_stats.npz')
norm_stats = {'low': stats_file['low'], 'high': stats_file['high']}

## 1. Load Datasets

In [None]:
aug_config = config['augmentation']
augmentation = SatelliteAugmentation(
    rotation=aug_config['random_rotation'],
    flip_h=aug_config['horizontal_flip'],
    flip_v=aug_config['vertical_flip'],
    brightness=aug_config['brightness_jitter'],
)

train_dataset = CampTileDataset(
    manifest_path=tiles_dir / 'manifest.csv',
    split='train',
    transform=augmentation,
    normalize=True,
    norm_stats=norm_stats,
)

val_dataset = CampTileDataset(
    manifest_path=tiles_dir / 'manifest.csv',
    split='val',
    transform=None,  # No augmentation for validation
    normalize=True,
    norm_stats=norm_stats,
)

print(f"Train: {len(train_dataset)} tiles")
print(f"Val: {len(val_dataset)} tiles")

## 2. Train ResNet-18 Classifier

In [None]:
model = create_camp_classifier(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

history = train_model(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    config=config,
    device=device,
    checkpoint_dir=f'{PROJECT_DIR}/checkpoints',
)

## 3. Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss')
axes[0].legend()

# Precision & Recall
axes[1].plot(history['val_precision'], label='Precision')
axes[1].plot(history['val_recall'], label='Recall')
axes[1].plot(history['val_f1'], label='F1')
axes[1].set_xlabel('Epoch')
axes[1].set_title('Precision / Recall / F1')
axes[1].legend()

# AUC
axes[2].plot(history['val_auc'])
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('AUC')
axes[2].set_title('ROC-AUC')

plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/training_curves.png', dpi=150)
plt.show()

print(f"Best val metrics:")
best_epoch = np.argmin(history['val_loss'])
print(f"  Epoch: {best_epoch + 1}")
print(f"  Precision: {history['val_precision'][best_epoch]:.3f}")
print(f"  Recall: {history['val_recall'][best_epoch]:.3f}")
print(f"  F1: {history['val_f1'][best_epoch]:.3f}")
print(f"  AUC: {history['val_auc'][best_epoch]:.3f}")

## 4. Baselines

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_auc_score
from src.model import extract_band_features, ndvi_threshold_classifier

# Extract features from train and val sets
train_tiles = [np.load(row['path']) for _, row in
               train_dataset.manifest.iterrows()]
train_labels = [CampTileDataset.LABEL_MAP.get(row['label'], 0)
                for _, row in train_dataset.manifest.iterrows()]

val_tiles = [np.load(row['path']) for _, row in
             val_dataset.manifest.iterrows()]
val_labels = [CampTileDataset.LABEL_MAP.get(row['label'], 0)
              for _, row in val_dataset.manifest.iterrows()]

X_train, y_train = extract_band_features(train_tiles, train_labels)
X_val, y_val = extract_band_features(val_tiles, val_labels)

In [None]:
# Random Forest baseline
rf = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
rf.fit(X_train, y_train)
rf_preds = rf.predict(X_val)
rf_probs = rf.predict_proba(X_val)[:, 1]

print("=== Random Forest Baseline ===")
print(classification_report(y_val, rf_preds, target_names=['non-camp', 'camp']))
print(f"ROC-AUC: {roc_auc_score(y_val, rf_probs):.3f}")

In [None]:
# NDVI threshold baseline
ndvi_preds, ndvi_scores = ndvi_threshold_classifier(val_tiles, threshold=-0.1)

print("=== NDVI Threshold Baseline ===")
print(classification_report(y_val, ndvi_preds, target_names=['non-camp', 'camp']))
if len(np.unique(y_val)) > 1:
    print(f"ROC-AUC: {roc_auc_score(y_val, -ndvi_scores):.3f}")

## 5. Comparison Summary

In [None]:
print("\n" + "="*50)
print("MODEL COMPARISON (Validation Set)")
print("="*50)
print(f"{'Model':<20} {'Precision':>10} {'Recall':>10} {'F1':>10} {'AUC':>10}")
print("-"*60)
print(f"{'ResNet-18':<20} {history['val_precision'][best_epoch]:>10.3f} "
      f"{history['val_recall'][best_epoch]:>10.3f} "
      f"{history['val_f1'][best_epoch]:>10.3f} "
      f"{history['val_auc'][best_epoch]:>10.3f}")

from sklearn.metrics import precision_score, recall_score, f1_score
print(f"{'Random Forest':<20} {precision_score(y_val, rf_preds):>10.3f} "
      f"{recall_score(y_val, rf_preds):>10.3f} "
      f"{f1_score(y_val, rf_preds):>10.3f} "
      f"{roc_auc_score(y_val, rf_probs):>10.3f}")

print(f"{'NDVI Threshold':<20} {precision_score(y_val, ndvi_preds):>10.3f} "
      f"{recall_score(y_val, ndvi_preds):>10.3f} "
      f"{f1_score(y_val, ndvi_preds):>10.3f} "
      f"{'N/A':>10}")