# SpaceNet Roads Challenge

This notebook walks through the training of a [tiramisu](https://arxiv.org/pdf/1611.09326.pdf) segmentation network to be a road detector on the SpaceNet AOI 2 (Las Vegas) data set.

I use a relatively small version of the tiramiusu (not the full 103-layer version preferred by that paper), with a batch size of 1 on relatively small (256x256) subimages, so as to fit this effort into the limited RAM of the GPU in my home desktop.

All the same, I think I acheive reasonable results, with the network training up to an accuracy of 92% or 93% on the validation set after 9 hours (and beginning to over-fit the training data after that - training data accuracy shown in the first plot and validation data accuracy shown in the second):

![Training accuracy](plots/training_accuracy_curve.png)
![Validation accuracy](plots/validation_accuracy_curve.png)

(Those plots are both captured from [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), aimed at the logfile produced by the `TensorBoard` callback used in training.)

I'll move forward with the network weights as they were around the peak of that validation accuracy curve and before it began to over-fit: **epoch 85**.

Here's what that network's performance looks like on a few random tiles from the validation data:

![performance](plots/val/run-01/epoch-085.png)

---------------------------------------------------------------------
**Note:** this training session cuts a few corners in important ways:
- I train the network to identify a 20-pixel wide stripe running through the middle of each road, not the full road width (since the input labels given are the centerlines, and not a full per-pixel segmentation).  This is rough at best, and very noisy: many roads are wider than this, and dilating the centerlines the way I have muddies up many intersections.
- I randomly sample tiles from within random images throughout training and validation: I do not ensure that every bit of training data has been seen by the network once for each epoch. Given enough training time, this works out just fine, but it makes the description of an "epoch" a little bit messy.

## imports

In [10]:
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, Callback
from keras.optimizers import RMSprop
from os import path, makedirs
from keras.models import load_model

In [2]:
from subprocess import run
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

In [4]:
from data_generator import SpacenetGenerator
import tiramisu

## data generators

In [5]:
tile_size = (256, 256)

traingen = SpacenetGenerator('../../data/train/', tile_size=tile_size)
valgen = SpacenetGenerator('../../data/val/', tile_size=tile_size)

## a visualization callback

(this can be used to generate plots showing network performance on the validation data during training)

In [6]:
class PlotCallback(Callback):
    
    def __init__(self, datagen, period, savedir, *args, **kwargs):
        self.datagen = datagen
        self.period = period
        self.savedir = savedir
        # a custom colormap that is transparent at low values
        cmap = plt.cm.copper
        my_cmap = cmap(np.arange(cmap.N))
        my_cmap[:, -1] = np.zeros_like(my_cmap[:, -1])
        my_cmap[-1:, -1] = 0.75
        my_cmap = ListedColormap(my_cmap)
        self.my_cmap = my_cmap
        
        super().__init__(*args, **kwargs)

    @staticmethod
    def _reshape_output(output_array, img_shape):
        _reshaped = np.argmax(output_array[0], axis=-1)
        _reshaped = _reshaped.reshape(img_shape[0], img_shape[1])
        return _reshaped
        
    def on_epoch_end(self, epoch, logs={}):
        """
        Plots an array of five tiles and the network output for them at this epoch
        """
        if not epoch % self.period:
            fig, axs = plt.subplots(5, 2, sharex=True, sharey=True, figsize=(8, 20))
            plt.subplots_adjust(hspace=0.05, wspace=0.05)
            for i in range(5):
                img, lab = next(self.datagen)
                pred = self.model.predict(img, batch_size=1)

                lab_reshaped = self._reshape_output(lab, img[0].shape)
                pred_reshaped = self._reshape_output(pred, img[0].shape)

                axs[i, 0].imshow(img[0], aspect='equal')
                axs[i, 0].imshow(lab_reshaped, vmin=0, vmax=1, cmap=self.my_cmap, aspect='equal')
                axs[i, 1].imshow(img[0], aspect='equal')
                axs[i, 1].imshow(pred_reshaped, vmin=0, vmax=1, cmap=self.my_cmap,aspect='equal')
                if i == 0:
                    axs[i, 0].set_title('ground truth')
                    axs[i, 1].set_title('predictions'.format(epoch))

            file_name = path.join(self.savedir, 'epoch-{:03d}.png'.format(epoch))
            plt.savefig(file_name)
            plt.show()

## put together the callbacks


In [7]:
this_run = 'run-01'

logdir = 'logs/'
snapdir = 'snapshots/'
valplotdir = 'plots/val'
trainplotdir = 'plots/train'
for directory in [logdir, snapdir, valplotdir, trainplotdir]:
    rundir = path.join(directory, this_run)
    run('mkdir -p {}'.format(rundir), shell=True)

callbacks = [
    ReduceLROnPlateau('loss', factor=0.2, verbose=1, patience=10, cooldown=5),
    TensorBoard(log_dir=path.join(logdir, this_run)),
    ModelCheckpoint(path.join(snapdir, this_run, 'modelweights.{epoch:02d}-{val_loss:.2f}.hdf5'),
                    save_weights_only=False, verbose=1, period=5),
    PlotCallback(valgen.random_generator(1), 5, path.join(valplotdir, this_run)),
]


## build the model

In [8]:
tiramisu.POOLING = (4, 4)
model = tiramisu.tiramisu(n_classes=2, input_shape=(tile_size[0], tile_size[1], 3),
                          blocks=[2, 3, 5, 7], bottleneck=9)

model.compile(loss='kld', optimizer=RMSprop(2e-4), metrics=["accuracy"])

## train

In [None]:
batch_size = 1

model.fit_generator(traingen.random_generator(batch_size),
                    steps_per_epoch=len(traingen.images) // batch_size,
                    epochs=500, verbose=1, callbacks=callbacks,
                    validation_data=valgen.random_generator(batch_size),
                    validation_steps=len(valgen.images) // batch_size,
                    use_multiprocessing=True, workers=4,
                   )