# 02 - Download Sentinel-2 Imagery

Downloads Sentinel-2 L2A tiles and computes derived indices.

**Key design decisions:**
- Download at **128x128 px** (1.28 km) for more context, resize to 64 for the CNN
- Save **derived indices** (RGB + NDVI + NDBI + SWIR_ratio) not raw bands
- Track **scene_id** per tile for leakage-free splitting

**Run this notebook on Google Colab** for faster downloads.

**Input:** `data/labels/all_locations.csv`  
**Output:** `data/sentinel2/{tile_id}.npy` (6-channel: RGB + indices, 128x128 px)

In [None]:
# --- Colab setup (uncomment if running on Colab) ---
# !pip install pystac-client planetary-computer rasterio pyproj pyyaml
# from google.colab import drive
# drive.mount('/content/drive')
# PROJECT_DIR = '/content/drive/MyDrive/sentinel-refugee-detection'

# --- Local setup ---
PROJECT_DIR = '..'

In [None]:
import sys
sys.path.insert(0, PROJECT_DIR)

import pandas as pd
import numpy as np
import time
from pathlib import Path
from src.utils import load_config, search_sentinel2, download_tile, save_tile

In [None]:
config = load_config(f'{PROJECT_DIR}/configs/default.yaml')
locations = pd.read_csv(f'{PROJECT_DIR}/data/labels/all_locations.csv')
output_dir = Path(f'{PROJECT_DIR}/data/sentinel2')
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Locations to download: {len(locations)}")
print(locations['label'].value_counts())
print(f"\nTile size: {config['tile_size_download']}px = {config['tile_size_download'] * config['resolution']}m")

## Download all tiles

For each location:
1. Search for least cloudy Sentinel-2 scene
2. Download 6 raw bands
3. Compute derived indices: RGB + NDVI + NDBI + SWIR_ratio
4. Save the **index tile** (what the model sees)
5. Track the **scene_id** for leakage-free splitting

In [None]:
bands = config['bands']
tile_size = config['tile_size_download']  # 128
resolution = config['resolution']

success = 0
failures = []
scene_ids = {}  # tile_id -> scene_id mapping

for i, row in locations.iterrows():
    tile_id = row['tile_id']
    out_path = output_dir / f"{tile_id}.npy"
    
    # Skip if already downloaded
    if out_path.exists():
        # Try to load scene_id from metadata
        meta_path = out_path.with_suffix('.json')
        if meta_path.exists():
            import json
            with open(meta_path) as f:
                scene_ids[tile_id] = json.load(f).get('scene_id', 'unknown')
        success += 1
        continue
    
    try:
        items = search_sentinel2(row['lat'], row['lon'], config, max_items=3)
        
        if not items:
            failures.append((tile_id, 'no scenes found'))
            continue
        
        # download_tile now returns (raw_tile, index_tile, meta)
        raw_tile, index_tile, meta = download_tile(
            items[0], row['lat'], row['lon'],
            bands=bands, tile_size=tile_size, resolution=resolution,
        )
        
        # Validate
        if index_tile.shape != (6, tile_size, tile_size):
            failures.append((tile_id, f'wrong shape {index_tile.shape}'))
            continue
        
        if np.isnan(index_tile).any():
            failures.append((tile_id, 'contains NaN'))
            continue
        
        # Save the INDEX tile (RGB + NDVI + NDBI + SWIR_ratio)
        save_tile(index_tile, out_path, meta)
        scene_ids[tile_id] = meta.get('scene_id', 'unknown')
        success += 1
        
    except Exception as e:
        failures.append((tile_id, str(e)))
    
    if (i + 1) % 10 == 0:
        print(f"Progress: {i+1}/{len(locations)} "
              f"(success={success}, failures={len(failures)})")
    
    time.sleep(0.5)

print(f"\nDone! Success: {success}, Failures: {len(failures)}")

In [None]:
# Save scene_id mapping for leakage-free splitting
locations['scene_id'] = locations['tile_id'].map(scene_ids).fillna('unknown')
locations.to_csv(f'{PROJECT_DIR}/data/labels/all_locations.csv', index=False)

n_scenes = locations['scene_id'].nunique()
print(f"Unique Sentinel-2 scenes: {n_scenes}")
print(f"Scene IDs will be used for leakage-free train/val split")

if failures:
    print(f"\nFailed downloads ({len(failures)}):")
    for tile_id, reason in failures[:10]:
        print(f"  {tile_id}: {reason}")

## Sanity check: visualize tiles

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

camp_files = sorted(output_dir.glob('camp_*.npy'))[:4]
neg_files = sorted(output_dir.glob('neg_*.npy'))[:4]

for files, ax_row, title in [(camp_files, axes[0], 'Camps'), 
                              (neg_files, axes[1], 'Negatives')]:
    for f, ax in zip(files, ax_row):
        tile = np.load(f)  # (6, 128, 128): R, G, B, NDVI, NDBI, SWIR_ratio
        rgb = tile[:3].transpose(1, 2, 0)  # Already R, G, B
        rgb = np.clip(rgb / np.percentile(rgb, 98), 0, 1)
        ax.imshow(rgb)
        ax.set_title(f.stem, fontsize=8)
        ax.axis('off')
    ax_row[0].set_ylabel(title, fontsize=14)

plt.suptitle('Sample Tiles (RGB from derived channels, 128x128)', fontsize=14)
plt.tight_layout()
plt.show()

# Also show the derived indices for one tile
if camp_files:
    tile = np.load(camp_files[0])
    channel_names = ['Red', 'Green', 'Blue', 'NDVI', 'NDBI', 'SWIR_ratio']
    fig, axes = plt.subplots(1, 6, figsize=(18, 3))
    for ch, ax, name in zip(range(6), axes, channel_names):
        im = ax.imshow(tile[ch], cmap='RdYlGn' if 'NDV' in name else 'viridis')
        ax.set_title(name)
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=0.046)
    plt.suptitle(f'All 6 Channels: {camp_files[0].stem}', fontsize=14)
    plt.tight_layout()
    plt.show()