# 🛰️ End-to-End Deep Learning Pipeline for Road Network Extraction from Satellite Imagery

This project implements a complete, state-of-the-art deep learning pipeline for **semantic segmentation of road networks** from high-resolution satellite imagery. Using the **SpaceNet Roads Challenge dataset**, this work goes beyond a simple model implementation to tackle the complex, real-world challenges of **geospatial data processing, model optimization, and post-processing** to generate clean, connected road graphs.

✅ The final model achieves a validation **IoU of ~0.60**, demonstrating strong performance in identifying road pixels.  
🛠️ More importantly, the project includes a **post-processing pipeline** to convert raw pixel-level predictions into a **coherent road network**, suitable for real-world applications.

Imports

In [None]:
import shutil
import random
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchmetrics
from torchvision import transforms, datasets
from torchvision.transforms import ToPILImage
from PIL import Image
from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split
from pathlib import Path
import rasterio
from rasterio.features import rasterize
from rasterio.warp import calculate_default_transform
import geopandas as gpd
import torch.nn as nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
from torchmetrics.classification import BinaryJaccardIndex
import cv2
import sknw
import torch
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize, remove_small_objects
from scipy.ndimage import label
import os
import re

Download the data

In [None]:
def download_and_prepare_data(output_dir="data"):
    """
    Downloads and extracts the SpaceNet 3 dataset for Paris (Train and Test).
    
    Args:
        output_dir (str): The directory to download and extract the data into.
    """
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    print(f"Data will be downloaded and extracted to: '{output_dir}'")

    # List of files to download
    s3_files_to_download = [
        "s3://spacenet-dataset/spacenet/SN3_roads/tarballs/SN3_roads_train_AOI_3_Paris.tar.gz",
        "s3://spacenet-dataset/spacenet/SN3_roads/tarballs/SN3_roads_train_AOI_3_Paris_geojson_roads_speed.tar.gz",
        "s3://spacenet-dataset/spacenet/SN3_roads/tarballs/SN3_roads_test_public_AOI_3_Paris.tar.gz"
    ]

    tar_files_to_extract = [
        "SN3_roads_train_AOI_3_Paris.tar.gz",
        "SN3_roads_train_AOI_3_Paris_geojson_roads_speed.tar.gz",
        "SN3_roads_test_public_AOI_3_Paris.tar.gz"
    ]

    # --- 1. Download Files ---
    for s3_path in s3_files_to_download:
        filename = os.path.basename(s3_path)
        local_path = os.path.join(output_dir, filename)
        
        if os.path.exists(local_path):
            print(f"File '{filename}' already exists. Skipping download.")
        else:
            command = ["aws", "s3", "cp", s3_path, local_path, "--no-sign-request"]
            run_command(command)

    # --- 2. Extract Files ---
    # The final directory that will be created
    final_data_dir = os.path.join(output_dir, "AOI_3_Paris")
    
    # Check if the data has already been extracted
    if os.path.exists(final_data_dir):
         print(f"Directory '{final_data_dir}' already exists. Assuming data is extracted. Skipping extraction.")
    else:
        for filename in tar_files_to_extract:
            local_path = os.path.join(output_dir, filename)
            command = ["tar", "-xzvf", local_path]
            # We run the command from within the output directory
            run_command(command, working_dir=output_dir)

    print("\nData preparation complete!")
    print(f"Your data should be located in: '{final_data_dir}'")

In [None]:
download_and_prepare_data()

Explore the data

In [None]:
def contrast_stretch(band):
    p2, p98 = np.percentile(band, (2, 98))
    stretched = np.clip((band - p2) / (p98 - p2), 0, 1)
    return stretched

# PS-RGB file path
file_path = 'data/AOI_3_Paris/PS-RGB/SN3_roads_train_AOI_3_Paris_PS-RGB_img111.tif'

# Get the metadata
with rasterio.open(file_path) as src:
    image = src.read()
    sat_meta = src.meta.copy()
    width = src.width
    height = src.height
    transform = src.transform
    crs = src.crs



# Apply stretching to each band
stretched_image = np.stack([contrast_stretch(image[i]) for i in range(3)])

# Check the number of bands
print(f"Number of bands: {image.shape[0]}")

# Plot the stretched image
plt.figure(figsize=(6, 6))
plt.imshow(stretched_image.transpose(1, 2, 0))
plt.title("Contrast-Stretched TIF Image (RGB Composite)")
plt.axis("off")
plt.show()

In [None]:
# Load the multispectral image (MS) and read the NIR band
image_path = "data/AOI_3_Paris/PS-MS/SN3_roads_train_AOI_3_Paris_PS-MS_img111.tif"
with rasterio.open(image_path) as src:
    # Read the NIR band
    nir_band = src.read(4)

# Apply contrast stretching to the NIR band
stretched_nir = contrast_stretch(nir_band)

# Plot the NIR band as a grayscale image
plt.figure(figsize=(6, 6))
plt.imshow(stretched_nir, cmap="gray")
plt.title("NIR Band (Contrast-Stretched)")
plt.axis("off")
plt.show()

In [None]:
# Load the GeoJSON file
gdf = gpd.read_file("data/AOI_3_Paris/geojson_roads_speed/SN3_roads_train_AOI_3_Paris_geojson_roads_speed_img118.geojson")

# Plotting the geometries
gdf.plot(figsize=(5, 5), edgecolor='black', alpha=0.5)
plt.title("GeoJSON Visualization")
plt.show()

# Check the CRS of the image and GeoDataFrame
print(f"Image CRS: {src.crs}")
print(f"GeoDataFrame CRS: {gdf.crs}")

# Align the CRS if necessary
if gdf.crs != src.crs:
    gdf = gdf.to_crs(src.crs)
    print(f"GeoDataFrame CRS after alignment: {gdf.crs}")

In [None]:
# Load the satellite image
image_path = "data/AOI_3_Paris/PS-RGB/SN3_roads_train_AOI_3_Paris_PS-RGB_img341.tif"

with rasterio.open(image_path) as source:
    image = source.read([1, 2, 3])
    sat_meta = source.meta.copy()
    width = source.width
    height = source.height
    transform = source.transform
    crs = source.crs

# Apply stretching to each band
stretched_image = np.stack([contrast_stretch(image[i]) for i in range(3)])

# Load the rasterized mask
gdf = gpd.read_file("data/AOI_3_Paris/geojson_roads_speed/SN3_roads_train_AOI_3_Paris_geojson_roads_speed_img341.geojson")

# Prepare geometries and values (1 for roads, 0 elsewhere)

projected_crs = "EPSG:32631"
gdf_proj = gdf.to_crs(projected_crs)

# Buffer in meters (e.g., 2 meters)
buffered = gdf_proj.geometry.buffer(2)

# Reproject back to raster CRS (EPSG:4326)
buffered_wgs84 = buffered.to_crs("EPSG:4326")

# Prepare geometries for rasterization
geoms = [(geom, 1) for geom in buffered_wgs84 if geom is not None and not geom.is_empty]

# Rasterize
mask = rasterize(
    geoms,
    out_shape=(height, width),
    transform=transform,
    fill=0,
    dtype='uint8'
)

# Plot the satellite image and mask
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Stretched satellite image
axes[0].imshow(stretched_image.transpose(1, 2, 0))
axes[0].set_title(f"Contrast-Stretched TIF Image (RGB Composite) \n Shape{stretched_image.shape}")
axes[0].axis("off")

# Road mask
axes[1].imshow(mask, cmap="gray")
axes[1].set_title(f"Road Mask \n Shape{mask.shape}")
axes[1].axis("off")

plt.tight_layout()
plt.show()

Dataset and the pytorchlightning DataModule 

In [None]:
class GeoImageDataset(Dataset):
    """
    PyTorch Dataset for loading and preprocessing satellite images and their
    corresponding road network masks from GeoJSON files.

    This class handles:
    - Pairing of image (.tif) and mask (.geojson) files based on filenames.
    - Robustly processing geospatial data to ensure perfect pixel alignment
      between images and rasterized masks, regardless of their original
      Coordinate Reference Systems (CRS).
    - Applying data augmentations using the Albumentations library.
    - Padding images and masks to be compatible with model architectures that
      require specific input dimensions (e.g., divisible by 16 or 32).
    """
    def __init__(self, image_dir, mask_dir, augmentations=None):
        """
        Args:
            image_dir (str): Path to the directory containing image files (.tif).
            mask_dir (str): Path to the directory containing mask files (.geojson).
            augmentations (albumentations.Compose, optional): An Albumentations
                pipeline to apply to the image and mask. Defaults to None.
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.augmentations = augmentations

        if not os.path.isdir(image_dir): raise FileNotFoundError(f"Image directory not found: {image_dir}")
        if not os.path.isdir(mask_dir): raise FileNotFoundError(f"Mask directory not found: {mask_dir}")

        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith((".tif", ".tiff"))])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith(".geojson")])

        self.index_to_mask = {self.extract_index(f): f for f in self.mask_files}

        self.paired_files = [
            (img, self.index_to_mask.get(self.extract_index(img)))
            for img in self.image_files
            if self.extract_index(img) in self.index_to_mask
        ]
        self.paired_files = [(img, mask) for img, mask in self.paired_files if mask is not None]
        if not self.paired_files:
            print(f"CRITICAL WARNING: No image-mask pairs found after attempting to match filenames in {image_dir}.")

    def extract_index(self, filename):
        """Extracts the numerical index from a filename (e.g., 'img123')."""
        base_name = os.path.splitext(filename)[0]
        match = re.search(r'img(\d+)', base_name)
        return match.group(1) if match else None

    def process_image(self, image_path):
        """
        Loads a satellite image, performs contrast stretching, and returns it as a
        NumPy array along with its geographic metadata.
        
        Args:
            image_path (str): The path to the image file.

        Returns:
            tuple: A tuple containing:
                - np.ndarray: The processed image as a uint8 NumPy array (H, W, C).
                - dict: A dictionary of geographic information (crs, transform, etc.).
        """
        with rasterio.open(image_path) as src:
            geo_info = {"crs": src.crs, "transform": src.transform, "height": src.height, "width": src.width}
            image_data = src.read(list(range(1, min(src.count, 3) + 1))).astype(np.float32)
            if image_data.shape[0] < 3:
                padding = np.zeros((3 - image_data.shape[0], geo_info["height"], geo_info["width"]), dtype=np.float32)
                image_data = np.concatenate([image_data, padding], axis=0)

        stretched_bands = []
        for i in range(image_data.shape[0]):
             min_val, max_val = np.percentile(image_data[i], (2, 98))
             if np.isclose(max_val, min_val): stretched = np.zeros_like(image_data[i])
             else: stretched = (image_data[i] - min_val) / (max_val - min_val)
             stretched_bands.append(np.clip(stretched, 0, 1))
        stretched_image_np = np.stack(stretched_bands)
        t_img = stretched_image_np.transpose(1, 2, 0)
        img_uint8 = (t_img * 255).astype(np.uint8)
        return img_uint8, geo_info

    def rasterize_mask(self, mask_path, image_geo_info):
        """
        Creates a binary raster mask from a GeoJSON file, ensuring it is
        perfectly aligned with its corresponding source image. Uses a robust
        "Project-Buffer-Reproject" workflow.

        Args:
            mask_path (str): The path to the GeoJSON mask file.
            image_geo_info (dict): The geographic metadata from the source image.

        Returns:
            np.ndarray: The rasterized binary mask as a uint8 NumPy array (H, W).
        """
        try:
            gdf = gpd.read_file(mask_path)
        except Exception:
            return np.zeros((image_geo_info["height"], image_geo_info["width"]), dtype=np.uint8)

        if gdf.crs and gdf.crs != image_geo_info["crs"]:
            try: gdf = gdf.to_crs(image_geo_info["crs"])
            except Exception: return np.zeros((image_geo_info["height"], image_geo_info["width"]), dtype=np.uint8)

        if gdf.crs.is_geographic:
            try: gdf = gdf.to_crs("EPSG:32631")
            except Exception: return np.zeros((image_geo_info["height"], image_geo_info["width"]), dtype=np.uint8)

        valid_gdf = gdf[gdf.geometry.is_valid & ~gdf.geometry.is_empty]
        if valid_gdf.empty: return np.zeros((image_geo_info["height"], image_geo_info["width"]), dtype=np.uint8)

        # Using the buffer size that previously gave good results
        buffer_distance_meters = 3.0
        buffered_geometries = valid_gdf.geometry.buffer(buffer_distance_meters)
        buffered_gdf = gpd.GeoDataFrame(geometry=buffered_geometries, crs=gdf.crs)
        gdf_reprojected_for_raster = buffered_gdf.to_crs(image_geo_info["crs"])

        mask_array = rasterize(
            gdf_reprojected_for_raster.geometry,
            out_shape=(image_geo_info["height"], image_geo_info["width"]),
            transform=image_geo_info["transform"], fill=0, dtype='uint8', all_touched=True
        )
        if mask_array.ndim == 3: mask_array = mask_array.squeeze(0)
        return mask_array

    def __len__(self): return len(self.paired_files)

    def __getitem__(self, idx):
        image_name, mask_name = self.paired_files[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image_np, geo_info = self.process_image(image_path)
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask_np = self.rasterize_mask(mask_path, geo_info)

        pad_transform = A.PadIfNeeded(min_height=1312, min_width=1312, border_mode=cv2.BORDER_CONSTANT)
        padded = pad_transform(image=image_np, mask=mask_np)
        image_np_padded, mask_np_padded = padded['image'], padded['mask']

        if self.augmentations:
            augmented = self.augmentations(image=image_np_padded, mask=mask_np_padded)
            image_tensor, mask_tensor = augmented['image'], augmented['mask']
        else:
            image_tensor = torch.from_numpy(image_np_padded.transpose(2,0,1)).float()
            mask_tensor = torch.from_numpy(mask_np_padded).unsqueeze(0).float()

        return image_tensor.float(), mask_tensor.float()

class GeoImageDataModule(pl.LightningDataModule):
    """
    PyTorch Lightning DataModule for the road segmentation task.
    Handles the creation of training, validation, and test dataloaders.
    """
    def __init__(self, image_dir, mask_dir, augmentations_train=None, augmentations_val=None, batch_size=4, num_workers=0, split_perc=0.8, **kwargs):
        super().__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.augmentations_train = augmentations_train
        self.augmentations_val = augmentations_val
        self.split_perc = split_perc

    def setup(self, stage=None):
        temp_dataset = GeoImageDataset(self.image_dir, self.mask_dir)
        num_paired_files = len(temp_dataset)
        train_size = int(self.split_perc * num_paired_files)
        val_size = num_paired_files - train_size

        g = torch.Generator().manual_seed(42)
        indices = torch.randperm(num_paired_files, generator=g).tolist()
        train_indices, val_indices = indices[:train_size], indices[train_size:]

        train_dataset_full = GeoImageDataset(self.image_dir, self.mask_dir, augmentations=self.augmentations_train)
        self.train_dataset = torch.utils.data.Subset(train_dataset_full, train_indices)

        val_dataset_full = GeoImageDataset(self.image_dir, self.mask_dir, augmentations=self.augmentations_val)
        self.val_dataset = torch.utils.data.Subset(val_dataset_full, val_indices)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, persistent_workers=True if self.num_workers > 0 else False, pin_memory=True, drop_last=True)
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, persistent_workers=True if self.num_workers > 0 else False, pin_memory=True, drop_last=True)


Model

In [None]:
class GeoImageSegmentModel(pl.LightningModule):
    """
    A PyTorch Lightning module for road segmentation from satellite images.

    This class encapsulates the model architecture (U-Net), the loss function (a
    combination of Dice and BCE), the optimization logic (AdamW with a cosine
    annealing scheduler), and the training/validation steps.
    """
    def __init__(self, lr=1e-4, weight_decay=1e-4, encoder_name="resnet50"):
        """
        Initializes the model, loss functions, and metrics.

        Args:
            lr (float, optional): The learning rate for the optimizer. Defaults to 1e-4.
            weight_decay (float, optional): The weight decay for the AdamW optimizer. Defaults to 1e-4.
            encoder_name (str, optional): The name of the encoder backbone to use from
                segmentation-models-pytorch. Defaults to "resnet50".
        """
        super().__init__()
        # Save hyperparameters to the checkpoint, allowing for easy reloading
        self.save_hyperparameters()

        # Initialize the U-Net model from the segmentation-models-pytorch library
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights='imagenet', # Use pre-trained ImageNet weights for transfer learning
            in_channels=3,
            classes=1 # Binary output: road or not road
        )

        # Define the two components of the combination loss function
        self.dice_loss = DiceLoss(mode='binary', from_logits=True)
        self.bce_loss = nn.BCEWithLogitsLoss()
        
        # Initialize the IoU metric from torchmetrics for validation
        self.iou_metric = BinaryJaccardIndex()

    def forward(self, x):
        """
        Performs a forward pass through the model.

        Args:
            x (torch.Tensor): The input batch of images.

        Returns:
            torch.Tensor: The raw output logits from the model.
        """
        return self.model(x)

    def _common_step(self, batch, batch_idx, stage):
        """
        A common function for the training and validation steps to avoid code duplication.

        Args:
            batch (tuple): A tuple containing the images and masks.
            batch_idx (int): The index of the current batch.
            stage (str): The current stage, either "train" or "val".

        Returns:
            torch.Tensor: The calculated loss for the batch.
        """
        images, masks = batch
        if masks.ndim == 3: masks = masks.unsqueeze(1)

        # Get raw model outputs (logits)
        outputs = self(images)
        
        # Calculate the combination loss, giving more weight to Dice loss
        # to better handle class imbalance.
        loss = 0.8 * self.dice_loss(outputs, masks) + 0.2 * self.bce_loss(outputs, masks.float())

        # Convert logits to probabilities and then to a binary prediction mask
        preds_prob = torch.sigmoid(outputs)
        preds_binary = (preds_prob > 0.5)
        
        # Calculate the IoU metric
        iou = self.iou_metric(preds_binary, masks.int())

        # Log the loss and IoU for monitoring in TensorBoard or other loggers
        self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True, logger=True, batch_size=images.size(0))
        self.log(f'{stage}_iou', iou, on_epoch=True, prog_bar=True, logger=True, batch_size=images.size(0))

        return loss

    def training_step(self, batch, batch_idx):
        """Performs a single training step."""
        return self._common_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        """Performs a single validation step."""
        return self._common_step(batch, batch_idx, "val")

    def configure_optimizers(self):
        """
        Configures the optimizer (AdamW) and learning rate scheduler (CosineAnnealingLR).

        Returns:
            dict: A dictionary containing the optimizer and the LR scheduler configuration.
        """
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.hparams.lr, 
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=self.trainer.max_epochs, 
            eta_min=1e-6
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch", # Step the scheduler at the end of each epoch
            },
        }

Train

In [None]:
def train_model(
    image_dir,
    mask_dir,
    checkpoint_dir,
    batch_size=4,
    num_workers=2,
    lr=1e-4,
    weight_decay=1e-3,
    encoder_name="resnet34",
    max_epochs=100,
    patience=10,
    precision="16-mixed",
    augs_train=None,
    augs_val=None,
    resume_from_checkpoint=None
):
    """
    A complete training function that encapsulates the entire training workflow.

    Args:
        image_dir (str): Path to the directory with training/validation images.
        mask_dir (str): Path to the directory with training/validation masks.
        checkpoint_dir (str): Path to the directory where checkpoints will be saved.
        batch_size (int, optional): Batch size for training. Defaults to 4.
        num_workers (int, optional): Number of workers for the dataloader. Defaults to 2.
        lr (float, optional): Learning rate. Defaults to 1e-4.
        weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-3.
        encoder_name (str, optional): Encoder to use for the model. Defaults to "resnet34".
        max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 100.
        patience (int, optional): Patience for early stopping. Defaults to 10.
        precision (int or str, optional): Training precision (e.g., 16 or '16-mixed'). Defaults to 16.
        augs_train (A.Compose, optional): Custom training augmentations. If None, uses a strong default.
        augs_val (A.Compose, optional): Custom validation augmentations. If None, uses a default.
        resume_from_checkpoint (str, optional): Path to a checkpoint to resume training from. Defaults to None.
    """
    # Define default augmentations if none are provided
    if augs_train is None:
        augs_train = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.3),
            A.Affine(
                scale=(0.9, 1.1),
                translate_percent=0.05,
                rotate=(-15, 15),
                p=0.5,
                border_mode=cv2.BORDER_CONSTANT
            ),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.GaussianBlur(blur_limit=(3, 7), p=0.1),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

    if augs_val is None:
        augs_val = A.Compose([
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

    # Initialize DataModule
    data_module = GeoImageDataModule(
        image_dir=image_dir,
        mask_dir=mask_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        augmentations_train=augs_train,
        augmentations_val=augs_val
    )

    # Initialize Model
    model = GeoImageSegmentModel(lr=lr, weight_decay=weight_decay, encoder_name=encoder_name)

    # Initialize Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename=f"{encoder_name}-best-model-{{epoch:02d}}-{{val_iou:.4f}}",
        save_top_k=1,
        monitor="val_iou",
        mode="max"
    )
    early_stopping_callback = EarlyStopping(monitor="val_iou", patience=patience, mode="max")

    # Initialize Trainer
    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        precision=precision,
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback, early_stopping_callback]
    )

    # Start Training
    trainer.fit(
        model,
        datamodule=data_module,
        ckpt_path=resume_from_checkpoint
    )

In [None]:
IMAGE_DIR = "data/AOI_3_Paris/PS-RGB"
MASK_DIR = "data/AOI_3_Paris/geojson_roads"
CHECKPOINT_DIR = "checkpoints"

train_model(
    image_dir=IMAGE_DIR,
    mask_dir=MASK_DIR,
    checkpoint_dir=CHECKPOINT_DIR,
)

Evaluate

In [None]:
def run_evaluation(
    use_close,
    close_kernel_size,
    use_min_object_size,
    min_object_size,
    image_dir,
    mask_dir,
    checkpoint_path,
    max_batches=None
):
    """
    Loads a trained model from a checkpoint, runs it over a validation/test set,
    and computes the IoU for both the raw and post-processed predictions.

    Args:
        use_close (bool): If True, applies morphological closing.
        close_kernel_size (int): The size of the kernel for morphological closing.
        use_min_object_size (bool): If True, removes small objects from the thick mask.
        min_object_size (int): The minimum pixel area for an object to be kept.
        image_dir (str): Path to the directory of evaluation images.
        mask_dir (str): Path to the directory of evaluation masks.
        checkpoint_path (str): Path to the .ckpt file of the trained model.
        batch_size (int, optional): Batch size for evaluation. Defaults to 8.
        num_workers (int, optional): Number of dataloader workers. Defaults to 2.
        max_batches (int): limits the number of batches to evaluate. 
    """
    # --- 1. Load Model ---
    model = GeoImageSegmentModel.load_from_checkpoint(checkpoint_path=CHECKPOINT_PATH)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Model loaded and moved to {device}")

    # --- 2. Prepare Data ---
    augs_train = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.3),
    A.Affine(
        scale=(0.9, 1.1),
        translate_percent=0.05,
        rotate=(-15, 15),
        p=0.5,
        border_mode=cv2.BORDER_CONSTANT
    ),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.GaussianBlur(blur_limit=(3, 7), p=0.1),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
    ])
    
    augs_val = A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])
    data_module = GeoImageDataModule(
    image_dir=image_dir,
    mask_dir=mask_dir,
    batch_size=4,
    num_workers=2,
    augmentations_train=augs_train,
    augmentations_val=augs_val
    )
    data_module.setup()
    val_dataloader = data_module.val_dataloader()

    # --- 3. Run Evaluation Loop ---
    raw_iou_metric = BinaryJaccardIndex().to(device)
    post_processed_iou_metric = BinaryJaccardIndex().to(device)

    print("\nStarting evaluation over the entire validation set...")
    for batch_idx, batch in enumerate(tqdm(val_dataloader, desc="Evaluating")):
        if max_batches is not None and batch_idx >= max_batches:
            break

        images, ground_truth_masks = batch
        images = images.to(device)
        ground_truth_masks = ground_truth_masks.unsqueeze(1).to(device)

        with torch.no_grad():
            logits = model(images)
            probabilities = torch.sigmoid(logits)
            raw_predictions = (probabilities > 0.5)

        # Update the raw IoU metric
        raw_iou_metric.update(raw_predictions, ground_truth_masks.int())

        # Post-process the predictions and update the second metric
        post_processed_preds_batch = []
        for i in range(images.shape[0]):
            raw_pred_np = raw_predictions[i].squeeze().cpu().numpy()
            post_processed_pred = post_process_mask(
                raw_pred_np,
                use_close=use_close,
                close_kernel_size=close_kernel_size,
                use_min_object_size=use_min_object_size,
                min_object_size=min_object_size,
            )
            post_processed_preds_batch.append(post_processed_pred)

        # Convert the list of post-processed numpy arrays to a tensor for the metric
        pp_preds_tensor = torch.from_numpy(np.stack(post_processed_preds_batch)).unsqueeze(1).to(device)
        post_processed_iou_metric.update(pp_preds_tensor, ground_truth_masks.int())

    # --- 4. Print Final Results ---
    # Compute the final macro-IoU score from the accumulated stats
    avg_raw_iou = raw_iou_metric.compute()
    avg_pp_iou = post_processed_iou_metric.compute()

    print("\n--- Evaluation Complete ---")
    print(f"Average RAW Prediction IoU:      {avg_raw_iou.item():.4f}")
    print(f"Average POST-PROCESSED IoU:      {avg_pp_iou.item():.4f}")

    return avg_pp_iou

def plot_evaluation_results(
    CHECKPOINT_PATH,
    image_dir,
    mask_dir,
):
    """
    Loads a trained model from a checkpoint, runs it over a batch from the validation set,
    and visualizes the original images, ground truth masks, raw predictions, and post-processed predictions.

    Args:
        CHECKPOINT_PATH (str): Path to the .ckpt file of the trained model.
        image_dir (str): Path to the directory of evaluation images.
        mask_dir (str): Path to the directory of evaluation masks.
    """
    # --- 1. Load Your Trained Model ---
    model = GeoImageSegmentModel.load_from_checkpoint(checkpoint_path=CHECKPOINT_PATH)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Model loaded and moved to {device}")


    # --- 2. Set up the DataModule to get a batch ---
    augs_train = A.Compose([
        A.HorizontalFlip(p=0.5), 
        A.VerticalFlip(p=0.5),  
        A.RandomRotate90(p=0.3),  
        A.Affine(
            scale=(0.9, 1.1),      
            translate_percent=0.05,
            rotate=(-15, 15),      
            p=0.5,
            border_mode=cv2.BORDER_CONSTANT
        ),
        A.RandomBrightnessContrast(  
            brightness_limit=0.2,
            contrast_limit=0.2,
            p=0.3
        ),
        A.GaussianBlur(blur_limit=(3, 7), p=0.1),  
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),   
        ToTensorV2()
    ])

    augs_val = A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    data_module = GeoImageDataModule(
        image_dir=image_dir,
        mask_dir=mask_dir,
        batch_size=4,
        num_workers=2,
        augmentations_train=augs_train,
        augmentations_val=augs_val
    )
    data_module.setup()

    # Create the dataloader
    val_dataloader = data_module.val_dataloader()

    # Create the iterator object ONCE
    data_iterator = iter(val_dataloader)

    # Get the FIRST batch
    images, ground_truth_masks = next(data_iterator)
    print("Fetched batch 1")

    # --- 3. Get Model Predictions for the Batch ---
    images = images.to(device)
    with torch.no_grad():
        logits = model(images)
        probabilities = torch.sigmoid(logits)
        predictions = (probabilities > 0.5).cpu()


    # --- 4. Visualize Results for Each Image in the Batch ---

    batch_size = images.shape[0]
    plt.figure(figsize=(20, 5 * batch_size)) 

    for i in range(batch_size):
        # Get the i-th image, ground truth, and prediction from the batch

        image_to_plot = images[i]
        gt_mask_to_plot = ground_truth_masks[i].squeeze()
        pred_mask_to_plot = predictions[i].squeeze()

        # --- Apply the full post-processing pipeline ---
        final_prediction = post_process_mask(pred_mask_to_plot.numpy())

        # --- Plotting ---
        # Column 1: Original Image
        plt.subplot(batch_size, 4, i * 4 + 1)
        plt.imshow(visualize_image(image_to_plot))
        plt.title(f"Image {i+1}")
        plt.axis("off")

        # Column 2: Ground Truth
        plt.subplot(batch_size, 4, i * 4 + 2)
        plt.imshow(gt_mask_to_plot, cmap='gray')
        plt.title(f"Ground Truth {i+1}")
        plt.axis("off")

        # Column 3: Raw Model Prediction (Thick)
        plt.subplot(batch_size, 4, i * 4 + 3)
        plt.imshow(pred_mask_to_plot, cmap='gray')
        plt.title(f"Raw Prediction {i+1}")
        plt.axis("off")

        # Column 4: Fully Post-Processed Prediction
        plt.subplot(batch_size, 4, i * 4 + 4)
        plt.imshow(final_prediction, cmap='gray')
        plt.title(f"Post-Processed {i+1}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
# Define the paths to your data and best model
CHECKPOINT_PATH = "checkpoints/baseline-best-model-epoch=76-val_iou=0.5967.ckpt"
IMAGE_DIR = "data/AOI_3_Paris/PS-RGB"
MASK_DIR = "data/AOI_3_Paris/geojson_roads"
BATCH_SIZE = 6 # Use a larger batch size for evaluating more of the dataset (the val dataset will have 13 batches in total)

# Define the post-processing parameters you want to test
# These could be the best ones you found from your Optuna study
best_params = {
    'use_close': True,
    'close_kernel_size': 5,
    'use_min_object_size': True,
    'min_object_size': 1500,
}

# Run the evaluation
run_evaluation(use_close=True,
                    close_kernel_size=5
                    ,use_min_object_size=True,
                    min_object_size=1500,
                    max_batches=BATCH_SIZE,
                    image_dir=IMAGE_DIR,
                    mask_dir=MASK_DIR,
                    checkpoint_path=CHECKPOINT_PATH)

# Visualize the results (the first batch)
plot_evaluation_results(CHECKPOINT_PATH=CHECKPOINT_PATH,
                        image_dir=IMAGE_DIR,
                        mask_dir=MASK_DIR)