# How to Segment Buildings on Drone Imagery with Fast.ai & Cloud-Native GeoData Tools

## An Interactive Intro to Geospatial Deep Learning on Google Colab 

**by [@daveluo](https://github.com/daveluo)**

In this Google Colab notebook and accompanying [Medium post](https://medium.com/@anthropoco/how-to-segment-buildings-on-drone-imagery-with-fast-ai-cloud-native-geodata-tools-ae249612c321?source=friends_link&sk=57b82002ac47724ecf9a2aaa98de994b), we will learn all the code and concepts comprising a complete workflow to automatically detect and delineate building footprints (instance segmentation) from drone imagery with cutting edge deep learning models. 

All you'll need is a Google account, an internet connection, and a couple of hours to learn how to make the data & model that learns to make something like [this](https://alpha.anthropo.co/znz-demo):

![zanzibar demo](https://www.dropbox.com/s/vv6ebxkd3x7xxfy/zanzibar_banner.png?dl=1)

## In modular steps, we'll learn to…

### Preprocess image geoTIFFs and manually labeled data geoJSON files into training data for deep learning:

![geotiff and geojson](https://cdn-images-1.medium.com/max/1200/1*Myn-7f-tLhaMRaaNYg1oHw.png)

### Create a U-net segmentation model to predict what pixels in an image represent buildings (and building-related features):

![segmentation pred vs actual](https://cdn-images-1.medium.com/max/1200/1*4qsYToRH8Q-riSFtxuNWIg.png)

### Test our model's performance on unseen imagery with GPU or CPU:

![gpu inference](https://cdn-images-1.medium.com/max/900/1*ag6ERcdl-K1-Dj6ddyrBMA.png)
![cpu inference](https://cdn-images-1.medium.com/max/900/1*kNulsQQdDJAD5xN9Q8qa_A.png)

### Post-process raw model outputs into geo-registered building shapes evaluated against ground truth:

![eval](https://cdn-images-1.medium.com/max/1200/0*D47afqJ_7l54o-s3)

### And along the way, we'll get familiar with great geospatial data & deep learning tools/resources like:

- [Geopandas](http://geopandas.org/): "an open source project to make working with geospatial data in python easier. GeoPandas extends the datatypes used by pandas to allow spatial operations on geometric types."
- [Rasterio](https://github.com/mapbox/rasterio): "reads and writes geospatial raster datasets"
- [Supermercado](https://github.com/mapbox/supermercado): "supercharger for Mercantile" (spherical mercator tile and coordinate utilities)
- [Rio-tiler](https://github.com/cogeotiff/rio-tiler): "Rasterio plugin to read mercator tiles from Cloud Optimized GeoTIFF dataset"
- [Solaris](https://github.com/CosmiQ/solaris): "Geospatial Machine Learning Analysis Toolkit" by Cosmiq Works
- [Cloud-Optimized GeoTIFFs (COG)](https://www.cogeo.org/): "An imagery format for cloud-native geospatial processing"
- [Spatio-Temporal Asset Catalogs (STAC)](https://stacspec.org/): "Enabling online search and discovery of geospatial assets"
- [OpenAerialMap](https://openaerialmap.org/): "The open collection of aerial imagery"
- [Fast.ai](http://fast.ai/) for [geospatial deep learning](https://forums.fast.ai/t/geospatial-deep-learning-resources-study-group/31044): "The fastai library simplifies training fast and accurate neural nets using modern best practices" built on the [PyTorch](https://pytorch.org/) deep learning platform.

### How to get the most out of this tutorial:

This Colab notebook is our main learning resource - working interactively here is highly recommended!

Code is organized into modular sections, set up for installation/import of all required dependencies, and executable on either CPU or GPU runtimes (depending on the section). Links to load files generated at each step are also included so you can pick up and start from any section. Inline# comments (& references for further reading) are provided within code cells to explain steps or nuances in more detail as needed. Executing all code cells end-to-end takes <1 hour on GPU.

The Medium post serves as a high-level conceptual walkthrough and maps directly to sections within the Colab notebook. The post works best as a quick overview with handy bookmarks to Colab or viewed side-by-side with this Colab notebook as a code & concept companion set.

This tutorial assumes you have a working knowledge of Python, data analysis with Pandas, making training/validation/test sets for machine learning, and a beginner practitioner's grasp of deep learning concepts. Or the motivation to gain what knowledge you're missing by following the ample references linked throughout this post and notebook.

### With that as mental prep, let's do some geospatial deep learning!

# Pre-Processing

**Note that the preprocessing section is possible to be done on CPU runtime:**

Change in menu: Runtime > Change runtime type > Hardware Accelerator = None

## Install all the geo things

`Pip install` the required geodata processing packages we'll be using of, test that their import to Colab works, and create our output data directories.

In [0]:
!add-apt-repository ppa:ubuntugis/ppa -y
!apt-get update
!apt-get install python-numpy gdal-bin libgdal-dev
!apt install python3-rtree

!pip install rasterio
!pip install geopandas
!pip install descartes
!pip install solaris
!pip install rio-tiler

 Official stable UbuntuGIS packages.


 More info: https://launchpad.net/~ubuntugis/+archive/ubuntu/ppa
Press [ENTER] to continue or Ctrl-c to cancel adding it.


In [0]:
# for bleeding edge version of solaris:
# !pip install git+https://github.com/CosmiQ/solaris/@dev

In [0]:
import solaris as sol
import numpy as np
import geopandas as gpd
from matplotlib import pyplot as plt
from pathlib import Path
import rasterio
import os

data_dir = Path('data')
data_dir.mkdir(exist_ok=True)

img_path = data_dir/'images-256'
mask_path = data_dir/'masks-256'
img_path.mkdir(exist_ok=True)
mask_path.mkdir(exist_ok=True)

## Preview and load imagery and labels

For this tutorial, we'll use the [Tanzania Open AI Challenge dataset](https://competitions.codalab.org/competitions/20100#learn_the_details) of 7-cm resolution drone imagery and building footprint labels over Unguja Island, Zanzibar. 

Much thanks to the following organizations for producing, openly licensing, and making this invaluable dataset accessible:

- [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) licensed by Commission for Lands (COLA) - Revolutionary Government of Zanzibar (RGoZ)
- Labeled data produced & processed by State University of Zanzibar (SUZA), [World Bank OpenDRI](https://opendri.org/project/zanzibar/), [WeRobotics](https://werobotics.org/)
- Drone imagery created by [Zanzibar Mapping Initiative](http://www.zanzibarmapping.com/)  and hosted on [OpenAerialMap](https://map.openaerialmap.org/#/39.40040588378906,-5.980094945523311,10/square/3001121111?_k=34xcng):

![](https://cdn-images-1.medium.com/max/1200/1*-kSPTKsU5vF2c9NEKjOqNQ.png)

For simplicity of demonstration, we'll create training and validation data from a single drone image (in cloud-optimized geoTIFF format) and its accompanying ground-truth labels of manually traced building outlines (in GeoJSON format).

We'll work with imagery and labels from image grid `znz001` which covers the northern tip of Zanzibar's main island of Unguja. Here is a [browsable preview](https://geoml-samples.netlify.com/item/9Eiufow7wPXLqQEP1Di2J5X8kXkBLgMsCBoN37VrtRPB/2sEaEKnnyjG2mx7CnN1ESAdjYAEQjoNRxSxTjc4vPGR?si=0&t=preview#15/-5.732621/39.301114) of the drone imagery with its building footprint labels, organized per the Spatio-Temporal Asset Catalog ([STAC](https://github.com/radiantearth/stac-spec/)) [label extension](https://github.com/radiantearth/stac-spec/tree/dev/extensions/label) and visualized in an instance of [STAC browser](https://github.com/radiantearth/stac-browser):

![znz001-preview](https://www.dropbox.com/s/lfrzpgukk922jtr/stacpreview-znz001.png?dl=1)

After previewing the labeled data and imagery, let's import our geo-processing tools, let's copy the direct download URLs from the Assets tab of the browser and test loading them.

In [0]:
tif_url = 'http://oin-hotosm.s3.amazonaws.com/5afeda152b6a08001185f11a/0/5afeda152b6a08001185f11b.tif'
geojson_url = 'https://www.dropbox.com/sh/ct3s1x2a846x3yl/AAARCAOqhcRdoU7ULOb9GJl9a/grid_001.geojson?dl=1'

In [0]:
rasterio.open(tif_url).meta

In [0]:
# TODO: bug with rasterio/gdal not loading https urls, workaround by using http: urls or download file locally
# !export CURL_CA_BUNDLE=/etc/ssl/certs/ca-certificates.crt
!wget -O tmp.tif {tif_url}

In [0]:
rasterio.open('tmp.tif').meta

In [0]:
# load geojson for znz001 labels

label_df = gpd.read_file(geojson_url)
label_df = label_df[label_df['geometry'].isna() != True] # remove empty rows

In [0]:
label_df.plot(figsize=(10,10))

## Draw train and validation areas of interest (AOIs) with geojson.io

Since we are working with a single image, we need to delineate what sub-areas of the image and labels should be used as training versus validation data for model training.

Using [geojson.io](http://geojson.io), we'll draw our `trn` and `val` Areas of Interest (AOI) polygons in geojson format and add `dataset:trn` or `dataset:val` to the respective polygon `properties`.

The finished polygons look something like this in geojson.io:
![alt text](https://www.dropbox.com/s/v8u8ihnuj6b5lbl/geojson_screenshot3.png?dl=1)

And here is the exact GeoJSON file I created viewable in geojson.io:
http://geojson.io/#id=gist:daveluo/8e192744b2aa377db162bc34e0e0ae64&map=15/-5.7314/39.3026

**protip:** in geojson.io, you can display the drone imagery as a base layer via the menu: Meta > Add map layer > Layer URL: https://tiles.openaerialmap.org/5b100d4b2b6a08001185f344/0/5b100d4b2b6a08001185f345/{z}/{x}/{y}.png

In this case, I intentionally drew a more complex shape for each AOI to demonstrate some later steps but we could have drawn simpler adjacent rectanFor demonstration of later steps, I intentionally drew a more complex shape for each AOI but we could have simply drawn adjacent rectangles instead.
Or in more complex cases, we could choose to draw AOIs of smaller sub-areas that don't encompass the entire image - for instance, if we want to create training data for specific types of environments like dense urban areas or sparsely populated rural areas only or we want to avoid using poorly labeled areas in our training data.

Drawing the AOIs as geoJSON polygons in this way gives us the flexibility to choose exactly what and where our training and validation data represents.gles instead. 

Or we could choose to  draw AOIs of smaller areas that don't encompass the entire image if we want to only create training/validation data for specific types of environments (like dense urban areas only).

## Convert train and validation AOIs to slippy map tile polygons with supermercado and geopandas

In this step, we'll use [supermercado](https://github.com/mapbox/supermercado) to generate square polygons representing all the [slippy map tiles](https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames) at a specified zoom level that overlap the geojson training and validation AOIs we created above. 

For this tutorial, we'll work with slippy map tiles of `tile_size=256` and `zoom_level=19` which yields a manageable number of tiles and satisfactory segmentation results without too much preprocessing or model training time.  

You could also try setting a higher or lower `zoom_level` which would generate more or less tiles at higher or lower resolutions respectively. 

Here is an example of different tile `zoom_levels` over the same area of Zanzibar (see the round, white satellite TV dish for a consistently sized visual reference):

![zoom level comparison](https://cdn-images-1.medium.com/max/1200/1*06aV0V5_-uu0_mQCe13sBA.png)

Learn more about slippy maps [here](https://wiki.openstreetmap.org/wiki/Slippy_Map), [here](https://developers.planet.com/tutorials/slippy-maps-101/), and [here](https://wiki.openstreetmap.org/wiki/Zoom_levels). 


Then we'll merge our supermercado-generated slippy map tile polygons into one `GeoDataFrame` with [geopandas](http://geopandas.org/). We'll also check for and reconcile overlapping train and validation tiles which would otherwise throw off how we evaluate our progress with model training.

In [0]:
# download pre-made AOI geojson file:
!wget -O aoi.geojson https://www.dropbox.com/s/ojyjvvoer5guadr/znz001_trnval2.geojson?dl=1

In [0]:
tile_size = 256
zoom_level = 19

In [0]:
aoi_df = gpd.read_file('aoi.geojson')
aoi_df.plot()

In [0]:
aoi_df[aoi_df['dataset']=='trn']['geometry'].to_file('trn_aoi.geojson', driver='GeoJSON')
aoi_df[aoi_df['dataset']=='val']['geometry'].to_file('val_aoi.geojson', driver='GeoJSON')

In [0]:
# see https://github.com/mapbox/supermercado#supermercado-burn
!cat trn_aoi.geojson | supermercado burn {zoom_level} | mercantile shapes | fio collect > trn_aoi_z{zoom_level}tiles.geojson
!cat val_aoi.geojson | supermercado burn {zoom_level} | mercantile shapes | fio collect > val_aoi_z{zoom_level}tiles.geojson

In [0]:
trn_tiles = gpd.read_file(f'trn_aoi_z{zoom_level}tiles.geojson')
val_tiles = gpd.read_file(f'val_aoi_z{zoom_level}tiles.geojson')
trn_tiles['dataset'] = 'trn'
val_tiles['dataset'] = 'val'

In [0]:
# see if there's overlapping tiles between trn and val
fig, ax = plt.subplots(figsize=(10,10))
trn_tiles.plot(ax=ax, color='grey', alpha=0.5, edgecolor='red')
val_tiles.plot(ax=ax, color='grey', alpha=0.5, edgecolor='blue')

In [0]:
# merge into one gdf to keep all trn tiles while dropping overlapping/duplicate val tiles
import pandas as pd
tiles_gdf = gpd.GeoDataFrame(pd.concat([trn_tiles, val_tiles], ignore_index=True), crs=trn_tiles.crs)
tiles_gdf.drop_duplicates(subset=['id'], inplace=True)

In [0]:
# check that there's no more overlapping tiles between trn and val
fig, ax = plt.subplots(figsize=(10,10))
tiles_gdf[tiles_gdf['dataset'] == 'trn'].plot(ax=ax, color='grey', edgecolor='red', alpha=0.5)
tiles_gdf[tiles_gdf['dataset'] == 'val'].plot(ax=ax, color='grey', edgecolor='blue', alpha=0.5)

In [0]:
tiles_gdf.head()

In [0]:
# convert 'id' string to list of ints for z,x,y

def reformat_xyz(tile_gdf):
  tile_gdf['xyz'] = tile_gdf.id.apply(lambda x: x.lstrip('(,)').rstrip('(,)').split(','))
  tile_gdf['xyz'] = [[int(q) for q in p] for p in tile_gdf['xyz']]
  return tile_gdf

In [0]:
tiles_gdf = reformat_xyz(tiles_gdf)
tiles_gdf.head()

## Load slippy map tile image from COG with rio-tiler and corresponding label with geopandas

Now we'll use  [rio-tiler](https://github.com/cogeotiff/rio-tiler) and the slippy map tile polygons generated by supermercado to test load a single 256x256 pixel tile from our znz001 COG image file. We will also load the znz001 geoJSON labels into a geopandas GeoDataFrame and crop  the building geometries to only those that intersect the bounds of the tile image.

Here is a great intro to COGs, rio-tiler, and exciting developments in the cloud-native geospatial toolbox by [Vincent Sarago](https://medium.com/@_VincentS_) of [Development Seed](https://developmentseed.org/): https://medium.com/devseed/cog-talk-part-1-whats-new-941facbcd3d1

We'll then create our corresponding 3-channel RGB mask by passing these cropped geometries to solaris' df_to_px_mask function. Pixel value of 255 in the generated mask: 

- in the 1st (Red) channel represent building footprints, 
- in the 2nd (Green) channel represent building boundaries (visually looks yellow on the RGB mask display because the pixels overlap red and green+red=yellow), 
- and in the 3rd (Blue) channel represent close contact points between adjacent buildings

In [0]:
from rio_tiler import main as rt_main

# import mercantile
from rasterio.transform import from_bounds
from shapely.geometry import Polygon
from shapely.ops import cascaded_union

In [0]:
idx = 220
tiles_gdf.iloc[idx]['xyz']

In [0]:
tile, mask = rt_main.tile(tif_url, *tiles_gdf.iloc[idx]['xyz'], tilesize=tile_size)

In [0]:
plt.imshow(np.moveaxis(tile,0,2))

In [0]:
# redisplay our labeled geojson file
label_df.plot(figsize=(10,10))

In [0]:
# get the geometries from the geodataframe
all_polys = label_df.geometry

In [0]:
# preemptively fix and merge any invalid or overlapping geoms that would otherwise throw errors during the rasterize step. 
# TODO: probably a better way to do this

# https://gis.stackexchange.com/questions/271733/geopandas-dissolve-overlapping-polygons
# https://nbviewer.jupyter.org/gist/rutgerhofste/6e7c6569616c2550568b9ce9cb4716a3

def explode(gdf):
    """    
    Will explode the geodataframe's muti-part geometries into single 
    geometries. Each row containing a multi-part geometry will be split into
    multiple rows with single geometries, thereby increasing the vertical size
    of the geodataframe. The index of the input geodataframe is no longer
    unique and is replaced with a multi-index. 

    The output geodataframe has an index based on two columns (multi-index) 
    i.e. 'level_0' (index of input geodataframe) and 'level_1' which is a new
    zero-based index for each single part geometry per multi-part geometry
    
    Args:
        gdf (gpd.GeoDataFrame) : input geodataframe with multi-geometries
        
    Returns:
        gdf (gpd.GeoDataFrame) : exploded geodataframe with each single 
                                 geometry as a separate entry in the 
                                 geodataframe. The GeoDataFrame has a multi-
                                 index set to columns level_0 and level_1
        
    """
    gs = gdf.explode()
    gdf2 = gs.reset_index().rename(columns={0: 'geometry'})
    gdf_out = gdf2.merge(gdf.drop('geometry', axis=1), left_on='level_0', right_index=True)
    gdf_out = gdf_out.set_index(['level_0', 'level_1']).set_geometry('geometry')
    gdf_out.crs = gdf.crs
    return gdf_out

def cleanup_invalid_geoms(all_polys):
  all_polys_merged = gpd.GeoDataFrame()
  all_polys_merged['geometry'] = gpd.GeoSeries(cascaded_union([p.buffer(0) for p in all_polys]))

  gdf_out = explode(all_polys_merged)
  gdf_out = gdf_out.reset_index()
  gdf_out.drop(columns=['level_0','level_1'], inplace=True)
  all_polys = gdf_out['geometry']
  return all_polys

all_polys = cleanup_invalid_geoms(all_polys)

In [0]:
# get the same tile polygon as our tile image above
tile_poly = tiles_gdf.iloc[idx]['geometry']
print(tile_poly.bounds)
tile_poly

In [0]:
# get affine transformation matrix for this tile using rasterio.transform.from_bounds: https://rasterio.readthedocs.io/en/stable/api/rasterio.transform.html#rasterio.transform.from_bounds
tfm = from_bounds(*tile_poly.bounds, tile_size, tile_size) 
tfm

In [0]:
# crop znz001 geometries to what overlaps our tile polygon bounds
cropped_polys = [poly for poly in all_polys if poly.intersects(tile_poly)]
cropped_polys_gdf = gpd.GeoDataFrame(geometry=cropped_polys, crs={'init': 'epsg:4326'})
cropped_polys_gdf.plot()

In [0]:
# burn a footprint/boundary/contact 3-channel mask with solaris: https://solaris.readthedocs.io/en/latest/tutorials/notebooks/api_masks_tutorial.html

fbc_mask = sol.vector.mask.df_to_px_mask(df=cropped_polys_gdf,
                                         channels=['footprint', 'boundary', 'contact'],
                                         affine_obj=tfm, shape=(tile_size,tile_size),
                                         boundary_width=5, boundary_type='inner', contact_spacing=5, meters=True)

In [0]:
fig, (ax1, ax2) = plt.subplots(1,2,figsize=(10, 5))
ax1.imshow(np.moveaxis(tile,0,2))
ax2.imshow(fbc_mask)

In [0]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(10, 5))
ax1.imshow(fbc_mask[:,:,0])
ax2.imshow(fbc_mask[:,:,1])
ax3.imshow(fbc_mask[:,:,2])

## Make and save all the image and mask tiles

Now that we've successfully loaded one tile image from COG with rio-tiler and created its 3-channel RGB mask with solaris, let's generate our full training and validation datasets. 

We'll write some functions and loops to run through all of our `trn` and `val` tiles at `zoom_level=19` and save them as lossless `png` files in the appropriate folders with a filename schema of `{save_path}/{prefix}{z}_{x}_{y}` so we can easily identify and geolocate what tile each file represents.

In [0]:
import skimage
from tqdm import tqdm

In [0]:
def save_tile_img(tif_url, xyz, tile_size, save_path='', prefix='', display=False):
  x,y,z = xyz
  tile, mask = rt_main.tile(tif_url, x,y,z, tilesize=tile_size)
  if display: 
    plt.imshow(np.moveaxis(tile,0,2))
    plt.show()
    
  skimage.io.imsave(f'{save_path}/{prefix}{z}_{x}_{y}.png',np.moveaxis(tile,0,2), check_contrast=False) 

In [0]:
def save_tile_mask(labels_poly, tile_poly, xyz, tile_size, save_path='', prefix='', display=False):
  x,y,z = xyz
  tfm = from_bounds(*tile_poly.bounds, tile_size, tile_size) 
  
  cropped_polys = [poly for poly in labels_poly if poly.intersects(tile_poly)]
  cropped_polys_gdf = gpd.GeoDataFrame(geometry=cropped_polys, crs={'init': 'epsg:4326'})
  
  fbc_mask = sol.vector.mask.df_to_px_mask(df=cropped_polys_gdf,
                                         channels=['footprint', 'boundary', 'contact'],
                                         affine_obj=tfm, shape=(tile_size,tile_size),
                                         boundary_width=5, boundary_type='inner', contact_spacing=5, meters=True)
  
  if display: plt.imshow(fbc_mask); plt.show()
  
  skimage.io.imsave(f'{save_path}/{prefix}{z}_{x}_{y}_mask.png',fbc_mask, check_contrast=False) 

In [0]:
tiles_gdf[tiles_gdf['dataset'] == 'trn'].shape, tiles_gdf[tiles_gdf['dataset'] == 'val'].shape

In [0]:
# we'll load our COG locally but could also load directly from url which is slower and subject to potentially more i/o issues
# TODO: try loading from url and catch i/o exceptions
# TODO: multithread/multiprocess this? Took ~3.5 mins to load and save 1261 image tiles on local COG file loading
for idx, tile in tqdm(tiles_gdf.iterrows()):
  dataset = tile['dataset']
  save_tile_img('tmp.tif', tile['xyz'], tile_size, save_path=img_path, prefix=f'znz001{dataset}_', display=False)

In [0]:
# TODO: multiprocess this? Took ~3 mins to burn and save 1261 masks
for idx, tile in tqdm(tiles_gdf.iterrows()):
  dataset = tile['dataset']
  tile_poly = tile['geometry']
  save_tile_mask(all_polys, tile_poly, tile['xyz'], tile_size, save_path=mask_path,prefix=f'znz001{dataset}_', display=False)

In [0]:
# check that tile images and masks saved correctly
start_idx, end_idx = 200,205
for i,j in zip(sorted(img_path.iterdir())[start_idx:end_idx], sorted(mask_path.iterdir())[start_idx:end_idx]):
  fig, (ax1,ax2) = plt.subplots(1,2,figsize=(10,5))
  ax1.imshow(skimage.io.imread(i))
  ax2.imshow(skimage.io.imread(j))
  plt.show()

In [0]:
# compress and download
!tar -czf znz001trn.tar.gz data

## Save files to GDrive (or download to computer)

Colab does not persistently store any files created and saved in its runtimes for than 8-12 hours (or less depending on inactivity or overall demand on the system). We'll transfer or download the files we create somewhere else. We can:

1. **Mount our own account's Google Drive storage and transfer files** there via a `!cp` command: see below cell
2. **Download  files to local computer:** go to Files tab on left > find and right-click selected file > click Download > file will be prepared by Colab and automatically downloaded when ready

In [0]:
# to mount and transfer to GDrive: uncomment and run this and the next cell, follow instructions to auhorize access to your GDrive

# from google.colab import drive
# drive.mount('/content/drive')

In [0]:
# copy training data compressed tarball to root of your GDrive
# !cp znz001trn.tar.gz /content/drive/My\ Drive/ 

# Train u-net segmentation model with fastai & pytorch

As our deep learning framework and library of tools, we'll use the excellent [fastai](https://github.com/fastai/fastai) library built on top of [PyTorch](https://pytorch.org/). 

For more info:
- about Fast.ai, the organization: https://www.fast.ai/about/
- direct links to the free MOOC series: 
  - Part 1 ("Practical Deep Learning for Coders"): https://course.fast.ai/index.html
  - Part 2 ("Deep Learning from the Foundations"): https://course.fast.ai/part2

## Download and install fastai

Let's download, install, and set up fastai v1 (currently at 1.0.55). And if we're not already on it, let's reset Colab to a GPU runtime (this removes locally stored files since it switches to a new environment so you will have to re-download and untar the training dataset created in above steps):

**SWITCH TO GPU RUNTIME: Menu > Runtime > Change runtime type > Hardware Accelerator = GPU**

Colab's free GPUs range from a Tesla K80, T4, or T8 depending on their availability. See the `===Hardware===` section of `show_install()` for what GPU type and how much GPU memory  is available which will affect the batch size and training time.

For all of these GPUs and mem sizes, a batch size of `bs=16` at `size=256` should train at <2 mins/epoch without encountering out-of-memory issues but if it does comes up, lower the bs to 8.


In [0]:
!curl https://course.fast.ai/setup/colab | bash

In [0]:
from fastai.vision import *
from fastai.callbacks import *

In [0]:
from fastai.utils.collect_env import *
show_install(True)

## Set up data

Now we'll set up our training dataset of tile images and masks created above to load correctly into fastai for training and validation. 

The code in this step tracks closely with that of fastai course's lesson3-camvid so please refer to that [lesson video](https://course.fast.ai/videos/?lesson=3) and [notebook](https://nbviewer.jupyter.org/github/fastai/course-v3/blob/master/nbs/dl1/lesson3-camvid.ipynb) for more detailed and excellent explanation by Jeremy Howard about the code and fastai's [Data Block API](https://docs.fast.ai/data_block.html).

The main departures from the camvid lesson notebook is the use of filename string parsing to determine which image and mask files comprise the validation data.

And we'll subclass `SegmentationLabelList` to alter the behavior of `open_mask` and `PIL.Image` underlying it in order to open the 3-channel target masks as RGB images `(convert_mode='RGB')` instead of default greyscale 1-channel images `(convert_mode='L')`.

We'll also visually confirm that the image files and channels of the respective target mask file are loaded and paired correctly with a display function `show_3ch`.

In [0]:
# if not already present in file storage, download and extract the training/validation dataset created in above sections
!wget -O znz001trn.tar.gz https://www.dropbox.com/s/2a2ikf7m265davv/znz001trn.tar.gz?dl=1
!tar -xf znz001trn.tar.gz

In [0]:
path = Path('data')
path.ls()

In [0]:
path_lbl = path/'masks-256'
path_img = path/'images-256'

In [0]:
fnames = get_image_files(path_img)
lbl_names = get_image_files(path_lbl)
print(len(fnames), len(lbl_names))
fnames[:3], lbl_names[:3] 

In [0]:
get_y_fn = lambda x: path_lbl/f'{x.stem}_mask.png'

In [0]:
# test that masks are opening correctly with open_mask() settings
img_f = fnames[121]
img = open_image(img_f)
mask = open_mask(get_y_fn(img_f), convert_mode='RGB', div=False)

fig,ax = plt.subplots(1,1, figsize=(10,10))
img.show(ax=ax)
mask.show(ax=ax, alpha=0.5)

In [0]:
plt.hist(mask.data.view(-1), bins=3)

In [0]:
# define the valdation set by fn prefix
holdout_grids = ['znz001val_']
valid_idx = [i for i,o in enumerate(fnames) if any(c in str(o) for c in holdout_grids)]
print(len(valid_idx))

In [0]:
# subclassing SegmentationLabelList to set open_mask(fn, div=True, convert_mode='RGB') for 3 channel target masks

class SegLabelListCustom(SegmentationLabelList):
    def open(self, fn): return open_mask(fn, div=True, convert_mode='RGB')
    
class SegItemListCustom(SegmentationItemList):
    _label_cls = SegLabelListCustom

In [0]:
# the classes corresponding to each channel
codes = np.array(['Footprint','Boundary','Contact'])

In [0]:
size = 256
bs = 16

In [0]:
# define image transforms for data augmentation and create databunch. More about image tfms and data aug at https://docs.fast.ai/vision.transform.html 
tfms = get_transforms(flip_vert=True, max_warp=0.1, max_rotate=20, max_zoom=2, max_lighting=0.3)

src = (SegItemListCustom.from_folder(path_img)
        .split_by_idx(valid_idx)
        .label_from_func(get_y_fn, classes=codes))

data = (src.transform(tfms, size=size, tfm_y=True)
        .databunch(bs=bs)
        .normalize(imagenet_stats))

In [0]:
def show_3ch(imgitem, figsize=(10,5)):
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=figsize)
    ax1.imshow(np.asarray(imgitem.data[0,None])[0])
    ax2.imshow(np.asarray(imgitem.data[1,None])[0])
    ax3.imshow(np.asarray(imgitem.data[2,None])[0])
    
    ax1.set_title('Footprint')
    ax2.set_title('Boundary')
    ax3.set_title('Contact')
    
    plt.show()

In [0]:
for idx in range(10,15):
    print(data.valid_ds.items[idx].name)
    fig, (ax1,ax2) = plt.subplots(1,2, figsize=(10,5))
    data.valid_ds.x[idx].show(ax=ax1)
    ax2.imshow(image2np(data.valid_ds.y[idx].data*255))
    plt.show()
    show_3ch(data.valid_ds.y[idx])

In [0]:
# visually inspect data-augmented training images
# TODO: show_batch doesn't display RGB mask correctly, setting alpha=0 to turn off for now
data.show_batch(4,figsize=(10,10), alpha=0.)

In [0]:
data

## Define custom losses and metrics to handle 3-channel targets

Here we implement some new loss functions like Dice Loss and Focal Loss which have been shown to perform well in image segmentation tasks. Then we'll create a `MultiChComboLoss` class to combine multiple loss functions and  calculate them across the 3 channels with adjustable weighting.

The approach of combining a Dice or Jaccard loss to consider image-wide context with individual pixel-focused Binary Cross Entropy or Focal loss with adjustable weighing of the 3 target mask channels has been shown to consistently outperform single loss functions. This is well-documented by Nick Weir's deep dive into the recent [SpaceNet 4 Off-Nadir Building Detection](https://spacenetchallenge.github.io/datasets/spacenet-OffNadir-summary.html) top results: 

https://medium.com/the-downlinq/a-deep-dive-into-the-spacenet-4-winning-algorithms-8d611a5dfe25

Finally, we adapt our model evaluation metrics (accuracy and dice score) to calculate a mean score for all channels or by a specified individual channel.

In [0]:
import pdb

def dice_loss(input, target):
#     pdb.set_trace()
    smooth = 1.
    input = torch.sigmoid(input)
    iflat = input.contiguous().view(-1).float()
    tflat = target.contiguous().view(-1).float()
    intersection = (iflat * tflat).sum()
    return 1 - ((2. * intersection + smooth) / ((iflat + tflat).sum() +smooth))

# adapted from https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65938
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean': return F_loss.mean()
        elif self.reduction == 'sum': return F_loss.sum()
        else: return F_loss

class DiceLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, input, target):
        loss = dice_loss(input, target)
        if self.reduction == 'mean': return loss.mean()
        elif self.reduction == 'sum': return loss.sum()
        else: return loss

class MultiChComboLoss(nn.Module):
    def __init__(self, reduction='mean', loss_funcs=[FocalLoss(),DiceLoss()], loss_wts = [1,1], ch_wts=[1,1,1]):
        super().__init__()
        self.reduction = reduction
        self.ch_wts = ch_wts
        self.loss_wts = loss_wts
        self.loss_funcs = loss_funcs 
        
    def forward(self, output, target):
#         pdb.set_trace()
        for loss_func in self.loss_funcs: loss_func.reduction = self.reduction # need to change reduction on fwd pass for loss calc in learn.get_preds(with_loss=True)
        loss = 0
        channels = output.shape[1]
        assert len(self.ch_wts) == channels
        assert len(self.loss_wts) == len(self.loss_funcs)
        for ch_wt,c in zip(self.ch_wts,range(channels)):
            ch_loss=0
            for loss_wt, loss_func in zip(self.loss_wts,self.loss_funcs): 
                ch_loss+=loss_wt*loss_func(output[:,c,None], target[:,c,None])
            loss+=ch_wt*(ch_loss)
        return loss/sum(self.ch_wts)

In [0]:
# calculate metrics on one channel (i.e. ch 0 for building footprints only) or on all 3 channels

def acc_thresh_multich(input:Tensor, target:Tensor, thresh:float=0.5, sigmoid:bool=True, one_ch:int=None)->Rank0Tensor:
    "Compute accuracy when `y_pred` and `y_true` are the same size."
    
#     pdb.set_trace()
    if sigmoid: input = input.sigmoid()
    n = input.shape[0]
    
    if one_ch is not None:
        input = input[:,one_ch,None]
        target = target[:,one_ch,None]
    
    input = input.view(n,-1)
    target = target.view(n,-1)
    return ((input>thresh)==target.byte()).float().mean()

def dice_multich(input:Tensor, targs:Tensor, iou:bool=False, one_ch:int=None)->Rank0Tensor:
    "Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
#     pdb.set_trace()
    n = targs.shape[0]
    input = input.sigmoid()
    
    if one_ch is not None:
        input = input[:,one_ch,None]
        targs = targs[:,one_ch,None]
    
    input = (input>0.5).view(n,-1).float()
    targs = targs.view(n,-1).float()

    intersect = (input * targs).sum().float()
    union = (input+targs).sum().float()
    if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())
    else: return intersect / (union-intersect+1.0)

## Set up model

We'll set up fastai's Dynamic Unet model with an ImageNet-pretrained resnet34 encoder. This architecture, inspired by the original U-net, uses by default many advanced deep learning techniques such as:

- One cycle learning schedule: https://sgugger.github.io/the-1cycle-policy.html
- AdamW optimizer: https://www.fast.ai/2018/07/02/adam-weight-decay/
- Pixel shuffle upsampling with ICNR initiation from super resolution research: https://medium.com/@hirotoschwert/introduction-to-deep-super-resolution-c052d84ce8cf
- Optionally set leaky ReLU, blur, self attention: https://docs.fast.ai/vision.models.unet.html#DynamicUnet

We'll define our `MultiChComboLoss` function as a balanced combination of Focal Loss and Dice Loss and set our accuracy and dice metrics. 

Also note that our metrics displayed during training shows channel-0 (building footprint channel only) accuracy and dice metrics in the right-most 2 columns while the first two accuracy and dice metrics (left-hand columns) show the mean of the respective metric across all 3 channels.

In [0]:
# set up metrics to show mean metrics for all channels as well as the building-only metrics (channel 0)

acc_ch0 = partial(acc_thresh_multich, one_ch=0)
dice_ch0 = partial(dice_multich, one_ch=0)
metrics = [acc_thresh_multich, dice_multich, acc_ch0, dice_ch0]

In [0]:
# combo Focal + Dice loss with equal channel wts

learn = unet_learner(data, models.resnet34, model_dir='../../models',
                     metrics=metrics, 
                     loss_func=MultiChComboLoss(
                        reduction='mean',
                        loss_funcs=[FocalLoss(gamma=1, alpha=0.95),
                                    DiceLoss(),
                                   ], 
                        loss_wts=[1,1],
                        ch_wts=[1,1,1])
                    )

In [0]:
learn.metrics

In [0]:
learn.loss_func

In [0]:
learn.summary()

## Train model, inspect results, unfreeze & train more, export for inference

First, we'll fine-tune our Unet on the decoder part only (leaving the weights for the ImageNet-pretrained resnet34 encoder frozen) for some epochs. Then we'll unfreeze all the trainable weights/layers of our model and train for some more epochs.

We'll track the `valid_loss`, `acc_...`, and `dice_..`. metrics per epoch as training progresses to make sure they continue to improve and we're not overfitting. And we set a `SaveModelCallback` which will track the channel-0 dice score, save a model checkpoint each time there's an improvement, and reload the highest performing model checkpoint file at the end of training.

We'll also inspect our model's results by setting `learn.model.eval()`, generating some batches of predictions on the validation set, calculating and reshaping the image-wise loss values, and sorting by highest loss first to see the worst performing results (as measured by the loss which may differ in surprising ways from visually gauging results). 

**Pro-tip:** display and view your results every chance you get! You'll pick up on all kinds of interesting clues about your model's behavior and how to make it better.

Finally, we'll export our trained Unet segmentation model for inference purposes as a `.pkl` file. Learn more about exporting fastai models for inference in this tutorial: https://docs.fast.ai/tutorial.inference.html

In [0]:
# learn.lr_find()

In [0]:
# learn.recorder.plot(0,2,suggestion=True)

In [0]:
lr = slice(3e-6, 3e-4)
learn.fit_one_cycle(10, max_lr=lr, 
                    callbacks=[
                        SaveModelCallback(learn,
                                         monitor='dice_multich',
                                         mode='max',
                                         name='znz001trn-focaldice-stage1-best')
                    ]
                   )

In [0]:
learn.model.eval()
outputs,labels,losses = learn.get_preds(ds_type=DatasetType.Valid,n_batch=3,with_loss=True)
losses.shape

In [0]:
losses_reshaped = torch.mean(losses.view(outputs.shape[0],-1), dim=1)
sorted_idx = torch.argsort(losses_reshaped,descending=True)
losses_reshaped.shape

In [0]:
# look at predictions vs actual by channel sorted by highest image-wise loss first

for i in sorted_idx[:10]:

    print(f'{data.valid_ds.items[i].name}')
    print(f'loss: {losses_reshaped[i].mean()}')
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))
    
    data.valid_ds.x[i].show(ax=ax1)
    ax1.set_title('Prediction')
    ax1.imshow(image2np(outputs[i].sigmoid()), alpha=0.4)
    
    ax2.set_title('Ground Truth')
    data.valid_ds.x[i].show(ax=ax2)
    ax2.imshow(image2np(labels[i])*255, alpha=0.4)
    plt.show()
    
    print('Predicted:')
    show_3ch(outputs[i].sigmoid())
    print('Actual:')
    show_3ch(labels[i])

In [0]:
learn.load('znz001trn-focaldice-stage1-best')
learn.model.train()
learn.unfreeze()

In [0]:
learn.lr_find()

In [0]:
learn.recorder.plot(suggestion=True)

In [0]:
learn.fit_one_cycle(20, max_lr=slice(3e-6,3e-4), 
                    callbacks=[
                        SaveModelCallback(learn,
                                           monitor='dice_multich',
                                           mode='max',
                                           name='znz001trn-focaldice-unfrozen-best')
                    ]
                   )

In [0]:
learn.model.eval()
outputs,labels,losses = learn.get_preds(ds_type=DatasetType.Valid,n_batch=6,with_loss=True)
losses_reshaped = torch.mean(losses.view(outputs.shape[0],-1), dim=1)
sorted_idx = torch.argsort(losses_reshaped,descending=True)

In [0]:
# look at predictions vs actual by channel sorted by highest image-wise loss first

for i in sorted_idx[:10]:

    print(f'{data.valid_ds.items[i].name}')
    print(f'loss: {losses_reshaped[i].mean()}')
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))
    
    data.valid_ds.x[i].show(ax=ax1)
    ax1.set_title('Prediction')
    ax1.imshow(image2np(outputs[i].sigmoid()), alpha=0.4)
    
    ax2.set_title('Ground Truth')
    data.valid_ds.x[i].show(ax=ax2)
    ax2.imshow(image2np(labels[i])*255, alpha=0.4)
    plt.show()
    
    print('Predicted:')
    show_3ch(outputs[i].sigmoid())
    print('Actual:')
    show_3ch(labels[i])

In [0]:
# pickling with custom classes like MultiChComboLoss is a bit tricky 
learn.export('../../models/znz001trn-focaldice.pkl')

## Save files to GDrive (or download to computer)

Colab does not persistently store any files created and saved in its runtimes for than 8-12 hours (or less depending on inactivity or overall demand on the system). We'll transfer or download the files we create somewhere else. We can:

1. **Mount our own account's Google Drive storage and transfer files** there via a `!cp` command: see below cell
2. **Download  files to local computer:** go to Files tab on left > find and right-click selected file > click Download > file will be prepared by Colab and automatically downloaded when ready

In [0]:
# to mount and transfer files to GDrive: uncomment and run this and the next cell, follow instructions to auhorize access to your GDrive

# from google.colab import drive
# drive.mount('/content/drive')

In [0]:
# copy model export .pkl file to root of your GDrive
# !cp models/znz001trn-focaldice.pkl /content/drive/My\ Drive/ 

# Inference on new imagery

With our segmentation model trained and exported for inference use, we will now re-load it as an inference-only model to test on new unseen imagery. We'll test the generalizability of our trained segmentation model on tiles from drone imagery captured over another part of Zanzibar and in other parts of the world as well as at varying `zoom_levels` (locations and zoom levels indicated):

![](https://cdn-images-1.medium.com/max/1200/1*DaS2dVfeaxZCg6cqOcHDrg.jpeg)

We'll also compare our model inference time per tile on GPU versus CPU.

## Load exported model for inference

In [0]:
!curl https://course.fast.ai/setup/colab | bash

In [0]:
from fastai.vision import *
from fastai.callbacks import *

In [0]:
from fastai.utils.collect_env import *
show_install(True)

In [0]:
# TODO: look into better way of loading export.pkl w/o needing to redefine these custom classes

class SegLabelListCustom(SegmentationLabelList):
    def open(self, fn): return open_mask(fn, div=True, convert_mode='RGB')
    
class SegItemListCustom(SegmentationItemList):
    _label_cls = SegLabelListCustom

def dice_loss(input, target):
#     pdb.set_trace()
    smooth = 1.
    input = torch.sigmoid(input)
    iflat = input.contiguous().view(-1).float()
    tflat = target.contiguous().view(-1).float()
    intersection = (iflat * tflat).sum()
    return 1 - ((2. * intersection + smooth) / ((iflat + tflat).sum() +smooth))

# adapted from https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65938
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean': return F_loss.mean()
        elif self.reduction == 'sum': return F_loss.sum()
        else: return F_loss

class DiceLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, input, target):
        loss = dice_loss(input, target)
        if self.reduction == 'mean': return loss.mean()
        elif self.reduction == 'sum': return loss.sum()
        else: return loss

class MultiChComboLoss(nn.Module):
    def __init__(self, reduction='mean', loss_funcs=[FocalLoss(),DiceLoss()], loss_wts = [1,1], ch_wts=[1,1,1]):
        super().__init__()
        self.reduction = reduction
        self.ch_wts = ch_wts
        self.loss_wts = loss_wts
        self.loss_funcs = loss_funcs 
        
    def forward(self, output, target):
#         pdb.set_trace()
        for loss_func in self.loss_funcs: loss_func.reduction = self.reduction # need to change reduction on fwd pass for loss calc in learn.get_preds(with_loss=True)
        loss = 0
        channels = output.shape[1]
        assert len(self.ch_wts) == channels
        assert len(self.loss_wts) == len(self.loss_funcs)
        for ch_wt,c in zip(self.ch_wts,range(channels)):
            ch_loss=0
            for loss_wt, loss_func in zip(self.loss_wts,self.loss_funcs): 
                ch_loss+=loss_wt*loss_func(output[:,c,None], target[:,c,None])
            loss+=ch_wt*(ch_loss)
        return loss/sum(self.ch_wts)

def acc_thresh_multich(input:Tensor, target:Tensor, thresh:float=0.5, sigmoid:bool=True, one_ch:int=None)->Rank0Tensor:
    "Compute accuracy when `y_pred` and `y_true` are the same size."
    
#     pdb.set_trace()
    if sigmoid: input = input.sigmoid()
    n = input.shape[0]
    
    if one_ch is not None:
        input = input[:,one_ch,None]
        target = target[:,one_ch,None]
    
    input = input.view(n,-1)
    target = target.view(n,-1)
    return ((input>thresh)==target.byte()).float().mean()

def dice_multich(input:Tensor, targs:Tensor, iou:bool=False, one_ch:int=None)->Rank0Tensor:
    "Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
#     pdb.set_trace()
    n = targs.shape[0]
    input = input.sigmoid()
    
    if one_ch is not None:
        input = input[:,one_ch,None]
        targs = targs[:,one_ch,None]
    
    input = (input>0.5).view(n,-1).float()
    targs = targs.view(n,-1).float()

    intersect = (input * targs).sum().float()
    union = (input+targs).sum().float()
    if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())
    else: return intersect / (union-intersect+1.0)

In [0]:
!wget -O models/znz001trn-focaldice.pkl https://www.dropbox.com/s/by3nc1xri8y7t4p/znz001trn-focaldice.pkl?dl=1

In [0]:
# if you have your own model .pkl file to load, either:

# upload from computer: Files tab > Upload on left
# or mount GDrive and transfer file to Colab storage: uncomment below, change filepaths to the .pkl file on your GDrive if needed, and run:

# !cp /content/drive/My\ Drive/znz001trn-focaldice.pkl models/

In [0]:
inference_learner = load_learner(path='models/', file='znz001trn-focaldice.pkl')

## Inference on new unseen tiles


In [0]:
import skimage 
import time

def get_pred(learner, tile):
#     pdb.set_trace()
    t_img = Image(pil2tensor(tile[:,:,:3],np.float32).div_(255))
    outputs = learner.predict(t_img)
    im = image2np(outputs[2].sigmoid())
    im = (im*255).astype('uint8')
    return im

In [0]:
# try a different tile by changing or adding your own urls to list

urls = [
  'https://tiles.openaerialmap.org/5b1009f22b6a08001185f24a/0/5b1009f22b6a08001185f24b/19/319454/270706.png',
  'https://tiles.openaerialmap.org/5b1e6fd42b6a08001185f7bf/0/5b1e6fd42b6a08001185f7c0/20/569034/537093.png',
  'https://tiles.openaerialmap.org/5beaaba463f9420005ef8db0/0/5beaaba463f9420005ef8db1/19/313479/283111.png',
  'https://tiles.openaerialmap.org/5d050c3673de290005853a91/0/5d050c3673de290005853a92/18/203079/117283.png',
  'https://tiles.openaerialmap.org/5c88ff77225fc20007ab4e26/0/5c88ff77225fc20007ab4e27/21/1035771/1013136.png',
  'https://tiles.openaerialmap.org/5d30bac2e757aa0005951652/0/5d30bac2e757aa0005951653/19/136700/197574.png'
]

### On GPU

In [0]:
for url in urls:
  t1 = time.time()
  test_tile = skimage.io.imread(url)
  result = get_pred(inference_learner, test_tile)
  t2 = time.time()
  
  print(url)
  print(f'GPU inference took {t2-t1:.2f} secs')
  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))
  ax1.imshow(test_tile)
  ax2.imshow(result)
  ax1.axis('off')
  ax2.axis('off')
  plt.show()

### On CPU

In [0]:
for url in urls:
  t1 = time.time()
  test_tile = skimage.io.imread(url)
  print(url)
  result = get_pred(inference_learner, test_tile)

  t2 = time.time()
  
  print(f'CPU inference took {t2-t1:.2f} secs')
  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))
  ax1.imshow(test_tile)
  ax2.imshow(result)
  ax1.axis('off')
  ax2.axis('off')
  plt.show()

# Post-processing

## Predict on a tile, threshold, polygonize, and georegister

For good evaluation of model performance against ground truth, we'll use another set of labeled data that the model was not trained on. We'll  get this from the larger Zanzibar dataset. Preview the imagery and ground truth labels for `znz029` in the STAC browser [here](https://geoml-samples.netlify.com/item/9Eiufow7wPXLqQEP1Di2J5X8kXkBLgMsCBoN37VrtRPB/2sEaEKnnyjG2mx7CnN1ESAdjYAEQjoNRxT2vgQRC9oB?si=0&t=preview#14/-5.865178/39.348986):

![](https://cdn-images-1.medium.com/max/1200/0*fGsRIu-2ExIWXzc0)

For demonstration, we'll use this particular tile at `z=19, x=319454, y=270706` from `znz029`:

![alt text](https://tiles.openaerialmap.org/5b1009f22b6a08001185f24a/0/5b1009f22b6a08001185f24b/19/319454/270706.png)

Using solaris and geopandas, we'll convert our model's prediction as a 3-channel pixel raster output into a GeoJSON file by:

1. thresholding and combining the 3-channels of pixel values in our raw prediction output into a 1 channel binary pixel mask
2. polygonizing this binary pixel mask into shape vectors representing the predicted footprint of every building
3. georegistering the x, y display coordinates of these vectorized building shapes into longitude, latitude coordinates

In [0]:
# if not already loaded in runtime: 
# install fastai and load inference learner from "Inference on new imagery section" 
# and uncomment below and re-install geo packages

# !add-apt-repository ppa:ubuntugis/ppa
# !apt-get update
# !apt-get install python-numpy gdal-bin libgdal-dev
# !apt install python3-rtree

# !pip install rasterio
# !pip install geopandas
# !pip install descartes
# !pip install solaris
# !pip install rio-tiler

In [0]:
import solaris as sol 
from affine import Affine
from rasterio.transform import from_bounds
from shapely.geometry import Polygon
import math
import geopandas as gpd
import skimage
  
def deg2num(lat_deg, lon_deg, zoom):
    lat_rad = math.radians(lat_deg)
    n = 2.0 ** zoom
    xtile = int((lon_deg + 180.0) / 360.0 * n)
    ytile = int((1.0 - math.log(math.tan(lat_rad) + (1 / math.cos(lat_rad))) / math.pi) / 2.0 * n)
    return (xtile, ytile)

def num2deg(xtile, ytile, zoom):
    n = 2.0 ** zoom
    lon_deg = xtile / n * 360.0 - 180.0
    lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * ytile / n)))
    lat_deg = math.degrees(lat_rad)
    return (lat_deg, lon_deg)
  
def tile_to_poly(z,x,y, size):
    top, left = num2deg(x, y, z)
    bottom, right = num2deg(x+1, y+1, z)
    tfm = from_bounds(left, bottom, right, top, size, size)

    return Polygon.from_bounds(left,top,right,bottom), tfm

In [0]:
z,x,y = 19,319454,270706
url= 'https://tiles.openaerialmap.org/5b1009f22b6a08001185f24a/0/5b1009f22b6a08001185f24b/19/319454/270706.png'

test_tile = skimage.io.imread(url)
result = get_pred(inference_learner, test_tile)
  
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))
ax1.imshow(test_tile)
ax2.imshow(result)
plt.show()

In [0]:
# threshold and polygonize with solaris: https://solaris.readthedocs.io/en/latest/tutorials/notebooks/api_mask_to_vector.html

mask2poly = sol.vector.mask.mask_to_poly_geojson(result, 
                                                 channel_scaling=[1,0,-1], 
                                                 bg_threshold=245, 
                                                 simplify=True,
                                                 tolerance=2
                                                 )

In [0]:
mask2poly.plot(figsize=(10,10))

In [0]:
mask2poly.head()

In [0]:
# get the bounds of the tile and its affine tfm matrix for georegistering purposes
tile_poly, tile_tfm = tile_to_poly(z,x,y,256)
tile_tfm

In [0]:
# convert polys from pixel coords to geo coords: https://solaris.readthedocs.io/en/latest/api/vector.html?highlight=georegister_px_df#solaris.vector.polygon.georegister_px_df
result_polys = sol.vector.polygon.georegister_px_df(mask2poly, 
                                                   affine_obj=tile_tfm, 
                                                   crs='epsg:4326')

In [0]:
# show tile image to raw prediction to georegistered polygons
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(15,5))

ax1.imshow(test_tile)
ax2.imshow(result)
result_polys.plot(ax=ax3)

In [0]:
result_polys.to_file('result_polys.geojson', driver='GeoJSON')

### Check that the saved result_poly.geojson is correctly georegistered on geojson.io

http://geojson.io/#id=gist:daveluo/3dfe4695e31b2b3a4c7c6e13ada5d1e6&map=19/-5.86910/39.35198

TMS layer link: https://tiles.openaerialmap.org/5ae242fd0b093000130afd38/0/5ae242fd0b093000130afd39/{z}/{x}/{y}.png

![alt text](https://www.dropbox.com/s/kxideja8cx1ao14/check_predicted_polys.png?dl=1)




## Evaluate prediction against ground truth

Finally with georegistered building predictions as a GeoJSON file, we can evaluate it against the ground truth GeoJSON file for the same tile.

We'll clip the ground truth labels to the bounds of this particular tile and use solaris's Evaluator to calculate the precision, recall, and F1 score. We will also visualize our predicted buildings (in red) against the ground truth buildings (in blue) in this particular tile.

For more information about  these common evaluation metrics for models applied to overhead imagery, see the following articles and more by the SpaceNet team:

https://medium.com/the-downlinq/the-spacenet-metric-612183cc2ddb

https://medium.com/the-downlinq/the-good-and-the-bad-in-the-spacenet-off-nadir-building-footprint-extraction-challenge-4c3a96ee9c72

In [0]:
# get the ground truth labels for all znz029
labels_url = 'https://www.dropbox.com/sh/ct3s1x2a846x3yl/AADHytc8fSCf3gna0wNAW3lZa/grid_029.geojson?dl=1'

gt_gdf = gpd.read_file(labels_url)

In [0]:
print(tile_poly.bounds)

In [0]:
# visualize the tile (in red) against the entire labeled znz029 area (in blue)
fig, ax = plt.subplots(figsize=(10,10))
gt_gdf.plot(ax=ax)
gpd.GeoDataFrame(geometry=[tile_poly], crs='epsg:4326').plot(alpha=0.5, color='red', ax=ax)

In [0]:
# clip gt_gdf to the tile bounds
clipped_gt_polys = gpd.overlay(gt_gdf, gpd.GeoDataFrame(geometry=[tile_poly], crs='epsg:4326'), how='intersection')

In [0]:
clipped_gt_polys.plot()

In [0]:
result_polys.plot()

In [0]:
clipped_gt_polys.to_file('clipped_gt_polys.geojson', driver='GeoJSON')

In [0]:
# solaris tutorial on evaluation: https://solaris.readthedocs.io/en/latest/tutorials/notebooks/api_evaluation_tutorial.html 
evaluator = sol.eval.base.Evaluator('clipped_gt_polys.geojson')
evaluator.load_proposal('result_polys.geojson', proposalCSV=False, conf_field_list=[])
evaluator.eval_iou(calculate_class_scores=False)

In [0]:
# visualize predicted vs ground truth
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,5))

ax1.imshow(test_tile)
clipped_gt_polys.plot(ax=ax2, color='blue', alpha=0.5) #gt
result_polys.plot(ax=ax2, color='red', alpha=0.5) #pred

# Ideas to Try for Performance Gains

Congratulations, you did it! 

You've completed the tutorial and now know how to do everything from producing training data to creating a deep learning model for segmentation to postprocessing and evaluating your model's performance.

To flex your newfound knowledge and make your model perform potentially **much better**, try implementing some or all these ideas:

- Create and use more training data: there are 13 grids' worth of training data for Zanzibar released as part of the [Open AI Tanzania Building Footprint Segmentation Challenge dataset](https://docs.google.com/spreadsheets/d/1kHZo2KA0-VtCCcC5tL4N0SpyoxnvH7mLbybZIHZGTfE/edit#gid=0).

- Change the zoom_level of your training/validation tiles. Better yet, try using tiles across multiple zooms (i.e. z21, z20, z19, z18). Note that with multiple zoom levels over the same imagery, you should be extra careful of overlapping tiles across those different zoom levels. ← test your understanding of slippy map tiles by checking that you understand what I mean here but feel free to message me for the answer!

- Change the Unet's encoder to a bigger or different architecture (i.e. resnet50, resnet101, densenet). 

- Change the combinations, weighting, and hyperparameters of the loss functions. Or implement completely new loss functions like [Lovasz Loss](https://github.com/bermanmaxim/LovaszSoftmax).

- Try different data augmentation combinations and techniques.

- Train for more epochs and with different learning rate schedules. Try [mixed-precision](https://docs.fast.ai/callbacks.fp16.html) for faster model training. 

- Your idea here.

I look forward to seeing what you discover!

# Coming Up

If you liked this tutorial, look forward to next ones which will potentially cover topics like:
- classifying building completeness (foundation, incomplete, complete)
- inference on multiple tiles and much larger images
- working with messy, sparse, imperfect training data
- model deployment and inference at scale
- examining data/model biases, considerations of fairness, accountability, transparency, and ethics

Curious about more geospatial deep learning topics? Did I miss something? Share your questions and thoughts in the [Medium post](https://medium.com/@anthropoco/how-to-segment-buildings-on-drone-imagery-with-fast-ai-cloud-native-geodata-tools-ae249612c321?source=friends_link&sk=57b82002ac47724ecf9a2aaa98de994b) so I can add them into this and next tutorials. 

Good luck and happy deep learning!


# Acknowledgments and Special Thanks to

- [World Bank GFDRR](https://www.gfdrr.org/en)'s Open Data for Resilience
Initiative ([OpenDRI](https://opendri.org/)) for consultation projects which have inspired & informed.
- [Zanzibar Mapping Initiative](http://www.zanzibarmapping.com/), [OpenAerialMap](https://openaerialmap.org/), State University of Zanzibar ([SUZA](https://www.suza.ac.tz/)), Govt of Zanzibar's Commission for Lands, & [WeRobotics](https://werobotics.org/) for the [2018 Open AI Tanzania Building Footprint Segmentation Challenge](https://competitions.codalab.org/competitions/20100).
- [Fast.ai team](https://www.fast.ai/about/), [contributors](https://github.com/fastai/fastai/graphs/contributors), & [community](https://forums.fast.ai/) for both "making neural nets uncool again" and pushing its cutting edge (very cool).
- [SpaceNet](https://spacenet.ai/) & [Cosmiq Works](http://www.cosmiqworks.org/) for the open challenges, datasets, knowledge-sharing, [Solaris geoML toolkit](https://github.com/CosmiQ/solaris), & more that advance geospatial machine learning.
- Contributors to [COG](https://www.cogeo.org/), [STAC](https://stacspec.org/), and more initiatives advancing the [cloud native geospatial](https://medium.com/planet-stories/tagged/cloud-native-geospatial) ecosystem.
- [Free & open source](https://en.wikipedia.org/wiki/Free_and_open-source_software) creators & collaborators everywhere for the invaluable public goods you provide.

# Notebook Changelog

## v1 (2019-07-25):
-------
New: 1st public release

Changed: n/a

Fixed: n/a