In [1]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

# Training a Semantic Segmentation Model for Mapping Settlements

In this notebook, we will train a semantic segmentation model to identify **buildings** and **solar panels** in aerial imagery.

Accurate and up-to-date mapping is essential for supporting humanitarian efforts, particularly in refugee camps where infrastructure monitoring and development planning are critical. Traditional mapping services often fall short in these rapidly changing environments. Semantic segmentation models can accelerate the mapping process, providing detailed information that helps ensure resources are effectively allocated to meet the needs of the population.

We will cover the following steps:
1. [Creating Semantic Segmentation Masks](#Creating-Semantic-Segmentation-Masks)
    - [Collecting Annotations](#collecting-annotations)
    - [Burning Masks](#burning-masks)
    - [Clipping Masks to Imagery](#clipping-masks-to-imagery)
2. [Sampling Chips to Create a Dataset](#sampling-chips-for-semantic-segmentation-dataset)
2. [Model Training](#Model-Training)
3. [Evaluation and Inference](#evaluation-and-inference)

In [2]:
import argparse
import glob
import os
import pickle
import subprocess
from pathlib import Path

import geopandas as gpd
import matplotlib.colors as colors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
import shapely.geometry
from shapely.ops import unary_union




## About the Data
The data used in this demo is sourced from OpenAerialMap and OpenStreetMap (OSM). We will use an aerial imagery scene downloaded from [OpenAerialMap](https://map.openaerialmap.org/#/34.745614528656006,3.7642264421477485,15/square/12233202121301210/6447e8162155f0000546362b?_k=0r8ow6), which covers a portion of the Kalobeyei Integrated Settlement Camp 2. This scene is provided as a GeoTIFF file and is divided into two spatially disjoint areas: one for training and validation, and the other for testing (Figure 1).

We use the bounds of the aerial imagery scene to extract features from OpenStreetMap using the `osmnx` package. Specifically, we retrieve features with the following OSM tags:
- `building=True`
- `generator:source=solar`
- `amenity=toilets`

Note that toilets in this area may be tagged either with `building=toilets` or `amenity=toilets`, so we retrieve features with both tags.

The raw OSM features do not perfectly align with our imagery scene (Figure 2). To correct this, we use QGIS to manually adjust the polygon features to align with the buildings visible in the imagery. Additionally, the `amenity=toilets` features are returned as point geometries. We use QGIS to buffer these points into square polygons and align them with the visible toilet structures before merging them with the other polygon layer (Figure 3). We also manually label numerous toilets and solar panels that were not initially labeled.

We also draw polygons to represent background areas without buildings, solar panels, or toilets, and label some background points over objects like cars, fences, hedges, and wells to serve as hard negative samples.

<div style="display: flex; justify-content: space-between;">

<div style="text-align: center; width: 32%;">
<img src="./assets/traintest_split.jpg" alt="Kalobeyei Camp 2" style="height: 250px; object-fit: cover;"/>
<p>Figure 1: Kalobeyei Camp 2</p>
</div>

<div style="text-align: center; width: 32%;">
<img src="./assets/raw_osm_features.png" alt="Raw OSM Features" style="height: 250px; object-fit: cover;"/>
<p>Figure 2: Raw OSM Features</p>
</div>

<div style="text-align: center; width: 32%;">
<img src="./assets/cleaned_osm_features.png" alt="Cleaned OSM Features" style="height: 250px; object-fit: cover;"/>
<p>Figure 3: Cleaned OSM Features</p>
</div>

</div>

In [3]:
# Directories for raw, interim, and processed data
RAW_DATA_DIR = './data/raw'
INTERIM_DATA_DIR = './data/interim'
PROCESSED_DATA_DIR = './data/processed'

# Paths to raw imagery files
TRAINVAL_IMAGE_FILE = f'{RAW_DATA_DIR}/images/trainval_Kalobeyei_2B_Flight_03.tif'
TEST_IMAGE_FILE = f'{RAW_DATA_DIR}/images/test_Kalobeyei_2B_Flight_03.tif'

# Paths to raw annotation files
ANNOTATIONS_FILE = f'{RAW_DATA_DIR}/annotations/osm_buildings-toilets-solar_aligned.geojson'
BACKGROUND_POLYGONS_FILE = f'{RAW_DATA_DIR}/annotations/background_polygons.geojson'
CAR_POINTS_FILE = f'{RAW_DATA_DIR}/annotations/car_points.geojson'
HEDGE_POINTS_FILE = f'{RAW_DATA_DIR}/annotations/hedge_points.geojson'
WELL_POINTS_FILE = f'{RAW_DATA_DIR}/annotations/well_points.geojson'

## Creating Semantic Segmentation Masks

In this section, we will continue preparing our data for training the semantic segmentation model. Our OSM and manually drawn annotations contain vector data for buildings, solar panels, and background regions. However, to train our model, we need pixelwise segmentation masks.

### Collecting Annotations
The OSM annotation file contains multiple columns with OSM tags for buildings and solar panels, while the background polygons file has only a `class` column with the value "background" for all rows. As a first step, we will resolve these inconsistencies to create a uniform polygon dataset.

To do this, we will use the `collect_annotations.py` script from this repository. This script performs the following tasks:
- Collects building and solar panel polygons from the OSM annotations and background polygons from our manual annotation file.
- Creates building boundary polygons by applying a buffer to the exterior of each building polygon to improve separation between closely packed buildings.
- Adds a `class` column for all the features.
- Combines the features into one dataframe with only the `class` and `geometry` columns.

In [None]:
# Define output file targets for collected annotations
polygon_annotations_file = f'{INTERIM_DATA_DIR}/annotations/{os.path.basename(ANNOTATIONS_FILE).replace(".geojson", "_polygons.geojson")}'
output_file = f'{INTERIM_DATA_DIR}/semantic_segmentation_classes.gpkg'

# Construct the command to collect the annotations
command = (
    f"python ../scripts/semantic_segmentation/collect_annotations.py "
    f"--object-annotations {ANNOTATIONS_FILE} "
    f"--polygon-annotations {polygon_annotations_file} "
    f"--background-annotations {BACKGROUND_POLYGONS_FILE} "
    f"--save-path {output_file}"
)
# Display and run the command
print(command)
subprocess.run(command, shell=True)

# Load the collected annotations
semantic_segmentation_classes = gpd.read_file(output_file)
display(semantic_segmentation_classes.head())
display(semantic_segmentation_classes['class'].value_counts())


### Burning Masks
Now, we will create the semantic segmentation masks by using the `create_mask.py` script. This script burns polygon features into a raster in a specified order, converting our vector data into pixelwise segmentation masks suitable for training our model.

In [None]:
# Rasterize the annotations to create masks
for image_file in [TRAINVAL_IMAGE_FILE, TEST_IMAGE_FILE]:
    # Define the output file path for the mask
    output_file = os.path.join(
        INTERIM_DATA_DIR,
        'unclipped_masks',
        os.path.basename(image_file)
    )
    os.makedirs(os.path.dirname(output_file), exist_ok=True)  # Ensure the directory exists

    # Define the label file and column
    label_file = f'{INTERIM_DATA_DIR}/semantic_segmentation_classes.gpkg'
    label_column = 'class'

    # Define the order of labels for the mask
    label_order = ["background", "building", "building_boundary", "solar"]
    label_order_cli = ' '.join(label_order)  # Convert list to space-separated string for CLI

    # Construct the command to create the mask
    command = (
        f"python ../scripts/semantic_segmentation/create_mask.py "
        f"-i {image_file} "
        f"-l {label_file} "
        f"--label-column {label_column} "
        f"-o {output_file} "
        f"--label-order {label_order_cli}"
    )

    # Display and run the command, suppressing its output
    print(command)
    subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


Let us plot the ground truth segmentation mask raster we just created. Each pixel in the raster will be colored according to the class that was burned into the raster.

In [None]:
# Create a listed colormap for the mask
class_colors = {
    "unlabeled": "#D3D3D3",       # Light Grey
    "background": "#D2B48C",      # Tan
    "building": "#0000FF",        # Blue
    "building_boundary": "#FF0000",  # Red
    "solar": "#FFFF00"            # Yellow
}
segmentation_cmap = colors.ListedColormap([class_colors[key] for key in class_colors.keys()])


# Plot the mask for the training/validation image
mask_file = f'{INTERIM_DATA_DIR}/unclipped_masks/{os.path.basename(TRAINVAL_IMAGE_FILE)}'
with rasterio.open(mask_file) as src:
    mask = src.read(1)

plt.figure(figsize=(4, 6))
plt.imshow(mask, cmap=segmentation_cmap, vmin=0, vmax=4)
plt.axis('off')

legend_patches = [mpatches.Patch(color=segmentation_cmap(i), label=label.replace("_", " ")) for i, label in enumerate(class_colors.keys())]
plt.legend(handles=legend_patches, loc='center left', bbox_to_anchor=(1, 0.5), frameon=False)

plt.show()

### Clipping Masks to Imagery

The current mask spans the entire bounding box of the train/validation imagery, including areas outside the actual imagery footprint. To accurately confine the mask to the imagery, we will create footprint shapefiles for the train/validation and test imagery files and use these shapefiles to clip the masks.

In [None]:
# Define the directory to save the clipped masks
footprints_dir = f'{INTERIM_DATA_DIR}/image_footprints/'
os.makedirs(footprints_dir, exist_ok=True)  # Ensure the directory exists

# List of image files to process
for image_file in [TRAINVAL_IMAGE_FILE, TEST_IMAGE_FILE]:
    # Define the output file path for the footprint shapefile
    output_file = os.path.join(footprints_dir, os.path.basename(image_file).replace('.tif', '_footprint.shp'))
    
    # Construct the command to generate the footprint shapefile
    command = f"python ../scripts/data_preprocessing/get_raster_footprint.py -f {image_file} -o {output_file}"
    print(command)  # Display the command being run
    
    # Execute the command, suppressing its output
    subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

# Plot the footprints
print("Plotting the image footprints...")
fig, ax = plt.subplots(figsize=(5, 5))
for image_file, color in zip([TRAINVAL_IMAGE_FILE, TEST_IMAGE_FILE], ['green', 'blue']):
    footprint_file = os.path.join(footprints_dir, os.path.basename(image_file).replace('.tif', '_footprint.shp'))
    extent = gpd.read_file(footprint_file)
    extent.plot(ax=ax, facecolor=color, alpha=0.5, edgecolor='none', linewidth=2)

# Create custom legend patches
train_patch = mpatches.Patch(facecolor='green', alpha=0.5, edgecolor='none', label='Train / Val')
test_patch = mpatches.Patch(facecolor='blue', alpha=0.5, edgecolor='none', label='Test')
plt.legend(handles=[train_patch, test_patch], loc='upper left', frameon=False)

plt.show()

In [None]:
# Clip masks to the footprints using gdalwarp cutline
print("Clipping the masks to the image footprints...")
for image_file in [TRAINVAL_IMAGE_FILE, TEST_IMAGE_FILE]:
    # Define the paths to the mask, footprint, and output files
    mask_file = os.path.join(
        INTERIM_DATA_DIR,
        'unclipped_masks',
        os.path.basename(image_file)
    )
    footprint_file = os.path.join(footprints_dir, os.path.basename(image_file).replace('.tif', '_footprint.shp'))
    output_file = os.path.join(
        RAW_DATA_DIR,
        'semantic_segmentation',
        'masks',
        os.path.basename(image_file)
    )
    os.makedirs(os.path.dirname(output_file), exist_ok=True) # Ensure the directory exists

    # Construct the command to clip the mask
    command = f"gdalwarp -cutline {footprint_file} -crop_to_cutline {mask_file} {output_file}"

    # Display and run the command, suppressing its output
    print(command)
    subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


# Plot the mask for the training/validation image
print("Plotting the clipped mask for the training/validation image...")
mask_file = f'{RAW_DATA_DIR}/semantic_segmentation/masks/{os.path.basename(TRAINVAL_IMAGE_FILE)}'
with rasterio.open(mask_file) as src:
    mask = src.read(1)

plt.figure(figsize=(10, 10))
plt.imshow(mask, cmap=segmentation_cmap, vmin=0, vmax=4)
plt.axis('off')

legend_patches = [mpatches.Patch(color=segmentation_cmap(i), label=label.replace("_", " ")) for i, label in enumerate(class_colors.keys())]
plt.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1, 1))
plt.show()

## Sampling Chips for Semantic Segmentation Dataset

In this step, we will extract smaller image chips from our paired imagery and mask scenes to create a dataset for training our model. Pre-sampling these chips helps balance the dataset, ensuring that small features, such as solar panels, which occupy a relatively small area in the scenes, are adequately represented.

The `sample_chips` script is responsible for generating these smaller image chips. To tailor the chip sampling process to our specific data, we need to customize the `_get_candidate_points` function. This function implements a sampling strategy to generate diverse chips by:
- Including all solar panel centroids, since solar panels are a minority class.
- Including building centroids that don't contain solar panels (since those that do will already be present in chips sampled from the solar panel centroids).
- Including manually specified as well as randomly drawn background points, ensuring they are not too close to other points.

With these candidate chip centroids, the `sample_chips` function then samples the chips.

Here’s the code to override the `_get_candidate_points` function and set up the arguments for the `sample_chips` function:

In [None]:
import scripts.semantic_segmentation.sample_chips
from scripts.semantic_segmentation.sample_chips import _filter_dataframe, _sample_diverse_values, _sample_points_from_footprints
from src.geo_utils import concat_geo_files, exclude_points_within_buffer

def _get_candidate_points(seed):
    """
    Helper function to get candidate points for chip sampling.
    It loads all polygon annotations and filters for solar and building polygons. Then,
    it removes building polygons that contain solar polygons (since these will already
    be included in the solar polygons). Next, it subsamples a diverse set of building
    polygons (since there are many more building polygons than solar polygons). Then,
    it gets centroids of these solar and building polygons.

    To get background points, it samples points from footprints and also loads
    manually-selected background points. It combines the footprint points and manual
    points, making sure to exclude points that are too close to solar or building
    centroids.

    Returns:
    tuple: A tuple containing the solar points, building points, and background points.
    """
    print("Getting candidate points for chip sampling...")
    POLYGON_ANNOTATIONS_PATH = (polygon_annotations_file)
    FOOTPRINT_PATHS = [
        "./data/interim/image_footprints/trainval_Kalobeyei_2B_Flight_03_footprint.shp",
        "./data/interim/image_footprints/test_Kalobeyei_2B_Flight_03_footprint.shp",
    ]
    BACKGROUND_ANNOTATIONS_PATHS = [
        CAR_POINTS_FILE,
        WELL_POINTS_FILE,
        HEDGE_POINTS_FILE,
    ]
    DATA_CRS = "EPSG:32636"

    # Load all polygon annotations
    polygons = gpd.read_file(POLYGON_ANNOTATIONS_PATH).to_crs(DATA_CRS)

    # Filter for solar and building polygons
    solar_conditions = {
        "power": lambda x: x == "generator",
        "generator:source": lambda x: x == "solar",
    }
    building_conditions = {"building": lambda x: ~x.isna()}

    solar = _filter_dataframe(polygons, solar_conditions)
    bldg = _filter_dataframe(polygons, building_conditions)
    print(f"Number of SOLAR polygons: {len(solar)}")

    # Get building polygons that do NOT contain solar polygons
    bldg_contains_solar = gpd.sjoin(
        bldg, solar, how="inner", predicate="contains"
    )
    bldg_no_solar = bldg[~bldg.index.isin(bldg_contains_solar.index)].copy()

    # Subsample diverse set of building polygons
    bldg_no_solar = _sample_diverse_values(
        bldg_no_solar, "building", num_to_sample=150, seed=seed
    )
    print(f"Number of BUILDING polygons: {len(bldg_no_solar)}")

    # Get centroids of solar and building polygons
    solar["geometry"] = solar.centroid
    bldg_no_solar["geometry"] = bldg_no_solar.centroid

    # Sample points from footprints
    footprint_points = _sample_points_from_footprints(
        FOOTPRINT_PATHS, DATA_CRS, num_to_sample=400, seed=seed
    )
    # Load manually-selected background points
    manual_points = concat_geo_files(BACKGROUND_ANNOTATIONS_PATHS, DATA_CRS)
    # Combine footprint points and manual points
    background = pd.concat([footprint_points, manual_points]).reset_index(
        drop=True
    )
    background["background_idx"] = background.index
    # Exclude points that are too close to solar or building centroids
    background = exclude_points_within_buffer(
        background, pd.concat([solar, bldg_no_solar]), 20
    )
    print(f"Number of BACKGROUND points: {len(background)}")

    return solar, bldg_no_solar, background

def _find_raster_file_for_bbox(bbox, raster_files, footprint_files_dir=f'{INTERIM_DATA_DIR}/image_footprints'):
    """
    Find the raster file that contains the bounding box
    """
    bbox_left, bbox_bottom, bbox_right, bbox_top = bbox
    bbox_polygon = shapely.geometry.box(bbox_left, bbox_bottom, bbox_right, bbox_top)
    for raster_file in raster_files:
        # Find the footprint file for the raster file
        footprint_file = glob.glob(f'{footprint_files_dir}/*{os.path.basename(raster_file).replace(".tif", "_footprint.shp")}')
        if len(footprint_file) == 0:
            raise ValueError(f'No footprint file found for {raster_file}')
        footprint_file = footprint_file[0]
        
        # Load the footprint file
        footprint = gpd.read_file(footprint_file)
        footprint = unary_union(footprint.geometry)
        
        # Check if the footprint completely contains the bounding box
        if footprint.contains(bbox_polygon):
            return raster_file
    # If no raster file contains the bounding box, return None
    return None

scripts.semantic_segmentation.sample_chips._get_candidate_points = _get_candidate_points
scripts.semantic_segmentation.sample_chips._find_raster_file_for_bbox = _find_raster_file_for_bbox

args = argparse.Namespace(
    seed=42,
    chip_size=256,
    input_imagery_dir = f'{RAW_DATA_DIR}/images',
    input_mask_dir = f'{RAW_DATA_DIR}/semantic_segmentation/masks',
    chip_output_dir = f'{PROCESSED_DATA_DIR}/chipped_datasets/semantic_segmentation',
    chip_locations_save_path = f'{INTERIM_DATA_DIR}/semantic_segmentation/chip_locations.gpkg',
)

# If this cell was already run and there are already chips in the output directory,
# remove them before re-sampling chips
for chip_file in Path(args.chip_output_dir).rglob('*.tif'):
    os.remove(chip_file)

scripts.semantic_segmentation.sample_chips.sample_chips(args)

In [None]:
image_chips_dir = f'{PROCESSED_DATA_DIR}/chipped_datasets/semantic_segmentation/images'
train_chips = glob.glob(f'{image_chips_dir}/trainval_Kalobeyei_2B_Flight_03_chip_*.tif')
print(f"Number of training chips: {len(train_chips)}")
test_chips = glob.glob(f'{image_chips_dir}/test_Kalobeyei_2B_Flight_03_chip_*.tif')
print(f"Number of test chips: {len(test_chips)}")

np.random.seed(42)
random_train_chips = np.random.choice(train_chips, 10, replace=False)

# Plot the 10 random image chips
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()

for ax, chip in zip(axes, random_train_chips):
    with rasterio.open(chip) as src:
        image = src.read([1, 2, 3]).transpose(1, 2, 0)
    mask_chip = chip.replace('images', 'masks')
    with rasterio.open(mask_chip) as src:
        mask = src.read(1)
    ax.imshow(image)
    ax.imshow(mask, cmap=segmentation_cmap, alpha=0.3, vmin=0, vmax=4)
    ax.set_title(os.path.basename(chip))
    ax.axis('off')

plt.tight_layout()
plt.show()


## Model Training

In this section, we will train a U-Net model with a `resnext50_32x4d` backbone on the prepared dataset. The U-Net architecture is popular for semantic segmentation due to its ability to capture both spatial and contextual information, while the `resnext50_32x4d` backbone is a popular feature extractor. While we are using this specific setup, other architectures and backbones can also be used.

For this demo, we will use cross-entropy loss, which is commonly used for multi-class classification problems. Cross-entropy loss measures the performance of a classification model whose output is a probability value between 0 and 1. However, other loss functions like Jaccard loss (also known as Intersection over Union loss) can be used to optimize the model further, especially in cases where the dataset is imbalanced or where the overlap between predicted and actual segments is critical. For this demo, we will use pixelwise cross-entropy loss, which measures the performance of a model by comparing the predicted class probability for each pixel to the true class label. Other loss functions like Jaccard loss can also be used for further optimization.

Below is the configuration for the `CustomLogSemanticSegmentation` task used in our training script:

```python
task = CustomLogSemanticSegmentation(
    model="unet",
    backbone="resnext50_32x4d",
    weights="imagenet",  # use pretrained imagenet weights
    in_channels=3,
    num_classes=len(class_names) + 1,  # +1 for 0 "not labeled" class
    loss="ce",  # cross-entropy loss
    class_weights=args.class_weights,
    ignore_index=0,  # class 0 represents "not labeled" in the label masks
    learning_rate=args.learning_rate,
    learning_rate_schedule_patience=10,
    train_metrics_file=base_path / "train_metrics.csv",
    val_metrics_file=base_path / "val_metrics.csv",
    test_metrics_file=base_path / "test_metrics.csv",
)

In [None]:
# Define the output directory for the model
output_dir = './outputs'

# Define the directories for the chipped images and masks
chip_dir = f'{PROCESSED_DATA_DIR}/chipped_datasets/semantic_segmentation/images'
mask_dir = f'{PROCESSED_DATA_DIR}/chipped_datasets/semantic_segmentation/masks'
# Define the path to the tile split file
tile_split_file = './config/dataset_splits.yml'

# Construct the command to train the model
command = (
    f"python ../scripts/semantic_segmentation/train.py "
    "--exp-version demo "
    f"--chip-dir {chip_dir} "
    f"--mask-dir {mask_dir} "
    f"--tile-split-file {tile_split_file} "
    f"--output-dir {output_dir} "
    "--max-epochs 2 "
    # "--gpu-id -1"
)

# Display and run the command
print(command)
subprocess.run(command, shell=True)

## Evaluation and Inference

The `train.py` saves predictions for training, validation, and test set chips using the best model from a training run. Let's visualize a sample of those saved predictions from the most recent model training run.

In [None]:
# Load chip predictions
output_dirs = sorted(Path(output_dir).glob('semantic_segmentation/demo-*-*-*'))
latest_output_dir = "./" + str(output_dirs[-1]) # Get most recent output directory
print(f"Latest output directory: {latest_output_dir}")
chip_preds_f = list(Path(latest_output_dir).glob('epoch=*-step=*_chip_inference.pkl'))[0]
with chip_preds_f.open('rb') as f:
    chip_preds = pickle.load(f)

# Plot 5 random chips with predictions
np.random.seed(42)
random_chips = np.random.choice(list(chip_preds['test_chips'].keys()), 5, replace=False)

fig, axes = plt.subplots(5, 3, figsize=(8, 10))
for i, chip_path in enumerate(random_chips):
    chip_rel_path = chip_path.relative_to('./')
    with rasterio.open(chip_rel_path) as ds:
        chip_img = ds.read()[0:3].transpose(1, 2, 0)
    mask_path = Path(*[part if part != 'images' else 'masks' for part in chip_rel_path.parts])
    with rasterio.open(mask_path) as ds:
        mask_img = ds.read().squeeze()
    chip_pred = chip_preds['test_chips'][chip_path]
    axes[i, 0].imshow(chip_img)
    axes[i, 0].set_title('Image')
    axes[i, 1].imshow(mask_img, cmap=segmentation_cmap, vmin=0, vmax=4)
    axes[i, 1].set_title('Mask')
    axes[i, 2].imshow(chip_pred, cmap=segmentation_cmap, vmin=0, vmax=4)
    axes[i, 2].set_title('Prediction')
    for ax in axes[i]:
        ax.axis('off')
plt.tight_layout()

In a real-world scenario, we would train the model for many more epochs, and possibly on a larger dataset. We've included outputs from training the model for 50 total epochs--let's compare the test chip predictions we got from that run with the ones we have.

In [None]:
# Load chip predictions from model trained for 100 epochs
max_epochs_100_output_dir = Path('./outputs/semantic_segmentation/demo-100-epochs')
max_epochs_100_chip_preds_f = list(max_epochs_100_output_dir.glob('epoch=*-step=*_chip_inference.pkl'))[0]
with max_epochs_100_chip_preds_f.open('rb') as f:
    max_epochs_100_chip_preds = pickle.load(f)

# Create a figure with subplots
fig, axes = plt.subplots(5, 4, figsize=(10, 10))

# Iterate over random chips and plot the images, masks, and predictions
for i, chip_path in enumerate(random_chips):
    # Get the relative path of the chip
    chip_rel_path = chip_path.relative_to('./')
    
    # Load the chip image
    with rasterio.open(chip_rel_path) as ds:
        chip_img = ds.read()[0:3].transpose(1, 2, 0)
    
    # Construct the mask path and load the mask image
    mask_path = Path(*[part if part != 'images' else 'masks' for part in chip_rel_path.parts])
    with rasterio.open(mask_path) as ds:
        mask_img = ds.read().squeeze()
    
    # Get the predictions for the chip
    chip_pred = chip_preds['test_chips'][chip_path]
    max_epochs_100_chip_pred = max_epochs_100_chip_preds['test_chips'][chip_path]
    
    # Plot the chip image
    axes[i, 0].imshow(chip_img)
    axes[i, 0].set_title('Image')
    
    # Plot the mask image
    axes[i, 1].imshow(mask_img, cmap=segmentation_cmap, vmin=0, vmax=4)
    axes[i, 1].set_title('Mask')
    
    # Plot the prediction
    axes[i, 2].imshow(chip_pred, cmap=segmentation_cmap, vmin=0, vmax=4)
    axes[i, 2].set_title('Prediction')
    
    # Plot the prediction from the model trained for 100 epochs
    axes[i, 3].imshow(max_epochs_100_chip_pred, cmap=segmentation_cmap, vmin=0, vmax=4)
    axes[i, 3].set_title('Prediction (100 epochs)')
    
    # Turn off axis for all subplots in the current row
    for ax in axes[i]:
        ax.axis('off')

# Adjust layout to prevent overlap
plt.tight_layout()

In [None]:
def read_test_metrics(file_path):
    """
    Reads the test metrics from a CSV file and returns it as a dictionary.
    
    Args:
        file_path (str or Path): Path to the test metrics CSV file.
    
    Returns:
        dict: Dictionary containing the test metrics.
    """
    file_path = Path(file_path)
    
    # Read the single line from the file
    with file_path.open('r') as f:
        line = f.readline().strip()
    
    # Convert the line to a dictionary
    metrics_dict = eval(line)
    
    return metrics_dict

# Path to the test metrics file from 100-epoch model
max_epochs_100_test_metrics_f = Path('./outputs/semantic_segmentation/demo-100-epochs/test_metrics.csv')
max_epochs_100_test_metrics = read_test_metrics(max_epochs_100_test_metrics_f)

# Path to the test metrics file from latest model
output_dirs = sorted(Path(output_dir).glob('semantic_segmentation/demo-*'))
latest_output_dir = output_dirs[-1]  # Get most recent output directory
latest_test_metrics_f = Path(latest_output_dir) / 'test_metrics.csv'
latest_test_metrics = read_test_metrics(latest_test_metrics_f)

# Create a DataFrame from the test metrics
model_run = ['demo-100-epochs', 'latest']
multiclass_accuracy = [
    max_epochs_100_test_metrics['test_MulticlassAccuracy'],
    latest_test_metrics['test_MulticlassAccuracy']
]
multiclass_jaccard = [
    max_epochs_100_test_metrics['test_MulticlassJaccardIndex'],
    latest_test_metrics['test_MulticlassJaccardIndex']
]
metrics_df = pd.DataFrame({
    'Model Run': model_run,
    'Multiclass Accuracy': multiclass_accuracy,
    'Multiclass Jaccard': multiclass_jaccard
})

# Display the DataFrame
display(metrics_df)

Let's use the best checkpoint from the latest model training run to perform inference on the train/val and test scenes.

In [None]:
checkpoint = list((Path(latest_output_dir) / 'checkpoints').glob('epoch=*-step=*.ckpt'))[0]
checkpoint_f = f'./{checkpoint}'
image_dir = './data/raw/images'
inference_output_dir = f'{latest_output_dir}/inference'
# device = 'cpu'

# Construct the command to run inference
command = (
    f"python ../scripts/semantic_segmentation/inference.py "
    f"--checkpoint {checkpoint_f} "
    f"--image-dir {image_dir} "
    f"--image-glob-pattern '*test_Kalobeyei_2B_Flight_03.tif' "
    f"--output-dir {inference_output_dir} "
    # f"--device {device}"
)

# Display and run the command
print(command)
subprocess.run(command, shell=True)

In [None]:
# Load scene predictions
scene_preds_f = list(Path(inference_output_dir).glob('*_test_Kalobeyei_2B_Flight_03.tif'))[0]
with rasterio.open(scene_preds_f) as src:
    scene_preds = src.read().squeeze()

fig, ax = plt.subplots(1, 2, figsize=(5, 25))
# Plot the entire scene predictions
ax[0].imshow(scene_preds, cmap=segmentation_cmap, vmin=0, vmax=4)
ax[0].set_title('Kalobeyei_2B Test Predictions', fontsize=10)
# Plot a zoomed in section of the predictions
ax[1].imshow(scene_preds[7000:8000, 4000:5000], cmap=segmentation_cmap, vmin=0, vmax=4)
ax[1].set_title('Zoomed-in Predictions', fontsize=10)
plt.show()
