# Sliding-window Geospatial Inference on Sentinel-2 Imagery

This notebook runs sliding-window inference on **Sentinel-2 satellite imagery** (GeoTIFF format), performs aircraft detection, stitches results back to geographic coordinates, and visualizes them.

**Requirements:** 
- `rasterio`, `geopandas`, `ultralytics`, `torch`, `torchvision`
- The `sentinel2_yolo` package must be installed: `pip install -e ..`

## What is Sentinel-2?
Sentinel-2 is a satellite constellation providing free, multi-spectral imagery with 11 bands:
- **Spatial resolution**: 10m (RGB + NIR), 20m (SWIR, Red Edge), 60m (Coastal Aerosol)
- **Temporal resolution**: 5 days at equator (A+B constellation)
- **Band combinations**: 
  - Natural Color: Bands 4, 3, 2 (Red, Green, Blue)
  - False Color (Vegetation): Bands 8, 4, 3 (NIR, Red, Green)
  - SWIR: Bands 12, 11, 4

## Workflow
1. **Setup**: Check dependencies and configure Sentinel-2 paths
2. **Preview**: View Sentinel-2 scene with band selection
3. **Inference**: Run sliding-window detection on GeoTIFF
4. **Analysis**: Display detection statistics and metrics
5. **Visualization**: Overlay detections on satellite imagery
6. **Export**: Save results in multiple GIS formats


In [None]:
# Check and install dependencies if needed
import subprocess
import sys

def check_and_install_deps():
    """Verify required packages are installed"""
    required_packages = {
        'rasterio': 'rasterio',
        'geopandas': 'geopandas',
        'ultralytics': 'ultralytics',
        'torch': 'torch',
        'sentinel2_yolo': 'sentinel2-yolo-aircraft'
    }
    
    missing = []
    for name, pip_name in required_packages.items():
        try:
            __import__(name)
            print(f"✓ {name}")
        except ImportError:
            missing.append(pip_name)
            print(f"✗ {name} - NOT FOUND")
    
    if missing:
        print(f"\nInstalling missing packages: {', '.join(missing)}")
        subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing)
        print("✓ Installation complete")
    else:
        print("\n✓ All dependencies already installed")

check_and_install_deps()


In [None]:
# Imports
from pathlib import Path
import logging
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import rasterio
from rasterio.plot import show
import geopandas as gpd
import numpy as np
from sentinel2_yolo.inference import sliding_window_inference

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("✓ All imports successful")


In [None]:
# Configuration for Sentinel-2 Inference
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output

# Try to auto-detect paths, with fallbacks
def find_latest_weights(search_dir: Path = Path('runs/train')) -> Optional[Path]:
    """Find the most recently trained model"""
    if not search_dir.exists():
        return None
    
    weights_files = list(search_dir.rglob('best.pt'))
    if not weights_files:
        return None
    
    # Return most recently modified
    return max(weights_files, key=lambda p: p.stat().st_mtime)

def find_sentinel2_geotiffs(search_dir: Path = Path('data')) -> list:
    """Find available Sentinel-2 GeoTIFF files"""
    if not search_dir.exists():
        return []
    
    # Look in test directory first, then all subdirectories
    test_tiffs = list((search_dir / 'test').glob('*.tif*')) if (search_dir / 'test').exists() else []
    if test_tiffs:
        return sorted(test_tiffs)
    
    return sorted(list(search_dir.rglob('*.tif')) + list(search_dir.rglob('*.tiff')))

# Auto-detect defaults
detected_weights = find_latest_weights()
detected_geotiffs = find_sentinel2_geotiffs()

# Set defaults
DEFAULT_GEOTIFF = detected_geotiffs[0] if detected_geotiffs else Path('data/test/scene.tif')
DEFAULT_WEIGHTS = detected_weights if detected_weights else Path('runs/train/notebook_enhanced/weights/best.pt')
DEFAULT_OUTPUT = Path('data/detections/sentinel2_inference_results.geojson')

print(f"Auto-detected options:")
print(f"  Sentinel-2 GeoTIFFs found: {len(detected_geotiffs)}")
if detected_geotiffs:
    for g in detected_geotiffs[:3]:
        print(f"    - {g.name}")
    if len(detected_geotiffs) > 3:
        print(f"    ... and {len(detected_geotiffs)-3} more")
print(f"  Latest trained model: {DEFAULT_WEIGHTS.name if DEFAULT_WEIGHTS.exists() else 'NOT FOUND'}")

# === SENTINEL-2 CONFIGURATION ===
# Paths
GEO_TIFF = DEFAULT_GEOTIFF
WEIGHTS = DEFAULT_WEIGHTS
OUT_GEOJSON = DEFAULT_OUTPUT

# Sentinel-2 Band Index Reference (1-indexed):
# 1: Coastal aerosol (60m) - 2: Blue (10m)
# 3: Green (10m) - 4: Red (10m)
# 5: Vegetation Red Edge (20m) - 6: Vegetation Red Edge (20m)
# 7: Vegetation Red Edge (20m) - 8: NIR (10m)
# 8A: Vegetation Red Edge (20m) - 11: SWIR (20m)
# 12: SWIR (20m)

# Common band combinations for RGB display:
# Natural Color: [4, 3, 2] (Red, Green, Blue)
# False Color (Vegetation): [8, 4, 3] (NIR, Red, Green) 
# SWIR: [12, 11, 4] (SWIR2, SWIR1, Red)

# For inference, we use the model as trained
# But we can select which bands to use for display
BAND_INDICES = [4, 3, 2]  # Natural color RGB (1-indexed for Sentinel-2)
BAND_COMBO_NAME = "Natural Color (RGB)"

# Inference parameters
TILE_SIZE = 1024          # Tile size in pixels (at 10m resolution = 10.24km)
OVERLAP = 200             # Overlap between tiles in pixels
CONF_THRESHOLD = 0.5      # Confidence threshold (0.0-1.0)
DEVICE = 'cpu'            # 'cpu' or 'cuda' or device number
NMS_IOU = 0.3             # NMS deduplication threshold (0.3 = remove duplicates > 30% overlap)

# Validation
SAVE_INTERMEDIATE = False # Save tiles for debugging
VERBOSE = True            # Detailed logging

print("\n=== Sentinel-2 Inference Configuration ===")
print(f"GeoTIFF: {GEO_TIFF}")
print(f"Weights: {WEIGHTS}")
print(f"Output: {OUT_GEOJSON}")
print(f"Band Visualization: {BAND_COMBO_NAME} (Bands {BAND_INDICES})")
print(f"Inference parameters:")
print(f"  Tile size: {TILE_SIZE}px (~{TILE_SIZE*10}m at 10m resolution)")
print(f"  Overlap: {OVERLAP}px")
print(f"  Confidence threshold: {CONF_THRESHOLD}")
print(f"  NMS IOU: {NMS_IOU}")
print(f"  Device: {DEVICE}")


In [None]:
# Validate paths and preview Sentinel-2 GeoTIFF
def validate_paths():
    """Check if required files exist"""
    errors = []
    
    if not GEO_TIFF.exists():
        errors.append(f"GeoTIFF not found: {GEO_TIFF}")
    
    if not WEIGHTS.exists():
        errors.append(f"Model weights not found: {WEIGHTS}")
    
    if errors:
        print("❌ Configuration errors:")
        for err in errors:
            print(f"  - {err}")
        return False
    
    print("✓ All required files found")
    return True

if not validate_paths():
    print("\nPlease update GEO_TIFF and WEIGHTS paths in the previous cell")
else:
    # Display Sentinel-2 GeoTIFF metadata and preview
    with rasterio.open(GEO_TIFF) as src:
        print(f"\n=== Sentinel-2 GeoTIFF Information ===")
        print(f"File: {GEO_TIFF.name}")
        print(f"Size: {src.width} x {src.height} pixels")
        print(f"Bands: {src.count}")
        print(f"Data type: {src.dtypes[0]}")
        print(f"CRS: {src.crs}")
        print(f"Bounds: {src.bounds}")
        
        if hasattr(src, 'indexes'):
            print(f"Band indexes available: {src.indexes}")
        
        # Show preview
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        try:
            # Read specified bands for display
            if src.count >= max(BAND_INDICES):
                rgb_data = src.read(BAND_INDICES)
                
                # Normalize to 8-bit for display (Sentinel-2 is typically 12-bit or 16-bit)
                rgb_display = np.zeros((rgb_data.shape[1], rgb_data.shape[2], 3), dtype=np.uint8)
                for i in range(3):
                    band = rgb_data[i]
                    # Stretch band to 0-255
                    vmin, vmax = np.percentile(band[band > 0], [2, 98]) if np.any(band > 0) else (0, 255)
                    band_normalized = np.clip((band.astype(float) - vmin) / (vmax - vmin) * 255, 0, 255).astype(np.uint8)
                    rgb_display[:, :, i] = band_normalized
                
                ax.imshow(rgb_display)
            else:
                # Fallback: use first 3 available bands
                num_bands = min(3, src.count)
                rgb_data = src.read(list(range(1, num_bands+1)))
                rgb_display = np.transpose(rgb_data, (1, 2, 0))
                ax.imshow(rgb_display)
            
            ax.set_title(f'Sentinel-2 Preview: {BAND_COMBO_NAME}')
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"⚠ Preview failed: {e}")


In [None]:
# Run sliding-window inference with error handling
import time

print("=== Running Inference ===")
print(f"Processing: {GEO_TIFF.name}")
print(f"Model: {WEIGHTS.name}")
print(f"Parameters: tile_size={TILE_SIZE}, overlap={OVERLAP}, conf={CONF_THRESHOLD}, device={DEVICE}")

if not validate_paths():
    print("❌ Cannot run inference — paths invalid (see cell above)")
else:
    try:
        # Create output directory
        OUT_GEOJSON.parent.mkdir(parents=True, exist_ok=True)
        
        # Run inference with progress tracking
        start_time = time.time()
        
        print(f"\n⏳ Starting inference...")
        out_path = sliding_window_inference(
            geotiff_path=str(GEO_TIFF),
            weights=str(WEIGHTS),
            tile_size_px=TILE_SIZE,
            overlap=OVERLAP,
            out_geojson=str(OUT_GEOJSON),
            conf_threshold=CONF_THRESHOLD,
            device=DEVICE,
            band_indices=BAND_INDICES,
            nms_iou=NMS_IOU
        )
        
        elapsed = time.time() - start_time
        print(f"\n✓ Inference completed in {elapsed:.1f}s")
        print(f"✓ Output saved: {out_path}")
        
        # Load and display statistics
        gdf = gpd.read_file(out_path)
        print(f"\n=== Detection Statistics ===")
        print(f"Total detections: {len(gdf)}")
        if len(gdf) > 0:
            print(f"Confidence - Mean: {gdf['confidence'].mean():.3f}, Min: {gdf['confidence'].min():.3f}, Max: {gdf['confidence'].max():.3f}")
            print(f"Classes: {gdf['class'].unique().tolist()}")
            print(f"Spatial bounds: {gdf.total_bounds}")
        
    except FileNotFoundError as e:
        print(f"❌ File error: {e}")
    except ValueError as e:
        print(f"❌ Validation error: {e}")
    except Exception as e:
        print(f"❌ Inference failed: {e}")
        import traceback
        traceback.print_exc()


In [None]:
# Load detections and create enhanced visualization
print("=== Detection Visualization ===")

if not OUT_GEOJSON.exists():
    print(f"⚠ GeoJSON not found: {OUT_GEOJSON}")
    print("Please run the inference cell above first")
else:
    try:
        # Load detections
        gdf = gpd.read_file(OUT_GEOJSON)
        print(f"Loaded {len(gdf)} detections from {OUT_GEOJSON.name}")
        
        if len(gdf) == 0:
            print("⚠ No detections found - try lowering CONF_THRESHOLD")
        else:
            # Create visualization with confidence coloring
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 9))
            
            # Left: Base map with detections
            with rasterio.open(GEO_TIFF) as src:
                show(src.read(BAND_INDICES if src.count >= max(BAND_INDICES) else [1,2,3]), 
                     transform=src.transform, ax=ax1)
                ax1.set_title('GeoTIFF with Detections')
            
            # Color code by confidence
            norm = Normalize(vmin=gdf['confidence'].min(), vmax=gdf['confidence'].max())
            gdf.plot(ax=ax1, column='confidence', cmap='RdYlGn', 
                    edgecolor='black', linewidth=1.5, alpha=0.7, norm=norm, legend=False)
            
            # Add colorbar
            sm = ScalarMappable(cmap='RdYlGn', norm=norm)
            sm.set_array([])
            cbar1 = plt.colorbar(sm, ax=ax1, label='Confidence Score')
            
            # Right: Detections only
            gdf.plot(ax=ax2, column='confidence', cmap='RdYlGn', 
                    edgecolor='black', linewidth=2, alpha=0.8, norm=norm, legend=True)
            ax2.set_title(f'Detections ({len(gdf)} objects)')
            
            plt.tight_layout()
            plt.show()
            
            # Display detection table
            print(f"\n=== Detection Details ===")
            display_df = gdf[['class', 'confidence', 'class_id']].copy()
            display_df['geometry'] = gdf.geometry.bounds.apply(
                lambda x: f"({x[0]:.2f}, {x[1]:.2f}, {x[2]:.2f}, {x[3]:.2f})"
            )
            print(display_df.to_string())
            
            # Summary by class
            if 'class' in gdf.columns:
                print(f"\n=== Detections by Class ===")
                class_summary = gdf.groupby('class').agg({
                    'confidence': ['count', 'mean', 'min', 'max']
                }).round(3)
                print(class_summary)
                
    except Exception as e:
        print(f"❌ Visualization failed: {e}")
        import traceback
        traceback.print_exc()


## Export Results

The detections are already saved as GeoJSON in the inference step. You can export to other formats as needed:


In [None]:
# Export detections to multiple formats
if not OUT_GEOJSON.exists():
    print("⚠ No detections to export - run inference first")
else:
    gdf = gpd.read_file(OUT_GEOJSON)
    print(f"Exporting {len(gdf)} detections...")
    
    # GeoPackage
    gpkg_path = OUT_GEOJSON.parent / 'detections.gpkg'
    try:
        gdf.to_file(gpkg_path, driver='GPKG')
        print(f"✓ GeoPackage: {gpkg_path}")
    except Exception as e:
        print(f"✗ GeoPackage failed: {e}")
    
    # Shapefile
    shp_path = OUT_GEOJSON.parent / 'detections'
    try:
        gdf.to_file(shp_path, driver='SHAPEFILE')
        print(f"✓ Shapefile: {shp_path}")
    except Exception as e:
        print(f"✗ Shapefile failed: {e}")
    
    # CSV
    csv_path = OUT_GEOJSON.parent / 'detections.csv'
    try:
        csv_df = gdf.copy()
        csv_df['geometry'] = csv_df.geometry.to_wkt()
        csv_df.to_csv(csv_path, index=False)
        print(f"✓ CSV: {csv_path}")
    except Exception as e:
        print(f"✗ CSV failed: {e}")
    
    print(f"\n✓ All exports complete")


## Troubleshooting & Tips for Sentinel-2 Inference

### Common Issues

| Issue | Cause | Solution |
|-------|-------|----------|
| "GeoTIFF not found" | Wrong path or file missing | Ensure GeoTIFF is in `data/test/` or update GEO_TIFF path |
| "Model weights not found" | Model not trained | Run training notebook first or download pre-trained weights |
| "No detections found" | Confidence threshold too high | Lower CONF_THRESHOLD (e.g., 0.3) for more detections |
| "Band indices out of range" | Wrong bands for this GeoTIFF | Check actual band count and adjust BAND_INDICES (1-indexed) |
| "CUDA out of memory" | Image too large for GPU | Use `DEVICE = 'cpu'` or increase OVERLAP |
| Very slow inference | Large scene or small tiles | Increase TILE_SIZE (1024+) or use GPU (`DEVICE = '0'`) |
| Poor detection quality | Model trained on different data | Retrain model with your Sentinel-2 imagery |

### Sentinel-2 Band Combinations

Common visualization band combinations (1-indexed):

| Name | Bands | Description | Use Case |
|------|-------|-------------|----------|
| Natural Color | 4, 3, 2 | RGB similar to natural vision | General viewing |
| False Color (Vegetation) | 8, 4, 3 | NIR, Red, Green | Vegetation analysis |
| SWIR | 12, 11, 4 | SWIR2, SWIR1, Red | Cloud detection, built-up areas |
| Agriculture | 11, 8, 2 | SWIR1, NIR, Blue | Agricultural analysis |

Update `BAND_INDICES` in configuration to use different combinations.

### Performance Tips

- **Tile Size**: Larger tiles (1024+) = faster but higher memory. Smaller (512) = slower but more precise.
- **Overlap**: 15-20% of tile size recommended. Increase for better edge detection of small objects.
- **Confidence**: Start at 0.5, lower to 0.3 for more detections, raise to 0.7 for fewer false positives.
- **GPU**: Use `DEVICE = '0'` (or GPU number) for ~5-10x speedup on NVIDIA GPUs.
- **Resolution**: Consider downsampling very large scenes (>10,000x10,000 px) before inference.

### Sentinel-2 Specifics

1. **Data Availability**: Download from Copernicus (EODATA) or use cloud platforms (AWS, GCP)
2. **Radiometric Correction**: Data is provided as Top-of-Atmosphere (TOA) reflectance (12-bit)
3. **Cloud Masking**: Use SCL band (band 11 in some data products) to mask clouds
4. **Temporal Analysis**: Compare scenes from different dates for change detection
5. **Multi-spectral Training**: Consider using NIR band (8) for better aircraft/ship distinction from clouds

### Next Steps

1. **Inspect Results**: Examine detections overlaid on satellite imagery
2. **Validate**: Compare with ground truth or manual inspection
3. **Refine Parameters**: Adjust CONF_THRESHOLD, TILE_SIZE, and bands
4. **Batch Processing**: Adapt code to process multiple scenes
5. **Integration**: Use exported GeoJSON/Shapefile in QGIS or ArcGIS
