# 05 - Generalization Testing

Tests the trained model on **unseen countries** (Chad, Ethiopia, Yemen).

**Key metrics:**
- Standard: Precision, Recall, F1, ROC-AUC
- **Precision@Top-K**: Of our top 50 detections, how many are real? (operationally useful)
- **Error analysis by negative category**: Where does the model fail?

**Input:** `checkpoints/best_model.pth`, `data/tiles/manifest.csv`  
**Output:** Generalization metrics, candidate new detections

In [None]:
# --- Colab setup (uncomment if running on Colab) ---
# PROJECT_DIR = '/content/drive/MyDrive/sentinel-refugee-detection'

# --- Local setup ---
PROJECT_DIR = '..'

In [None]:
import sys
sys.path.insert(0, PROJECT_DIR)

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from pathlib import Path
from torch.utils.data import DataLoader

from src.utils import load_config
from src.data import CampTileDataset
from src.model import create_camp_classifier
from src.train import evaluate, predict, precision_at_top_k

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
config = load_config(f'{PROJECT_DIR}/configs/default.yaml')
tiles_dir = Path(f'{PROJECT_DIR}/data/tiles')

stats = np.load(tiles_dir / 'norm_stats.npz')
norm_stats = {'low': stats['low'], 'high': stats['high']}

model = create_camp_classifier(config)
model.load_state_dict(
    torch.load(f'{PROJECT_DIR}/checkpoints/best_model.pth', map_location=device)
)
model = model.to(device)
model.eval()
print('Model loaded.')

## 1. Test Set Evaluation (Unseen Countries)

In [None]:
import pandas as pd

test_dataset = CampTileDataset(
    manifest_path=tiles_dir / 'manifest.csv',
    split='test',
    transform=None,
    normalize=True,
    norm_stats=norm_stats,
    model_size=config['tile_size_model'],
)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

pos_weight_cfg = config['model'].get('pos_weight', 'auto')
if pos_weight_cfg in (None, 'auto'):
    manifest = pd.read_csv(tiles_dir / 'manifest.csv')
    train_labels = manifest[manifest['split'] == 'train']['label']
    n_pos = train_labels.isin(['camp', 'camp_context']).sum()
    n_neg = (~train_labels.isin(['camp', 'camp_context'])).sum()
    pos_weight_value = (n_neg / max(n_pos, 1)) if n_pos > 0 else 1.0
else:
    pos_weight_value = float(pos_weight_cfg)

criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.tensor([pos_weight_value], device=device)
)

test_metrics = evaluate(model, test_loader, criterion, device)

print('
' + '='*50)
print('TEST SET RESULTS (Unseen Countries)')
print('='*50)
print(f'Precision: {test_metrics["precision"]:.3f}')
print(f'Recall:    {test_metrics["recall"]:.3f}')
print(f'F1:        {test_metrics["f1"]:.3f}')
print(f'ROC-AUC:   {test_metrics["auc"]:.3f}')


## 2. Precision @ Top-K

The operationally useful metric: if we send a team to investigate the top 50 detections, how many are real camps?

In [None]:
labels = test_metrics['labels']
probs = test_metrics['probs']

for k in [10, 20, 50]:
    p_at_k, details = precision_at_top_k(labels, probs, k=k)
    print(f'Precision@{k:>3d}: {p_at_k:.3f} '
          f'({details["true_positives"]} TP, {details["false_positives"]} FP, '
          f'min_prob={details["min_prob"]:.3f})')

## 3. Per-Country Breakdown

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

test_manifest = test_dataset.manifest
preds = (probs >= 0.5).astype(int)

print(f"{'Country':<15} {'N':>5} {'Prec':>8} {'Rec':>8} {'F1':>8} {'AUC':>8}")
print('-' * 55)

for country in config['test_countries']:
    mask = test_manifest['country'].values == country
    if mask.sum() == 0:
        continue
    
    c_labels = labels[mask]
    c_probs = probs[mask]
    c_preds = preds[mask]
    
    p = precision_score(c_labels, c_preds, zero_division=0)
    r = recall_score(c_labels, c_preds, zero_division=0)
    f = f1_score(c_labels, c_preds, zero_division=0)
    auc = roc_auc_score(c_labels, c_probs) if len(np.unique(c_labels)) > 1 else 0
    
    print(f'{country:<15} {mask.sum():>5} {p:>8.3f} {r:>8.3f} {f:>8.3f} {auc:>8.3f}')

## 4. Error Analysis by Negative Category

Where does the model make false positives? Rural? Urban? Barren?

In [None]:
# False positive analysis by negative category
neg_mask = labels == 0
neg_preds = preds[neg_mask]
neg_categories = test_manifest['neg_category'].values[neg_mask]

print('False Positive Rate by Negative Category:')
print(f"{'Category':<15} {'N':>6} {'FP':>6} {'FPR':>8}")
print('-' * 40)

for cat in ['rural', 'urban', 'barren', '']:
    cat_mask = neg_categories == cat
    if cat_mask.sum() == 0:
        continue
    cat_preds = neg_preds[cat_mask]
    fp = cat_preds.sum()
    fpr = fp / cat_mask.sum()
    cat_name = cat if cat else 'camp_peripheral'
    print(f'{cat_name:<15} {cat_mask.sum():>6} {fp:>6} {fpr:>8.3f}')

print('\nKey insight: high FPR on urban = model confuses dense formal settlements with camps')
print('             high FPR on barren = model triggers on bare soil')

## 5. High-Confidence Detections

In [None]:
threshold = config['inference']['confidence_threshold']

test_manifest = test_manifest.copy()
test_manifest['prob'] = probs
test_manifest['predicted'] = preds

high_conf = test_manifest[test_manifest['prob'] >= threshold]
print(f'\nHigh-confidence detections (p >= {threshold}): {len(high_conf)}')
print(f'  True camps: {(high_conf["label"] == "camp").sum()}')
print(f'  False positives: {(high_conf["label"] != "camp").sum()}')

if len(high_conf) > 0:
    print('\nTop 20 detections:')
    print(high_conf[['tile_id', 'country', 'label', 'neg_category', 'prob']]
          .sort_values('prob', ascending=False)
          .head(20)
          .to_string(index=False))

## 6. Visualizations

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, confusion_matrix, ConfusionMatrixDisplay

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# ROC curve
fpr, tpr, _ = roc_curve(labels, probs)
axes[0].plot(fpr, tpr, 'b-', label=f'AUC={test_metrics["auc"]:.3f}')
axes[0].plot([0, 1], [0, 1], 'k--')
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC Curve (Test Set - Unseen Countries)')
axes[0].legend()

# Confusion matrix
cm = confusion_matrix(labels, preds)
ConfusionMatrixDisplay(cm, display_labels=['non-camp', 'camp']).plot(ax=axes[1])
axes[1].set_title('Confusion Matrix')

# Score distribution
axes[2].hist(probs[labels == 0], bins=20, alpha=0.5, label='non-camp', color='blue')
axes[2].hist(probs[labels == 1], bins=20, alpha=0.5, label='camp', color='red')
axes[2].axvline(0.5, color='k', linestyle='--', label='threshold')
axes[2].set_xlabel('Predicted Probability')
axes[2].set_ylabel('Count')
axes[2].set_title('Score Distribution')
axes[2].legend()

plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/generalization_results.png', dpi=150)
plt.show()

In [None]:
# Example tiles: TP, FP, TN, FN
categories = {
    'True Positive': (labels == 1) & (preds == 1),
    'False Positive': (labels == 0) & (preds == 1),
    'True Negative': (labels == 0) & (preds == 0),
    'False Negative': (labels == 1) & (preds == 0),
}

fig, axes = plt.subplots(4, 3, figsize=(12, 16))

for row_idx, (cat_name, mask) in enumerate(categories.items()):
    indices = np.where(mask)[0][:3]
    for col_idx, idx in enumerate(indices):
        tile = np.load(test_manifest.iloc[idx]['path'])
        # Channels: R, G, B, NDVI, NDBI, SWIR_ratio
        rgb = tile[:3].transpose(1, 2, 0)
        rgb = np.clip(rgb / np.percentile(rgb, 98), 0, 1)
        axes[row_idx, col_idx].imshow(rgb)
        info = test_manifest.iloc[idx]
        axes[row_idx, col_idx].set_title(
            f"{info['tile_id']}\np={probs[idx]:.2f} | {info.get('neg_category', '')}",
            fontsize=8,
        )
        axes[row_idx, col_idx].axis('off')
    axes[row_idx, 0].set_ylabel(cat_name, fontsize=11)

plt.suptitle('Test Set Examples (Unseen Countries)', fontsize=14)
plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/example_tiles.png', dpi=150)
plt.show()