# Notebook 4: Fine-tuning a SpaceNet pre-trained model with __Solaris__

This notebook is developed for the FOSS4G International 2019 `solaris` Workshop. If you're using it outside of that context, some of the working environment materials will be unavailable. Check the GitHub repo for instructions on how to alter the notebooks for usage outside of the workshop.
 
## Summary

This notebook shows how to take a pre-trained SpaceNet Challenge-winning model and fine-tune it to work on a new imagery dataset. Note that this task requires a fair amount of computational oomph, and will be very slow without access to a GPU.

This notebook is split into 3 parts:

1. [__Checking model performance on a new dataset__](#Checking-model-performance-on-a-new-dataset)
    1. [Checking performance on the original input imagery](#Checking-performance-on-the-original-input-imagery)
    2. [Calculating dataset mean and standard deviation](#Calculating-dataset-mean-and-standard-deviation)
    3. [Re-writing the YAML config file for a new experiment](#Re-writing-the-YAML-config-file-for-a-new-experiment)
    4. [Evaluating prediction quality on Khartoum data](#Evaluating-prediction-quality-on-Khartoum-data)
2. [__Fine-tuning the model__](#fine-tuning-the-model)
    1. [Creating training masks](#Creating-training-masks)
    2. [Building the config file](#Building-the-config-file)
    3. [Model training](#Model-training)
    4. [Predictions with the new model](#Predictions-with-the-new-model)
3. [__Scoring model performance after fine-tuning__](#Scoring-model-performance-after-fine-tuning)

## Checking model performance on a new dataset

When a model is trained on imagery from one geography (or even a small set of geographies), it may not _"generalize"_ well, i.e. it may perform poorly on previously unseen geographies. Let's test that out with the [Khartoum AOI from the SpaceNet Dataset](https://spacenet.ai/spacenet-buildings-dataset-v2/).

We'll check to see how well the model trained on Atlanta data performs when we test on this image of Khartoum, Sudan from the [SpaceNet 2: Building Footprint Extraction Challenge](https://spacenet.ai/spacenet-buildings-dataset-v2/):

<img src="files/khartoum_infer_for_viz.png">

First, let's see how inference performs on the untouched, raw image:

### Checking performance on the original input imagery

We'll run inference just as we did previously, but using a config file that points to the Khartoum imagery instead of the MVOI (Atlanta) imagery. Note that we're not doing _any_ normalization here - we're just going to put the raw image in and see what happens.


In [None]:
import solaris as sol
import numpy as np
import os
import matplotlib.pyplot as plt
import time
import skimage
import geopandas as gpd
from shapely.ops import cascaded_union  # just for visualization purposes


data_path = '/data'   # NON-WORKSHOP PARTICIPANTS: change this path to point to the directory where you've stored the data.

print('Loading config...')
config = sol.utils.config.parse(os.path.join(data_path, 'workshop_configs/xdxd_workshop_khartoum_infer_raw.yml'))
print('config loaded. Initializing model...')
xdxd_inferer = sol.nets.infer.Inferer(config)
print('model initialized. Loading dataset...')
inf_df = sol.nets.infer.get_infer_df(config)
print('dataset loaded. Running inference on the image.')
start_time = time.time()
xdxd_inferer(inf_df)
end_time = time.time()
print('running inference on one image took {} seconds'.format(end_time-start_time))
print('vectorizing output...')
resulting_preds = skimage.io.imread('xdxd_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
predicted_footprints = sol.vector.mask.mask_to_poly_geojson(
    pred_arr=resulting_preds,
    reference_im=inf_df.loc[0, 'image'],
    do_transform=True,
    min_area=1e-10)  # need min_area=0 since the coord system is lat/long rather than UTM (metric)
print('output vectorized.')
predicted_footprints.to_file('xdxd_inference_out/Khartoum_img924_raw.geojson', driver='GeoJSON')

In [None]:
src_im_path = os.path.join(data_path, 'Khartoum_data/RGB_imagery/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
# read the image in
im_arr = skimage.io.imread(src_im_path)
# rescale to min/max in each channel
im_arr = im_arr.astype('float') - np.amin(im_arr, axis=(0,1))
im_arr = im_arr/np.amax(im_arr, axis=(0,1))
im_arr = (im_arr*255).astype('uint8')
# generate mask from the predictions
pred_arr = skimage.io.imread('xdxd_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
preds = (pred_arr[:, :, 0] > 0).astype('uint8')
ground_truth = sol.vector.mask.footprint_mask(
    os.path.join(data_path, 'Khartoum_data/geojson/buildings_AOI_5_Khartoum_img924.geojson'),
    reference_im=src_im_path)

f, axarr = plt.subplots(1, 3, figsize=(16,12))
axarr[0].imshow(im_arr)
axarr[0].set_title('Source image', size=14)
axarr[1].imshow(preds, cmap='gray')
axarr[1].set_title('Predictions', size=14)
axarr[2].imshow(ground_truth, cmap='gray')
axarr[2].set_title('Ground Truth', size=14);

Clearly those results are garbage.

This shows how important it is to make sure your inference target imagery is _normalized the same way your training data was_ when you pass it into a neural network. If it's not, the network has no idea what to do with the values it sees in the array!

When you passed imagery from MVOI into the neural net, it was _Z-scored_ - that is, the mean of the pixel intensities was set to zero and each band's standard deviation was set to 1. By contrast, the Khartoum image that you just fed in was 16-bit integers. This explains why the model performed poorly!

Next we'll go through how to Z-score the Khartoum imagery, then re-do the inference with that normalization.

### Calculating dataset mean and standard deviation
First, it's important to ensure that the data from Khartoum is normalized the same way as the data from Atlanta, as differences in intensity will propogate through the entire network, disrupting model performance. The Atlanta data is Z-scored, so we will do the same for Khartoum; to this end, we need to calculate the mean and standard deviation for each channel in the Khartoum dataset.

In [None]:
ims = [f for f in os.listdir(os.path.join(data_path, 'Khartoum_data/RGB_imagery'))]

R_cts = np.zeros(shape=(199,), dtype='uint32')
G_cts = np.zeros(shape=(199,), dtype='uint32')
B_cts = np.zeros(shape=(199,), dtype='uint32')
bins = np.arange(0, 2000, 10)

for idx, im in enumerate(ims):
    curr_im = sol.utils.io.imread(os.path.join(data_path, 'Khartoum_data', 'RGB_imagery', im))
    R_cts += np.array(np.histogram(curr_im[:, :, 0], bins=bins)[0], dtype='uint32')
    G_cts += np.array(np.histogram(curr_im[:, :, 1], bins=bins)[0], dtype='uint32')
    B_cts += np.array(np.histogram(curr_im[:, :, 2], bins=bins)[0], dtype='uint32')
    if idx%100 == 0:
        print("# {} of {} completed".format(idx, len(ims)))

_The above cell takes a couple of minutes - be patient!_

Let's look at the histogram of values for the channels. We're going to skip the first bin, as this is almost exclusively made up of `0` values. Zeros correspond to no data in these images.

In [None]:
f, ax = plt.subplots()
ax.plot(bins[1:-1] + 5, R_cts[1:], label='Red', color='red')
ax.plot(bins[1:-1] + 5, G_cts[1:], label='Green', color='green')
ax.plot(bins[1:-1] + 5, B_cts[1:], label='Blue', color='blue')
ax.legend(loc='upper right')
ax.set_xlabel('Intensity value', size=16)
ax.set_ylabel('Counts', size=16)
ax.set_title('Counts of pixel intensities for Khartoum data,\nsplit by channel',
             size=16);


In the next cell we'll calculate the mean and standard deviation to normalize these intensities.

In [None]:
def mean_and_std_from_histogram(bins, cts):
    """Calculate the mean and standard deviation from a histogram."""
    bin_centers = bins[1:-1] + ((bins[1]-bins[0])/2.)
    # skip the first bin since it contains the nodata values
    mean = np.sum(cts[1:]*bin_centers)/np.sum(cts[1:])
    std = np.sqrt((1./sum(cts[1:]))*np.sum(cts[1:]*np.square(bin_centers-mean)))
    return mean, std

r_mean, r_std = mean_and_std_from_histogram(bins, R_cts)
print("Red mean: {}".format(r_mean))
print("Red standard deviation: {}".format(r_std))
g_mean, g_std = mean_and_std_from_histogram(bins, G_cts)
print("Green mean: {}".format(g_mean))
print("Green standard deviation: {}".format(g_std))
b_mean, b_std = mean_and_std_from_histogram(bins, B_cts)
print("Blue mean: {}".format(b_mean))
print("Blue standard deviation: {}".format(b_std))

Because the [`albumentations`](https://albumentations.readthedocs.io/en/latest/index.html) library used in `solaris`  divides pixel intensity by the bit depth before performing normalization, we need to divide these by 65535 (the unsigned 16-bit max) for use as parameters in the pipeline.

In [None]:
print("r_mean for config file: {}".format(r_mean/65535))
print("g_mean for config file: {}".format(g_mean/65535))
print("b_mean for config file: {}".format(b_mean/65535))
print("r_std for config file: {}".format(r_std/65535))
print("g_std for config file: {}".format(g_std/65535))
print("b_std for config file: {}".format(b_std/65535))

These values (to a few decimal places) should be used in the config file.

### Re-writing the YAML config file for a new experiment

There are three other changes that need to be made to the original config files:
1. Remove the `DropChannel` pre-processing step: unlike the Atlanta dataset, these image files only have three channels. We therefore don't need to drop a 4th channel.
2. `SwapChannels`: The MVOI Atlanta dataset is B-G-R channel order, but Khartoum is R-G-B. We therefore need to use the `SwapChannels` pre-processing step to switch the channels at index `0` and `2`. Because we're using PyTorch models, these channels will be at axis 1.
3. `inference_data_csv`: because we're fine-tuning on different training data, we'll need to point to a CSV specifying different data. That CSV, `khartoum_inf.csv`, can be found in the `workshop_configs` directory.

Feel free to try to create this yourself from a copy of the `xdxd_workshop_infer.yml` file. Alternatively, we've provided the file for you as `xdxd_workshop_khartoum_infer.yml`.

__Let's try it out!__

### Evaluating prediction quality on Khartoum data

In [None]:
print('Loading config...')
config = sol.utils.config.parse(os.path.join(data_path, 'workshop_configs/xdxd_workshop_khartoum_infer.yml'))
print('config loaded. Initializing model...')
xdxd_inferer = sol.nets.infer.Inferer(config)
print('model initialized. Loading dataset...')
inf_df = sol.nets.infer.get_infer_df(config)
print('dataset loaded. Running inference on the image.')
start_time = time.time()
xdxd_inferer(inf_df)
end_time = time.time()
print('running inference on one image took {} seconds'.format(end_time-start_time))
print('vectorizing output...')
resulting_preds = skimage.io.imread('xdxd_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
predicted_footprints = sol.vector.mask.mask_to_poly_geojson(
    pred_arr=resulting_preds,
    reference_im=inf_df.loc[0, 'image'],
    do_transform=True,
    min_area=1e-10)  # need min_area=0 since the coord system is lat/long rather than UTM (metric)
print('output vectorized.')
predicted_footprints.to_file('xdxd_inference_out/Khartoum_img924.geojson', driver='GeoJSON')

In [None]:
src_im_path = os.path.join(data_path, 'Khartoum_data/RGB_imagery/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
# read the image in
im_arr = skimage.io.imread(src_im_path)
# rescale to min/max in each channel
im_arr = im_arr.astype('float') - np.amin(im_arr, axis=(0,1))
im_arr = im_arr/np.amax(im_arr, axis=(0,1))
im_arr = (im_arr*255).astype('uint8')
# generate mask from the predictions
pred_arr = skimage.io.imread('xdxd_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
preds = (pred_arr[:, :, 0] > 0).astype('uint8')
ground_truth = sol.vector.mask.footprint_mask(
    os.path.join(data_path, 'Khartoum_data/geojson/buildings_AOI_5_Khartoum_img924.geojson'),
    reference_im=src_im_path)

f, axarr = plt.subplots(1, 3, figsize=(16,12))
axarr[0].imshow(im_arr)
axarr[0].set_title('Source image', size=14)
axarr[1].imshow(preds, cmap='gray')
axarr[1].set_title('Predictions', size=14)
axarr[2].imshow(ground_truth, cmap='gray')
axarr[2].set_title('Ground Truth', size=14)

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(12,4))
axarr[0].imshow(pred_arr[:, :, 0], cmap='gray')
axarr[0].axis('off')
axarr[0].set_title('Raw predictions', size=16)
axarr[1].hist(pred_arr.flatten(), bins=25, density=True)
axarr[1].set_xlabel('Raw confidence', size=14)
axarr[1].set_ylabel('Fraction of pixels', size=14)
axarr[1].set_title('Prediction histogram', size=16);

These predictions are clearly terrible - the model is only finding one of the buildings in this image. However, if we directly examine the prediction outputs, we'll see that we're not _too_ far from a good model - it's finding some buildings, just at such a low raw confidence value that it can't distinguish them from background. Remember that the model takes anything with a raw prediction > 0 as a building.

__Pause here to go through the CosmiQ_Solaris_Training_Intro slides!__

So, what can we do to improve model performance? Let's try fine-tuning!

## Fine-tuning the model

### Creating training masks
Before we can continue training a model, we need target masks: images that the model will learn to create during training. We'll follow [this tutorial](https://solaris.readthedocs.io/en/latest/tutorials/notebooks/api_masks_tutorial.html) to create masks. __Note for workshop participants:__ this cell won't work because the `/data` directory is read-only; we've made the training masks for you, but this cell shows how to do it.

In [None]:
mask_dir = os.path.join(data_path, 'Khartoum_data', 'training_masks')
geojson_dir = os.path.join(data_path, 'Khartoum_data', 'geojson')
im_dir = os.path.join(data_path, 'Khartoum_data', 'RGB_imagery')
geojson_list = [f for f in os.listdir(geojson_dir) if f.endswith('.geojson')]
im_list = [f for f in os.listdir(geojson_dir) if f.endswith('.tif')]
n_chips = len(geojson_list)

if not os.path.exists(mask_dir):
    os.mkdir(mask_dir)
    
    for idx, gj in enumerate(geojson_list):
        # get the 'img[number] chip ID for the image'
        chip_id = os.path.splitext(gj)[0].split('_')[-1]
        matching_im = 'RGB-PanSharpen_AOI_5_Khartoum_' + chip_id + '.tif'
        mask_fname = 'mask_' + chip_id + '.tif'
        fp_mask = sol.vector.mask.footprint_mask(df=os.path.join(geojson_dir, gj),
                                                 out_file=os.path.join(mask_dir, mask_fname),
                                                 reference_im=os.path.join(im_dir, matching_im),
                                                 shape=(650, 650))
        if (idx+1)%100 == 0:
            print('chip {} of {} done'.format(idx+1, n_chips), flush=True)

Let's look at one of these just to make sure they came out right:

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(10, 6))
axarr[0].imshow(skimage.io.imread('files/khartoum_infer_for_viz.tif'))
axarr[0].axis('off')
axarr[1].imshow(skimage.io.imread(os.path.join(data_path, 'Khartoum_data', 'training_masks', 'mask_img924.tif')),
                cmap='gray')
axarr[1].axis('off')

Looks good! We're ready to set up for training.

### Building the config file

With model fine-tuning, we'll load the pre-trained weights used above, and continue training at a much lower learning rate for a couple of epochs. To this end we'll need _another_ config with a few more modifications:

1. A reduced learning rate - we'll try `1e-5` instead of `1e-4`
2. Change `train=False` to `train=True`
3. Specify where the newly trained versions are saved with the `training['callbacks']['model_checkpoint']` arguments
4. Specify a training data CSV. In this case, we'll use a CSV created [per this tutorial](https://solaris.readthedocs.io/en/latest/tutorials/notebooks/creating_im_reference_csvs.html) that points to all of the images and the masks that we just created, save for one: the image that we inferenced against earlier, which we'll save as a test image. The csv, named `khartoum_fine_tune.csv`, is available in the `workshop_configs` directory.

As earlier, feel free to create this config yourself; otherwise, you can use `xdxd_workshop_khartoum_train.yml`.

### Model training

Let's try it! <font style="color: red;">__WARNING: this is EXTREMELY slow without a GPU (each epoch may take several hours).__</font>

In [None]:
print('Loading config...')
config = sol.utils.config.parse(os.path.join(data_path, 'workshop_configs/xdxd_workshop_khartoum_train.yml'))
print('config loaded. Initializing Trainer instance...')
xdxd_trainer = sol.nets.train.Trainer(config)
print('model initialized. Beginning training...')
print()
start_time = time.time()
xdxd_trainer.train()
end_time = time.time()
print()
print('training took {} minutes'.format((end_time-start_time)/60))


### Predictions with the new model

We'll now run inference with the newly tuned model. Note that if your config file specifies `train=True` and you pass that config to an `Inferer` instance, `solaris` will automatically use the newly trained model for inference.

In [None]:
print('Loading config...')
config = sol.utils.config.parse(os.path.join(data_path, 'workshop_configs/xdxd_workshop_khartoum_train.yml'))
print('config loaded. Initializing model...')
xdxd_inferer = sol.nets.infer.Inferer(config)
print('model initialized. Loading dataset...')
inf_df = sol.nets.infer.get_infer_df(config)
print('dataset loaded. Running inference on the image.')
start_time = time.time()
xdxd_inferer(inf_df)
end_time = time.time()
print('running inference on one image took {} seconds'.format(end_time-start_time))
print('vectorizing output...')
resulting_preds = skimage.io.imread('xdxd_retrain_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
predicted_footprints = sol.vector.mask.mask_to_poly_geojson(
    pred_arr=resulting_preds,
    reference_im=inf_df.loc[0, 'image'],
    do_transform=True,
    min_area=1e-10)  # need min_area=0 since the coord system is lat/long rather than UTM (metric)
print('output vectorized.')
predicted_footprints.to_file('xdxd_retrain_inference_out/Khartoum_img924.geojson', driver='GeoJSON')

In [None]:
src_im_path = os.path.join(data_path, 'Khartoum_data/RGB_imagery/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
# read the image in
im_arr = skimage.io.imread(src_im_path)
# rescale to min/max in each channel
im_arr = im_arr.astype('float') - np.amin(im_arr, axis=(0,1))
im_arr = im_arr/np.amax(im_arr, axis=(0,1))
im_arr = (im_arr*255).astype('uint8')
# generate mask from the predictions
old_pred_arr = skimage.io.imread('xdxd_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
old_preds = (old_pred_arr[:, :, 0] > 0).astype('uint8')
new_pred_arr = skimage.io.imread('xdxd_retrain_inference_out/RGB-PanSharpen_AOI_5_Khartoum_img924.tif')
new_preds = (new_pred_arr[:, :, 0] > 0).astype('uint8')

ground_truth = sol.vector.mask.footprint_mask(
    os.path.join(data_path, 'Khartoum_data/geojson/buildings_AOI_5_Khartoum_img924.geojson'),
    reference_im=src_im_path)

f, axarr = plt.subplots(2, 2, figsize=(12, 8))
axarr[0, 0].imshow(im_arr)
axarr[0, 0].set_title('Source image', size=14)
axarr[0, 0].axis('off')
axarr[0, 1].imshow(old_preds, cmap='gray')
axarr[0, 1].set_title('Predictions before fine-tuning', size=14)
axarr[0, 1].axis('off')
axarr[1, 1].imshow(new_preds, cmap='gray')
axarr[1, 1].set_title('Predictions after fine-tuning', size=14)
axarr[1, 1].axis('off')
axarr[1, 0].imshow(ground_truth, cmap='gray')
axarr[1, 0].set_title('Ground Truth', size=14)
axarr[1, 0].axis('off');

Wow. This appears to show a _marked_ improvement with _just two epochs of training!_ How do the scores come out?

## Scoring model performance after fine-tuning

In [None]:
evaluator = sol.eval.base.Evaluator(os.path.join(data_path, 'Khartoum_data/geojson/buildings_AOI_5_Khartoum_img924.geojson'))
prediction_dirs = ['xdxd_inference_out', 'xdxd_retrain_inference_out']
model_names = ['Original', 'Fine-tuned']

f1_scores = []
precision = []
recall = []
for i in range(2):
    evaluator.load_proposal(os.path.join(prediction_dirs[i],'Khartoum_img924.geojson'),
                            pred_row_geo_value='geometry',
                            conf_field_list=[])
    results = evaluator.eval_iou(miniou=0.5, calculate_class_scores=False)
    f1_scores.append(results[0]['F1Score'])
    precision.append(results[0]['Precision'])
    recall.append(results[0]['Recall'])

f, axarr = plt.subplots(1, 3, figsize=(10, 4))
f.subplots_adjust(wspace=0.6)
axarr[0].bar(model_names, f1_scores)
axarr[0].set_ylabel('$F_1$ Score', size=16)
axarr[1].bar(model_names, precision)
axarr[1].set_ylabel('Precision', size=16)
axarr[2].bar(model_names, recall)
axarr[2].set_ylabel('Recall', size=16);
f.suptitle('Comparison of original vs. fine-tuned model performance', size=16);

Clearly, this is only _one_ sample image; however, it's noteworthy that this model briefly fine-tuned on Khartoum imagery [__achieved a higher score here than some of the prize-winning models trained on Khartoum for days during the SpaceNet Challenge Round 2__](https://medium.com/the-downlinq/2nd-spacenet-competition-winners-code-release-c7473eea7c11).

# Congratulations! You've completed the FOSS4G 2019 Solaris tutorial.

Hang around for a quick teaser on the SpaceNet 5 challenge that's starting soon!

## What's next?

Here are a few more resources that will help you as you continue to work with `solaris`:

- [Solaris documentation](https://solaris.readthedocs.io)
- [A blog post from Jake Shermeyer about using Solaris for car detection in the COWC dataset](https://medium.com/the-downlinq/beyond-infrastructure-mapping-finding-vehicles-with-solaris-11e08da0dab)
- [The Solaris GitHub repository](https://github.com/cosmiq/solaris)