# SAM3 / SamGeo3 — Multi-Prompt Segmentation → Union Mask → GeoPackage Layer

This notebook processes a georeferenced **RGB GeoTIFF (bands 1–3)** in tiles using **SAM3 via `samgeo.SamGeo3`**.

**Goal:** Use multiple keywords (e.g. `house`, `garage`, `building`) in one run, **union** the masks per tile, stitch tiles into a **full-resolution mask**, and export:

- `*_union_mask.tif` (GeoTIFF mask, 0/255)
- `*_union.gpkg` (GeoPackage, layer `objects`)

## Installation (reference only)

> Do not run this inside the notebook if your environment is already set up.

```bash
# Core
pip install --upgrade "segment-geospatial[samgeo3]" rasterio opencv-python tqdm numpy matplotlib

# PyTorch (choose ONE)
# CPU:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# GPU (example channel, adapt to your CUDA wheel channel):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

# Optional: GeoPackage export
pip install geopandas shapely pyproj fiona
```

In [None]:
from __future__ import annotations

import gc
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import rasterio
from rasterio.features import shapes
import cv2
from tqdm.auto import tqdm

import torch
from contextlib import nullcontext

from samgeo import SamGeo3


In [None]:
# --- Runtime / GPU Check ---
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA runtime: {torch.version.cuda}")


## Configuration

- `IMAGE_PATH`: RGB GeoTIFF (bands 1–3)
- `PROMPTS_RAW`: comma-separated prompts (German or English). German is partially mapped to English.
- `TILE_SIZE` / `OVERLAP`: tiling parameters (1024/128 is a solid default)
- `USE_FP16`: mixed precision (CUDA only) → often faster and less VRAM

In [None]:
# --- User Config ---
IMAGE_PATH = Path(r"E:/path/to/your_rgb_geotiff.tif")   # <-- adjust
OUT_DIR = Path("output")                                # <-- adjust
PROMPTS_RAW = "building,house,garage"  # or: "Gebäude,Haus,Garage"                     # <-- adjust

# Tiling / Performance
TILE_SIZE = 1024
OVERLAP = 128
USE_FP16 = True

# Post-processing / vectorization
MORPH_CLOSE = 3   # 0 disables
MORPH_OPEN  = 0   # 0 disables
MIN_AREA_M2 = 10.0  # Minimum polygon area (most meaningful with a projected CRS)

# Device handling:
# - "auto": use CUDA if available, else CPU
# - "cpu": force CPU (useful for testing / machines without NVIDIA GPU)
# - "cuda": force CUDA (will error if CUDA is not available)
DEVICE_PREFERENCE = "auto"  # "auto" | "cpu" | "cuda"

if DEVICE_PREFERENCE == "cpu":
    DEVICE = "cpu"
elif DEVICE_PREFERENCE == "cuda":
    DEVICE = "cuda"
else:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"OUT_DIR.mkdir(parents=True, exist_ok=True)

print("IMAGE_PATH:", IMAGE_PATH)
print("OUT_DIR:", OUT_DIR.resolve())
print("DEVICE:", DEVICE)


In [None]:
# ----------------------------
# Prompt normalization
# ----------------------------

GERMAN_TO_ENGLISH: Dict[str, str] = {
    "gebäude": "building",
    "gebaeude": "building",
    "haus": "house",
    "garage": "garage",
    "schuppen": "shed",
    "nebengebäude": "outbuilding",
    "nebengebaeude": "outbuilding",
    "dach": "roof",
}

DEFAULT_PROMPTS: List[str] = [
    "building",
    "house",
    "residential building",
    "roof",
    "garage",
    "shed",
    "outbuilding",
]

def normalize_prompts(raw: str) -> List[str]:
    """
    Normalisiert eine komma-separierte Prompt-Stringliste:
    - Split auf Komma
    - Trim whitespace
    - Map gängige deutsche Begriffe -> Englisch
    - Deduplicate (Order bleibt erhalten)
    """
    if not raw:
        return DEFAULT_PROMPTS

    parts = [p.strip() for p in raw.split(",") if p.strip()]
    if not parts:
        return DEFAULT_PROMPTS

    out: List[str] = []
    seen = set()
    for p in parts:
        key = p.lower()
        mapped = GERMAN_TO_ENGLISH.get(key, p)
        if mapped.lower() not in seen:
            out.append(mapped)
            seen.add(mapped.lower())
    return out

PROMPTS = normalize_prompts(PROMPTS_RAW)
print("PROMPTS:", PROMPTS)


In [None]:
# ----------------------------
# I/O, Tiling, Mask-Handling
# ----------------------------

def clear_gpu_memory() -> None:
    """Aggressive GPU/RAM cleanup (helps for long runs)."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def ensure_uint8_rgb(img: np.ndarray) -> np.ndarray:
    """
    SAM/SamGeo arbeiten am zuverlässigsten mit uint8 RGB [0..255].
    - If data is already uint8: keep as-is.
    - If e.g. uint16 oder float: robust auf 0..255 skalieren.
    """
    if img.dtype == np.uint8:
        return img

    img_f = img.astype(np.float32)
    # robust: per-band percentile stretch
    out = np.zeros_like(img_f, dtype=np.float32)
    for b in range(3):
        band = img_f[..., b]
        lo = np.nanpercentile(band, 1)
        hi = np.nanpercentile(band, 99)
        if hi <= lo:
            hi = lo + 1.0
        out[..., b] = np.clip((band - lo) / (hi - lo), 0.0, 1.0)
    return (out * 255.0).astype(np.uint8)

def create_tiles(image_array: np.ndarray, tile_size: int, overlap: int):
    """Split image into overlapping tiles. Pads edges to full tile_size.

    Returns list[dict] with tile numpy and placement metadata.

    """
    h, w = image_array.shape[:2]
    stride = tile_size - overlap
    tiles = []
    idx = 0

    for y in range(0, h, stride):
        for x in range(0, w, stride):
            y_end = min(y + tile_size, h)
            x_end = min(x + tile_size, w)
            tile = image_array[y:y_end, x:x_end]

            # pad if needed (keeps merge logic simple)
            if tile.shape[0] < tile_size or tile.shape[1] < tile_size:
                padded = np.zeros((tile_size, tile_size, tile.shape[2]), dtype=tile.dtype)
                padded[:tile.shape[0], :tile.shape[1]] = tile
                tile = padded

            tiles.append({
                "tile": tile,
                "x": x, "y": y,
                "x_end": x_end, "y_end": y_end,
                "index": idx
            })
            idx += 1

    return tiles

def process_mask_robust(seg: np.ndarray, tile_size: int) -> Optional[np.ndarray]:
    """
    Konvertiert SAM-Segmentation-Output in eine binäre uint8-mask (0/255).
    - torch.Tensor -> numpy
    - 3D -> 2D
    - resize -> tile_size
    - float -> threshold 0.5
    """
    if seg is None:
        return None

    if hasattr(seg, "cpu"):  # torch Tensor
        seg = seg.cpu().numpy()
    seg = np.asarray(seg)

    if seg.ndim == 3:
        seg = seg[:, :, 0]

    if seg.shape != (tile_size, tile_size):
        seg_float = seg.astype(np.float32) if seg.dtype != np.float32 else seg
        seg = cv2.resize(seg_float, (tile_size, tile_size), interpolation=cv2.INTER_NEAREST)

    if seg.dtype == bool:
        m = seg.astype(np.uint8) * 255
    elif seg.dtype.kind in ("f",):
        m = (seg > 0.5).astype(np.uint8) * 255
    else:
        m = (seg > 0).astype(np.uint8) * 255

    if int(np.sum(m > 0)) == 0:
        return None
    return m

def run_prompts_on_tile(
    sam3: SamGeo3,
    tile: np.ndarray,
    prompts: List[str],
    tile_size: int,
    device: str,
    use_fp16: bool,
) -> Tuple[Optional[np.ndarray], int, int]:
    """
    Führt pro tile *alle* Prompts aus und unioniert die maskn.
    Returns: (tile_mask or None, objects_seen, pixel_count)
    """
    try:
        sam3.set_image(tile)
        combined = np.zeros((tile_size, tile_size), dtype=np.uint8)
        total_objects = 0

        autocast_ctx = (
            torch.autocast(device_type="cuda", dtype=torch.float16)
            if (use_fp16 and device == "cuda" and torch.cuda.is_available())
            else nullcontext()
        )

        with autocast_ctx:
            for p in prompts:
                sam3.generate_masks(prompt=p)

                if hasattr(sam3, "masks") and sam3.masks is not None and len(sam3.masks) > 0:
                    total_objects += len(sam3.masks)
                    for mask_dict in sam3.masks:
                        if isinstance(mask_dict, dict) and "segmentation" in mask_dict:
                            m = process_mask_robust(mask_dict["segmentation"], tile_size)
                            if m is not None:
                                combined = np.maximum(combined, m)

        px = int(np.sum(combined > 0))
        if px > 0:
            return combined, total_objects, px
        return None, total_objects, 0

    except Exception:
        # bewusst still: einzelne tilen dürfen fehlschlagen, Pipeline läuft weiter
        return None, 0, 0

def merge_masks(tiles_with_masks, original_shape, overlap: int) -> np.ndarray:
    """
    Merge Tile-maskn zurück auf volle Bildgröße.
    - Blend/Feather in Overlap-Zone (simple weighting) reduziert Nahtkanten.
    - final threshold: >0 -> 255
    """
    h, w = original_shape[:2]
    merged = np.zeros((h, w), dtype=np.float32)
    weights = np.zeros((h, w), dtype=np.float32)

    for t in tqdm(tiles_with_masks, desc="Merge Masks", unit="tile"):
        if t["mask"] is None:
            continue

        x, y = t["x"], t["y"]
        x_end, y_end = t["x_end"], t["y_end"]
        m = t["mask"]

        actual_h = y_end - y
        actual_w = x_end - x
        m = m[:actual_h, :actual_w].astype(np.float32)

        wgt = np.ones((actual_h, actual_w), dtype=np.float32)
        if overlap > 0:
            fade = min(overlap // 2, actual_h // 4, actual_w // 4)
            if fade > 0:
                for i in range(fade):
                    alpha = (i + 1) / fade
                    wgt[i, :] *= alpha
                    wgt[-i - 1, :] *= alpha
                    wgt[:, i] *= alpha
                    wgt[:, -i - 1] *= alpha

        merged[y:y_end, x:x_end] += m * wgt
        weights[y:y_end, x:x_end] += wgt

    weights[weights == 0] = 1
    merged = merged / weights
    out = (merged > 0).astype(np.uint8) * 255
    return out

def apply_morphology(mask: np.ndarray, close_k: int, open_k: int) -> np.ndarray:
    """Optionales Cleanup: Close (Löcher schließen) und Open (Noise entfernen)."""
    out = mask
    if close_k and close_k > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_k, close_k))
        out = cv2.morphologyEx(out, cv2.MORPH_CLOSE, k, iterations=1)
    if open_k and open_k > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_k, open_k))
        out = cv2.morphologyEx(out, cv2.MORPH_OPEN, k, iterations=1)
    return out


In [None]:
# ----------------------------
# Load image (RGB) + metadata
# ----------------------------

if not IMAGE_PATH.exists():
    raise FileNotFoundError(f"IMAGE_PATH not found: {IMAGE_PATH}")

with rasterio.open(IMAGE_PATH) as src:
    profile = src.profile.copy()
    transform = src.transform
    crs = src.crs

    # read bands 1-3 as RGB (shape: (3, H, W) -> (H, W, 3))
    rgb = src.read([1, 2, 3])
    rgb = np.transpose(rgb, (1, 2, 0))

rgb = ensure_uint8_rgb(rgb)

print("Image shape:", rgb.shape, "dtype:", rgb.dtype)
print("CRS:", crs)


In [None]:
# ----------------------------
# Create tiles
# ----------------------------
tiles = create_tiles(rgb, tile_size=TILE_SIZE, overlap=OVERLAP)
print("Tiles:", len(tiles), "| tile_size:", TILE_SIZE, "| overlap:", OVERLAP)


In [None]:
# ----------------------------
# SAM3 initialisieren
# ----------------------------
clear_gpu_memory()

sam3 = SamGeo3(
    backend="transformers",
    device=DEVICE,
    checkpoint_path=None,
    load_from_HF=True,
)

print("SamGeo3 initialized on:", DEVICE)


In [None]:
# ----------------------------
# Process tiles: multi-prompt -> union mask per tile
# ----------------------------

tiles_with_masks = []
successful = 0
total_px = 0
total_objects_seen = 0

pbar = tqdm(tiles, desc="Process Tiles", unit="tile")
for t in pbar:
    mask, objs, px = run_prompts_on_tile(
        sam3=sam3,
        tile=t["tile"],
        prompts=PROMPTS,
        tile_size=TILE_SIZE,
        device=DEVICE,
        use_fp16=USE_FP16,
    )

    t2 = dict(t)
    t2["mask"] = mask
    tiles_with_masks.append(t2)

    total_objects_seen += int(objs)
    if mask is not None:
        successful += 1
        total_px += int(px)

    pbar.set_postfix({
        "ok": f"{successful}/{t['index'] + 1}",
        "px": f"{total_px:,}",
        "objs": total_objects_seen,
    })

    # periodisch GPU-Cache leeren
    if (t["index"] + 1) % 10 == 0:
        clear_gpu_memory()

print("Done. Successful tiles:", successful, "/", len(tiles))


In [None]:
# ----------------------------
# Merge + post-processing
# ----------------------------

final_mask = merge_masks(tiles_with_masks, rgb.shape, overlap=OVERLAP)
final_mask = apply_morphology(final_mask, close_k=MORPH_CLOSE, open_k=MORPH_OPEN)

coverage = float(np.mean(final_mask > 0)) * 100.0
print(f"Mask coverage: {coverage:.2f}%")


In [None]:
# ----------------------------
# Export: GeoTIFF mask
# ----------------------------

out_mask = OUT_DIR / f"{IMAGE_PATH.stem}_union_mask.tif"

out_profile = profile.copy()
out_profile.update(dtype=rasterio.uint8, count=1, compress="lzw")

with rasterio.open(out_mask, "w", **out_profile) as dst:
    dst.write(final_mask.astype(np.uint8), 1)

print("Mask saved:", out_mask.resolve())


## Export: Vector layer (GeoPackage)

Vectorization is done via `rasterio.features.shapes` → `shapely` polygons → optional `geopandas` export.

- If `geopandas` is missing (or export fails), the notebook writes a **GeoJSON fallback**.

In [None]:
# ----------------------------
# Export: GeoPackage (optional)
# ----------------------------

from shapely.geometry import shape as shapely_shape

def mask_to_polygons(mask: np.ndarray, transform, crs, min_area_m2: float):
    """
    Convert binary mask (0/255) to polygons.
    - returns GeoDataFrame when geopandas is available
    - otherwise returns list[dict] as GeoJSON-like features (fallback)
    """
    mask_bin = (mask > 0).astype(np.uint8)

    geoms = []
    for geom, val in shapes(mask_bin, mask=mask_bin, transform=transform):
        if int(val) != 1:
            continue
        geoms.append(shapely_shape(geom))

    if not geoms:
        return None

    try:
        import geopandas as gpd
        import pandas as pd  # noqa: F401
    except Exception:
        return [{"geometry": g.__geo_interface__} for g in geoms]

    gdf = gpd.GeoDataFrame({"geometry": geoms}, crs=crs)

    # Area/perimeter only meaningful for projected CRS (meters)
    is_projected = getattr(gdf.crs, "is_projected", False)
    if is_projected:
        gdf["area_m2"] = gdf.geometry.area.astype(float)
        gdf["perimeter_m"] = gdf.geometry.length.astype(float)
        if min_area_m2 and min_area_m2 > 0:
            gdf = gdf[gdf["area_m2"] >= float(min_area_m2)].copy()
    else:
        gdf["area_m2"] = np.nan
        gdf["perimeter_m"] = np.nan

    gdf = gdf.reset_index(drop=True)
    gdf["id"] = np.arange(1, len(gdf) + 1)
    gdf["prompts"] = ", ".join(PROMPTS)
    return gdf

gdf = mask_to_polygons(final_mask, transform=transform, crs=crs, min_area_m2=MIN_AREA_M2)

out_gpkg = OUT_DIR / f"{IMAGE_PATH.stem}_union.gpkg"

if gdf is None:
    print("No polygons extracted (mask empty after filtering).")
else:
    try:
        if hasattr(gdf, "to_file"):
            gdf.to_file(out_gpkg, layer="objects", driver="GPKG")
            print("GeoPackage saved:", out_gpkg.resolve(), "(layer='objects')")
        else:
            # GeoJSON fallback
            import json
            out_geojson = OUT_DIR / f"{IMAGE_PATH.stem}_union.geojson"
            fc = {"type": "FeatureCollection", "features": [{"type": "Feature", **f} for f in gdf]}
            out_geojson.write_text(json.dumps(fc))
            print("GeoJSON saved (fallback):", out_geojson.resolve())
    except Exception as e:
        print("Vector export failed:", e)


## Quick QA: Visualization

- Overlay the union mask on top of RGB (quick sanity check)
- For proper QA: load the outputs in QGIS (GeoTIFF / GeoPackage)

In [None]:
import matplotlib.pyplot as plt

# sample quick view (upper-left window)
H, W = rgb.shape[:2]
win = (slice(0, min(H, 1024)), slice(0, min(W, 1024)))

plt.figure(figsize=(10, 10))
plt.imshow(rgb[win])
plt.imshow(final_mask[win], alpha=0.35)
plt.title("RGB + union mask (preview window)")
plt.axis("off")
plt.show()


## Operational notes / common tuning

- If GPU runs out of memory:
  - lower `TILE_SIZE` (e.g. 512)
  - lower `OVERLAP` (e.g. 64)
  - keep `USE_FP16 = True`
- If you see too many false positives:
  - use tighter prompts (e.g. only `building,roof,garage`)
  - increase `MORPH_OPEN` (e.g. 3)
  - increase `MIN_AREA_M2` (e.g. 25–100)
- If tile seams are visible:
  - increase `OVERLAP` (128–256)
  - use moderate `MORPH_CLOSE` (3–5)