<a href="https://colab.research.google.com/github/JonasLewe/terramind_object_detection/blob/main/notebooks/SAR_Ship_Detection_optimized.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive")


In [None]:
!pip install terratorch==1.1.1 gdown tensorboard > /dev/null 2>&1

In [None]:
import os
import gc
import torch
import gdown
import terratorch
import albumentations
import rasterio
import rasterio.windows
import rasterio.enums
import json
from collections import Counter
import shutil
import numpy as np
import pandas as pd
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import pyproj
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
warnings.filterwarnings("ignore")

# Download Dataset

In [None]:
!mkdir -p /content/xview3/raw /content/xview3/chips /content/xview3/runs
!mkdir -p /content/drive/MyDrive/xview3/{chips,runs,meta}

In [None]:
!rsync -ah --info=progress2 /content/drive/MyDrive/xview3/raw/ /content/xview3/raw/

In [None]:
!for f in /content/xview3/raw/*.tar.gz; do tar -xzvf "$f" -C /content/xview3/raw && rm "$f"; done

In [None]:
!mkdir -p /content/xview3/meta
!cp /content/drive/MyDrive/xview3/meta/train.csv /content/xview3/meta/train.csv
!cp /content/drive/MyDrive/xview3/meta/validation.csv /content/xview3/meta/validation.csv
!ls -lah /content/xview3/meta

# Dataset Exploration

In [None]:
BASE = Path("/content/xview3")
RAW  = BASE / "raw"
META = BASE / "meta"
EXP  = BASE / "explore"
CHIP = BASE / "chips"
RUNS = BASE / "runs"

# Memory-saving: Downsample factor for overview plots
DOWNSAMPLE_FACTOR = 8

In [None]:
def list_scenes(raw_dir: Path):
    """Detect scene folders containing VV_dB.tif + VH_dB.tif"""
    scenes = []
    if not raw_dir.exists():
        print(f"Warning: {raw_dir} does not exist")
        return scenes

    for d in sorted(raw_dir.iterdir()):
        if d.is_dir() and (d / "VV_dB.tif").exists() and (d / "VH_dB.tif").exists():
            scenes.append(d.name)
    return scenes

scenes = list_scenes(RAW)
print(f"Found {len(scenes)} scenes under {RAW}")
print("First scenes:", scenes[:10])


In [None]:
# Cell: Dataset Overview
import json
from collections import Counter

# Load metadata
train_df = pd.read_csv(META / "train.csv")
val_df = pd.read_csv(META / "validation.csv")

print(f"{'='*60}")
print(f"xView3 Dataset Overview")
print(f"{'='*60}\n")

print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"Total labels: {len(train_df) + len(val_df):,}\n")

# Scene distribution
train_scenes = train_df['scene_id'].nunique()
val_scenes = val_df['scene_id'].nunique()
print(f"Training scenes: {train_scenes}")
print(f"Validation scenes: {val_scenes}\n")

# Class distribution (assuming 'is_vessel', 'is_fishing' columns)
if 'is_vessel' in train_df.columns:
    print("Class Distribution (Train):")
    print(f"  Vessels: {train_df['is_vessel'].sum():,}")
    # Fix: Compare to 0 instead of using ~
    print(f"  Non-vessels: {(train_df['is_vessel'] == 0).sum():,}")
    if 'is_fishing' in train_df.columns:
        # Fix: Boolean indexing requires bool or comparison
        fishing = train_df[train_df['is_vessel'] == 1]['is_fishing'].sum()
        print(f"  Fishing vessels: {fishing:,}")
        print(f"  Non-fishing vessels: {train_df['is_vessel'].sum() - fishing:,}\n")

# Labels per scene
labels_per_scene = train_df.groupby('scene_id').size()
print(f"Labels per scene (train):")
print(f"  Mean: {labels_per_scene.mean():.1f}")
print(f"  Median: {labels_per_scene.median():.1f}")
print(f"  Min/Max: {labels_per_scene.min()} / {labels_per_scene.max()}")

# Visualization Helper Functions

In [None]:
def load_sar_scene(scene_dir: Path, bands=["VV", "VH"],
                   downsample=1, window=None):
    """
    Load VV and VH bands from scene directory.

    Args:
        scene_dir: Path to scene folder
        bands: List of bands to load
        downsample: Factor for downsampling (e.g., 8 = 1/8 resolution)
        window: rasterio.windows.Window for chip extraction (overrides downsample)

    Returns:
        dict with band data and metadata
    """
    data = {}
    for band in bands:
        path = scene_dir / f"{band}_dB.tif"
        with rasterio.open(path) as src:
            if window is not None:
                # Read specific window only (for chips)
                data[band] = src.read(1, window=window).astype("float32")
                if band == bands[0]:
                    data['meta'] = {
                        'transform': src.window_transform(window),
                        'crs': src.crs,
                        'bounds': src.bounds,
                        'shape': src.shape,
                        'loaded_shape': data[band].shape
                    }
            elif downsample > 1:
                # Read downsampled for overview plots
                h, w = src.height // downsample, src.width // downsample
                data[band] = src.read(
                    1,
                    out_shape=(h, w),
                    resampling=rasterio.enums.Resampling.average
                ).astype("float32")
                if band == bands[0]:
                    # Adjust transform for downsampled data
                    data['meta'] = {
                        'transform': src.transform * src.transform.scale(
                            src.width / w,
                            src.height / h
                        ),
                        'crs': src.crs,
                        'bounds': src.bounds,
                        'shape': src.shape,
                        'loaded_shape': (h, w),
                        'downsample': downsample
                    }
            else:
                # Full resolution (WARNING: high RAM usage!)
                data[band] = src.read(1).astype("float32")
                if band == bands[0]:
                    data['meta'] = {
                        'transform': src.transform,
                        'crs': src.crs,
                        'bounds': src.bounds,
                        'shape': src.shape,
                        'loaded_shape': src.shape
                    }
    return data

def sar_to_rgb(vv, vh, percentile_clip=(2, 98)):
    """
    Create RGB composite from SAR dual-pol data
    R = VV, G = VH, B = VV/VH ratio (cross-pol)
    """
    # Clip outliers
    vv_clip = np.nanpercentile(vv, percentile_clip)
    vh_clip = np.nanpercentile(vh, percentile_clip)

    vv_norm = np.clip((vv - vv_clip[0]) / (vv_clip[1] - vv_clip[0] + 1e-6), 0, 1)
    vh_norm = np.clip((vh - vh_clip[0]) / (vh_clip[1] - vh_clip[0] + 1e-6), 0, 1)

    # Cross-pol ratio (indicator for surface roughness)
    ratio = np.where(vh != 0, vv / (vh + 1e-6), 0)
    ratio_clip = np.nanpercentile(ratio[np.isfinite(ratio)], percentile_clip)
    ratio_norm = np.clip((ratio - ratio_clip[0]) / (ratio_clip[1] - ratio_clip[0] + 1e-6), 0, 1)

    rgb = np.stack([vv_norm, vh_norm, ratio_norm], axis=-1)
    return rgb

def latlon_to_pixel(lats, lons, transform, crs, downsample=1):
    """
    Convert lat/lon to pixel coordinates.

    Args:
        downsample: If image was downsampled, adjust pixel coords accordingly
    """
    transformer = pyproj.Transformer.from_crs("EPSG:4326", crs, always_xy=True)
    xs, ys = transformer.transform(lons, lats)

    # Rasterio transform: pixel = ~transform * (x, y)
    inv_transform = ~transform
    pixels = [inv_transform * (x, y) for x, y in zip(xs, ys)]
    cols = np.array([p[0] for p in pixels], dtype=int)
    rows = np.array([p[1] for p in pixels], dtype=int)

    # Adjust for downsampling
    if downsample > 1:
        cols = cols // downsample
        rows = rows // downsample

    return rows, cols

def plot_sar_with_labels(scene_id, df, raw_dir=RAW, max_labels=None, downsample=DOWNSAMPLE_FACTOR):
    """
    Plot SAR scene with overlaid labels (memory-efficient with downsampling)
    """
    scene_dir = raw_dir / scene_id
    scene_labels = df[df['scene_id'] == scene_id].copy()

    if max_labels:
        scene_labels = scene_labels.head(max_labels)

    # Load SAR data with downsampling for memory efficiency
    sar_data = load_sar_scene(scene_dir, downsample=downsample)
    vv, vh = sar_data['VV'], sar_data['VH']
    meta = sar_data['meta']

    # Convert labels to pixels (using original transform, then scale)
    with rasterio.open(scene_dir / "VV_dB.tif") as src:
        orig_transform = src.transform
        orig_crs = src.crs

    rows, cols = latlon_to_pixel(
        scene_labels['detect_lat'].values,
        scene_labels['detect_lon'].values,
        orig_transform,
        orig_crs,
        downsample=downsample
    )

    # Create RGB composite
    rgb = sar_to_rgb(vv, vh)

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # VV band
    axes[0].imshow(vv, cmap='gray', vmin=np.nanpercentile(vv, 2), vmax=np.nanpercentile(vv, 98))
    axes[0].scatter(cols, rows, c='red', s=30, marker='x', alpha=0.7, linewidths=2)
    axes[0].set_title(f'VV (dB) - {len(scene_labels)} labels')
    axes[0].axis('off')

    # VH band
    axes[1].imshow(vh, cmap='gray', vmin=np.nanpercentile(vh, 2), vmax=np.nanpercentile(vh, 98))
    axes[1].scatter(cols, rows, c='red', s=30, marker='x', alpha=0.7, linewidths=2)
    axes[1].set_title('VH (dB)')
    axes[1].axis('off')

    # RGB Composite
    axes[2].imshow(rgb)
    axes[2].scatter(cols, rows, c='cyan', s=30, marker='x', alpha=0.8, linewidths=2)
    axes[2].set_title('RGB Composite (VV/VH/Ratio)')
    axes[2].axis('off')

    ds_info = f" (1/{downsample} res)" if downsample > 1 else ""
    plt.suptitle(f'Scene: {scene_id} | Original: {meta["shape"]}{ds_info}', fontsize=14, y=0.98)
    plt.tight_layout()
    plt.show()

    # Print stats
    print(f"\nScene Stats:")
    print(f"  VV range: [{np.nanmin(vv):.2f}, {np.nanmax(vv):.2f}] dB")
    print(f"  VH range: [{np.nanmin(vh):.2f}, {np.nanmax(vh):.2f}] dB")
    print(f"  Labels in scene: {len(scene_labels)}")
    print(f"  Loaded shape: {meta['loaded_shape']}")

    # Cleanup
    del sar_data, vv, vh, rgb
    gc.collect()

In [None]:
# Cell: Detailed Scene Exploration

# Pick a scene with reasonable number of labels (automatic selection)
scene_counts = train_df['scene_id'].value_counts()

# Filter to scenes that actually exist in RAW folder
available_scenes = set(scenes)
valid_scene_counts = scene_counts[scene_counts.index.isin(available_scenes)]

# Select scene with 10-100 labels
candidates = valid_scene_counts[valid_scene_counts.between(10, 100)]
if len(candidates) > 0:
    selected_scene = candidates.index[0]
else:
    # Fallback to any available scene
    selected_scene = valid_scene_counts.index[0]

print(f"Analyzing scene: {selected_scene} ({valid_scene_counts[selected_scene]} labels)")
plot_sar_with_labels(selected_scene, train_df, max_labels=50)

In [None]:
# Cell: Multi-Scene Gallery

# Filter scene_counts to only available scenes
available_scenes = set(scenes)
valid_scene_counts = scene_counts[scene_counts.index.isin(available_scenes)]

# Select diverse scenes (different label counts)
low_count = valid_scene_counts[valid_scene_counts < 10].index[:2]
mid_count = valid_scene_counts[valid_scene_counts.between(10, 50)].index[:2]
high_count = valid_scene_counts[valid_scene_counts > 50].index[:2]

gallery_scenes = list(low_count) + list(mid_count) + list(high_count)

for scene_id in gallery_scenes[:6]:  # Show first 6
    print(f"\n{'='*80}")
    plot_sar_with_labels(scene_id, train_df, max_labels=30)

In [None]:
# Cell: Zoomed-in Label Chips (Memory-efficient with windowed reads)

def plot_label_chips(scene_id, df, n_chips=9, chip_size=128):
    """Plot grid of zoomed-in chips around labels using windowed reads"""
    scene_labels_full = df[df['scene_id'] == scene_id]
    scene_labels = scene_labels_full.sample(min(n_chips, len(scene_labels_full)))

    scene_dir = RAW / scene_id
    vv_path = scene_dir / "VV_dB.tif"
    vh_path = scene_dir / "VH_dB.tif"

    # Open files once, read chips as needed
    with rasterio.open(vv_path) as vv_src, rasterio.open(vh_path) as vh_src:
        transform = vv_src.transform
        crs = vv_src.crs
        img_height, img_width = vv_src.height, vv_src.width

        # Convert labels to pixels
        rows, cols = latlon_to_pixel(
            scene_labels['detect_lat'].values,
            scene_labels['detect_lon'].values,
            transform,
            crs
        )

        n_cols_plot = 3
        n_rows_plot = (len(scene_labels) + n_cols_plot - 1) // n_cols_plot
        fig, axes = plt.subplots(n_rows_plot, n_cols_plot, figsize=(12, 4*n_rows_plot))
        axes = np.array(axes).flatten() if n_rows_plot > 1 else [axes] if n_rows_plot == 1 else axes

        for idx, (row, col) in enumerate(zip(rows, cols)):
            # Calculate window bounds with boundary checks
            half = chip_size // 2
            c_start = max(0, col - half)
            r_start = max(0, row - half)
            c_end = min(img_width, col + half)
            r_end = min(img_height, row + half)

            window = rasterio.windows.Window(
                col_off=c_start,
                row_off=r_start,
                width=c_end - c_start,
                height=r_end - r_start
            )

            # Read only the chip (memory efficient!)
            vv_chip = vv_src.read(1, window=window).astype("float32")
            vh_chip = vh_src.read(1, window=window).astype("float32")
            rgb_chip = sar_to_rgb(vv_chip, vh_chip)

            # Center marker (relative to chip)
            center_col = col - c_start
            center_row = row - r_start

            axes[idx].imshow(rgb_chip)
            axes[idx].scatter([center_col], [center_row], c='red', s=100, marker='+', linewidths=3)
            axes[idx].set_title(f'Label {idx+1}', fontsize=10)
            axes[idx].axis('off')

        # Hide empty subplots
        for idx in range(len(scene_labels), len(axes)):
            axes[idx].axis('off')

    plt.suptitle(f'Label Chips from {scene_id}', fontsize=14)
    plt.tight_layout()
    plt.show()

    gc.collect()

# Run for selected scene
plot_label_chips(selected_scene, train_df, n_chips=9, chip_size=192)

In [None]:
# Cell: Data Quality Analysis (Memory-efficient sampling)

def analyze_data_quality(df, scene_sample=10):
    """Check for data quality issues using memory-efficient sampling"""
    print(f"{'='*60}")
    print("Data Quality Checks")
    print(f"{'='*60}\n")

    # 1. Missing values
    print("Missing Values:")
    print(df.isnull().sum())
    print()

    # 2. Coordinate ranges
    print("Coordinate Ranges:")
    print(f"  Latitude: [{df['detect_lat'].min():.4f}, {df['detect_lat'].max():.4f}]")
    print(f"  Longitude: [{df['detect_lon'].min():.4f}, {df['detect_lon'].max():.4f}]")
    print()

    # 3. Sample scenes for image quality (memory efficient)
    available_scenes = set(scenes)
    sample_scene_ids = df['scene_id'].value_counts().head(scene_sample * 2).index
    sample_scene_ids = [s for s in sample_scene_ids if s in available_scenes][:scene_sample]

    vv_ranges, vh_ranges = [], []
    nan_counts = []

    print(f"Sampling {len(sample_scene_ids)} scenes for statistics...")

    for scene_id in sample_scene_ids:
        scene_dir = RAW / scene_id
        if not scene_dir.exists():
            continue

        # Memory-efficient: read downsampled version
        sar_data = load_sar_scene(scene_dir, downsample=16)  # Very small for stats
        vv, vh = sar_data['VV'], sar_data['VH']

        vv_ranges.append((np.nanmin(vv), np.nanmax(vv)))
        vh_ranges.append((np.nanmin(vh), np.nanmax(vh)))
        nan_counts.append((np.isnan(vv).sum(), np.isnan(vh).sum()))

        del sar_data, vv, vh
        gc.collect()

    if vv_ranges:
        print(f"\nSAR Value Ranges (sampled {len(vv_ranges)} scenes):")
        vv_mins, vv_maxs = zip(*vv_ranges)
        vh_mins, vh_maxs = zip(*vh_ranges)
        print(f"  VV: [{np.mean(vv_mins):.2f}, {np.mean(vv_maxs):.2f}] dB (avg)")
        print(f"  VH: [{np.mean(vh_mins):.2f}, {np.mean(vh_maxs):.2f}] dB (avg)")
        print(f"  NaN pixels (VV): {np.mean([n[0] for n in nan_counts]):.0f} avg (in downsampled)")
        print(f"  NaN pixels (VH): {np.mean([n[1] for n in nan_counts]):.0f} avg (in downsampled)")

analyze_data_quality(train_df, scene_sample=10)