# Geo Building Pipeline (Notebook)

Interactive notebook version of `geo_building_pipeline.py`. It mirrors the CLI pipeline,
but keeps results visible inline (overlays, polygon previews) so you can inspect output
without leaving the notebook. Adjust the configuration cell below and run the pipeline
cell to process your GeoTIFF tiles.


In [None]:

import logging
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import rasterio
import torch
from IPython.display import display
from PIL import Image
from rasterio.features import shapes
from shapely.geometry import shape as shapely_shape

from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

try:
    import geopandas as gpd
except ImportError:  # geopandas is optional
    gpd = None

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


In [None]:

# Configuration for the pipeline. Update paths/prompts as needed.
INPUT_DIR = Path("assets/geotiff_tiles")
MASK_OUT_DIR = Path("out/notebook_masks")
VECTOR_OUT_DIR = Path("out/notebook_vectors")
OVERLAY_OUT_DIR = Path("out/notebook_overlays")

CONCEPT = "building"  # Text prompt for SAM3
MIN_AREA = 20.0        # Minimum polygon area in CRS units
CONFIDENCE_THRESHOLD = 0.25
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_TILES = 1          # Limit for quick previews; set to None to process all tiles
OVERWRITE = True


In [None]:

def load_sam3_model(device: str):
    """Load and cache the SAM3 model on the requested device."""
    model = build_sam3_image_model(device=device)
    model.eval()
    model._geo_device = device
    return model


def _get_processor(model, confidence_threshold: float) -> Sam3Processor:
    processor = getattr(model, "_geo_processor", None)
    cached_thresh = getattr(model, "_geo_conf_threshold", None)
    if processor is None or cached_thresh != confidence_threshold:
        device = getattr(model, "_geo_device", None)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        processor = Sam3Processor(
            model, device=device, confidence_threshold=confidence_threshold
        )
        model._geo_processor = processor
        model._geo_conf_threshold = confidence_threshold
    return processor


def segment_buildings_for_tile(
    model,
    image_array: np.ndarray,
    concept: str,
    confidence_threshold: float = CONFIDENCE_THRESHOLD,
) -> np.ndarray:
    """
    Use SAM3 to perform concept segmentation on image_array for the given concept.

    Returns:
        A 2D numpy array (H x W) with 0 as background and positive integers for each
        detected instance. Masks are sorted by score, and higher-scoring instances
        claim pixels first.
    """
    processor = _get_processor(model, confidence_threshold=confidence_threshold)
    pil_image = Image.fromarray(image_array)
    state = processor.set_image(pil_image)
    state = processor.set_text_prompt(prompt=concept, state=state)

    masks = state.get("masks")
    scores = state.get("scores")
    if masks is None or scores is None or masks.numel() == 0:
        height, width = image_array.shape[:2]
        return np.zeros((height, width), dtype=np.uint16)

    masks_np = masks.squeeze(1).cpu().numpy().astype(np.bool_)
    scores_np = scores.cpu().numpy()
    if masks_np.ndim == 2:  # Single mask edge case
        masks_np = np.expand_dims(masks_np, axis=0)
        scores_np = np.expand_dims(scores_np, axis=0)

    order = np.argsort(scores_np)[::-1]  # Highest score first
    labeled_mask = np.zeros(masks_np.shape[1:], dtype=np.uint16)
    for label, idx in enumerate(order, start=1):
        candidate_mask = masks_np[idx]
        labeled_mask[(candidate_mask) & (labeled_mask == 0)] = label

    return labeled_mask


In [None]:

def mask_to_polygons(mask: np.ndarray, transform, crs, min_area: float):
    """Convert a labeled mask into polygons with georeferencing."""
    if gpd is None:
        raise ImportError(
            "geopandas is required for vector export but is not installed"
        )

    polygons, labels, areas = [], [], []
    for geom_mapping, value in shapes(mask.astype(np.int32), transform=transform):
        if value == 0:
            continue
        geometry = shapely_shape(geom_mapping)
        area = geometry.area
        if area < min_area:
            continue
        polygons.append(geometry)
        labels.append(int(value))
        areas.append(float(area))

    return gpd.GeoDataFrame(
        {"label": labels, "area": areas, "geometry": polygons},
        crs=crs,
    )


def _prepare_image_array(dataset: rasterio.io.DatasetReader) -> np.ndarray:
    """
    Read a rasterio dataset and convert it to an RGB-like numpy array (H, W, 3)
    suitable for SAM3.
    """
    image = dataset.read()  # (C, H, W)
    c, h, w = image.shape

    if c > 3:
        image = image[:3, :, :]
        c = 3
    if c == 2:
        image = image[:1, :, :]
        c = 1

    image = np.moveaxis(image, 0, -1)  # (H, W, C')
    if image.ndim == 2:
        image = np.expand_dims(image, axis=-1)
    if image.shape[2] == 1:
        image = np.repeat(image, 3, axis=2)

    img = image.astype(np.float32)
    mins = img.min(axis=(0, 1), keepdims=True)
    maxs = img.max(axis=(0, 1), keepdims=True)
    denom = np.clip(maxs - mins, 1e-6, None)
    img = (img - mins) / denom
    img = np.clip(img * 255.0, 0, 255).astype(np.uint8)
    return img


def _normalize_image_for_overlay(image_array: np.ndarray) -> np.ndarray:
    img = image_array.astype(np.float32)
    mins = img.min(axis=(0, 1), keepdims=True)
    maxs = img.max(axis=(0, 1), keepdims=True)
    denom = np.clip(maxs - mins, 1e-6, None)
    img = (img - mins) / denom
    img = np.clip(img * 255.0, 0, 255).astype(np.uint8)
    return img


def _render_overlay(image_array: np.ndarray, mask: np.ndarray, alpha: float = 0.5) -> Image.Image:
    base = _normalize_image_for_overlay(image_array)
    overlay = base.copy()

    unique_labels = [lbl for lbl in np.unique(mask) if lbl != 0]
    rng = np.random.default_rng(12345)
    for lbl in unique_labels:
        color = rng.integers(0, 255, size=3, dtype=np.uint8)
        m = mask == lbl
        overlay[m] = (
            (1.0 - alpha) * overlay[m].astype(np.float32)
            + alpha * color.astype(np.float32)
        ).astype(np.uint8)

    return Image.fromarray(overlay)


In [None]:

def run_geo_building_pipeline(
    input_dir: Path,
    mask_out_dir: Path,
    vector_out_dir: Optional[Path],
    overlay_out_dir: Optional[Path],
    concept: str,
    min_area: float,
    confidence_threshold: float,
    device: str,
    max_tiles: Optional[int],
    overwrite: bool,
    display_inline: bool = True,
    preview_polygons: int = 5,
) -> list:
    """
    Run SAM3 open-vocabulary segmentation on GeoTIFF tiles and keep results visible
    inside the notebook. Returns a list of result dictionaries (one per tile).
    """
    logging.info("Loading SAM3 model on device %s", device)
    model = load_sam3_model(device=device)
    if vector_out_dir and gpd is None:
        raise ImportError(
            "geopandas is required for vector export; install it or omit vector_out_dir"
        )
    mask_out_dir.mkdir(parents=True, exist_ok=True)
    if vector_out_dir:
        vector_out_dir.mkdir(parents=True, exist_ok=True)
    if overlay_out_dir:
        overlay_out_dir.mkdir(parents=True, exist_ok=True)

    concept_slug = concept.lower().replace(" ", "_")
    tile_paths = sorted(list(input_dir.glob("*.tif")) + list(input_dir.glob("*.tiff")))
    if max_tiles is not None:
        tile_paths = tile_paths[:max_tiles]

    results = []
    logging.info("Found %d tiles in %s", len(tile_paths), input_dir)
    for idx, tile_path in enumerate(tile_paths, start=1):
        try:
            with rasterio.open(tile_path) as src:
                image_array = _prepare_image_array(src)
                profile = src.profile
                crs = src.crs
                transform = src.transform

            mask_filename = (
                f"{tile_path.stem}_{concept_slug}_mask.tif" if concept_slug else f"{tile_path.stem}_mask.tif"
            )
            mask_path = mask_out_dir / mask_filename

            vector_path = None
            if vector_out_dir:
                vector_filename = (
                    f"{tile_path.stem}_{concept_slug}_polygons.gpkg" if concept_slug else f"{tile_path.stem}_polygons.gpkg"
                )
                vector_path = vector_out_dir / vector_filename

            if not overwrite and mask_path.exists():
                logging.info(
                    "[%d/%d] Skipping %s; mask exists and overwrite is False",
                    idx,
                    len(tile_paths),
                    tile_path.name,
                )
                continue

            logging.info("[%d/%d] Processing %s", idx, len(tile_paths), tile_path.name)
            labeled_mask = segment_buildings_for_tile(
                model=model,
                image_array=image_array,
                concept=concept,
                confidence_threshold=confidence_threshold,
            )
            logging.info(
                "Mask stats for %s: labels=%d (confidence_threshold=%.3f)",
                tile_path.name,
                int(labeled_mask.max()),
                confidence_threshold,
            )

            mask_profile = profile.copy()
            mask_profile.update(
                dtype="uint16",
                count=1,
                nodata=0,
                compress="lzw",
                tiled=True,
                blockxsize=min(profile.get("blockxsize", labeled_mask.shape[1]), 512),
                blockysize=min(profile.get("blockysize", labeled_mask.shape[0]), 512),
            )
            with rasterio.open(mask_path, "w", **mask_profile) as dst:
                dst.write(labeled_mask.astype(np.uint16), 1)

            overlay_img = None
            if overlay_out_dir or display_inline:
                overlay_img = _render_overlay(image_array, labeled_mask)
                if overlay_out_dir:
                    overlay_filename = (
                        f"{tile_path.stem}_{concept_slug}_overlay.png" if concept_slug else f"{tile_path.stem}_overlay.png"
                    )
                    overlay_path = overlay_out_dir / overlay_filename
                    overlay_img.save(overlay_path)

            gdf = None
            if vector_path:
                gdf = mask_to_polygons(
                    mask=labeled_mask,
                    transform=transform,
                    crs=crs,
                    min_area=min_area,
                )
                if len(gdf) == 0:
                    logging.info(
                        "No polygons above min_area=%s for %s", min_area, tile_path.name
                    )
                else:
                    driver = "GPKG" if vector_path.suffix.lower() == ".gpkg" else "GeoJSON"
                    gdf.to_file(vector_path, driver=driver)

            result = {
                "tile": tile_path,
                "mask_path": mask_path,
                "overlay_image": overlay_img,
                "vector_path": vector_path,
                "polygon_count": 0 if gdf is None else len(gdf),
                "mask_labels": int(labeled_mask.max()),
            }

            if display_inline:
                print(
                    f"{tile_path.name}: labels={result['mask_labels']} -> mask saved to {mask_path.name}"
                )
                if overlay_img is not None:
                    print("Overlay preview:")
                    display(overlay_img)
                if gdf is not None:
                    preview_rows = min(len(gdf), preview_polygons)
                    print(f"Polygon preview (first {preview_rows}):")
                    display(gdf.head(preview_rows))

            results.append(result)
        except Exception as exc:
            logging.exception("Failed processing %s: %s", tile_path, exc)
            continue

    return results


In [None]:

# Run the pipeline. Set display_inline=False if you only want files on disk.
results = run_geo_building_pipeline(
    input_dir=INPUT_DIR,
    mask_out_dir=MASK_OUT_DIR,
    vector_out_dir=VECTOR_OUT_DIR,
    overlay_out_dir=OVERLAY_OUT_DIR,
    concept=CONCEPT,
    min_area=MIN_AREA,
    confidence_threshold=CONFIDENCE_THRESHOLD,
    device=DEVICE,
    max_tiles=MAX_TILES,
    overwrite=OVERWRITE,
    display_inline=True,
    preview_polygons=5,
)
results
