![NASA logo](https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/NASA_logo.svg/110px-NASA_logo.svg.png) ![IBM Research logo](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSwHxsDwxcOHsQUD2pghQ32j90pzsZLcOujpGCyU1yE&s)

# Geospatial Foundation Model: Flood mapping fine-tuning

This is an example of how to fine-tune a model for flood mapping from HLS data using the IBM Geospatial Foundation models as a starting point.  

To run a fine-tuning experiment for flood mapping we will use the MMSegmentation library (https://github.ibm.com/GeoFM-Finetuning/mmsegmentation) to fine-tune a model starting from the geospatial foundation model trained on HLS data.

The following notebook assumes that you project files are placed in folder on the shared volume in the following folder structure:
```
configs                   - folder to place experiment configuration files
fine-tune-checkpoints     - folder where training outputs will be generated
GFM-Models                - folder containing the checkpoint files from the pre-trained GFM
inference                 - folder where we will carry out our inference tasks
training_data             - folder containing the training dataset (including labels and test/train splits etc)
```

You then create you configuration script, before submitting to the cluster to run.  The notebook will then guide you to: 
* monitor and visualise the training, 
* run the test tasks
* use the trained model for local inference.


In [None]:
import json
import pandas as pd
import glob
import matplotlib.pyplot as plt
import os
import subprocess
from pprint import pprint
from dotenv import load_dotenv
import datetime
import string
import sys

sys.path.append('../')
import geoft

# Load environment variables
load_dotenv()

# Grab cluster details
login_url, namespace, path_to_shared_volume = geoft.get_cluster_details()

# Create S3 client (for pulling data and model weights)
aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
aws_access_key_secret = os.getenv("AWS_ACCESS_KEY_SECRET")

s3 = geoft.create_s3_client(aws_access_key_id, aws_access_key_secret)

# S3 bucket where data and model weights reside
bucket_name = "nasa-gfm-summer-school"

In [None]:
#------- Define the project name you wish to use

project_name = "flood"

## Connect to cluster for task submission

The first thing you need to do in order to submit a training job to the cluster is login to the cluster.  This will only need to be done once per 24 hours.

Run the cell below (`login_url`), and click on the generated url.

Authenticate, then copy and paste the `oc login` command into the cell below (with `%%sh` at the top) and this will log you in to the cluster and allow you submit and monitor jobs.

In [None]:
login_url

In [None]:
%%sh
# <Paste oc command here>

## Project setup

If we are starting a new fine-tuning project, we can create a new set of folders and download the training data+labels and the pre-train foundation model weights.  We create the folder structure described above, then pull the data and weights from an S3 bucket.

In [None]:
#------- Create project folder structure 
geoft.create_project_folders(project_name)

In [None]:
#------- Download the pre-trained model weights
model_name = 'epoch-916-loss-0.0779.pt' # best for flood mapping
s3.download_file(bucket_name, 'gfm-models/' + model_name, path_to_shared_volume + project_name + '/gfm-models/' + model_name)


In [None]:
#------- Download the training data

dataset = 'sen1floods11'

training_data_path = path_to_shared_volume + project_name + '/training-data/'

# Download training data
subfolder = 'dataset/S2Hand/'
geoft.download_s3_dir(dataset + '/' + subfolder, training_data_path, bucket_name, client=s3)

# Download training data labels
subfolder = 'dataset/LabelHand/'
geoft.download_s3_dir(dataset + '/' + subfolder, training_data_path, bucket_name, client=s3)

# Download training data splits
subfolder = 'data_splits/'
geoft.download_s3_dir(dataset + '/' + subfolder, training_data_path, bucket_name, client=s3)

## Creating fine-tuning configuration

![Fine-tune architecture](../images/finetune_arch.png)

### Brief introduction to the hyperparameters we will adapt as part of this session
**Loss function**:
    Both tasks we will solve as part of this exercise are binary semantic segmentation task (e.g., pixelwise classification of flood vs. background). There will be two available loss functions for the task:
| Loss functions | Description | Code |
| -------------- | ----------- | ---- |
| CrossEntropyLoss | is sensitive to class imbalance but very general and a good choice for an initial training | `type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1, class_weight=[0.15, 0.85, 0]` |
| DiceLoss | is invariant against class imbalance but tends to be more sensitive to other hyperparameters | `type='DiceLoss', use_sigmoid=False, loss_weight=1` |

**Weighting classes** in the loss function:
    As described above some loss functions (like CE loss) are sensitive to class imbalance. We can counter class imbalance by weighting the classes in the loss. For example, for flood mapping ~5-10% of the pixels represent parts of flood events while the rest is background. To meet the class imbalance, we can set the class weight of the flood class to, e.g., 90%, while the background class will be assigned a class weight of 10%. The data contains a third class for invalid pixels (including for example cloud coverage). The class weight for this class will be set to zero.
    
Cross Entropy Loss options:
| Weight water class | Weight land class | Weight cloud class | Code | 
| ------------------ | ----------------- | ------------------ | ---- |
| 0.7 | 0.3 | 0.0 | `[0.7, 0.3, 0]` |
| 0.9 | 0.1 | 0.0 | `[0.9, 0.1, 0]` |
<!--     * cross entropy loss with weight water class = 0.7, weight land class = 0.3, weight cloud class = 0.0,
    * cross entropy loss with weight water class = 0.9, weight land class = 0.1, weight cloud class = 0.0
     -->
**Batch size**: Defines the number of training examples in a single forward/backward pass (i.e., how many images the model sees as part of one iteration). Batch sizes are typically powers of 2 to facilitate GPU computations. Options in this exercise are: `2`, `4`, `8`, `16`, `32`
<!--     * 2
    * 4
    * 8
    * 16
    * 32
 -->
**Learning rate**: Defines how much we want the model to change in response to the estimated error each time the model weights are updated. Options: `6e-4`, `6e-5`, `6e-6`

**Auxiliary head**: To stabilize the finetuning process, the model not only includes an encoder and a decode head for segmentation, but also an auxiliary head. This part of the architecture helps to make the model more robust during finetuning. You can add and remove the auxiliary head using the boolean option: `aux_head=True`, `aux_head=False`

**Depth of the decoder**: Generally, the decoder is quite light-weight compared to the GeoFM encoder. A default choice would be one or two layers of convolutions. Increasing this value will result in more parameters that the model can leverage to adapt to the downstream task -- at the cost of heavier computations (finetuning will take more time!). Options: `decode_head_conv = 1`, `decode_head_conv = 2`

**Number of epochs**: Deep neural networks are typically require a certain number of epochs to converge. For example, in our experiments, we observed that the finetuning for flood mapping achieves a desirable level of fitness after ~40-50 epochs. *Please do not extend the number of epochs to more than 50 epochs to have a managable time for computations. :-)*

### Setting up your experiment
Now to set up your experiment populate the dictionary below with the options you wish to chose (based on the description above and discussions).  Don't edit the `gfm_ckpt`, `num_epochs`, `number_training_files` or `project_name`.

You generate your config which places the options you have chosen into a configuration file, which you can then view using the next cell.


In [None]:
conf = {'gfm_ckpt': 'epoch-916-loss-0.0779.pt',
        'loss_function': '''type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1, class_weight=[0.15, 0.85, 0]''',
        'batch_size': '2',
        'learning_rate': '6e-5',
        'aux_head': 'True',
        'decode_head_conv': '1',
        'num_epochs': 40,
        'number_training_files': 252,
        'project_name': project_name}


In [None]:
experiment_name, experiment_filepath = geoft.generate_config(project_name, conf, "flood_config.py.template")

In [None]:
geoft.view_config(experiment_filepath)


## Submitting fine-tuning job to run
Now we have the configuration script ready, we can just send it to the cluster to run.  The next cell will submit the job to the cluster using TorchX.  This will spin up a now pod/container where the fine-tuning will run.


In [None]:
mcad_id = geoft.submit_tune(project_name,
                namespace,
                experiment_name,
                image='quay.io/bedwards-ibm/mmsegmentation-geo:latest',
                num_gpus=1,
                memory_mb=14000)

print(mcad_id)

## Monitoring training job
Once you have submitted the job to the cluster, we can monitor it using the following commands.


In [None]:
%%sh
torchx list -s kubernetes_mcad

In [None]:
check_log_cmd = '''torchx log ''' + str(mcad_id) +  ''' | tail -n20'''
os.system(check_log_cmd)


## Viewing the training metrics

Now that we have run (or at least are running) the experiment, we can view the training metrics.  To do this we will load the log file and extract the metrics to a dataframe (`val_df`).

In [None]:
train_df, val_df = geoft.load_tune_metrics(project_name, experiment_name)

In [None]:
plt.figure().set_figwidth(15)
plt.subplot(1, 2, 1)
plt.plot(train_df.index, train_df.loss, '-r');
plt.ylabel('Training Loss');
# plt.yscale('log')

plt.subplot(1, 2, 2)
plt.plot(train_df.index, train_df.loss_val, '-b');
plt.ylabel('Validation Loss');

## Test output model

In [None]:
test_mcad_id = geoft.submit_test(project_name,
                        namespace,
                        experiment_name,
                        checkpoint='latest.pth',
                        num_gpus=1,
                        memory_mb=8000)

In [None]:
check_log_cmd = '''torchx log ''' + str(test_mcad_id) +  ''' | tail -n20'''
os.system(check_log_cmd)

In [None]:
test_metrics = geoft.get_test_metrics(project_name, experiment_name)

## Running inference using the trained model

Once we have a trained model, we can use it to run inference on other images.

In [None]:
!pip install rasterio folium morecantile tqdm

In [None]:
# import required libraries
import json
import morecantile

import requests

from tqdm import tqdm

In [None]:
# Event details of known locations with date time, boundingbox, collection name, and bands used.
EVENT_DETAILS = {
    'pakistan_flood': {
        'start_date': '2022-06-12',
        'bounding_box': [-89.974358, 33.428227, -89.863122, 33.355992],
        'collection': 'HLSL30',
        'bands': [1, 2, 3, 8]
    },
    'mongolian_fire': {
        'start_date': '2022-04-19',
        'bounding_box': [119.147681, 47.030565, 119.047681, 46.830565],
        'collection': 'HLSL30',
        'bands': [1, 2, 3, 4, 5, 6]
    },
    'new_mexico_black_fire': {
        'start_date': '2020-05-16T00:00:00Z',
        'bounding_box': [-107.951695, 33.326903, -107.651695, 33.126903],
        'collection': 'HLSS30',
        'bands': [1, 2, 3, 4, 5, 6]
    }
}

# URL to acquire tiles from
TILE_URL = "https://kv9drwgv6l.execute-api.us-west-2.amazonaws.com/mosaic/tiles/{search_id}/UTM31WGS84Quad/{z}/{x}/{y}@2x.tif?{assets}"
ZOOM_LEVEL = 12 # for resolution close to 30m

# Method to download tif files from tile server.
def prepare_tiles(project, event, search_id, selected_date):
    """
    Download the list of tiles specific to the bounding box, date, and project provided.
    ARGS:
        project (str): Name of the project (one of burn or sen1floods11)
        event (str): Event name (one of pakistan_flood, mongolian_fire, new_mexico_black_fire) 
        search_id (str): searchid retrieved from tiler endpoint
        selected_date (str): Selected date
    """
    assets_key = "assets={band}"
    assets = []
    for band in EVENT_DETAILS[event]['bands']:
        assets.append(assets_key.format(band="B%02d" % band))
    tms = morecantile.tms.get("UTM31WGS84Quad")
    tiles = tms.tiles(*EVENT_DETAILS[event]['bounding_box'], ZOOM_LEVEL)
    for tile in tqdm(tiles):
        tile_url = TILE_URL.format(
                search_id=search_id,
                x=tile.x,
                y=tile.y,
                z=tile.z,
                assets="&".join(assets)
            )
        tile_response = requests.get(
            tile_url
        )
        filename = "/opt/app-root/src/data/{project}/inference/{event}-{selected_date}-x{x}_y{y}_z{z}.tif".format(
            project=project,
            event=event,
            selected_date=selected_date, 
            x=tile.x,
            y=tile.y,
            z=tile.z
        )
        if tile_response.status_code == 200:
            print(tile_url)
            open(filename, 'wb').write(tile_response.content)
                                     


In [None]:
event = 'pakistan_flood'

selected_date = EVENT_DETAILS[event]['start_date']
selected_datetime_range = f"{selected_date}T00:00:00Z/{selected_date}T23:59:59Z"
selected_datetime_range

In [None]:
collection = EVENT_DETAILS[event]['collection']
collection

In [None]:
register_parameters = {
    "datetime": selected_datetime_range,
    "filter-lang": "cql-json",
    "collections": [collection]
}
register_parameters

In [None]:
# Get search ID for the selected parameters
response = requests.post(
    'https://d1nzvsko7rbono.cloudfront.net/mosaic/register', 
    data=json.dumps(register_parameters)
).json()
print(response)
search_id = response['searchid']

prepare_tiles(dataset, event, search_id, selected_date)


In [None]:
infer_mcad_id = geoft.submit_inference(project_name,
                namespace,
                experiment_name,
                checkpoint='latest.pth',
                image='quay.io/bedwards-ibm/mmsegmentation-geo:latest',
                num_gpus=1,
                memory_mb=8000)

## Visualizing the predicitons

In [None]:
import folium
import folium.plugins as plugins
import numpy as np
import rasterio
from rasterio.plot import show

def colorize(array, cmax, cmin=0, cmap="rainbow"):
    """Converts a 2D numpy array of values into an RGBA array given a colour map and range.
    Args:
        array (ndarray):
        cmax (float): Max value for colour range
        cmin (float): Min value for colour range
        cmap (string): Colour map to use (from matplotlib colourmaps)
    Returns:
            rgba_array (ndarray): 3D RGBA array which can be plotted.
    """
    normed_data = (array - cmin) / (array.max() - cmin)
    cm = plt.cm.get_cmap(cmap)
    return cm(normed_data)



In [None]:
inference_files = sorted(glob.glob('/opt/app-root/src/data/' + project_name + '/inference/*.tif'))
inference_files

In [None]:
filenum = 0
original_file = inference_files[filenum]
predict_file = inference_files[filenum].replace('/inference','/inference/pred/' + experiment_name).replace('.tif','_pred.tif')

# Load the original image layer
with rasterio.open(original_file) as src:
    redArray = src.read(1)
    greenArray = src.read(2)
    blueArray = src.read(3)
    bounds = src.bounds
    nd = src.nodata
    midLat = (bounds[3] + bounds[1]) / 2
    midLon = (bounds[2] + bounds[0]) / 2
    im_rgb = np.moveaxis(np.array([redArray,greenArray,blueArray]), 0, -1)/2048
    # im_rgb = im_rgb/np.max(im_rgb)


# Create the map
m = folium.Map(location=[midLat, midLon], tiles='openstreetmap', max_zoom=22)

# Add the prediciton layer to the map
with rasterio.open(predict_file) as src:
    dataArray = src.read(1)
    bounds = src.bounds
    nd = src.nodata

# cmax = np.max(dataArray)
cmax = 1000
dataArrayMasked = np.ma.masked_where(dataArray == nd, dataArray)
dataArrayMasked = np.ma.masked_where(dataArray == 0, dataArrayMasked)
imc = colorize(dataArrayMasked, cmax, cmin=0, cmap="viridis")

# Add the layers to the map
pred = folium.raster_layers.ImageOverlay(imc, [[bounds[1], bounds[0]], [bounds[3], bounds[2]]], name="Prediction", opacity=0.8)
orig = folium.raster_layers.ImageOverlay(im_rgb, [[bounds[1], bounds[0]], [bounds[3], bounds[2]]], name="Original image", opacity=1.0)

orig.add_to(m)
pred.add_to(m)

folium.LayerControl().add_to(m)
m.fit_bounds(bounds)

m