In [1]:
# ---imports---
import sys
import os
import math
import random
sys.path.append('.')
sys.path.append('..')

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning.pytorch as pl
import rasterio

from models.DeepLabV3_Lightning_ESRI_UrbanRural import DeepLabV3_Lightning_ESRI_UrbanRural

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib import patches as mpatches

### Codices

In [2]:
# --- Class Codices ---

# -ESRI LULC 2020 (9 classes): class codec
# 1(->0): Water
# 2(->1): Tree
# 4(->2): Flooded Vegetation
# 5(->3): Crops
# 7(->6): Built Area
# 8(->4): Bare Ground
# 9(->(-1)): Snow / Ice
# 10(->(-1)): Clouds
# 11(->5): Rangeland
# 12(->(-1)): Missing
esri_classes = [1, 2, 4, 5, 7, 8, 9, 10, 11, 12]
labels = [0, 1, 2, 3, 6, 4, -1, -1, 5, -1]

esri_class_to_index_map = np.zeros(max(esri_classes) + 1, dtype='int64')
esri_class_to_index_map[esri_classes] = labels


# ---GHS-SMOD 2020 (8 classes): class codec---
# 30(->7): URBAN CENTRE GRID CELL
# 23(->7): DENSE URBAN CLUSTER GRID CELL
# 22(->7): SEMI-DENSE URBAN CLUSTER GRID CELL
# 21(->7): SUBURBAN OR PERI-URBAN GRID CELL
# 13(->6): RURAL CLUSTER GRID CELL
# 12(->6): LOW DENSITY RURAL GRID CELL
# 11(->6): VERY LOW DENSITY RURAL GRID CELL
# 10(->6): WATER GRID CELL
# NoData [-Inf] -> 7
smod_classes = [10, 11, 12, 13, 21, 22, 23, 30]
labels = [6, 6, 6, 6, 7, 7, 7, 7]

smod_class_to_index_map = np.zeros(max(smod_classes) + 1, dtype='int64')
smod_class_to_index_map[smod_classes] = labels

In [3]:
colormap_classes_rgb = np.array([[0, 100, 200], [0, 100, 0], [0, 207, 117], 
                                 [240, 150, 255], [180, 180, 180], [255, 187, 34], 
                                 [255, 0,0], [80,0,0]]) / 255  # RGB values for each class, normalized to (0, 1)

class_labels = ['c1:Water', 'c2:Tree', 'c3:Flooded Vegetation', 
                'c4:Crops', 'c5:Bare Ground', 'c6:Rangeland', 'c7:Rural', 'c8:Urban']

cmap_classes = ListedColormap(colormap_classes_rgb, name='ESRI_WorldCover_2020_class_cmap', N=8)

### Preprocessing

In [4]:
def get_patches(images):
    '''read and preprocess images from list of image paths, return iterator'''
    
    for im_path in images:
        
        im = rasterio.open(im_path).read()
    
        # input preprocessing: normalize input bands to range (0.0, 1.0), mask missing data
        im[0:6] = (im[0:6] - 1) / (65455 - 1)  # bands SR_B2 to SR_B7: BGR, NIR, SWIR1, SWIR2 (1, 65455)
        nl_clipped = np.clip(im[6], a_min=-1.5, a_max=193565)
        im[6] = (nl_clipped + 1.5) / (193565 + 1.5)   # band avg_rad: VIIRS (-1.5, 193565)

        inf_mask = np.isinf(im[0])
        im[0:6, inf_mask] = 0

        im[1, np.isinf(im[1])] = 0
        im[2, np.isinf(im[2])] = 0

        yield im[:7]


In [5]:
def get_patches_sub(im):
    '''split each (1000,1000)-supertile into 36*(250,250)-subtiles'''
    
    # loop through each subtile
    for sub_tile_idx in range(36):
            
        # calculate offset
        r = (sub_tile_idx // 6) * 150
        c = (sub_tile_idx % 6) * 150

        sub_tile = im[:, r:r+250, c:c+250]

        # retrieve input
        image = sub_tile[:7]

        yield (image, _)

### Load Model

In [6]:
# load model from checkpoint
mimer = "/mimer/NOBACKUP/groups/globalpoverty1/albin_and_albin/scripts_and_notebooks/job_scripts"

# model 1
#ckpt = mimer + "/lightning_logs/deeplabv3_esri_urban_rural_12345/checkpoints/12345_epoch=19-step=126080-val_loss=0.440.ckpt"

# model 3
ckpt = mimer + "/lightning_logs/deeplabv3_esri_urban_rural_34512/checkpoints/34512_epoch=19-step=106240-val_loss=0.369.ckpt"

lightning_model = DeepLabV3_Lightning_ESRI_UrbanRural.load_from_checkpoint(ckpt, training_folds=['Algeria'], validation_fold=['Libya']).cuda()  # NOTE: values for training/validation have no effect

### Inference (+save output)

In [7]:
location = 'abidjan'

# ENTER IMAGE INDICES HERE
im_indices = [2755, 1229, 1991, 1440,
              872, 1252, 3078, 2118,
              224, 2336, 1560, 2303,
              135, 2921, 2072, 1565]
grid_dim = int(math.sqrt(len(im_indices)))

start_year = 2013
end_year = 2022

for year in range(end_year, start_year-1, -1) : 
    # make list of image paths
    im_paths = [f'/mimer/NOBACKUP/groups/globalpoverty1/albin_and_albin/{location}/input/tile_{str(im_idx)}.tif_{str(year)}' for im_idx in im_indices]
    
    # path to output
    out = f'/mimer/NOBACKUP/groups/globalpoverty1/albin_and_albin/{location}/output/{location}_{str(year)}.png'
    
    # iterator over tiles
    im_iterator = get_patches(im_paths)
    
    # create lists for holding the predictions
    predictions = []
    subpredictions = []

    # predict for each 1x1km tile
    for tile in im_iterator:

        # split the tile into smaller subtiles to smoothen prediction
        for idx, img in enumerate(get_patches_sub(tile)):

            in_data, labels = img
            in_data = torch.tensor(in_data).cuda()
            in_data = torch.unsqueeze(in_data, 0)

            lightning_model.eval()

            with torch.no_grad():
                outputs = lightning_model(in_data)
                class_probs = F.softmax(outputs, dim=1)
                class_predictions = torch.argmax(input=class_probs, dim=1)

                subpredictions.append(class_predictions[0])

        # assemble sub-tile predictions into supertile
        predictions_supertile = np.zeros((1000, 1000))
        for sub_tile_idx in range(36):

            sub_tile = subpredictions[sub_tile_idx].squeeze()

            # calculate offset
            r = (sub_tile_idx // 6) * 150
            c = (sub_tile_idx % 6) * 150

            # adjust row margins
            if r == 0:
                row_start_margin = 0
                row_end_margin = 200
            elif r == 750:
                row_start_margin = 50
                row_end_margin = 250
            else:
                row_start_margin = 50
                row_end_margin = 200

            # adjust col margins
            if c == 0:
                col_start_margin = 0
                col_end_margin = 200
            elif c == 750:
                col_start_margin = 50
                col_end_margin = 250
            else:
                col_start_margin = 50
                col_end_margin = 200

            # put sub-tile in correct position
            cut_subtile = sub_tile[row_start_margin:row_end_margin, col_start_margin:col_end_margin]
            predictions_supertile[r+row_start_margin: r+row_end_margin, c+col_start_margin : c+col_end_margin] = cut_subtile.cpu()
        
        # append prediction, and clear list of subpredictions 
        predictions.append(predictions_supertile)
        subpredictions.clear()
    
    # merge all 1000x1000px tiles in the grid into one bigger tile
    predictions_merged = np.zeros((grid_dim*1000, grid_dim*1000))
    for sub_tile_idx in range(len(im_indices)):

            # calculate offset
            r = (sub_tile_idx // grid_dim) * 1000
            c = (sub_tile_idx % grid_dim) * 1000

            predictions_merged[r:r+1000, c:c+1000] = predictions[sub_tile_idx]

    
    # save the output for current year
    plt.imsave(out , arr=predictions_merged, cmap=cmap_classes, vmin=0, vmax=len(colormap_classes_rgb))

