# Production-Grade Satellite Segmentation (DeepGlobe)

This notebook rebuilds the original U-Net workflow into a production-grade, GPU-efficient training and inference pipeline for 1024×1024 images, optimized for NVIDIA L4 (24 GB) and comparable GPUs.


Audience:
- ML engineers or applied researchers deploying semantic segmentation.

Prerequisites:
- Familiarity with PyTorch, CUDA, and image segmentation.
- Dataset available locally with `metadata.csv` and `class_dict.csv`.

By the end you will:
- Train a 1024×1024 model with mixed precision and robust metrics.
- Run tiled inference and compute class-change analytics.
- Save reproducible artifacts for production use.


## Outline

1. Setup and configuration
2. Data loading and class encoding
3. Datasets and augmentations
4. Metrics
5. Model options (U-Net and SegFormer)
6. Training loop with AMP and checkpointing
7. Inference, visualization, and tiling
8. Change detection from predicted masks
9. Exercises and extensions


## Step 0 - Setup and configuration

This cell configures the environment, GPU settings, and paths. Update `DATA_ROOT` to your dataset location.


In [None]:
from __future__ import annotations

import os
import json
import math
import random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, Tuple, List

import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import albumentations as A
from tqdm import tqdm

# Optional: transformers v5 for SegFormer
try:
    from transformers import SegformerForSemanticSegmentation
    TRANSFORMERS_AVAILABLE = True
except Exception:
    TRANSFORMERS_AVAILABLE = False

# --------------------
# Reproducibility
# --------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# --------------------
# CUDA / performance
# --------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

print("Torch:", torch.__version__)
print("CUDA:", torch.version.cuda)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

@dataclass
class Config:
    data_root: str = "."
    metadata_csv: str = "metadata.csv"
    class_dict_csv: str = "class_dict.csv"
    run_dir: str = "runs/l4_1024"
    image_size: int = 1024
    batch_size: int = 2
    num_workers: int = 4
    lr: float = 2e-4
    weight_decay: float = 1e-4
    epochs: int = 40
    grad_accum_steps: int = 2
    amp: bool = True
    model_type: str = "unet_resnet34"  # options: unet_resnet34, segformer_b2
    save_every: int = 1

CFG = Config()

DATA_ROOT = Path(CFG.data_root)
RUN_DIR = Path(CFG.run_dir)
RUN_DIR.mkdir(parents=True, exist_ok=True)

print(f"Device: {DEVICE}")
print(f"Run dir: {RUN_DIR.resolve()}")


## Step 1 - Load metadata and class mapping

We use `metadata.csv` to build splits. If validation masks are missing, we create a deterministic validation split from the training set.


In [None]:
metadata_path = DATA_ROOT / CFG.metadata_csv
class_dict_path = DATA_ROOT / CFG.class_dict_csv

if not metadata_path.exists():
    raise FileNotFoundError(f"Missing metadata.csv at {metadata_path}")
if not class_dict_path.exists():
    raise FileNotFoundError(f"Missing class_dict.csv at {class_dict_path}")

meta_df = pd.read_csv(metadata_path)
class_df = pd.read_csv(class_dict_path)

# class_dict.csv uses row order as class index
id2label = {i: name for i, name in enumerate(class_df["name"].tolist())}
label2id = {v: k for k, v in id2label.items()}

colors = class_df[["r", "g", "b"]].values.tolist()
color_to_id = {tuple(color): idx for idx, color in enumerate(colors)}

NUM_CLASSES = len(colors)
print(f"Classes: {NUM_CLASSES}")
print(id2label)


## Step 2 - Dataset, encoding, and augmentations

Masks are encoded into class indices (0..C-1). We keep augmentation minimal but correct for segmentation (nearest-neighbor for masks).


In [None]:
def encode_mask(mask_rgb: np.ndarray, color_to_id: Dict[Tuple[int, int, int], int]) -> np.ndarray:
    mask = np.zeros((mask_rgb.shape[0], mask_rgb.shape[1]), dtype=np.int64)
    for color, idx in color_to_id.items():
        mask[(mask_rgb == color).all(axis=-1)] = idx
    return mask

def decode_mask(mask: np.ndarray, colors: List[List[int]]) -> np.ndarray:
    out = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for idx, color in enumerate(colors):
        out[mask == idx] = color
    return out

class DeepGlobeDataset(Dataset):
    def __init__(self, df: pd.DataFrame, root: Path, transforms=None, with_masks: bool = True):
        self.df = df.reset_index(drop=True)
        self.root = Path(root)
        self.transforms = transforms
        self.with_masks = with_masks

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.root / row["sat_image_path"]
        image = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = None
        if self.with_masks and isinstance(row.get("mask_path", None), str) and len(row["mask_path"]):
            mask_path = self.root / row["mask_path"]
            mask_rgb = cv2.imread(str(mask_path), cv2.IMREAD_COLOR)
            if mask_rgb is None:
                raise FileNotFoundError(f"Mask not found: {mask_path}")
            mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_BGR2RGB)
            mask = encode_mask(mask_rgb, color_to_id)

        if self.transforms is not None:
            if mask is not None:
                augmented = self.transforms(image=image, mask=mask)
                image, mask = augmented["image"], augmented["mask"]
            else:
                augmented = self.transforms(image=image)
                image = augmented["image"]

        image = image.astype(np.float32) / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1)

        if mask is not None:
            mask = torch.from_numpy(mask.astype(np.int64))
            return image, mask, row.get("image_id", idx)

        return image, row.get("image_id", idx)

train_tfms = A.Compose([
    A.RandomCrop(CFG.image_size, CFG.image_size),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomRotate90(p=0.5),
    A.ColorJitter(p=0.2),
])

val_tfms = A.Compose([
    A.CenterCrop(CFG.image_size, CFG.image_size)
])


## Step 3 - Build splits and loaders

If validation masks are missing, we create a deterministic split from the training set. Test set is used for inference only.


In [None]:
from sklearn.model_selection import train_test_split

meta_df["mask_path"] = meta_df["mask_path"].fillna("")

train_df = meta_df[meta_df["split"] == "train"].copy()
valid_df = meta_df[meta_df["split"] == "valid"].copy()
test_df = meta_df[meta_df["split"] == "test"].copy()

valid_has_masks = (valid_df["mask_path"].str.len() > 0).any()

if not valid_has_masks:
    train_df, valid_df = train_test_split(
        train_df,
        test_size=0.2,
        random_state=SEED,
        shuffle=True
    )

train_ds = DeepGlobeDataset(train_df, DATA_ROOT, transforms=train_tfms, with_masks=True)
val_ds = DeepGlobeDataset(valid_df, DATA_ROOT, transforms=val_tfms, with_masks=True)
test_ds = DeepGlobeDataset(test_df, DATA_ROOT, transforms=val_tfms, with_masks=False)

train_loader = DataLoader(
    train_ds,
    batch_size=CFG.batch_size,
    shuffle=True,
    num_workers=CFG.num_workers,
    pin_memory=True,
    persistent_workers=(CFG.num_workers > 0)
)

val_loader = DataLoader(
    val_ds,
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=CFG.num_workers,
    pin_memory=True,
    persistent_workers=(CFG.num_workers > 0)
)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")


## Step 4 - Metrics (mIoU + per-class IoU)


In [None]:
@torch.no_grad()
def compute_confusion(preds: torch.Tensor, targets: torch.Tensor, num_classes: int) -> torch.Tensor:
    preds = preds.view(-1)
    targets = targets.view(-1)
    mask = (targets >= 0) & (targets < num_classes)
    hist = torch.bincount(
        num_classes * targets[mask] + preds[mask],
        minlength=num_classes ** 2
    ).reshape(num_classes, num_classes)
    return hist

@torch.no_grad()
def compute_iou(confusion: torch.Tensor) -> Tuple[float, Dict[int, float]]:
    tp = torch.diag(confusion)
    fp = confusion.sum(0) - tp
    fn = confusion.sum(1) - tp
    denom = tp + fp + fn + 1e-6
    iou = (tp / denom).cpu().numpy()
    miou = float(np.nanmean(iou))
    per_class = {i: float(iou[i]) for i in range(len(iou))}
    return miou, per_class


## Step 5 - Model options

We provide a production-ready U-Net (ResNet34 encoder) and an optional SegFormer model (transformers v5). Use `CFG.model_type` to select.


In [None]:
class UNetResNet34(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        encoder = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
        self.pool = encoder.maxpool
        self.layer1 = encoder.layer1
        self.layer2 = encoder.layer2
        self.layer3 = encoder.layer3
        self.layer4 = encoder.layer4

        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = nn.Sequential(nn.Conv2d(256 + 256, 256, 3, padding=1), nn.ReLU(inplace=True))

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(nn.Conv2d(128 + 128, 128, 3, padding=1), nn.ReLU(inplace=True))

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(nn.Conv2d(64 + 64, 64, 3, padding=1), nn.ReLU(inplace=True))

        self.up1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(nn.Conv2d(64 + 64, 64, 3, padding=1), nn.ReLU(inplace=True))

        self.up0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec0 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(inplace=True))

        self.out = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        x0 = self.layer0(x)   # 1/2
        x1 = self.layer1(self.pool(x0))  # 1/4
        x2 = self.layer2(x1)  # 1/8
        x3 = self.layer3(x2)  # 1/16
        x4 = self.layer4(x3)  # 1/32

        d4 = self.up4(x4)
        d4 = self.dec4(torch.cat([d4, x3], dim=1))

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, x2], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, x1], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, x0], dim=1))

        d0 = self.up0(d1)
        d0 = self.dec0(d0)

        return self.out(d0)


def build_model() -> nn.Module:
    if CFG.model_type == "unet_resnet34":
        return UNetResNet34(NUM_CLASSES)

    if CFG.model_type == "segformer_b2":
        if not TRANSFORMERS_AVAILABLE:
            raise RuntimeError("transformers v5 is not available. Install transformers>=5 to use SegFormer.")
        model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b2-finetuned-ade-512-512",
            num_labels=NUM_CLASSES,
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True
        )
        return model

    raise ValueError(f"Unknown model type: {CFG.model_type}")


## Step 6 - Training loop (AMP, gradient accumulation, checkpointing)


In [None]:
model = build_model().to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)

scaler = torch.cuda.amp.GradScaler(enabled=(CFG.amp and DEVICE.type == "cuda"))

best_miou = -1.0

@torch.no_grad()
def forward_logits(model, images):
    out = model(images)
    if isinstance(out, torch.Tensor):
        return out
    if hasattr(out, "logits"):
        return out.logits
    raise RuntimeError("Unexpected model output type")


def train_one_epoch(epoch: int):
    model.train()
    running_loss = 0.0
    for step, batch in enumerate(tqdm(train_loader, desc=f"Train {epoch}")):
        images, masks, _ = batch
        images = images.to(DEVICE, non_blocking=True)
        masks = masks.to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(CFG.amp and DEVICE.type == "cuda")):
            logits = forward_logits(model, images)
            if logits.shape[-2:] != masks.shape[-2:]:
                logits = F.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
            loss = criterion(logits, masks)
            loss = loss / CFG.grad_accum_steps

        scaler.scale(loss).backward()

        if (step + 1) % CFG.grad_accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        running_loss += loss.item()

    return running_loss / max(1, len(train_loader))

@torch.no_grad()
def evaluate(epoch: int):
    model.eval()
    confusion = torch.zeros((NUM_CLASSES, NUM_CLASSES), device=DEVICE)
    running_loss = 0.0

    for batch in tqdm(val_loader, desc=f"Val {epoch}"):
        images, masks, _ = batch
        images = images.to(DEVICE, non_blocking=True)
        masks = masks.to(DEVICE, non_blocking=True)

        logits = forward_logits(model, images)
        if logits.shape[-2:] != masks.shape[-2:]:
            logits = F.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)

        loss = criterion(logits, masks)
        running_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        confusion += compute_confusion(preds, masks, NUM_CLASSES)

    miou, per_class = compute_iou(confusion)
    return running_loss / max(1, len(val_loader)), miou, per_class

for epoch in range(1, CFG.epochs + 1):
    train_loss = train_one_epoch(epoch)
    val_loss, val_miou, per_class = evaluate(epoch)
    scheduler.step()

    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | mIoU={val_miou:.4f}")

    if val_miou > best_miou:
        best_miou = val_miou
        ckpt_path = RUN_DIR / "best_model.pt"
        torch.save({
            "model_state": model.state_dict(),
            "config": asdict(CFG),
            "miou": val_miou
        }, ckpt_path)
        print(f"Saved best model to {ckpt_path}")


## Step 7 - Inference and visualization

This step loads the best checkpoint, runs a prediction, and saves a colorized mask.


In [None]:
import matplotlib.pyplot as plt

ckpt = torch.load(RUN_DIR / "best_model.pt", map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

sample_image, sample_id = test_ds[0]
image = sample_image.unsqueeze(0).to(DEVICE)

with torch.no_grad():
    logits = forward_logits(model, image)
    preds = torch.argmax(logits, dim=1)[0].cpu().numpy()

pred_color = decode_mask(preds, colors)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input")
plt.imshow(sample_image.permute(1, 2, 0))
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Predicted Mask")
plt.imshow(pred_color)
plt.axis("off")
plt.show()


## Step 8 - Sliding window inference for large images

Use this to run inference on 2048×2048 images with a 1024×1024 model.


In [None]:
@torch.no_grad()
def sliding_window_predict(image: np.ndarray, tile_size: int = 1024, overlap: int = 128) -> np.ndarray:
    h, w, _ = image.shape
    stride = tile_size - overlap
    full_probs = np.zeros((NUM_CLASSES, h, w), dtype=np.float32)
    count = np.zeros((h, w), dtype=np.float32)

    for y in range(0, h, stride):
        for x in range(0, w, stride):
            y1, x1 = y, x
            y2, x2 = min(y1 + tile_size, h), min(x1 + tile_size, w)
            tile = image[y1:y2, x1:x2]

            pad_bottom = tile_size - (y2 - y1)
            pad_right = tile_size - (x2 - x1)
            if pad_bottom > 0 or pad_right > 0:
                tile = cv2.copyMakeBorder(tile, 0, pad_bottom, 0, pad_right, cv2.BORDER_REFLECT_101)

            tile_t = torch.from_numpy(tile.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
            logits = forward_logits(model, tile_t)
            probs = torch.softmax(logits, dim=1)[0].cpu().numpy()

            probs = probs[:, : (y2 - y1), : (x2 - x1)]
            full_probs[:, y1:y2, x1:x2] += probs
            count[y1:y2, x1:x2] += 1.0

    full_probs /= np.maximum(count, 1e-6)
    pred = np.argmax(full_probs, axis=0)
    return pred


## Step 9 - Change detection using class-index masks

This avoids HSV heuristics and computes changes from class indices directly.


In [None]:
def class_percentages(mask: np.ndarray, num_classes: int) -> Dict[int, float]:
    total = mask.size
    stats = {}
    for c in range(num_classes):
        stats[c] = float((mask == c).sum() / total * 100.0)
    return stats

def compare_masks(mask_a: np.ndarray, mask_b: np.ndarray, id2label: Dict[int, str]):
    p1 = class_percentages(mask_a, NUM_CLASSES)
    p2 = class_percentages(mask_b, NUM_CLASSES)
    rows = []
    for k in range(NUM_CLASSES):
        rows.append({
            "class": id2label.get(k, str(k)),
            "period_1": p1[k],
            "period_2": p2[k],
            "difference": p2[k] - p1[k]
        })
    return pd.DataFrame(rows)


## Exercises

1. Switch `CFG.model_type` to `segformer_b2` and compare mIoU.
2. Increase `CFG.image_size` to 1536 and tune batch size + grad accumulation.
3. Add class weights to `CrossEntropyLoss` based on training mask histograms.


In [None]:
# Exercise scaffold
# TODO: Compute class weights from training masks and re-train the model.


## Pitfalls and extensions

Common pitfall:
- Resizing masks with bilinear interpolation. Always use nearest neighbor for class labels.

Extensions:
- Add TensorBoard logging and model export (TorchScript or ONNX).
- Add multi-GPU training with DDP if you move to larger instances.
