<a href="https://colab.research.google.com/github/JonasLewe/terramind_object_detection/blob/main/notebooks/sar_ship_detection_optimized.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
from google.colab import drive
drive.mount("/content/drive")


In [None]:
!pip install terratorch==1.1.1 gdown tensorboard > /dev/null 2>&1

In [None]:
import os
import gc
import torch
import gdown
import terratorch
import albumentations
import rasterio
import rasterio.windows
import rasterio.enums
import json
from collections import Counter
import shutil
import numpy as np
import pandas as pd
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import pyproj
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
warnings.filterwarnings("ignore")

# Initialize Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/xview3/runs

# Download Dataset

In [None]:
!mkdir -p /content/xview3/raw /content/xview3/chips /content/xview3/runs
!mkdir -p /content/drive/MyDrive/xview3/{chips,runs,meta}

In [None]:
!rsync -ah --info=progress2 /content/drive/MyDrive/xview3/raw/ /content/xview3/raw/

In [None]:
!for f in /content/xview3/raw/*.tar.gz; do tar -xzvf "$f" -C /content/xview3/raw && rm "$f"; done

In [None]:
!mkdir -p /content/xview3/meta
!cp /content/drive/MyDrive/xview3/meta/train.csv /content/xview3/meta/train.csv
!cp /content/drive/MyDrive/xview3/meta/validation.csv /content/xview3/meta/validation.csv
!ls -lah /content/xview3/meta

# Dataset Exploration

In [None]:
BASE = Path("/content/xview3")
RAW  = BASE / "raw"
META = BASE / "meta"
EXP  = BASE / "explore"
CHIP = BASE / "chips"
RUNS = BASE / "runs"

# Memory-saving: Downsample factor for overview plots
DOWNSAMPLE_FACTOR = 8

In [None]:
def list_scenes(raw_dir: Path):
    """Detect scene folders containing VV_dB.tif + VH_dB.tif"""
    scenes = []
    if not raw_dir.exists():
        print(f"Warning: {raw_dir} does not exist")
        return scenes

    for d in sorted(raw_dir.iterdir()):
        if d.is_dir() and (d / "VV_dB.tif").exists() and (d / "VH_dB.tif").exists():
            scenes.append(d.name)
    return scenes

scenes = list_scenes(RAW)
print(f"Found {len(scenes)} scenes under {RAW}")
print("First scenes:", scenes[:10])


In [None]:
# Cell: Dataset Overview
import json
from collections import Counter

# Load metadata
train_df = pd.read_csv(META / "train.csv")
val_df = pd.read_csv(META / "validation.csv")

print(f"{'='*60}")
print(f"xView3 Dataset Overview")
print(f"{'='*60}\n")

print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"Total labels: {len(train_df) + len(val_df):,}\n")

# Scene distribution
train_scenes = train_df['scene_id'].nunique()
val_scenes = val_df['scene_id'].nunique()
print(f"Training scenes: {train_scenes}")
print(f"Validation scenes: {val_scenes}\n")

# Class distribution (assuming 'is_vessel', 'is_fishing' columns)
if 'is_vessel' in train_df.columns:
    print("Class Distribution (Train):")
    print(f"  Vessels: {train_df['is_vessel'].sum():,}")
    # Fix: Compare to 0 instead of using ~
    print(f"  Non-vessels: {(train_df['is_vessel'] == 0).sum():,}")
    if 'is_fishing' in train_df.columns:
        # Fix: Boolean indexing requires bool or comparison
        fishing = train_df[train_df['is_vessel'] == 1]['is_fishing'].sum()
        print(f"  Fishing vessels: {fishing:,}")
        print(f"  Non-fishing vessels: {train_df['is_vessel'].sum() - fishing:,}\n")

# Labels per scene
labels_per_scene = train_df.groupby('scene_id').size()
print(f"Labels per scene (train):")
print(f"  Mean: {labels_per_scene.mean():.1f}")
print(f"  Median: {labels_per_scene.median():.1f}")
print(f"  Min/Max: {labels_per_scene.min()} / {labels_per_scene.max()}")

# Visualization Helper Functions

In [None]:
def load_sar_scene(scene_dir: Path, bands=["VV", "VH"],
                   downsample=1, window=None, nodata_threshold=-9999):
    """
    Load VV and VH bands from scene directory.

    Args:
        scene_dir: Path to scene folder
        bands: List of bands to load
        downsample: Factor for downsampling (e.g., 8 = 1/8 resolution)
        window: rasterio.windows.Window for chip extraction (overrides downsample)
        nodata_threshold: Values <= this are treated as NoData (replaced with NaN)

    Returns:
        dict with band data and metadata
    """
    data = {}
    for band in bands:
        path = scene_dir / f"{band}_dB.tif"
        with rasterio.open(path) as src:
            if window is not None:
                arr = src.read(1, window=window).astype("float32")
                if band == bands[0]:
                    data['meta'] = {
                        'transform': src.window_transform(window),
                        'crs': src.crs,
                        'bounds': src.bounds,
                        'shape': src.shape,
                        'loaded_shape': arr.shape
                    }
            elif downsample > 1:
                h, w = src.height // downsample, src.width // downsample
                arr = src.read(
                    1,
                    out_shape=(h, w),
                    resampling=rasterio.enums.Resampling.average
                ).astype("float32")
                if band == bands[0]:
                    data['meta'] = {
                        'transform': src.transform * src.transform.scale(
                            src.width / w,
                            src.height / h
                        ),
                        'crs': src.crs,
                        'bounds': src.bounds,
                        'shape': src.shape,
                        'loaded_shape': (h, w),
                        'downsample': downsample
                    }
            else:
                arr = src.read(1).astype("float32")
                if band == bands[0]:
                    data['meta'] = {
                        'transform': src.transform,
                        'crs': src.crs,
                        'bounds': src.bounds,
                        'shape': src.shape,
                        'loaded_shape': src.shape
                    }

            # Replace NoData with NaN
            arr = np.where(arr <= nodata_threshold, np.nan, arr)
            data[band] = arr

    return data

def sar_to_rgb(vv, vh, percentile_clip=(2, 98)):
    """
    Create RGB composite from SAR dual-pol data
    R = VV, G = VH, B = VV/VH ratio (cross-pol)
    """
    # Clip outliers
    vv_clip = np.nanpercentile(vv, percentile_clip)
    vh_clip = np.nanpercentile(vh, percentile_clip)

    vv_norm = np.clip((vv - vv_clip[0]) / (vv_clip[1] - vv_clip[0] + 1e-6), 0, 1)
    vh_norm = np.clip((vh - vh_clip[0]) / (vh_clip[1] - vh_clip[0] + 1e-6), 0, 1)

    # Cross-pol ratio (indicator for surface roughness)
    ratio = np.where(vh != 0, vv / (vh + 1e-6), 0)
    ratio_clip = np.nanpercentile(ratio[np.isfinite(ratio)], percentile_clip)
    ratio_norm = np.clip((ratio - ratio_clip[0]) / (ratio_clip[1] - ratio_clip[0] + 1e-6), 0, 1)

    rgb = np.stack([vv_norm, vh_norm, ratio_norm], axis=-1)
    return rgb

def latlon_to_pixel(lats, lons, transform, crs, downsample=1):
    """
    Convert lat/lon to pixel coordinates.

    Args:
        downsample: If image was downsampled, adjust pixel coords accordingly
    """
    transformer = pyproj.Transformer.from_crs("EPSG:4326", crs, always_xy=True)
    xs, ys = transformer.transform(lons, lats)

    # Rasterio transform: pixel = ~transform * (x, y)
    inv_transform = ~transform
    pixels = [inv_transform * (x, y) for x, y in zip(xs, ys)]
    cols = np.array([p[0] for p in pixels], dtype=int)
    rows = np.array([p[1] for p in pixels], dtype=int)

    # Adjust for downsampling
    if downsample > 1:
        cols = cols // downsample
        rows = rows // downsample

    return rows, cols

def plot_sar_with_labels(scene_id, df, raw_dir=RAW, max_labels=None, downsample=DOWNSAMPLE_FACTOR):
    """
    Plot SAR scene with overlaid labels (memory-efficient with downsampling)
    """
    scene_dir = raw_dir / scene_id
    scene_labels = df[df['scene_id'] == scene_id].copy()

    if max_labels:
        scene_labels = scene_labels.head(max_labels)

    # Load SAR data with downsampling for memory efficiency
    sar_data = load_sar_scene(scene_dir, downsample=downsample)
    vv, vh = sar_data['VV'], sar_data['VH']
    meta = sar_data['meta']

    # Convert labels to pixels (using original transform, then scale)
    with rasterio.open(scene_dir / "VV_dB.tif") as src:
        orig_transform = src.transform
        orig_crs = src.crs

    rows, cols = latlon_to_pixel(
        scene_labels['detect_lat'].values,
        scene_labels['detect_lon'].values,
        orig_transform,
        orig_crs,
        downsample=downsample
    )

    # Create RGB composite
    rgb = sar_to_rgb(vv, vh)

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # VV band
    axes[0].imshow(vv, cmap='gray', vmin=np.nanpercentile(vv, 2), vmax=np.nanpercentile(vv, 98))
    axes[0].scatter(cols, rows, c='red', s=30, marker='x', alpha=0.7, linewidths=2)
    axes[0].set_title(f'VV (dB) - {len(scene_labels)} labels')
    axes[0].axis('off')

    # VH band
    axes[1].imshow(vh, cmap='gray', vmin=np.nanpercentile(vh, 2), vmax=np.nanpercentile(vh, 98))
    axes[1].scatter(cols, rows, c='red', s=30, marker='x', alpha=0.7, linewidths=2)
    axes[1].set_title('VH (dB)')
    axes[1].axis('off')

    # RGB Composite
    axes[2].imshow(rgb)
    axes[2].scatter(cols, rows, c='cyan', s=30, marker='x', alpha=0.8, linewidths=2)
    axes[2].set_title('RGB Composite (VV/VH/Ratio)')
    axes[2].axis('off')

    ds_info = f" (1/{downsample} res)" if downsample > 1 else ""
    plt.suptitle(f'Scene: {scene_id} | Original: {meta["shape"]}{ds_info}', fontsize=14, y=0.98)
    plt.tight_layout()
    plt.show()

    # Print stats
    print(f"\nScene Stats:")
    print(f"  VV range: [{np.nanmin(vv):.2f}, {np.nanmax(vv):.2f}] dB")
    print(f"  VH range: [{np.nanmin(vh):.2f}, {np.nanmax(vh):.2f}] dB")
    print(f"  Labels in scene: {len(scene_labels)}")
    print(f"  Loaded shape: {meta['loaded_shape']}")

    # Cleanup
    del sar_data, vv, vh, rgb
    gc.collect()

In [None]:
# Cell: Detailed Scene Exploration

# Pick a scene with reasonable number of labels (automatic selection)
scene_counts = train_df['scene_id'].value_counts()

# Filter to scenes that actually exist in RAW folder
available_scenes = set(scenes)
valid_scene_counts = scene_counts[scene_counts.index.isin(available_scenes)]

# Select scene with 10-100 labels
candidates = valid_scene_counts[valid_scene_counts.between(10, 100)]
if len(candidates) > 0:
    selected_scene = candidates.index[0]
else:
    # Fallback to any available scene
    selected_scene = valid_scene_counts.index[0]

print(f"Analyzing scene: {selected_scene} ({valid_scene_counts[selected_scene]} labels)")
plot_sar_with_labels(selected_scene, train_df, max_labels=50)

In [None]:
# Cell: Multi-Scene Gallery

# Filter scene_counts to only available scenes
available_scenes = set(scenes)
valid_scene_counts = scene_counts[scene_counts.index.isin(available_scenes)]

# Select diverse scenes (different label counts)
low_count = valid_scene_counts[valid_scene_counts < 10].index[:2]
mid_count = valid_scene_counts[valid_scene_counts.between(10, 50)].index[:2]
high_count = valid_scene_counts[valid_scene_counts > 50].index[:2]

gallery_scenes = list(low_count) + list(mid_count) + list(high_count)

for scene_id in gallery_scenes[:6]:  # Show first 6
    print(f"\n{'='*80}")
    plot_sar_with_labels(scene_id, train_df, max_labels=30)

In [None]:
# Cell: Zoomed-in Label Chips (Memory-efficient with windowed reads)

def plot_label_chips(scene_id, df, n_chips=9, chip_size=128):
    """Plot grid of zoomed-in chips around labels using windowed reads"""
    scene_labels_full = df[df['scene_id'] == scene_id]
    scene_labels = scene_labels_full.sample(min(n_chips, len(scene_labels_full)))

    scene_dir = RAW / scene_id
    vv_path = scene_dir / "VV_dB.tif"
    vh_path = scene_dir / "VH_dB.tif"

    # Open files once, read chips as needed
    with rasterio.open(vv_path) as vv_src, rasterio.open(vh_path) as vh_src:
        transform = vv_src.transform
        crs = vv_src.crs
        img_height, img_width = vv_src.height, vv_src.width

        # Convert labels to pixels
        rows, cols = latlon_to_pixel(
            scene_labels['detect_lat'].values,
            scene_labels['detect_lon'].values,
            transform,
            crs
        )

        n_cols_plot = 3
        n_rows_plot = (len(scene_labels) + n_cols_plot - 1) // n_cols_plot
        fig, axes = plt.subplots(n_rows_plot, n_cols_plot, figsize=(12, 4*n_rows_plot))
        axes = np.array(axes).flatten() if n_rows_plot > 1 else [axes] if n_rows_plot == 1 else axes

        for idx, (row, col) in enumerate(zip(rows, cols)):
            # Calculate window bounds with boundary checks
            half = chip_size // 2
            c_start = max(0, col - half)
            r_start = max(0, row - half)
            c_end = min(img_width, col + half)
            r_end = min(img_height, row + half)

            window = rasterio.windows.Window(
                col_off=c_start,
                row_off=r_start,
                width=c_end - c_start,
                height=r_end - r_start
            )

            # Read only the chip (memory efficient!)
            vv_chip = vv_src.read(1, window=window).astype("float32")
            vh_chip = vh_src.read(1, window=window).astype("float32")
            rgb_chip = sar_to_rgb(vv_chip, vh_chip)

            # Center marker (relative to chip)
            center_col = col - c_start
            center_row = row - r_start

            axes[idx].imshow(rgb_chip)
            axes[idx].scatter([center_col], [center_row], c='red', s=100, marker='+', linewidths=3)
            axes[idx].set_title(f'Label {idx+1}', fontsize=10)
            axes[idx].axis('off')

        # Hide empty subplots
        for idx in range(len(scene_labels), len(axes)):
            axes[idx].axis('off')

    plt.suptitle(f'Label Chips from {scene_id}', fontsize=14)
    plt.tight_layout()
    plt.show()

    gc.collect()

# Run for selected scene
plot_label_chips(selected_scene, train_df, n_chips=9, chip_size=192)

In [None]:
# Cell: Data Quality Analysis (Memory-efficient sampling)

def analyze_data_quality(df, scene_sample=10):
    """Check for data quality issues using memory-efficient sampling"""
    print(f"{'='*60}")
    print("Data Quality Checks")
    print(f"{'='*60}\n")

    # 1. Missing values
    print("Missing Values:")
    print(df.isnull().sum())
    print()

    # 2. Coordinate ranges
    print("Coordinate Ranges:")
    print(f"  Latitude: [{df['detect_lat'].min():.4f}, {df['detect_lat'].max():.4f}]")
    print(f"  Longitude: [{df['detect_lon'].min():.4f}, {df['detect_lon'].max():.4f}]")
    print()

    # 3. Sample scenes for image quality (memory efficient)
    available_scenes = set(scenes)
    sample_scene_ids = df['scene_id'].value_counts().head(scene_sample * 2).index
    sample_scene_ids = [s for s in sample_scene_ids if s in available_scenes][:scene_sample]

    vv_ranges, vh_ranges = [], []
    nan_counts = []

    print(f"Sampling {len(sample_scene_ids)} scenes for statistics...")

    for scene_id in sample_scene_ids:
        scene_dir = RAW / scene_id
        if not scene_dir.exists():
            continue

        # Memory-efficient: read downsampled version
        sar_data = load_sar_scene(scene_dir, downsample=16)  # Very small for stats
        vv, vh = sar_data['VV'], sar_data['VH']

        vv_ranges.append((np.nanmin(vv), np.nanmax(vv)))
        vh_ranges.append((np.nanmin(vh), np.nanmax(vh)))
        nan_counts.append((np.isnan(vv).sum(), np.isnan(vh).sum()))

        del sar_data, vv, vh
        gc.collect()

    if vv_ranges:
        print(f"\nSAR Value Ranges (sampled {len(vv_ranges)} scenes):")
        vv_mins, vv_maxs = zip(*vv_ranges)
        vh_mins, vh_maxs = zip(*vh_ranges)
        print(f"  VV: [{np.mean(vv_mins):.2f}, {np.mean(vv_maxs):.2f}] dB (avg)")
        print(f"  VH: [{np.mean(vh_mins):.2f}, {np.mean(vh_maxs):.2f}] dB (avg)")
        print(f"  NaN pixels (VV): {np.mean([n[0] for n in nan_counts]):.0f} avg (in downsampled)")
        print(f"  NaN pixels (VH): {np.mean([n[1] for n in nan_counts]):.0f} avg (in downsampled)")

analyze_data_quality(train_df, scene_sample=10)

# Dataset Generation Functions

In [None]:
# ============================================================
# DATASET GENERATION CONFIG
# ============================================================

CONFIG = {
    # Chip settings
    "chip_size": 256,              # Kleiner für Platzersparnis
    "save_dtype": np.int16,        # Speicher-Format
    "save_scale": 10,             # dB * 100 → Int16 (behält 2 Dezimalstellen)
    "compression": "lzw",
    "pixel_size_m": 10,           # Sentinel-1 ~10m resolution

    # Mask generation
    "mask_type": "circle",        # Einfach, klar definiert
    "mask_radius_default_m": 150, # Sichtbar, aber nicht riesig
    "mask_radius_min_px": 8,
    "mask_radius_max_px": 30,
    "gaussian_sigma_factor": 0.5,  # sigma = radius * factor

    # Sampling strategy
    "positive_chips_per_ship": 1,  # Chips zentriert auf Schiffe
    "negative_ratio": 0.3,         # 30% Negative Chips
    "negative_min_water_frac": 0.7,# Negatives: mind. 70% Wasser
    "random_offset_px": 30,       # Jitter für positive chips

    # Data filtering
    # "min_confidence": 0.5,         # xView3 confidence threshold
    "only_confirmed_vessels": True,# Nur is_vessel == 1

    # Output
    "output_dir": Path("/content/drive/MyDrive/xview3/chips"),
    "n_workers": 4,
    "max_scenes": 50,

    # Split (train.csv wird zu train+test, validation.csv bleibt val)
    "test_fraction": 0.1,          # 10% von train.csv → test
}

# Derived
CONFIG["mask_radius_default_px"] = CONFIG["mask_radius_default_m"] // CONFIG["pixel_size_m"]

print(f"Default mask radius: {CONFIG['mask_radius_default_px']} px")

In [None]:
# ============================================================
# HELPER FUNCTIONS
# ============================================================

from scipy.ndimage import gaussian_filter
from skimage.draw import disk
import rasterio.windows

def get_ship_radius_px(vessel_length_m, config=CONFIG):
    """Convert vessel length to pixel radius"""
    if pd.isna(vessel_length_m) or vessel_length_m <= 0:
        return config["mask_radius_default_px"]

    # Radius = halbe Länge, konvertiert zu Pixeln
    radius = (vessel_length_m / 2) / config["pixel_size_m"]
    return int(np.clip(radius, config["mask_radius_min_px"], config["mask_radius_max_px"]))


def create_mask(shape, points, radii, config=CONFIG):
    """
    Create segmentation mask from point labels.

    Args:
        shape: (H, W) of output mask
        points: List of (row, col) tuples
        radii: List of radii in pixels
    """
    mask = np.zeros(shape, dtype=np.float32)

    for (row, col), radius in zip(points, radii):
        if config["mask_type"] == "circle":
            rr, cc = disk((row, col), radius, shape=shape)
            mask[rr, cc] = 1.0
        else:  # gaussian
            # Create small gaussian, paste into mask
            size = radius * 4
            y, x = np.ogrid[-size:size+1, -size:size+1]
            sigma = radius * config["gaussian_sigma_factor"]
            gaussian = np.exp(-(x*x + y*y) / (2*sigma*sigma))

            # Bounds check
            r_start = max(0, row - size)
            r_end = min(shape[0], row + size + 1)
            c_start = max(0, col - size)
            c_end = min(shape[1], col + size + 1)

            g_r_start = size - (row - r_start)
            g_r_end = size + (r_end - row)
            g_c_start = size - (col - c_start)
            g_c_end = size + (c_end - col)

            mask[r_start:r_end, c_start:c_end] = np.maximum(
                mask[r_start:r_end, c_start:c_end],
                gaussian[g_r_start:g_r_end, g_c_start:g_c_end]
            )

    return (mask > 0.5).astype(np.uint8) if config["mask_type"] == "circle" else mask


def estimate_water_fraction(vv_chip, water_threshold=-15):
    """Estimate fraction of water pixels (low backscatter)"""
    valid = ~np.isnan(vv_chip)
    if valid.sum() == 0:
        return 0
    water = (vv_chip < water_threshold) & valid
    return water.sum() / valid.sum()


def extract_chip_with_window(src_vv, src_vh, row, col, size):
    """Extract chip using rasterio window (memory efficient)"""
    window = rasterio.windows.Window(
        col_off=col, row_off=row, width=size, height=size
    )

    # Bounds check
    if (row < 0 or col < 0 or
        row + size > src_vv.height or
        col + size > src_vv.width):
        return None, None

    vv = src_vv.read(1, window=window).astype(np.float32)
    vh = src_vh.read(1, window=window).astype(np.float32)

    # Replace nodata
    vv = np.where(vv <= -9999, np.nan, vv)
    vh = np.where(vh <= -9999, np.nan, vh)

    # Skip if too much nodata
    if np.isnan(vv).sum() > (size * size * 0.3):
        return None, None

    return vv, vh

In [None]:
# ============================================================
# SCENE PROCESSOR
# ============================================================

def process_scene(scene_id, labels_df, config=CONFIG):
    """
    Process one scene: extract positive and negative chips.

    Returns: List of (chip_id, vv, vh, mask, is_positive)
    """
    scene_dir = RAW / scene_id
    scene_labels = labels_df[labels_df["scene_id"] == scene_id].copy()

    chips = []
    size = config["chip_size"]

    with rasterio.open(scene_dir / "VV_dB.tif") as src_vv, \
         rasterio.open(scene_dir / "VH_dB.tif") as src_vh:

        transform = src_vv.transform
        crs = src_vv.crs
        img_h, img_w = src_vv.height, src_vv.width

        # Convert all labels to pixel coords
        if len(scene_labels) > 0:
            rows, cols = latlon_to_pixel(
                scene_labels["detect_lat"].values,
                scene_labels["detect_lon"].values,
                transform, crs
            )
            scene_labels["pixel_row"] = rows
            scene_labels["pixel_col"] = cols

            # Filter labels outside image bounds
            valid = (rows >= 0) & (rows < img_h) & (cols >= 0) & (cols < img_w)
            scene_labels = scene_labels[valid].copy()

        # ========== POSITIVE CHIPS ==========
        for idx, label in scene_labels.iterrows():
            row, col = int(label["pixel_row"]), int(label["pixel_col"])

            # Add random offset
            offset = config["random_offset_px"]
            row_start = row - size // 2 + np.random.randint(-offset, offset + 1)
            col_start = col - size // 2 + np.random.randint(-offset, offset + 1)

            # Clamp to image bounds
            row_start = max(0, min(img_h - size, row_start))
            col_start = max(0, min(img_w - size, col_start))

            vv, vh = extract_chip_with_window(src_vv, src_vh, row_start, col_start, size)
            if vv is None:
                continue

            # Find ALL labels in this chip
            in_chip = scene_labels[
                (scene_labels["pixel_row"] >= row_start) &
                (scene_labels["pixel_row"] < row_start + size) &
                (scene_labels["pixel_col"] >= col_start) &
                (scene_labels["pixel_col"] < col_start + size)
            ]

            if len(in_chip) == 0:
                continue

            # Create mask
            points = [
                (int(r - row_start), int(c - col_start))
                for r, c in zip(in_chip["pixel_row"], in_chip["pixel_col"])
            ]
            radii = [get_ship_radius_px(l) for l in in_chip["vessel_length_m"]]
            mask = create_mask((size, size), points, radii, config)

            chip_id = f"{scene_id}_{row_start}_{col_start}"
            chips.append((chip_id, vv, vh, mask, True))

        # ========== NEGATIVE CHIPS ==========
        n_positive = len([c for c in chips if c[4]])
        n_negative_target = int(n_positive * config["negative_ratio"] / (1 - config["negative_ratio"]))

        attempts = 0
        negatives_found = 0

        while negatives_found < n_negative_target and attempts < n_negative_target * 10:
            attempts += 1

            row_start = np.random.randint(0, max(1, img_h - size))
            col_start = np.random.randint(0, max(1, img_w - size))

            # Check no ships in this area
            in_chip = scene_labels[
                (scene_labels["pixel_row"] >= row_start - 50) &
                (scene_labels["pixel_row"] < row_start + size + 50) &
                (scene_labels["pixel_col"] >= col_start - 50) &
                (scene_labels["pixel_col"] < col_start + size + 50)
            ]

            if len(in_chip) > 0:
                continue

            vv, vh = extract_chip_with_window(src_vv, src_vh, row_start, col_start, size)
            if vv is None:
                continue

            # Prefer water chips
            water_frac = estimate_water_fraction(vv)
            if water_frac < config["negative_min_water_frac"]:
                if np.random.random() > 0.2:  # 20% chance to keep non-water
                    continue

            mask = np.zeros((size, size), dtype=np.float32)
            chip_id = f"{scene_id}_{row_start}_{col_start}_neg"
            chips.append((chip_id, vv, vh, mask, False))
            negatives_found += 1

    return chips

In [None]:
# ============================================================
# TEST RUN - 1 Scene, 5 Chips, Local Storage
# ============================================================

def test_pipeline(n_chips=5):
    """Quick test before full run"""

    # Temp config
    test_config = CONFIG.copy()
    test_config["output_dir"] = Path("/content/xview3/test_chips")
    test_config["positive_chips_per_ship"] = 1
    test_config["negative_ratio"] = 0.3

    # Pick 1 scene with decent labels
    test_scene = scenes[0]
    test_labels = train_df[train_df["scene_id"] == test_scene].head(n_chips)

    print(f"Testing with scene: {test_scene}")
    print(f"Labels in test: {len(test_labels)}")

    # Process
    try:
        chips = process_scene(test_scene, train_df, test_config)
        chips = chips[:n_chips]
        print(f"✓ process_scene: {len(chips)} chips generated")
    except Exception as e:
        print(f"✗ process_scene FAILED: {e}")
        import traceback; traceback.print_exc()
        return False

    # Save
    try:
        for chip_id, vv, vh, mask, is_pos in chips:
            save_chip(chip_id, vv, vh, mask, test_config["output_dir"], test_config)
        print(f"✓ save_chip: OK")
    except Exception as e:
        print(f"✗ save_chip FAILED: {e}")
        import traceback; traceback.print_exc()
        return False

    # Verify saved files
    img_dir = test_config["output_dir"] / "images"
    mask_dir = test_config["output_dir"] / "masks"

    saved_imgs = list(img_dir.glob("*.tif"))
    saved_masks = list(mask_dir.glob("*.tif"))
    print(f"✓ Saved: {len(saved_imgs)} images, {len(saved_masks)} masks")

    # Load back & visualize
    print(f"\n{'='*60}")
    print("VISUAL VERIFICATION")
    print(f"{'='*60}\n")

    n_show = min(3, len(saved_imgs))
    fig, axes = plt.subplots(n_show, 3, figsize=(12, 4*n_show))
    if n_show == 1:
        axes = axes.reshape(1, -1)

    for idx, img_path in enumerate(saved_imgs[:n_show]):
        mask_path = mask_dir / img_path.name

        with rasterio.open(img_path) as src:
            vv = src.read(1).astype(np.float32) / test_config["save_scale"]
            vh = src.read(2).astype(np.float32) / test_config["save_scale"]

        with rasterio.open(mask_path) as src:
            mask = src.read(1)

        rgb = sar_to_rgb(vv, vh)

        axes[idx, 0].imshow(rgb)
        axes[idx, 0].set_title(f"RGB ({img_path.stem[:25]}...)")
        axes[idx, 0].axis("off")

        axes[idx, 1].imshow(mask, cmap="Reds", vmin=0, vmax=1)
        axes[idx, 1].set_title(f"Mask ({mask.sum()} px)")
        axes[idx, 1].axis("off")

        axes[idx, 2].imshow(rgb)
        axes[idx, 2].imshow(mask, cmap="Reds", alpha=0.5)
        axes[idx, 2].set_title("Overlay")
        axes[idx, 2].axis("off")

    plt.tight_layout()
    plt.show()

    # ============================================================
    # SIZE ESTIMATION - Korrigiert für tatsächlich verfügbare Daten
    # ============================================================
    total_size = sum(f.stat().st_size for f in saved_imgs + saved_masks)
    avg_size_kb = total_size / len(saved_imgs) / 1024

    # Filter labels to scenes die tatsächlich existieren
    train_labels_available = train_df[train_df["scene_id"].isin(scenes)]
    val_labels_available = val_df[val_df["scene_id"].isin(scenes)]
    total_labels = len(train_labels_available) + len(val_labels_available)

    # Schätze Chips: Labels + 30% negatives
    estimated_chips = int(total_labels * 1.3)
    estimated_gb = avg_size_kb * estimated_chips / 1024 / 1024

    print(f"\n{'='*60}")
    print("SIZE ESTIMATION")
    print(f"{'='*60}")
    print(f"Available scenes: {len(scenes)}")
    print(f"Labels in available scenes: {total_labels:,}")
    print(f"Avg chip size: {avg_size_kb:.1f} KB")
    print(f"Estimated chips: ~{estimated_chips:,}")
    print(f"Estimated dataset size: {estimated_gb:.2f} GB")

    print(f"\n✅ Test passed! Run full generation when ready.")
    return True

# RUN TEST
test_pipeline(n_chips=5)

# Execute Dataset Generation

In [None]:
# ============================================================
# MAIN PROCESSING LOOP
# ============================================================

from tqdm.auto import tqdm
import traceback

def save_chip(chip_id, vv, vh, mask, output_dir, config=CONFIG):
    """Save chip as 2-band image + mask (compressed)"""
    img_dir = output_dir / "images"
    mask_dir = output_dir / "masks"
    img_dir.mkdir(parents=True, exist_ok=True)
    mask_dir.mkdir(parents=True, exist_ok=True)

    # Stack VV + VH
    img_stack = np.stack([vv, vh], axis=0)

    # Convert for saving: float32 → int16 (scaled)
    img_stack = np.nan_to_num(img_stack, nan=-9999)
    img_stack = (img_stack * config["save_scale"]).astype(config["save_dtype"])

    # Save image
    with rasterio.open(
        img_dir / f"{chip_id}.tif", "w",
        driver="GTiff",
        height=vv.shape[0], width=vv.shape[1],
        count=2,
        dtype=config["save_dtype"],
        compress=config["compression"]
    ) as dst:
        dst.write(img_stack)

    # Save mask (uint8, 0/1)
    mask_binary = (mask > 0.5).astype(np.uint8)
    with rasterio.open(
        mask_dir / f"{chip_id}.tif", "w",
        driver="GTiff",
        height=mask.shape[0], width=mask.shape[1],
        count=1,
        dtype=np.uint8,
        compress=config["compression"]
    ) as dst:
        dst.write(mask_binary, 1)


def generate_dataset(scenes, labels_df, config=CONFIG):
    """Process all scenes and save chips"""

    output_dir = config["output_dir"]
    output_dir.mkdir(parents=True, exist_ok=True)

    # Limit scenes for POC
    if "max_scenes" in config:
        scenes = scenes[:config["max_scenes"]]
        print(f"Limited to {len(scenes)} scenes (POC mode)")

    all_chips = []
    stats = {"positive": 0, "negative": 0, "errors": 0}

    for scene_id in tqdm(scenes, desc="Processing scenes"):
        try:
            chips = process_scene(scene_id, labels_df, config)

            for chip_id, vv, vh, mask, is_positive in chips:
                save_chip(chip_id, vv, vh, mask, output_dir, config)
                all_chips.append((chip_id, is_positive))

                if is_positive:
                    stats["positive"] += 1
                else:
                    stats["negative"] += 1

        except Exception as e:
            print(f"Error processing {scene_id}: {e}")
            traceback.print_exc()
            stats["errors"] += 1
            continue

    print(f"\n{'='*60}")
    print(f"Dataset Generation Complete!")
    print(f"{'='*60}")
    print(f"Positive chips: {stats['positive']:,}")
    print(f"Negative chips: {stats['negative']:,}")
    print(f"Total: {stats['positive'] + stats['negative']:,}")
    print(f"Errors: {stats['errors']}")
    print(f"Output: {output_dir}")

    return all_chips, stats

In [None]:
# ============================================================
# GENERATE TRAIN/VAL/TEST SPLITS
# ============================================================

def create_splits(all_chips, train_scenes, val_scenes, config=CONFIG):
    """Create train/val/test split files"""

    output_dir = config["output_dir"]
    splits_dir = output_dir / "splits"
    splits_dir.mkdir(exist_ok=True)

    train_chips = []
    test_chips = []
    val_chips = []

    for chip_id, is_positive in all_chips:
        scene_id = "_".join(chip_id.split("_")[:-2])  # Extract scene from chip_id
        if chip_id.endswith("_neg"):
            scene_id = "_".join(chip_id.split("_")[:-3])

        if scene_id in val_scenes:
            val_chips.append(chip_id)
        elif scene_id in train_scenes:
            # Split train into train + test
            if np.random.random() < config["test_fraction"]:
                test_chips.append(chip_id)
            else:
                train_chips.append(chip_id)

    # Shuffle
    np.random.shuffle(train_chips)
    np.random.shuffle(test_chips)
    np.random.shuffle(val_chips)

    # Save
    for name, chips in [("train", train_chips), ("val", val_chips), ("test", test_chips)]:
        with open(splits_dir / f"{name}.txt", "w") as f:
            f.write("\n".join(chips))
        print(f"{name}: {len(chips):,} chips")

    return train_chips, val_chips, test_chips

In [None]:
# ============================================================
# EXECUTE DATASET GENERATION
# ============================================================

# 1. Filter labels
if CONFIG["only_confirmed_vessels"]:
    filtered_train = train_df[train_df["is_vessel"] == 1].copy()
    filtered_val = val_df[val_df["is_vessel"] == 1].copy()
else:
    filtered_train = train_df.copy()
    filtered_val = val_df.copy()

print(f"Filtered labels: {len(filtered_train):,} train, {len(filtered_val):,} val")

# 2. Get available scenes
train_scene_ids = set(filtered_train["scene_id"].unique()) & set(scenes)
val_scene_ids = set(filtered_val["scene_id"].unique()) & set(scenes)

print(f"Available scenes: {len(train_scene_ids)} train, {len(val_scene_ids)} val")

# 3. Combine labels
all_labels = pd.concat([filtered_train, filtered_val], ignore_index=True)
all_scene_ids = list(train_scene_ids | val_scene_ids)

print(f"Processing {len(all_scene_ids)} scenes...")

# 4. Generate!
all_chips, stats = generate_dataset(all_scene_ids, all_labels, CONFIG)

# 5. Create splits
train_chips, val_chips, test_chips = create_splits(
    all_chips,
    train_scene_ids,
    val_scene_ids,
    CONFIG
)

print(f"\n✅ Done! Dataset saved to {CONFIG['output_dir']}")

In [None]:
# ============================================================
# VERIFY DATASET
# ============================================================

def verify_dataset(config=CONFIG, n_samples=6):
    """Visual verification of generated chips"""
    img_dir = config["output_dir"] / "images"
    mask_dir = config["output_dir"] / "masks"

    # Get random samples
    all_chips = list(img_dir.glob("*.tif"))
    samples = np.random.choice(all_chips, min(n_samples, len(all_chips)), replace=False)

    fig, axes = plt.subplots(n_samples, 3, figsize=(12, 4 * n_samples))

    for idx, img_path in enumerate(samples):
        chip_id = img_path.stem
        mask_path = mask_dir / f"{chip_id}.tif"

        with rasterio.open(img_path) as src:
            vv = src.read(1)
            vh = src.read(2)

        with rasterio.open(mask_path) as src:
            mask = src.read(1)

        rgb = sar_to_rgb(vv, vh)

        axes[idx, 0].imshow(rgb)
        axes[idx, 0].set_title(f"RGB: {chip_id[:20]}...")
        axes[idx, 0].axis("off")

        axes[idx, 1].imshow(mask, cmap="Reds", vmin=0, vmax=1)
        axes[idx, 1].set_title(f"Mask: {mask.sum()} ship pixels")
        axes[idx, 1].axis("off")

        axes[idx, 2].imshow(rgb)
        axes[idx, 2].imshow(mask, cmap="Reds", alpha=0.5)
        axes[idx, 2].set_title("Overlay")
        axes[idx, 2].axis("off")

    plt.tight_layout()
    plt.show()

    # Stats
    print(f"\nDataset Stats:")
    print(f"  Images: {len(list(img_dir.glob('*.tif'))):,}")
    print(f"  Masks: {len(list(mask_dir.glob('*.tif'))):,}")

verify_dataset()


# Model Training

In [None]:
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================================
# TRAINING CONFIG
# ============================================================

TRAIN_CONFIG = {
    # Paths (Google Drive - persistent!)
    "data_root": Path("/content/drive/MyDrive/xview3/chips"),
    "output_dir": Path("/content/drive/MyDrive/xview3/runs"),

    # Model
    "backbone": "terramind_v1_small",  # small für POC, base für Production
    "pretrained": True,
    "freeze_backbone": True,  # Schneller, weniger Overfitting bei kleinen Daten

    # Training
    "batch_size": 8,
    "max_epochs": 30,
    "lr": 2e-4,  # Höher wenn backbone frozen
    "num_workers": 2,

    # Data
    "num_classes": 2,  # Background + Ship
    "in_channels": 2,  # VV + VH
    "chip_size": 256,
    "save_scale": 10,  # Wie beim Speichern (dB * 10)
}

# Verify paths exist
assert TRAIN_CONFIG["data_root"].exists(), f"Data not found: {TRAIN_CONFIG['data_root']}"
print(f"✓ Data found at {TRAIN_CONFIG['data_root']}")
print(f"  Images: {len(list((TRAIN_CONFIG['data_root'] / 'images').glob('*.tif')))}")
print(f"  Masks: {len(list((TRAIN_CONFIG['data_root'] / 'masks').glob('*.tif')))}")

In [None]:
# ============================================================
# DATA MODULE
# ============================================================

# Standardization values für Sentinel-1 (dB * save_scale)
# Typische SAR dB Werte: VV ~ -15, VH ~ -22
# Mit scale=10: VV ~ -150, VH ~ -220
S1_MEANS = [-150, -220]  # Angepasst für deine Skalierung
S1_STDS = [50, 50]       # Rough estimate

# Augmentations
train_transform = albumentations.Compose([
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.RandomRotate90(p=0.5),
    albumentations.pytorch.transforms.ToTensorV2(),
])

val_transform = albumentations.Compose([
    albumentations.pytorch.transforms.ToTensorV2(),
])

# DataModule
datamodule = GenericNonGeoSegmentationDataModule(
    batch_size=TRAIN_CONFIG["batch_size"],
    num_workers=TRAIN_CONFIG["num_workers"],
    num_classes=TRAIN_CONFIG["num_classes"],

    # Data paths (all from Google Drive)
    train_data_root=str(TRAIN_CONFIG["data_root"] / "images"),
    train_label_data_root=str(TRAIN_CONFIG["data_root"] / "masks"),
    val_data_root=str(TRAIN_CONFIG["data_root"] / "images"),
    val_label_data_root=str(TRAIN_CONFIG["data_root"] / "masks"),
    test_data_root=str(TRAIN_CONFIG["data_root"] / "images"),
    test_label_data_root=str(TRAIN_CONFIG["data_root"] / "masks"),

    # Split files
    train_split=str(TRAIN_CONFIG["data_root"] / "splits" / "train.txt"),
    val_split=str(TRAIN_CONFIG["data_root"] / "splits" / "val.txt"),
    test_split=str(TRAIN_CONFIG["data_root"] / "splits" / "test.txt"),

    # File patterns
    img_grep="*.tif",
    label_grep="*.tif",

    # Normalization
    means=S1_MEANS,
    stds=S1_STDS,

    # Transforms
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=val_transform,

    # NaN handling
    no_label_replace=-1,
    no_data_replace=0,
)

# Setup
datamodule.setup("fit")
print(f"✓ Train samples: {len(datamodule.train_dataset)}")
print(f"✓ Val samples: {len(datamodule.val_dataset)}")

In [None]:
# ============================================================
# VERIFY DATA LOADING
# ============================================================

# Check one batch
batch = next(iter(datamodule.train_dataloader()))
images = batch["image"]
masks = batch["mask"]

print(f"Image batch shape: {images.shape}")  # [B, 2, 256, 256]
print(f"Mask batch shape: {masks.shape}")    # [B, 256, 256]
print(f"Image range: [{images.min():.2f}, {images.max():.2f}]")
print(f"Mask unique values: {torch.unique(masks)}")

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(4):
    # VV channel
    axes[0, i].imshow(images[i, 0].numpy(), cmap="gray")
    axes[0, i].set_title(f"VV {i}")
    axes[0, i].axis("off")

    # Mask overlay
    axes[1, i].imshow(images[i, 0].numpy(), cmap="gray")
    axes[1, i].imshow(masks[i].numpy(), cmap="Reds", alpha=0.5)
    axes[1, i].set_title(f"Mask {i}")
    axes[1, i].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# MODEL SETUP
# ============================================================

pl.seed_everything(42)

# Checkpoint callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=str(TRAIN_CONFIG["output_dir"] / "checkpoints"),
    monitor="val/mIoU",
    mode="max",
    filename="best-{epoch:02d}-{val/mIoU:.4f}",
    save_top_k=1,
    save_weights_only=True,
)

# Early stopping
early_stop = pl.callbacks.EarlyStopping(
    monitor="val/mIoU",
    patience=10,
    mode="max",
)

# Model
model = terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # TerraMind backbone
        "backbone": TRAIN_CONFIG["backbone"],
        "backbone_pretrained": TRAIN_CONFIG["pretrained"],

        # SAR input - 2 Kanäle (VV, VH)
        # Option A: Neuen Patch Embedding erstellen
        "backbone_modalities": [],
        "backbone_in_chans": TRAIN_CONFIG["in_channels"],

        # Necks für hierarchischen Decoder
        "necks": [
            {"name": "SelectIndices", "indices": [2, 5, 8, 11]},  # small/base
            {"name": "ReshapeTokensToImage", "remove_cls_token": False},
            {"name": "LearnedInterpolateToPyramidal"},
        ],

        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [256, 128, 64, 32],

        # Head
        "head_dropout": 0.1,
        "num_classes": TRAIN_CONFIG["num_classes"],
    },

    # Training config
    loss="dice",  # Gut für unbalanced segmentation
    optimizer="AdamW",
    lr=TRAIN_CONFIG["lr"],
    ignore_index=-1,
    freeze_backbone=TRAIN_CONFIG["freeze_backbone"],
    freeze_decoder=False,

    # Logging
    plot_on_val=True,
    class_names=["Background", "Ship"],
)

print(f"✓ Model created: {TRAIN_CONFIG['backbone']}")
print(f"  Backbone frozen: {TRAIN_CONFIG['freeze_backbone']}")

In [None]:
# ============================================================
# TRAINING
# ============================================================

# Trainer
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    precision="16-mixed",  # Faster training
    max_epochs=TRAIN_CONFIG["max_epochs"],
    logger=pl.loggers.TensorBoardLogger(
        save_dir=str(TRAIN_CONFIG["output_dir"]),
        name="tensorboard",
    ),
    callbacks=[
        checkpoint_callback,
        early_stop,
        pl.callbacks.RichProgressBar(),
    ],
    log_every_n_steps=5,
    default_root_dir=str(TRAIN_CONFIG["output_dir"]),
)

# Optional: TensorBoard starten
%load_ext tensorboard
%tensorboard --logdir {TRAIN_CONFIG["output_dir"]}

# TRAIN!
trainer.fit(model, datamodule=datamodule)

print(f"\n✅ Training complete!")
print(f"Best checkpoint: {checkpoint_callback.best_model_path}")

In [None]:
# ============================================================
# EVALUATION
# ============================================================

# Test on best checkpoint
datamodule.setup("test")
results = trainer.test(
    model,
    datamodule=datamodule,
    ckpt_path=checkpoint_callback.best_model_path
)

print(f"\nTest Results:")
for k, v in results[0].items():
    print(f"  {k}: {v:.4f}")

In [None]:
# ============================================================
# INFERENCE VISUALIZATION
# ============================================================

# Load best model
model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model_factory="EncoderDecoderFactory",
    model_args=model.hparams.model_args,
)
model.eval()
model.to("cuda" if torch.cuda.is_available() else "cpu")

# Predict on test samples
test_loader = datamodule.test_dataloader()
batch = next(iter(test_loader))
images = batch["image"].to(model.device)
masks_gt = batch["mask"]

with torch.no_grad():
    outputs = model(images)
    preds = torch.argmax(outputs.output, dim=1).cpu()

# Visualize
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
for i in range(4):
    # Input
    axes[i, 0].imshow(images[i, 0].cpu(), cmap="gray")
    axes[i, 0].set_title("Input (VV)")
    axes[i, 0].axis("off")

    # Ground Truth
    axes[i, 1].imshow(masks_gt[i], cmap="Reds")
    axes[i, 1].set_title(f"GT ({masks_gt[i].sum()} px)")
    axes[i, 1].axis("off")

    # Prediction
    axes[i, 2].imshow(preds[i], cmap="Reds")
    axes[i, 2].set_title(f"Pred ({preds[i].sum()} px)")
    axes[i, 2].axis("off")

plt.tight_layout()
plt.show()