# Road segmentation

In this notebook, we'll be segmenting road networks from aerial imagery.

In [None]:
import geocube.api.core
import geopandas as gpd
import matplotlib.pyplot as plt
import rioxarray
import xarray as xr

## Data preprocessing

### Get image data from OpenAerialMap

- OpenAerialMap images over Vanuatu - https://map.openaerialmap.org/#/168.3819580078125,-16.688816956180833,7

We'll be using a Maxar Worldview-2 image with a spatial resolution of 32cm over Port Vila on 2024 October 19.

- Preview at https://map.openaerialmap.org/#/168.31419467926025,-17.73086527059167,14/square/311123113/676737799f511a0001cc98c1

In [None]:
image_url = "https://oin-hotosm-temp.s3.us-east-1.amazonaws.com/676733089f511a0001cc98b6/0/676733089f511a0001cc98b7.tif"

The RGB images are distributed in a Cloud-optimized GeoTIFF (COG) format.
We'll follow https://corteva.github.io/rioxarray/stable/examples/COG.html to open the file in Python.

**Note**: We set `overview_level=2` to get a lower resolution image.
The spatial resolutions at different overview levels are:
- Level -1 (native): 0.32 meters
- Level 0: 0.64 meters
- Level 1: 1.28 meters
- Level 2: 2.57 meters

In [None]:
rda = rioxarray.open_rasterio(filename=image_url, overview_level=2)
rda

In [None]:
# Check spatial resolution in meters
rda.rio.resolution()

In [None]:
# Check coordinate reference system
rda.rio.crs.to_string()

In [None]:
# Check bounding box extent
bbox = rda.rio.bounds()
bbox

The image in an `xarray.DataArray` can be plotted using `.plot.imshow(rgb="band")`

In [None]:
rda.plot.imshow(rgb="band")

There are some black NoData/NaN areas, let's crop them out using
[`.rio.clip_box`](https://corteva.github.io/rioxarray/stable/examples/clip_box.html#Clip-using-a-bounding-box).

In [None]:
rda_portvila = rda.rio.clip_box(minx=18732000, miny=-2012000, maxx=18742000, maxy=-2002000)

In [None]:
rda_portvila.plot.imshow(rgb="band")

In [None]:
bbox_portvila = rda_portvila.rio.bounds()
bbox_portvila

### Load road linestrings from shapefile

Read from zipfile containing "Roads_Vanuatu_Cleaned_UNOSAT.shp"
into a [geopandas.GeoDataFrame](https://geopandas.org/en/v1.0.1/docs/reference/geodataframe.html).

In [None]:
gdf_roads = gpd.read_file(filename="Roads_VUT.zip")
gdf_roads.head()

### Reproject vector roads to match aerial image

The vector road shapefile are in EPSG:4326,
and we will need to reproject it to EPSG:3857 to match the RGB image.

In [None]:
gdf_roads_3857 = gdf_roads.to_crs(crs="EPSG:3857")

Next, we'll also clip the roads to the bounding box extent of the RGB image.

In [None]:
gdf_roads_portvila = gdf_roads_3857.clip(mask=bbox_portvila)

Plot the clipped vector roads using
[`.plot()`](https://geopandas.org/en/v1.0.1/docs/reference/api/geopandas.GeoDataFrame.plot.html)

In [None]:
gdf_roads_portvila.plot()

### Rasterize road lines

The vector road lines need to be converted into a raster format for the machine learning model.
We'll first buffer the road lines to become polygons, and then rasterize them using
[`geocube.api.core.make_geocube`](https://corteva.github.io/geocube/stable/geocube.html#make-geocube).

In [None]:
# Assume all roads are 12 meters in width
gdf_roads_portvila.geometry = gdf_roads_portvila.buffer(distance=12)

In [None]:
rds_roads = geocube.api.core.make_geocube(
    vector_data=gdf_roads_portvila,
    like=rda_portvila,
    measurements=["FID_Road_w"],
)

In [None]:
# Convert to binary where 0=no_roads, 1=roads
rda_roads = rds_roads.FID_Road_w.notnull()

In [None]:
rda_roads.plot.imshow()

### Stack RGB image and road mask together

We now have an RGB aerial image and rasterized Road map,
both in an `xarray.DataArray` format with the same
spatial resolution and bounding box spatial extent.
Let's stack them together using
[`xarray.merge`](https://docs.xarray.dev/en/v2025.03.1/generated/xarray.merge.html).

In [None]:
ds_image_and_mask = xr.merge(
    objects=[rda_portvila.rename("image"), rda_roads.rename("mask")],
    join="override",
)
ds_image_and_mask

Double check to see that resulting xarray.Datasetâ€™s image and mask looks ok.

In [None]:
# Create subplot with RGB image on the left and Road mask on the right
fig, axs = plt.subplots(ncols=2, figsize=(11.5, 4.5), sharey=True)
ds_image_and_mask.image.plot.imshow(ax=axs[0], rgb="band")
axs[0].set_title("Maxar RGB image")
ds_image_and_mask.mask.plot.imshow(ax=axs[1], cmap="Blues")
axs[1].set_title("Road mask")
plt.show()