<a href="https://colab.research.google.com/github/JonasLewe/terramind_object_detection/blob/main/notebook/SAR_Ship_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive")


In [None]:
!pip install terratorch==1.1.1
!pip install gdown tensorboard

In [None]:
import os
import torch
import gdown
import terratorch
import albumentations
import numpy as np
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
warnings.filterwarnings("ignore")

# Download Dataset

In [None]:
!mkdir -p /content/xview3/raw /content/xview3/chips /content/xview3/runs
!mkdir -p /content/drive/MyDrive/xview3/{chips,runs,meta}
!apt-get -qq update && apt-get -qq install -y aria2

In [None]:
!aria2c --input-file=/content/drive/MyDrive/xview3/download/tiny \
  --auto-file-renaming=false \
  --continue=true \
  --dir=/content/xview3/raw \
  --dry-run=false

In [None]:
!for f in /content/xview3/raw/*.tar.gz; do tar -xzvf "$f" -C /content/xview3/raw && rm "$f" done

In [None]:
!mkdir -p /content/xview3/meta
!cp /content/drive/MyDrive/xview3/meta/train.csv /content/xview3/meta/train.csv
!cp /content/drive/MyDrive/xview3/meta/validation.csv /content/xview3/meta/validation.csv
!ls -lah /content/xview3/meta

# Dataset Exploration

In [None]:
import rasterio
from pathlib import Path

scene_dir = Path("/content/xview3/raw/72dba3e82f782f67t")
vv = scene_dir/"VV_dB.tif"
vh = scene_dir/"VH_dB.tif"

with rasterio.open(vv) as ds:
    print("VV path:", vv)
    print("CRS:", ds.crs)
    print("Shape (H,W):", ds.height, ds.width)
    print("Dtype:", ds.dtypes)
    print("Transform:", ds.transform)
    print("Bounds:", ds.bounds)

with rasterio.open(vh) as ds:
    print("\\nVH dtype:", ds.dtypes, "shape:", (ds.height, ds.width))

In [None]:
import pandas as pd

train_path = "/content/xview3/meta/train.csv"
val_path   = "/content/xview3/meta/validation.csv"

df = pd.read_csv(train_path)
print("Columns:", list(df.columns))
print("Rows:", len(df))
print(df.head(3).to_string(index=False))

scene_id = "72dba3e82f782f67t"
# je nach Spaltennamen: "scene_id" oder "scene_id" Ã¤hnlich
candidates = [c for c in df.columns if "scene" in c.lower()]
print("\\nScene-like columns:", candidates)

# versuche typische Namen
scene_col = "scene_id" if "scene_id" in df.columns else candidates[0]
sub = df[df[scene_col].astype(str) == scene_id].head(10)
print(f"\\nFirst labels for scene {scene_id} (n={len(sub)} shown):")
print(sub.to_string(index=False))

In [None]:
import pandas as pd
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import pyproj # Import pyproj for coordinate transformation

scene_id = "72dba3e82f782f67t"
scene_dir = f"/content/xview3/raw/{scene_id}"
vv_path = f"{scene_dir}/VV_dB.tif"
train_path = "/content/xview3/meta/train.csv"

df = pd.read_csv(train_path)

# >>> ggf. anpassen, nachdem du die Columns gesehen hast:
scene_col = "scene_id"
lat_col   = "detect_lat"
lon_col   = "detect_lon"

sub = df[df[scene_col].astype(str) == scene_id].head(5).copy()
print("Using rows:", len(sub))
print(sub[[lat_col, lon_col]].to_string(index=False))

with rasterio.open(vv_path) as ds:
    # Define a transformer to convert from WGS84 (EPSG:4326) to the raster's CRS
    transformer = pyproj.Transformer.from_crs("EPSG:4326", ds.crs, always_xy=True)

    # Transform the lat/lon coordinates to the raster's CRS (x_utm, y_utm)
    x_utm, y_utm = transformer.transform(sub[lon_col].values, sub[lat_col].values)

    # Map (x_utm, y_utm) -> (row,col)
    rc = [ds.index(x, y) for x, y in zip(x_utm, y_utm)]
    rows = np.array([r for r,c in rc], dtype=int)
    cols = np.array([c for r,c in rc], dtype=int)

    # Read a small window around the first point just to sanity-check visually
    r0, c0 = rows[0], cols[0]
    win = rasterio.windows.Window(col_off=max(c0-256,0), row_off=max(r0-256,0),
                                  width=512, height=512)
    img = ds.read(1, window=win).astype("float32")

# Plot (simple normalization for display)
plt.figure(figsize=(6,6))
p2, p98 = np.nanpercentile(img, [2,98])
img_vis = np.clip((img - p2) / (p98 - p2 + 1e-6), 0, 1)
plt.imshow(img_vis, cmap="gray")
# plot all mapped points (translated into window coords)
for r,c in zip(rows, cols):
    rr = r - int(win.row_off)
    cc = c - int(win.col_off)
    if 0 <= rr < 512 and 0 <= cc < 512:
        plt.scatter([cc],[rr], s=60, marker="x")
plt.title("VV_dB window + mapped label points")
plt.axis("off")
plt.show()

print("Example pixel coords (row,col):", list(zip(rows, cols))[:5])