<a href="https://colab.research.google.com/github/JonasLewe/terramind_object_detection/blob/main/notebooks/SAR_Ship_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive")


Mounted at /content/drive


In [2]:
!pip install terratorch==1.1.1 gdown tensorboard > /dev/null 2>&1

In [3]:
import os
import torch
import gdown
import terratorch
import albumentations
import rasterio
import json
from collections import Counter
import shutil
import numpy as np
import pandas as pd
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import pyproj
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
warnings.filterwarnings("ignore")

Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
  _C._set_float32_matmul_precision(precision)


# Download Dataset

In [None]:
!mkdir -p /content/xview3/raw /content/xview3/chips /content/xview3/runs
!mkdir -p /content/drive/MyDrive/xview3/{chips,runs,meta}

In [None]:
!rsync -ah --info=progress2 /content/drive/MyDrive/xview3/raw/ /content/xview3/raw/

In [None]:
!for f in /content/xview3/raw/*.tar.gz; do tar -xzvf "$f" -C /content/xview3/raw && rm "$f"; done

In [None]:
!mkdir -p /content/xview3/meta
!cp /content/drive/MyDrive/xview3/meta/train.csv /content/xview3/meta/train.csv
!cp /content/drive/MyDrive/xview3/meta/validation.csv /content/xview3/meta/validation.csv
!ls -lah /content/xview3/meta

# Dataset Exploration

In [None]:
BASE = Path("/content/xview3")
RAW  = BASE / "raw"
META = BASE / "meta"
EXP  = BASE / "explore"
CHIP = BASE / "chips"
RUNS = BASE / "runs"

In [None]:
def list_scenes(raw_dir: Path):
    # scene folders are dirs containing VV_dB.tif + VH_dB.tif
    scenes = []
    for d in sorted(raw_dir.iterdir()):
        if d.is_dir() and (d / "VV_dB.tif").exists() and (d / "VH_dB.tif").exists():
            scenes.append(d.name)
    return scenes

scenes = list_scenes(RAW)
print(f"Found {len(scenes)} scenes under {RAW}")
print("First scenes:", scenes[:10])


In [23]:
# Cell: Dataset Overview
import json
from collections import Counter

# Load metadata
train_df = pd.read_csv(META / "train.csv")
val_df = pd.read_csv(META / "validation.csv")

print(f"{'='*60}")
print(f"xView3 Dataset Overview")
print(f"{'='*60}\n")

print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"Total labels: {len(train_df) + len(val_df):,}\n")

# Scene distribution
train_scenes = train_df['scene_id'].nunique()
val_scenes = val_df['scene_id'].nunique()
print(f"Training scenes: {train_scenes}")
print(f"Validation scenes: {val_scenes}\n")

# Class distribution (assuming 'is_vessel', 'is_fishing' columns)
if 'is_vessel' in train_df.columns:
    print("Class Distribution (Train):")
    print(f"  Vessels: {train_df['is_vessel'].sum():,}")
    # Fix: Compare to 0 instead of using ~
    print(f"  Non-vessels: {(train_df['is_vessel'] == 0).sum():,}")
    if 'is_fishing' in train_df.columns:
        # Fix: Boolean indexing requires bool or comparison
        fishing = train_df[train_df['is_vessel'] == 1]['is_fishing'].sum()
        print(f"  Fishing vessels: {fishing:,}")
        print(f"  Non-fishing vessels: {train_df['is_vessel'].sum() - fishing:,}\n")

# Labels per scene
labels_per_scene = train_df.groupby('scene_id').size()
print(f"Labels per scene (train):")
print(f"  Mean: {labels_per_scene.mean():.1f}")
print(f"  Median: {labels_per_scene.median():.1f}")
print(f"  Min/Max: {labels_per_scene.min()} / {labels_per_scene.max()}")

xView3 Dataset Overview

Training samples: 64,113
Validation samples: 19,224
Total labels: 83,337

Training scenes: 554
Validation scenes: 50

Class Distribution (Train):
  Vessels: 36,375
  Non-vessels: 16,692
  Fishing vessels: 12,510
  Non-fishing vessels: 23,865

Labels per scene (train):
  Mean: 115.7
  Median: 84.5
  Min/Max: 26 / 435


# Visualization Helper Functions

In [25]:
def load_sar_scene(scene_dir: Path, bands=["VV", "VH"]):
    """Load VV and VH bands from scene directory"""
    data = {}
    for band in bands:
        path = scene_dir / f"{band}_dB.tif"
        with rasterio.open(path) as src:
            data[band] = src.read(1).astype("float32")
            if band == bands[0]:  # Store metadata from first band
                data['meta'] = {
                    'transform': src.transform,
                    'crs': src.crs,
                    'bounds': src.bounds,
                    'shape': src.shape
                }
    return data

def sar_to_rgb(vv, vh, percentile_clip=(2, 98)):
    """
    Create RGB composite from SAR dual-pol data
    R = VV, G = VH, B = VV/VH ratio (cross-pol)
    """
    # Clip outliers
    vv_clip = np.nanpercentile(vv, percentile_clip)
    vh_clip = np.nanpercentile(vh, percentile_clip)

    vv_norm = np.clip((vv - vv_clip[0]) / (vv_clip[1] - vv_clip[0] + 1e-6), 0, 1)
    vh_norm = np.clip((vh - vh_clip[0]) / (vh_clip[1] - vh_clip[0] + 1e-6), 0, 1)

    # Cross-pol ratio (indicator for surface roughness)
    ratio = np.where(vh != 0, vv / (vh + 1e-6), 0)
    ratio_clip = np.nanpercentile(ratio[np.isfinite(ratio)], percentile_clip)
    ratio_norm = np.clip((ratio - ratio_clip[0]) / (ratio_clip[1] - ratio_clip[0] + 1e-6), 0, 1)

    rgb = np.stack([vv_norm, vh_norm, ratio_norm], axis=-1)
    return rgb

def latlon_to_pixel(lats, lons, transform, crs):
    """Convert lat/lon to pixel coordinates"""
    transformer = pyproj.Transformer.from_crs("EPSG:4326", crs, always_xy=True)
    xs, ys = transformer.transform(lons, lats)

    # Rasterio transform: pixel = ~transform * (x, y)
    inv_transform = ~transform
    pixels = [inv_transform * (x, y) for x, y in zip(xs, ys)]
    cols = np.array([p[0] for p in pixels], dtype=int)
    rows = np.array([p[1] for p in pixels], dtype=int)

    return rows, cols

def plot_sar_with_labels(scene_id, df, raw_dir=RAW, window_size=1024, max_labels=None):
    """
    Plot SAR scene with overlaid labels
    """
    scene_dir = raw_dir / scene_id
    scene_labels = df[df['scene_id'] == scene_id].copy()

    if max_labels:
        scene_labels = scene_labels.head(max_labels)

    # Load SAR data
    sar_data = load_sar_scene(scene_dir)
    vv, vh = sar_data['VV'], sar_data['VH']
    meta = sar_data['meta']

    # Convert labels to pixels
    rows, cols = latlon_to_pixel(
        scene_labels['detect_lat'].values,
        scene_labels['detect_lon'].values,
        meta['transform'],
        meta['crs']
    )

    # Create RGB composite
    rgb = sar_to_rgb(vv, vh)

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # VV band
    axes[0].imshow(vv, cmap='gray', vmin=np.nanpercentile(vv, 2), vmax=np.nanpercentile(vv, 98))
    axes[0].scatter(cols, rows, c='red', s=30, marker='x', alpha=0.7, linewidths=2)
    axes[0].set_title(f'VV (dB) - {len(scene_labels)} labels')
    axes[0].axis('off')

    # VH band
    axes[1].imshow(vh, cmap='gray', vmin=np.nanpercentile(vh, 2), vmax=np.nanpercentile(vh, 98))
    axes[1].scatter(cols, rows, c='red', s=30, marker='x', alpha=0.7, linewidths=2)
    axes[1].set_title('VH (dB)')
    axes[1].axis('off')

    # RGB Composite
    axes[2].imshow(rgb)
    axes[2].scatter(cols, rows, c='cyan', s=30, marker='x', alpha=0.8, linewidths=2)
    axes[2].set_title('RGB Composite (VV/VH/Ratio)')
    axes[2].axis('off')

    plt.suptitle(f'Scene: {scene_id} | Shape: {meta["shape"]}', fontsize=14, y=0.98)
    plt.tight_layout()
    plt.show()

    # Print stats
    print(f"\nScene Stats:")
    print(f"  VV range: [{np.nanmin(vv):.2f}, {np.nanmax(vv):.2f}] dB")
    print(f"  VH range: [{np.nanmin(vh):.2f}, {np.nanmax(vh):.2f}] dB")
    print(f"  Labels in scene: {len(scene_labels)}")
    print(f"  Image shape: {meta['shape']}")

In [None]:
# Cell: Detailed Scene Exploration

# Pick a scene with reasonable number of labels
scene_counts = train_df['scene_id'].value_counts()
selected_scene = scene_counts[scene_counts.between(10, 100)].index[0]

selected_scene = "05bc615a9b0e1159t"
print(f"Analyzing scene: {selected_scene}")
plot_sar_with_labels(selected_scene, train_df, max_labels=50)

Analyzing scene: 05bc615a9b0e1159t


In [None]:
# Cell: Multi-Scene Gallery

# Select diverse scenes (different label counts)
scene_counts = train_df['scene_id'].value_counts()
low_count = scene_counts[scene_counts < 10].index[:2]
mid_count = scene_counts[scene_counts.between(10, 50)].index[:2]
high_count = scene_counts[scene_counts > 50].index[:2]

gallery_scenes = list(low_count) + list(mid_count) + list(high_count)

for scene_id in gallery_scenes[:6]:  # Show first 6
    print(f"\n{'='*80}")
    plot_sar_with_labels(scene_id, train_df, max_labels=30)

In [None]:
# Cell: Zoomed-in Label Chips

def extract_chip(vv, vh, row, col, chip_size=128):
    """Extract a chip centered on (row, col)"""
    r_start = max(0, row - chip_size // 2)
    r_end = min(vv.shape[0], row + chip_size // 2)
    c_start = max(0, col - chip_size // 2)
    c_end = min(vv.shape[1], col + chip_size // 2)

    vv_chip = vv[r_start:r_end, c_start:c_end]
    vh_chip = vh[r_start:r_end, c_start:c_end]

    return vv_chip, vh_chip, (r_start, c_start)

def plot_label_chips(scene_id, df, n_chips=9, chip_size=128):
    """Plot grid of zoomed-in chips around labels"""
    scene_labels = df[df['scene_id'] == scene_id].sample(min(n_chips, len(df[df['scene_id'] == scene_id])))

    scene_dir = RAW / scene_id
    sar_data = load_sar_scene(scene_dir)
    vv, vh = sar_data['VV'], sar_data['VH']
    meta = sar_data['meta']

    rows, cols = latlon_to_pixel(
        scene_labels['detect_lat'].values,
        scene_labels['detect_lon'].values,
        meta['transform'],
        meta['crs']
    )

    n_cols = 3
    n_rows = (len(scene_labels) + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 4*n_rows))
    axes = axes.flatten() if n_rows > 1 else [axes]

    for idx, (row, col) in enumerate(zip(rows, cols)):
        vv_chip, vh_chip, offset = extract_chip(vv, vh, row, col, chip_size)
        rgb_chip = sar_to_rgb(vv_chip, vh_chip)

        # Center marker
        center = (chip_size // 2, chip_size // 2)

        axes[idx].imshow(rgb_chip)
        axes[idx].scatter([center[1]], [center[0]], c='red', s=100, marker='+', linewidths=3)
        axes[idx].set_title(f'Label {idx+1}', fontsize=10)
        axes[idx].axis('off')

    # Hide empty subplots
    for idx in range(len(scene_labels), len(axes)):
        axes[idx].axis('off')

    plt.suptitle(f'Label Chips from {scene_id}', fontsize=14)
    plt.tight_layout()
    plt.show()

# Run for selected scene
plot_label_chips(selected_scene, train_df, n_chips=9, chip_size=192)

In [None]:
# Cell: Data Quality Analysis

def analyze_data_quality(df, scene_sample=10):
    """Check for data quality issues"""
    print(f"{'='*60}")
    print("Data Quality Checks")
    print(f"{'='*60}\n")

    # 1. Missing values
    print("Missing Values:")
    print(df.isnull().sum())
    print()

    # 2. Coordinate ranges
    print("Coordinate Ranges:")
    print(f"  Latitude: [{df['detect_lat'].min():.4f}, {df['detect_lat'].max():.4f}]")
    print(f"  Longitude: [{df['detect_lon'].min():.4f}, {df['detect_lon'].max():.4f}]")
    print()

    # 3. Sample scenes for image quality
    sample_scenes = df['scene_id'].value_counts().head(scene_sample).index

    vv_ranges, vh_ranges = [], []
    nan_counts = []

    for scene_id in sample_scenes:
        scene_dir = RAW / scene_id
        if not scene_dir.exists():
            continue

        sar_data = load_sar_scene(scene_dir)
        vv, vh = sar_data['VV'], sar_data['VH']

        vv_ranges.append((np.nanmin(vv), np.nanmax(vv)))
        vh_ranges.append((np.nanmin(vh), np.nanmax(vh)))
        nan_counts.append((np.isnan(vv).sum(), np.isnan(vh).sum()))

    print(f"SAR Value Ranges (sampled {len(vv_ranges)} scenes):")
    vv_mins, vv_maxs = zip(*vv_ranges)
    vh_mins, vh_maxs = zip(*vh_ranges)
    print(f"  VV: [{np.mean(vv_mins):.2f}, {np.mean(vv_maxs):.2f}] dB (avg)")
    print(f"  VH: [{np.mean(vh_mins):.2f}, {np.mean(vh_maxs):.2f}] dB (avg)")
    print(f"  NaN pixels (VV): {np.mean([n[0] for n in nan_counts]):.0f} avg")
    print(f"  NaN pixels (VH): {np.mean([n[1] for n in nan_counts]):.0f} avg")

analyze_data_quality(train_df, scene_sample=10)