# An Interactive Workflow for Geospatial ML Mapping.

<img src="https://github.com/GeoAIAfrica/interactive_geospatial_mapping/blob/main/static/image_analysis.webp?raw=1" width="50%" />

<a href="https://colab.research.google.com/github/GeoAIAfrica/interactive_geospatial_mapping/blob/main/end2end_geospatial_ml_mapping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**© AMLD 2024**. `MIT` License.

**Authors:** [Akram Zaytar](https://www.linkedin.com/in/akramz/), [Gilles Q. Hacheme](https://www.linkedin.com/in/gilles-q-hacheme-a0956ab7/), [Aisha Alaagib](https://www.linkedin.com/in/aishaalaagib/), [Girmaw A. Tadesse](https://www.linkedin.com/in/girmaw-abebe-tadesse/).

**Introduction:**

In this notebook, we will present an end-to-end workflow for geospatial mapping using deep neural networks.

We aim to cover the following:
1. Introduction to Geospatial Data
2. Pick a place & period of interest!
3. Load imagery into the interactive map!
4. Create a few labels for your object of interest! Export the Image/Mask!
5. _Train!_  Data augmentation, Regularization, Fine-tuning, object vectorization!
6. Export pixel-wise metrics for the local region!
7. Run the model over a much bigger region & find other instances!
8. Be ambitious! augmentation techniques, fusing layers (Sentinel-1), multi-class, other encoders/architectures, join our community!

**Topics:**

Content: <font color='blue'>`Geospatial Data Analysis`</font>, <font color='blue'>`Computer Vision`</font>, <font color='blue'>`Interactive Mapping`</font>.
Level: <font color='grey'>`Beginner`</font>, <font color='grey'>`Intermediate`</font>

**Outcome:**

- *The basics* of Geosptial Data analysis: learn about data formats & types, foundational concepts.
- *Interactive Mapping*: how can you acquire & prepare the inputs (satellite images) and create the targets (objects of interest) within a Jupyter notebook environment.
- *Geospatial ML*: learn about geospatial Train/Val splitting, Data augmentation, Regularization, Fine-tuning, object vectorization, Evaluation, and inference.

**Prerequisites:**
- Basic to intermediate knowledge in Python and machine learning.
- Familiarity with satellite imagery and geospatial data is beneficial but not mandatory.
- Installation of necessary software and tools as detailed in the workshop's GitHub repository README file.
- Participants are encouraged to install and set up the required tools prior to the workshop for a more efficient hands-on session.

**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.

## Installation and Imports

In [None]:
%pip install rioxarray -q
%pip install leafmap -q
%pip install torchgeo==0.4.0 -q
%pip install lightning -q
%pip install geopandas -q
%pip install localtileserver -q
%pip install pystac_client -q
%pip install planetary_computer -q
%pip install stackstac -q

In [None]:
import os
import warnings
import shutil
import hashlib
import warnings
import requests
import subprocess
from random import *
from tqdm import tqdm
from datetime import *
from pathlib import Path
from functools import reduce
from functools import partial
from typing import Any, Dict, cast

warnings.filterwarnings("ignore", category=FutureWarning)
if os.getenv("COLAB_RELEASE_TAG"):
    from google.colab import output
    output.enable_custom_widget_manager()

import numpy as np
import matplotlib.pyplot as plt

import rasterio
import stackstac
import pystac_client
import geopandas as gpd
import planetary_computer
from shapely.geometry import *
from rasterio.features import rasterize
from rasterio.transform import from_bounds
from localtileserver import get_leaflet_tile_layer, TileClient
from rasterio.windows import Window
import leafmap
from shapely.geometry import *

import torch
import torch.nn as nn
from lightning import LightningDataModule, LightningModule
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchgeo.datasets import RasterDataset
from torchgeo.datasets import stack_samples
from torchgeo.samplers import RandomBatchGeoSampler, GridGeoSampler
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from lightning.pytorch import LightningModule
import segmentation_models_pytorch as smp
from torchgeo.transforms import AugmentationSequential
import kornia.augmentation as K
from torchmetrics import MetricCollection
from torchmetrics.classification import (
    MulticlassJaccardIndex,
    Precision,
    Recall,
    F1Score,
)
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

In [None]:
# Create directories to be used later.
store = Path("./data")
store.mkdir(exist_ok=True)
logs = Path("./logs")
store.mkdir(exist_ok=True)
results = Path("./results")
store.mkdir(exist_ok=True)

---

## Introduction to Geospatial Data!

Geospatial data refers to **information that can be associated with locations on Earth**. It comes with attributes like **coordinates** and **geometry**.

Examples of geospatial data:
- Weather information.
- Transportation networks.
- Population density.

There are two types of geospatial data:

<div style="text-align:left;"> <figure> <img width="500px" src="https://i0.wp.com/pangeography.com/wp-content/uploads/2022/05/Raster_vector_tikz.png" /> <figcaption style="font-size:small;">Image credit: <a href="https://pangeography.com/geographic-data-structure-vector-data-and-raster-data/">Pan Geography</a></figcaption> </figure> </div>

- **Vector**: Points, Lines, Polygons, etc. Vector objects are geometries that may have multiple attributes. It is saved in a vector file (e.g., `Shapefile` (.shp), `GeoJSON`, among others).
- **Raster**: represented as a grid of pixels, each pixel contains a value that represents a measurement. Raster data is stored in format like `GeoTIFF` and `NetCDF`.

For **Vector** and **Raster** data, we need utilities that map a pixel's coordinates to a location on Earth. **Coordinate Reference Systems** (CRS) combine an earth model and a projection system, which translates the 3D Earth surface onto a 2D plane. Commonly used CRS include `WGS84`, often used for `GPS` data, and `UTM`, a set of projections that divide the world into a series of 6-degree longitudinal zones. When working with geospatial data, it is crucial to ensure that all datasets have the same CRS to avoid errors when aligning them.

Converting between different CRS is known as "reprojection." Care must be taken during reprojection to maintain data integrity, especially when working with large areas or when precision is crucial. Cooordinate reference systems (CRS) can take you from the geometric coordinates (numbers) to the earth's surface. `GeoPandas` allows us to inspect the CRS and reproject it if necessary.

In `Python`, we can use the [rasterio](https://github.com/rasterio/rasterio) library to read/write raster data and the [geopandas](https://github.com/geopandas/geopandas) library for vector data.

We present examples on how to read & visualize both vector & raster data:

In [None]:
# Use geopandas API to read all countries in the world
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
print(f"Vector data coordinate reference system is — {world.crs}")

# Filter African countries
africa = world[world["continent"] == "Africa"]

# Plot Africa
fig, ax = plt.subplots(figsize=(5, 5))
ax = africa.boundary.plot(ax=ax)
ax.axis("off")
plt.show()

We can use `leafmap` to interactively visualize a very high-resolution image (you can source others from [here](https://openaerialmap.org/)):

In [None]:
# Set the COG image URL
img_url = "https://oin-hotosm.s3.us-east-1.amazonaws.com/65de85797175970001f718f1/0/65de85797175970001f718f2.tif"

In [None]:
# Create a tile server from local raster file
client = TileClient(img_url)

# Create ipyleaflet tile layer from that server
t = get_leaflet_tile_layer(client)

# Create the map
m = leafmap.Map(center=client.center(), zoom=8)
m.add(t)
m

We can also focus on a specific sub-region of the image and plot it:

In [None]:
# Read a 1000x1000 window at the center of the image
with rasterio.open(img_url) as src:
    height, width = src.shape
    hc, wc = int(height / 2), int(width / 2)
    y0, x0 = hc - 1_000, wc - 1_000
    arr = src.read(window=Window(y0, x0, 1_000, 1_000))

# Check the array shape
arr.shape

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(arr.transpose(1, 2, 0))
ax.axis("off")
plt.show()

In [None]:
# Clear the layers
m.clear_layers()

# No need for the previous map
del [m]

---
## 1. Pick a Region & Period of Interest

After briefly going over geospatial data types, let's use `leafmap` to pick a region of interest:

In [None]:
# Initialize the Map object
m = leafmap.Map()

# Output the map object
m

**Note: before running the next cell, zoom to a region of interest and draw a polygon that represents your area of interest (AOI)**.

Let's save the ROI as a `GeoJSON` file:

In [None]:
# Set this variable to false in case you don't want to label!
labeling = False

In [None]:
if labeling:

    # Get the region of interest
    roi = box(*m.user_roi_bounds())

    # Save it as a `GeoJSON` file
    gpd.GeoDataFrame(geometry=[roi], crs="EPSG:4326").to_file("./data/roi.geojson")

else:

    # Download the ROI file
    !wget -nc -P data/ https://raw.githubusercontent.com/GeoAIAfrica/interactive_geospatial_mapping/main/data/roi.geojson

After setting the region of interest, let's set the period of interest:

In [None]:
start_date = "2023-01-01"
end_date = "2024-01-01"
start_date, end_date

---

## 2. Load imagery into the interactive map!

After setting the region and period of interest, we will use [Microsoft Planetary Computer](https://planetarycomputer.microsoft.com/catalog) to do the following:
1. Search for cloud-free Sentinel-2 images that correspond to the ROI/period.
2. Stack the found items on the time dimension and create a `rioxarray` object.
3. Crop the images to the region of interest and calculate the median, finally, save the resulting Mosaic.

In [None]:
# Get the catalog
api_url = "https://planetarycomputer.microsoft.com/api/stac/v1"
catalog = pystac_client.Client.open(api_url, modifier=planetary_computer.sign_inplace)

In [None]:
# Set the search parameters
product = "sentinel-2-l2a"
roi = gpd.read_file("./data/roi.geojson").geometry.values[0]
bbox = roi.bounds
max_cloud_percent = 0
period = f"{start_date}/{end_date}"
print(f"- Bounding box: {bbox}")
print(f"- Satellite product: {product}")
print(f"- Period: {period}")

In [None]:
# Search!
results = catalog.search(
    collections=[product],
    bbox=bbox,
    datetime=period,
    query={
        "eo:cloud_cover": {"lte": max_cloud_percent}
    },  # cloud_cover less than or equal to 0%
)

# Get the items
items = results.get_all_items()
print(f"Items found: {len(items)}.")

We can now stack the items and create a dataset object:

In [None]:
images = stackstac.stack(items, assets=["B02", "B03", "B04", "B08"])
images

Let's create a Mosaic focused on our region of interest and save/visualize it:

In [None]:
# Crop to the region of interest
X = images.rio.clip([roi], crs=f"EPSG:4326", drop=True)

# Create the Mosaic
X = X.median(dim="time", keep_attrs=True)

In [None]:
# Save the ROI image
X.rio.write_nodata(0.0, encoded=True, inplace=True)

In [None]:
X.rio.to_raster("./data/X.tif", compress="lzw", dtype="float32")

In [None]:
# Clear the previous drawing of the region of interest
if labeling:
    m.clear_drawings()

# Add the image to the map
m.add_raster("./data/X.tif", bands=[3, 2, 1], layer_name="X")

---

## 3. Create a few labels!

**Note: go back to the map widget and draw polygons for the object of interest! When you are done, continue from here...**

In [None]:
def contribute(file_path):
    """
    Uploads a file to a predefined URL using a POST request.

    Args:
        file_path (str): The path to the file to be uploaded.

    Returns:
        A requests.Response object containing the server's response to the HTTP request.
    """
    url = "https://labeluploader.azurewebsites.net/upload"
    files = {'file': (file_path, open(file_path, 'rb'), 'application/geo+json')}

    try:
        response = requests.post(url, files=files)
        return response
    except requests.exceptions.RequestException as e:
        print(f"An error occurred: {e}")
        return None

In [None]:
if labeling:

    # Get the geometries from the map
    geoms = [Polygon(e["geometry"]["coordinates"][0]) for e in m.draw_features]

    # Set the name of the object you just labeled!
    object_name = ""
    assert object_name, "Please set the name of the object you labeled!"

    # Get the current time as a `datetime` object
    now = datetime.now()

    # Create a `GeoDataFrame` from the geometries (EPSG:4326), the object name, and the date
    gdf = gpd.GeoDataFrame(geometry=geoms, crs="EPSG:4326")
    gdf["object"] = object_name
    gdf["time"] = now

    # Generate a random string hash for the file name `{hash}.geojson`
    fn_stem = hashlib.md5(f"{object_name}{now.strftime('%Y-%m-%d-%H-%M-%S')}".encode()).hexdigest()
    file_path = Path(f"./data/{fn_stem}.geojson")

    # Save the file
    gdf.to_file(file_path)

    # Contribute your labels to GeoAI Africa
    contribute(file_path)

else:

    # We download the labels
    !wget -nc -P data/ https://raw.githubusercontent.com/GeoAIAfrica/interactive_geospatial_mapping/main/data/labels.geojson

Now, let's create a function that takes the labels' `GeoDataFrame` and exports the mask file:

In [None]:
def create_mask(labels, crs, resolution, mask_path, pos_class=1, neg_class=2):
    """
    Create a raster mask from vector labels.

    Args:
        labels (GeoDataFrame): A GeoDataFrame containing geometries (labels).
        crs (CRS): Coordinate Reference System to use for the output raster.
        resolution (float): The pixel size in the units of the CRS.
        mask_path (str): Path where the raster mask will be saved.
        pos_class (int, optional): Value to assign for positive class (default: 1).
        neg_class (int, optional): Value to assign for negative class (unused, default: 0).

    The function converts the geometries in the GeoDataFrame to a raster mask.
    Each pixel in the mask represents whether it falls inside a geometry (positive class)
    or outside (negative class, typically 0). The output is a single-band GeoTIFF file.
    """

    # Ensure the labels' CRS matches the target CRS
    if str(labels.crs) != str(crs):
        labels = labels.to_crs(crs)

    # Get geometries from the labels and create a list of values for rasterization
    geoms = labels.geometry.tolist()
    vals = [pos_class] * len(geoms)

    # Compute bounds for the output mask and calculate the transform
    minx, miny, maxx, maxy = labels.unary_union.bounds
    mask_transform = from_bounds(
        minx,
        miny,
        maxx,
        maxy,
        width=int((maxx - minx) / resolution),
        height=int((maxy - miny) / resolution),
    )

    # Create metadata for the output raster file
    mask_metadata = {
        "driver": "GTiff",  # File format
        "dtype": "uint8",  # Data type of the raster
        "nodata": None,  # NoData value; None implies no NoData value
        "width": int((maxx - minx) / resolution),  # Raster width in pixels
        "height": int((maxy - miny) / resolution),  # Raster height in pixels
        "count": 1,  # Number of bands in the raster
        "crs": crs,  # Coordinate Reference System
        "transform": mask_transform,  # Affine transformation parameters
        "compress": "lzw",  # Compression algorithm
        "predictor": 2,  # Predictor for compression
    }

    # Write the raster mask to a file
    with rasterio.open(mask_path, "w", **mask_metadata) as out_img:
        # Rasterize the geometries into an array
        mask_arr = rasterize(
            tuple(zip(geoms, vals)),
            out_shape=(mask_metadata["height"], mask_metadata["width"]),
            transform=mask_transform,
            fill=neg_class,  # Fill value for negative class
            default_value=neg_class,  # Default fill value
        )
        # Write the array to the raster band
        out_img.write_band(1, mask_arr)

We can now export a mask that corresponds to the labels that we just created:

In [None]:
# We want the mask to have the same CRS as the image
img_crs = rasterio.open("./data/X.tif").crs

# We know that the resolution is 10 meters (Sentinel-2)
resolution = 10

# We load the labels
labels = gpd.read_file("./data/labels.geojson").to_crs(img_crs)

# We set the path of the mask to export
mask_path = Path("./data/mask.tif")

# Write!
create_mask(labels, img_crs, resolution, mask_path)

In [None]:
# Visualize it alongside the image
m.add_raster("./data/mask.tif", bands=[1], layer_name="y")

---

## Train!

In this section, we will use [`TorchGeo`](https://github.com/microsoft/torchgeo) to train a `UNet` model with a `ResNet18` encoder over the labeled pixels.

We will start by creating a `dataset` class that can index into a raster image:

In [None]:
class SingleRasterDataset(RasterDataset):
    """
    A class representing a single raster dataset, inheriting from RasterDataset.

    This class is designed to handle individual raster datasets by specifying a file name.
    It sets up the dataset by extracting the directory of the given file as its root directory
    and allows for the application of transformations to the dataset.

    Attributes:
        filename_regex (str): The base name of the file specified for the dataset.
                              This attribute is intended for internal use to identify
                              the dataset file.

    Parameters:
        fn (str): The path to the single raster file. This path is used to extract the
                  file name and the directory for initializing the dataset.
        transforms (callable, optional): A function/transform that takes in a sample and returns
                                         a transformed version. These transforms are applied to
                                         the dataset items. Default is None.

    Note:
        The `transforms` parameter allows for preprocessing or data augmentation operations
        to be applied to the dataset. Ensure that any transforms provided are compatible
        with raster data.
    """

    def __init__(self, fn, transforms=None):
        self.filename_regex = os.path.basename(fn)
        # Initialize the base RasterDataset class with the directory of the file and any transforms
        super().__init__(root=os.path.dirname(fn), transforms=transforms)

In [None]:
def preprocess(sample, remove_bbox=True, max_val=7938, bands=[2, 1, 0]):
    """
    Preprocesses a given sample by applying band selection, normalization, and optional removal of bounding boxes.

    This function modifies the input sample in-place by selecting specified bands from the "image" field,
    normalizing the selected image bands by a specified maximum value, converting the image data type to float,
    and optionally removing the bounding box information. If present, the "mask" field is squeezed to remove
    singleton dimensions and converted to long data type.

    Parameters:
        sample (dict): A dictionary representing a sample from a dataset. The sample is expected to contain
                       an "image" key with image data and optionally "mask" and "bbox" keys for the segmentation
                       mask and bounding box information, respectively.
        remove_bbox (bool, optional): A flag to indicate whether bounding box information ("bbox" key) should be
                                      removed from the sample. Defaults to True.
        max_val (int, optional): The maximum value used for normalizing the image data. Defaults to 7938, which
                                 is often used for satellite imagery normalization.
        bands (list of int, optional): The indices of the bands to be selected from the image. Defaults to [2, 1, 0],
                                       typically corresponding to the RGB bands of satellite imagery.

    Returns:
        dict: The preprocessed sample with the image data selected, normalized, and converted to float type,
              the mask (if present) squeezed and converted to long type, and the bounding box information (if present
              and `remove_bbox` is True) removed.

    Note:
        The function modifies the input `sample` dictionary in-place, but also returns the modified dictionary
        for convenience and chaining operations.
    """
    if "image" in sample:
        # Select specified bands and normalize the image
        sample["image"] = sample["image"][bands]
        sample["image"] = (sample["image"] / max_val).float()
    if "mask" in sample:
        # Squeeze the mask to remove singleton dimensions and convert to long data type
        sample["mask"] = sample["mask"].squeeze().long()
    if remove_bbox and "bbox" in sample:
        # Remove the bounding box information if specified
        del sample["bbox"]
    return sample

.. and now we create the data module class, responsible for creating **datalaoders** used for *batch generation*:

In [None]:
class SegmentationDataModule(LightningDataModule):
    """PyTorch Lightning DataModule for a segmentation task."""

    def __init__(
        self,
        img_path: Path,
        mask_path: Path,
        batch_size: int = 64,
        patch_size: int = 256,
        batches_per_epoch: int = 512,
        workers: int = 4,
    ):
        """
        Initialize the SegmentationDataModule.
        Args:
        img_path (Path): The filepath to the input image file.
        mask_path (Path): The filepath to the mask image file.
        batch_size (int, optional): The number of samples per batch during training. Defaults to 64.
        patch_size (int, optional): The size of patches to be extracted from the images. Defaults to 256.
        batches_per_epoch (int, optional): The number of batches per training epoch. Defaults to 512.
        workers (int, optional): The number of worker threads for data loading. Defaults to 4.
        """
        super().__init__()

        # Verify that the image file exists
        if not img_path.exists():
            raise FileNotFoundError("The image file does not exist.")
        if not mask_path.exists():
            raise FileNotFoundError("The mask file does not exist.")

        # Save the path to the input file
        self.X_file = img_path
        self.y_file = mask_path

        # Save the rest of the hyperparameters
        self.batch_size = batch_size
        self.patch_size = patch_size
        self.batches_per_epoch = batches_per_epoch
        self.workers = workers

        self.train_ds = None
        self.val_ds = None

    def setup(self, stage=None):
        """
        Setup method to prepare the datasets for training and validation.
        This method calculates class weights if they have not been calculated and creates the
        datasets if they have not been created.
        Args:
        stage (str, optional): The stage for which the setup is being run ('fit' or 'test'). Defaults to None.
        """

        # Only setup if datasets are not already initialized
        if self.train_ds is None:
            # Create the training dataset
            self.train_img_ds = SingleRasterDataset(self.X_file, transforms=preprocess)
            self.train_mask_ds = SingleRasterDataset(self.y_file, transforms=preprocess)
            self.train_mask_ds.is_image = False
            self.train_ds = self.train_img_ds & self.train_mask_ds

            # Because of lack of labels, we will use the same mask for validation
            self.val_img_ds = SingleRasterDataset(self.X_file, transforms=preprocess)
            self.val_mask_ds = SingleRasterDataset(self.y_file, transforms=preprocess)
            self.val_mask_ds.is_image = False
            self.val_ds = self.val_img_ds & self.val_mask_ds

    def train_dataloader(self):
        """
        Prepare the dataloader for the training dataset.
        Returns:
        DataLoader: Dataloader for the training dataset.
        """
        sampler = RandomBatchGeoSampler(
            self.train_ds,
            size=self.patch_size,
            batch_size=self.batch_size,
            length=self.batches_per_epoch * self.batch_size,
        )
        return DataLoader(
            self.train_ds,
            batch_sampler=sampler,
            num_workers=self.workers,
            collate_fn=stack_samples,
        )

    def val_dataloader(self):
        """
        Prepare the dataloader for the validation dataset.
        Returns:
        DataLoader: Dataloader for the validation dataset.
        """
        sampler = RandomBatchGeoSampler(
            self.val_ds,
            size=self.patch_size,
            batch_size=self.batch_size,
            length=self.batches_per_epoch * self.batch_size,
        )
        return DataLoader(
            self.train_ds,
            batch_sampler=sampler,
            num_workers=self.workers,
            collate_fn=stack_samples,
        )

Let's test the data module class:

In [None]:
store = Path("./data")
img = store / "X.tif"
mask = store / "mask.tif"
assert img.exists()
assert mask.exists()

In [None]:
dm = SegmentationDataModule(
    img_path=img, mask_path=mask, batch_size=8, patch_size=128, workers=0
)
dm.setup()

Let's visualize a few samples to make sure the data loaders are working correctly:

In [None]:
# Create the training data loader
train_dl = dm.train_dataloader()
batch = next(iter(train_dl))
imgs = batch["image"]
masks = batch["mask"]

# Visualize the images and masks side-by-side using matplotlib
fig, axs = plt.subplots(nrows=2, ncols=8, figsize=(32, 8))
for i, (img, mask) in enumerate(zip(imgs, masks)):
    axs[0, i].imshow(img.numpy().transpose(1, 2, 0))
    axs[0, i].axis("off")
    axs[1, i].imshow(mask.numpy())
    axs[1, i].axis("off")
plt.tight_layout()
plt.show()

Let's create the training class addapted from [`TorchGeo`](https://github.com/microsoft/torchgeo/blob/e04e1a53fd6a21506693d53f8a8519dbf4261817/torchgeo/trainers/segmentation.py#L24):

In [None]:
class SemanticSegmentationTask(LightningModule):

    def __init__(self, **kwargs: Any) -> None:
        """
        Adapted from `TorchGeo`.
        """

        # Call the superclass constructor
        super().__init__()

        # Init the loss and model choices
        self.loss = None
        self.model = None
        self.class_weights = None

        # Creates `self.hparams` from kwargs
        self.save_hyperparameters()
        self.hyperparams = cast(Dict[str, Any], self.hparams)

        # Validate the hyperparameters
        self._validate_ignore_index(kwargs["ignore_index"])
        self._validate_model(kwargs["model"])
        self._validate_loss(kwargs["loss"])

        # Get the number of classes
        self.num_classes = self.hyperparams["num_classes"]

        # Set the indices for each class
        self.class2idx = dict()
        for i, class_name in enumerate(self.hyperparams["class_names"]):
            self.class2idx[class_name] = i + 1

        # Set the ignore index
        self.ignore_index = kwargs["ignore_index"]

        # Call the config task method
        self.config_task()

        # Set the color jitter parameters
        color_jitter_params = {
            "brightness": 0.2,
            "contrast": 0.2,
            "saturation": 0.2,
            "hue": 0.1,
            "p": 0.8,
        }

        # If any augmentation is enabled, create the augmentation pipeline
        augmentations = list()
        augmentations.append(K.RandomHorizontalFlip(p=0.5))
        augmentations.append(K.RandomVerticalFlip(p=0.5))
        augmentations.append(K.RandomRotation(degrees=90.0, p=0.5))
        augmentations.append(K.ColorJitter(**color_jitter_params))

        # Create the augmentation function
        self.aug = AugmentationSequential(*augmentations, data_keys=["image", "mask"])

        # Set the metrics of interest
        metrics = {
            "JaccardIndex": MulticlassJaccardIndex(
                num_classes=self.hyperparams["num_classes"],
                average=None,
                ignore_index=self.ignore_index,
            ),
            "Precision": Precision(
                task="multiclass",
                num_classes=self.hyperparams["num_classes"],
                average=None,
                ignore_index=self.ignore_index,
            ),
            "Recall": Recall(
                task="multiclass",
                num_classes=self.hyperparams["num_classes"],
                average=None,
                ignore_index=self.ignore_index,
            ),
            "F1": F1Score(
                task="multiclass",
                num_classes=self.hyperparams["num_classes"],
                average=None,
                ignore_index=self.ignore_index,
            ),
        }
        self.train_metrics = MetricCollection(metrics, prefix="train_")
        self.val_metrics = self.train_metrics.clone(prefix="val_")

        # Intend to save the validation losses
        self.batch_val_losses = []

    def _validate_ignore_index(self, ignore_index):
        if not isinstance(ignore_index, (int, type(None))):
            raise ValueError("ignore_index must be an int or None")
        if (ignore_index is not None) and (self.hyperparams["loss"] == "jaccard"):
            warnings.warn(
                "ignore_index has no effect on training when loss='jaccard'",
                UserWarning,
            )

    def _validate_model(self, model):
        valid_models = [
            "unet",
            "deeplabv3+",
            "unet++",
            "manet",
            "linknet",
            "fpn",
            "pspnet",
            "pan",
            "deeplabv3",
        ]
        if model not in valid_models:
            raise ValueError(
                f"Model type '{model}' is not valid. "
                f"Currently, only supports {valid_models}."
            )

    def _validate_loss(self, loss):
        valid_losses = ["ce", "jaccard", "focal"]
        if loss not in valid_losses:
            raise ValueError(
                f"Loss type '{loss}' is not valid. "
                f"Currently, supports {valid_losses} loss."
            )

    def config_task(self) -> None:
        """Configures the task based on kwargs parameters passed to the constructor."""
        self._init_model()
        self._init_loss()

    def _init_model(self):
        if self.hyperparams["model"] == "unet":
            self.model = smp.Unet(
                encoder_name=self.hyperparams["backbone"],
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "deeplabv3+":
            self.model = smp.DeepLabV3Plus(
                encoder_name=self.hyperparams["backbone"],
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "unet++":
            self.model = smp.UnetPlusPlus(
                encoder_name=self.hyperparams["backbone"],
                encoder_depth=5,
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "manet":
            self.model = smp.MAnet(
                encoder_name=self.hyperparams["backbone"],
                encoder_depth=5,
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "linknet":
            self.model = smp.Linknet(
                encoder_name=self.hyperparams["backbone"],
                encoder_depth=5,
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "fpn":
            self.model = smp.FPN(
                encoder_name=self.hyperparams["backbone"],
                encoder_depth=5,
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "pspnet":
            self.model = smp.PSPNet(
                encoder_name=self.hyperparams["backbone"],
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "pan":
            self.model = smp.PAN(
                encoder_name=self.hyperparams["backbone"],
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )
        elif self.hyperparams["model"] == "deeplabv3":
            self.model = smp.DeepLabV3(
                encoder_name=self.hyperparams["backbone"],
                encoder_depth=5,
                encoder_weights=self.hyperparams["weights"],
                in_channels=self.hyperparams["in_channels"],
                classes=self.hyperparams["num_classes"],
            )

    def _init_loss(self):
        if self.hyperparams["loss"] == "ce":
            if self.class_weights is not None:
                self.class_weights = self.class_weights.to(self.device)
            self.loss = nn.CrossEntropyLoss(
                weight=self.class_weights,
                ignore_index=-1000 if self.ignore_index is None else self.ignore_index,
            )
        elif self.hyperparams["loss"] == "jaccard":
            self.loss = smp.losses.JaccardLoss(
                mode="multiclass", classes=self.hyperparams["num_classes"]
            )
        elif self.hyperparams["loss"] == "focal":
            self.loss = smp.losses.FocalLoss(
                "multiclass", ignore_index=self.ignore_index, normalized=True
            )

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        """Forward pass of the model."""
        return self.model(*args, **kwargs)

    def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
        """Compute and return the training loss."""
        batch = self.aug(args[0])  # if self.do_augment else args[0]
        x = batch["image"]
        y = batch["mask"]
        y_hat = self(x)

        # Report loss
        y_hat_hard = y_hat.argmax(dim=1)
        loss = self.loss(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False)

        # Report metrics
        self.train_metrics(y_hat_hard, y)
        return cast(Tensor, loss)

    def validation_step(self, *args: Any, **kwargs: Any) -> None:
        """Compute validation loss and log example predictions."""
        batch = args[0]
        x = batch["image"]
        y = batch["mask"]
        y_hat = self(x)

        # Report loss
        y_hat_hard = y_hat.argmax(dim=1)
        loss = self.loss(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.batch_val_losses.append(loss)

        # Report metrics
        self.val_metrics(y_hat_hard, y)

    def on_train_epoch_end(self) -> None:
        """Logs epoch level training metrics."""
        train_metrics = self.train_metrics.compute()
        new_metrics = dict()
        for k in train_metrics.keys():
            for category, cat_idx in self.class2idx.items():
                new_metrics[f"{k}_{category}"] = train_metrics[k][cat_idx]
        self.log_dict(new_metrics)
        self.train_metrics.reset()

    def on_validation_epoch_end(self) -> None:
        """Logs epoch level validation metrics."""

        # Calculate the rest of the metrics
        val_metrics = self.val_metrics.compute()
        new_metrics = dict()
        for k in val_metrics.keys():
            for category, cat_idx in self.class2idx.items():
                new_metrics[f"{k}_{category}"] = val_metrics[k][cat_idx]

        # Estimate the validation loss
        val_batch_losses = self.batch_val_losses
        val_loss = torch.nanmean(torch.stack(val_batch_losses))
        new_metrics["val_loss"] = val_loss

        self.log_dict(new_metrics)
        self.val_metrics.reset()
        self.batch_val_losses = list()

    def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
        """Compute and return the predictions."""
        batch = args[0]
        x = batch["image"]
        y_hat: Tensor = self(x).softmax(dim=1)
        return y_hat

    def configure_optimizers(self) -> Dict[str, Any]:
        """
        Configure the optimizer and learning rate scheduler based on the hyperparameters.
        Returns:
            A dictionary containing the optimizer and learning rate scheduler.
        """

        # Retrieve the optimizer name, learning rate, and weight decay from the hyperparameters
        optimizer_name = self.hyperparams["optimizer_name"]
        learning_rate = self.hyperparams["learning_rate"]
        weight_decay = self.hyperparams["weight_decay"]

        # Select the optimizer based on the specified name
        if optimizer_name == "SGD":
            optimizer = torch.optim.SGD(
                self.model.parameters(), lr=learning_rate, weight_decay=weight_decay
            )
        elif optimizer_name == "Adam":
            optimizer = torch.optim.Adam(
                self.model.parameters(), lr=learning_rate, weight_decay=weight_decay
            )
        elif optimizer_name == "RMSProp":
            optimizer = torch.optim.RMSprop(
                self.model.parameters(), lr=learning_rate, weight_decay=weight_decay
            )
        elif optimizer_name == "AdamW":
            optimizer = torch.optim.AdamW(
                self.model.parameters(), lr=learning_rate, weight_decay=weight_decay
            )
        else:
            raise ValueError(f"Unknown optimizer: {optimizer_name}")

        # Retrieve the scheduler name from the hyperparameters
        scheduler_name = self.hyperparams["scheduler_name"]

        # Select the learning rate scheduler based on the specified name
        if scheduler_name == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(
                optimizer, patience=self.hyperparams["learning_rate_schedule_patience"]
            )
        elif scheduler_name == "CosineAnnealingLR":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.hyperparams["T_max"]
            )
        elif scheduler_name == "StepLR":
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=self.hyperparams["step_size"],
                gamma=self.hyperparams["gamma"],
            )
        else:
            raise ValueError(f"Unknown scheduler: {scheduler_name}")

        # Return the optimizer and learning rate scheduler
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

We will fix a set of hyperparameters:

In [None]:
# Define experiment setup and hyperparameters

# Check if there is a GPU
is_gpu = torch.cuda.is_available()

# Name of the experiment for tracking
experiment_name = "urban_mappings"

# Loss function to be used (categorical crossentropy in this case)
loss = "ce"

# Size of batches for training
batch_size = 128  # 256

# Size of the patches to be extracted from the images
patch_size = 128

# Number of batches to process in an epoch
batches_per_epoch = 32

# Number of worker processes for loading data
workers = 4 if is_gpu else 0

# Number of epochs with no improvement after which training will be stopped
early_stopping_patience = 10

# Minimum/Maximum number of epochs to run before early stopping
min_epochs = 1  # 5
max_epochs = 10  # 50

# Number of classes in the dataset
num_classes = 3

# Architecture/Encoder of the model
arch = "unet"
backbone = "mobilenet_v2"

# Pre-trained weights to initialize the backbone (Imagenet weights)
weights = "imagenet"

# Initial learning rate for training
learning_rate = 0.0001

# Number of epochs with no improvement on validation loss after which learning rate will be reduced
lr_schedule_patience = 10

# Weight decay (L2 penalty) for regularization
weight_decay = 1e-2

# Scheduler for adjusting learning rate
scheduler_name = "ReduceLROnPlateau"

# Optimizer for training
optimizer_name = "AdamW"

# Index to be ignored in loss computation, useful for masked areas in segmentation tasks
ignore_index = 0

# Paths to the input image and mask files, ensuring both exist
img_path = store / "X.tif"
mask_path = store / "mask.tif"
assert img_path.exists() and mask_path.exists()  # Ensure both paths exist

We can train!

In [None]:
# Set the experiment name and directory
results_dir = Path("./results")
results_dir.mkdir(exist_ok=True)
experiment_dir = results_dir / experiment_name
experiment_dir

In [None]:
# Create the data module
dm = SegmentationDataModule(
    img_path=img_path,
    mask_path=mask_path,
    batches_per_epoch=batches_per_epoch,
    batch_size=batch_size,
    patch_size=patch_size,
    workers=workers,
)

In [None]:
# Create the task
task = SemanticSegmentationTask(
    model=arch,
    backbone=backbone,
    weights=weights,
    in_channels=3,
    num_classes=num_classes,
    ignore_index=ignore_index,
    learning_rate=learning_rate,
    learning_rate_schedule_patience=lr_schedule_patience,
    loss=loss,
    weight_decay=weight_decay,
    scheduler_name=scheduler_name,
    optimizer_name=optimizer_name,
    class_names=["urban", "background"],
)

In [None]:
# Create the checkpoint and early stopping callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_dir,
    save_top_k=1,
    mode="min",
)

# Create the early stopping callback
early_stop_callback = EarlyStopping(
    monitor="val_loss", patience=early_stopping_patience, verbose=True, mode="min"
)

# Create the TensorBoard logger
tb_logger = TensorBoardLogger(save_dir="logs/", name=experiment_name)

In [None]:
# Trainer definition
trainer = pl.Trainer(
    logger=[tb_logger],
    max_epochs=max_epochs,
    min_epochs=min_epochs,
    callbacks=[checkpoint_callback, early_stop_callback],
    precision=16,
    accelerator="gpu" if is_gpu else "cpu",
    devices=[0] if is_gpu else "auto",
)
trainer.fit(model=task, datamodule=dm)

Let's report the local Jaccard, F1, recall, and precision metrics:

In [None]:
report = dict()
class_of_interest = "urban"
for metric in ["F1", "JaccardIndex", "Recall", "Precision"]:
    report[metric] = round(
        float(trainer.logged_metrics[f"val_{metric}_{class_of_interest}"]), 3
    )
report

---

## Inference

In this section, we will use our trained model to predict over the original image, vectorize, postprocess, and save our predictions:

In [None]:
def predict(
    img_path,
    model_path,
    out_fp,
    patch_size=128,
    batch_size=1,
    padding=1,
    num_workers=4,
    gpu=0,
):
    """Segments a given image using a fine-tuned model and saves the output.

    This function performs semantic segmentation on an input image using a specified model checkpoint. It processes
    the image in patches, predicts each patch's segmentation mask, and stitches them together to form a complete output
    mask, which is then saved to a specified output file path.

    Parameters:
        img_path (str): Path to the input image file.
        model_path (str): Path to the pre-trained model checkpoint.
        out_fp (str): Output file path where the prediction result will be saved.
        patch_size (int, optional): Size of the patches to be extracted from the images. Defaults to 128.
        batch_size (int, optional): Number of samples per batch during prediction. Defaults to 1.
        padding (int, optional): Number of pixels to pad the patches. Defaults to 1.
        num_workers (int, optional): Number of worker processes for data loading. Defaults to 4.
        gpu (int, optional): The ID of the GPU to use for prediction. Defaults to 0.

    Process:
        1. Sets up device for prediction based on available GPU.
        2. Loads the model from the checkpoint and prepares it for evaluation.
        3. Creates a dataset and dataloader for the input image, dividing it into patches for efficient processing.
        4. Iterates over the image patches, performs prediction, and accumulates the results.
        5. Saves the aggregated predictions as a geospatial raster image, preserving the input image's spatial reference.

    Note:
        - The function assumes the use of a `SemanticSegmentationTask` model architecture for prediction.
        - The input image is processed in patches to manage memory usage and adapt to different input sizes.
        - The output raster will have a single band, float32 data type, and will be compressed using LZW compression.
        - Padding is used to reduce edge effects in patch-based prediction, and is removed in the final output.
        - The function uses a deterministic sampling strategy (GridGeoSampler) to ensure complete coverage of the input image.
    """

    # Set the stride
    stride = patch_size - padding * 2

    # Set the device
    device = torch.device(
        f"cuda:{gpu}" if (gpu is not None) and torch.cuda.is_available() else "cpu"
    )

    # Load task and data
    task = SemanticSegmentationTask.load_from_checkpoint(model_path)
    task.freeze()
    model = task.model
    model = model.eval().to(device)

    # Create a dataset object from a single image file
    val_ds = SingleRasterDataset(
        img_path, transforms=partial(preprocess, remove_bbox=False)
    )

    # Create the sampler (not random because we want to predict the whole image deterministically)
    sampler = GridGeoSampler(val_ds, size=patch_size, stride=stride)

    # Create the dataloader
    val_dl = DataLoader(
        val_ds,
        sampler=sampler,
        batch_size=batch_size,
        num_workers=num_workers,
        collate_fn=stack_samples,
    )

    # Open the input file
    with rasterio.open(img_path) as f:
        input_height, input_width = f.shape
        profile = f.profile
        transform = profile["transform"]

    # Initialize the output numpy array to zeros
    output = np.zeros((input_height, input_width), dtype=np.float16)

    # Create the enumerated to iterate over the TIF patches in batches
    dl_enumerator = tqdm(val_dl)

    # Iterate over the image batches are predict
    for batch in dl_enumerator:
        # Get the images and their bounding boxes
        images = batch["image"].to(device)
        bboxes = batch["bbox"]

        with torch.inference_mode():
            # Predict over all the images
            y_hat = model(images)

            # Get the predicted probabilities for the boma class
            y_hat_boma = y_hat.softmax(dim=1)[:, 1, ...].cpu().numpy()

        for i in range(len(bboxes)):
            bb = bboxes[i]

            left, top = ~transform * (bb.minx, bb.maxy)
            right, bottom = ~transform * (bb.maxx, bb.miny)
            left, right, top, bottom = (
                int(np.round(left)),
                int(np.round(right)),
                int(np.round(top)),
                int(np.round(bottom)),
            )

            assert right - left == patch_size
            assert bottom - top == patch_size

            output[
                top + padding : bottom - padding, left + padding : right - padding
            ] = y_hat_boma[i][padding:-padding, padding:-padding]

        # Save predictions
        profile["driver"] = "GTiff"
        profile["count"] = 1
        profile["dtype"] = "float32"
        profile["compress"] = "lzw"
        profile["predictor"] = 2
        profile["nodata"] = 0
        profile["blockxsize"] = 512
        profile["blockysize"] = 512
        profile["tiled"] = True
        profile["interleave"] = "pixel"

        # Save the file
        with rasterio.open(out_fp, "w", **profile) as f:
            f.write(output, 1)

In [None]:
# Get the best model checkpoint
best_model_path = checkpoint_callback.best_model_path

# Set the input/output paths
img_path = Path("./data/X.tif")
out_fp = Path("./data/y_hat.tif")

# Run the main function
predict(img_path, best_model_path, out_fp)

Let's vectorize our predictions and export object-wise metrics:

In [None]:
def raster_to_gdf(tif_path):
    """
    Converts a raster file to a GeoDataFrame by polygonizing it.

    This function takes the path to a raster (.tif) file, uses GDAL's gdal_polygonize utility to convert raster pixels
    into polygons, and then loads these polygons into a GeoDataFrame. It is particularly useful for converting rasterized
    masks or classifications into vector data for further geospatial analysis or visualization. The function specifically
    filters for polygons corresponding to the class of interest (with DN value of 1), removes unnecessary columns, and
    returns a clean GeoDataFrame.

    Parameters:
        tif_path (str): Path to the input raster (.tif) file.

    Returns:
        gpd.GeoDataFrame: A GeoDataFrame containing polygons for the specified class of interest from the raster. If the
        raster conversion or loading fails, or if a timeout occurs, an empty GeoDataFrame is returned.
    """

    # Ensure GDAL's gdal_polygonize.py is available
    if not shutil.which("gdal_polygonize.py"):
        raise EnvironmentError(
            "gdal_polygonize.py is not available in the system path."
        )

    # Generate a random output shapefile name
    output_shapefile = f"output_{randint(0, 1e6)}.shp"

    # Construct the command
    cmd = ["gdal_polygonize.py", tif_path, "-f", "ESRI Shapefile", output_shapefile]

    try:
        # Run the command with a timeout of 1 minute
        subprocess.run(cmd, check=True, timeout=60)

        # Load the shapefile into a GeoDataFrame
        if not os.path.exists(output_shapefile):
            raise FileNotFoundError(f"Output shapefile not found: {output_shapefile}")
        gdf = gpd.read_file(output_shapefile)

        # Clean up shapefile components
        for ext in [".shp", ".shx", ".dbf", ".prj"]:
            os.remove(output_shapefile.replace(".shp", ext))

        # Filter for the class of interest & return
        return gdf[gdf["DN"] == 1].drop("DN", axis=1)

    except subprocess.TimeoutExpired:

        # Clean up shapefile components
        for ext in [".shp", ".shx", ".dbf", ".prj"]:
            os.remove(output_shapefile.replace(".shp", ext))

        # Return an empty GeoDataFrame in case of a timeout
        return gpd.GeoDataFrame()

    except Exception as e:
        # Handle other potential exceptions
        raise e

In [None]:
# Vectorize the predicted masks
gdf = raster_to_gdf(out_fp)

In [None]:
# Visualize the vectorized predictions
ax = gdf.plot(color="red")
_ = ax.axis("off")
plt.show()

Now, we postprocess the predictions in the following manner:

1. Removes geometries that have a very small area.
2. Simplifies the geometries.
3. Simplifies the geometries by Dilation + Erosion.
4. Fills any holes in the geometries.

In [None]:
# Filter valid geometries
gdf = gdf[gdf.geometry.is_valid]

# Filter small geometries
q01_area = gdf.geometry.area.quantile(0.01)
gdf = gdf[gdf.geometry.area > q01_area]

# Simplify the geometries
gdf["geometry"] = gdf.geometry.simplify(10)

# Dilate and erode the geometries
buffer = 10
gdf["geometry"] = gdf.geometry.buffer(buffer)  # Dilation
gdf["geometry"] = gdf.geometry.buffer(-buffer)  # Erosion


def fillit(row):
    """A function to fill holes below an area threshold in a polygon"""
    newgeom = None
    rings = [i for i in row["geometry"].interiors]
    if len(rings) > 0:  # If there are any rings
        to_fill = [Polygon(ring) for ring in rings]
        if len(to_fill) > 0:
            newgeom = reduce(
                lambda geom1, geom2: geom1.union(geom2), [row["geometry"]] + to_fill
            )  # Union the original geometry with all holes
    if newgeom:
        return newgeom
    else:
        return row["geometry"]


# Apply the function
gdf["geometry"] = gdf.apply(fillit, axis=1)

In [None]:
# Visualize the final predictions
ax = gdf.plot(color="red")
_ = ax.axis("off")
plt.show()

Let's save our predictions and add them to the map!

In [None]:
# Save the predictions
gdf.to_file("./data/predictions.geojson")

In [None]:
# Add to the map
m.add_vector("./data/predictions.geojson", layer_name="Predictions")

---

## Explore!

While this tutorial covered essential techniques and approaches for end-to-end mapping, the field is vast, and there's much more to explore. Here are some directions you might consider to enhance your models' training and inference capabilities further:

- **Different Architectures & Encoders**: Consider experimenting with different neural network architectures and encoders to find the optimal combination for your task. Select other architectures/encoders from [here](https://kornia.readthedocs.io/en/latest/augmentation.html).
- **Advanced Data Augmentation Techniques**: Data augmentation is a powerful strategy to increase your training dataset, leading to more robust models. Investigate other augmentation techniques that could simulate more varied conditions or introduce more complex transformations. Look [here](https://kornia.readthedocs.io/en/latest/augmentation.html).
- **Semi-Supervised Learning Techniques**: Semi-supervised learning can be particularly beneficial in scenarios where labeled data is scarce but unlabeled data is abundant. Explore how incorporating semi-supervised learning techniques can leverage unlabeled data to improve your model's performance.
- **Fuse Data from Different Sources**: Combining data from different sensors, such as `Sentinel-1` (radar) and `Sentinel-2` (optical), can provide complementary information that enhances model understanding and performance.
- **Scale to Multi-Class Categorization**: If your current model focuses on binary classification or a limited number of classes, consider expanding its capabilities to multi-class categorization. This expansion can increase the model's applicability and challenge it to capture more complex patterns in the data.

---

## Conclusion

Throughout this notebook, we have ventured into the intricate domain of geospatial machine learning, highlighting the challenge of initiating machine learning-assisted geospatial mapping in the absence of labels, a task notably demanding due to the extensive steps involved. These steps range from acquiring satellite imagery, preparing and processing this imagery, establishing a labeling environment and budget, preparing a computational environment, to training the model. Such comprehensive requirements significantly hinder the application of deep learning models, particularly in regions with the most need, like Africa.

We demonstrated the use of **open-source software**, **publicly available satellite imagery**, and **free computational resources** to conduct end-to-end mapping of a region of interest. This approach is crucial in low-resource settings, underscoring our objective to demonstrate an end-to-end workflow for mapping objects in satellite imagery utilizing publicly accessible resources at minimal costs. The ultimate goal is to empower geospatial machine learning applications across the continent, offering a beacon of hope for regions that stand to benefit the most from these advancements.

As we conclude, it's important to remember that our journey through geospatial machine learning is just beginning. The field is ripe with opportunities for further exploration and innovation, promising to bring significant contributions to the world. Let's continue to learn, explore, and contribute to making a meaningful impact through geospatial machine learning. Happy learning!

---

## Resources

### Tutorials

- [Geospatial Primer](https://github.com/Akramz/geospatial-primer).
- [Deep Learning Indaba Geospatial Tutorial](https://github.com/deep-learning-indaba/indaba-pracs-2023).
- [Introduction to Geospatial Data](https://colab.research.google.com/drive/1-85h5tEB0AJYT8xQ5H1wtSnCafXuLTHo#scrollTo=JDT5jUmCiTH-).
- [Geospatial Data Analysis](https://colab.research.google.com/drive/1Yfkm63OV3eCtR3IVB-4owi2DJgj2Wd84).
- [Geospatial Deep learning: Getting started with TorchGeo](https://pytorch.org/blog/geospatial-deep-learning-with-torchgeo/).
- [Automating GIS-processes Course]((https://autogis-site.readthedocs.io/en/latest/))
- [Geospatial Data with Python: Shapely and Fiona](https://macwright.com/2012/10/31/gis-with-python-shapely-fiona.html)
- [Introduction to Raster Data Processing in Open Source Python](https://www.earthdatascience.org/courses/use-data-open-source-python/intro-raster-data-python/raster-data-processing/).
- [XArray fundamental](https://rabernat.github.io/research_computing_2018/xarray.html).
- [XArray tutorials](https://github.com/xarray-contrib/xarray-tutorial).
- [Visualization: contextily tutorial](https://geopandas.org/en/stable/gallery/plotting_basemap_background.html).


### Libraries

- [Shapely](https://github.com/shapely/shapely).
- [GeoPandas](https://github.com/geopandas/geopandas).
- [Contextily](https://github.com/geopandas/contextily).
- [Rasterio](https://github.com/rasterio/rasterio).
- [Xarray](https://github.com/pydata/xarray).
- [RioXarray](https://github.com/corteva/rioxarray).
- [TorchGeo](https://github.com/microsoft/torchgeo).

---