# Inference

## Load all required packages

In [63]:
import fastai
print(fastai.__version__)

2.7.9


In [64]:
from fastai.vision.all import *

In [65]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import rasterio
import rasterio.plot
import rioxarray as riox
from rioxarray.merge import merge_arrays
import json
from py_linq import Enumerable
import wget
import shutil
import multiprocessing as mp
from functools import partial

In [66]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [67]:
import importlib
import Helpers.MODIS8DaysHelper as mh
import Helpers.GEEHelpers as GEEHelpers
import Helpers.StaticFeaturesHelper as StaticFeaturesHelper

In [68]:
importlib.reload(mh)
importlib.reload(GEEHelpers)
importlib.reload(StaticFeaturesHelper)

<module 'Helpers.StaticFeaturesHelper' from '/home/jgiezendanner/UA/cvpr23-earthvision-CNN-LSTM-Inundation/Source/Helpers/StaticFeaturesHelper.py'>

In [69]:
from ModelClasses.Model import CNNLSTM as CNNLSTM

## Define Data Path and load data references

In [70]:
dataPath = Path('../../Data/InferenceData/Data')

In [71]:
lstmDF = pd.read_json(dataPath/'lstmFilesInference.json')

## Define number of time steps for LSTM

In [72]:
timeSteps = 10

## Define functions to access file path and open files

In [73]:
def getStaticFeaturesFromLabel(filePath):
    elevation = np.expand_dims(StaticFeaturesHelper.getScaledElevation(filePath.parent.parent/'Elevation'/('_'.join(filePath.stem.split('_')[0:2]) + '.tif')), 0)
    slopeFile = np.expand_dims(StaticFeaturesHelper.getScaledHAND(filePath.parent.parent/'Slope'/('_'.join(filePath.stem.split('_')[0:2]) + '.tif')), 0)
    hand = np.expand_dims(StaticFeaturesHelper.getSlope(filePath.parent.parent/'HAND'/('_'.join(filePath.stem.split('_')[0:2]) + '.tif')), 0)
    return np.concatenate((elevation, slopeFile, hand))

In [74]:
def getModisFileFromLabel(filePath):
    fileDir = filePath.parent
    return [fileDir/item for item in lstmDF[lstmDF.File == filePath.name].FeatureFiles.values[0]]

In [75]:
def readImage(file, bandsToUse):
    return mh.getScaledModisFileBands(file, bandsToUse)

In [76]:
def open_features(fn, chnls=None):
    bandsToUse = ['sur_refl_b03', 'sur_refl_b02', 'sur_refl_b01', 'sur_refl_b04', 'sur_refl_b05', 'sur_refl_b06', 'sur_refl_b07']
    files = getModisFileFromLabel(fn)[0:timeSteps]
    
    staticFeatures = getStaticFeaturesFromLabel(fn)
    
    img = np.empty((0,32,32))
    for file in files:
        try:
            newimg = readImage(file, bandsToUse)
        except:
            newimg = readImage(file, bandsToUse)
        
        newimg = np.concatenate((newimg, staticFeatures))
        img = np.concatenate((newimg, img))
    
    img = img.astype(np.float32)
    img = torch.from_numpy(img)
    return img

def open_mask(fn, chnls=None, cls=torch.Tensor):
    img = np.expand_dims(rasterio.open(fn).read(1),0)
    img = img.astype(np.float32)
    npimg = torch.from_numpy(img)
    clsImg = cls(npimg)
    return clsImg

In [77]:
class MultiChannelTensorImage(TensorImage):
    _show_args = ArrayImageBase._show_args
    def show(self, channels=[1], ctx=None, vmin=None, vmax=None, **kwargs):
        if len(channels) == 3: 
            return show_composite(self, channels=channels, ctx=ctx, vmin=vmin, vmax=vmax,
                                  **{**self._show_args, **kwargs})
    
            
    def norm(vals, vmin=None, vmax=None):
        vmin = ifnone(vmin, np.quantile(vals, 0.01))
        vmax = ifnone(vmax, np.quantile(vals, 0.99))
        return (vals - vmin)/(vmax-vmin)

    def show_composite(img, channels, ax=None, figsize=(3,3), title=None, scale=True,
                       ctx=None, vmin=None, vmax=None, **kwargs)->plt.Axes:
        
        ax = ifnone(ax, ctx)
        if ax is None: _, ax = plt.subplots()    
        r, g, b = channels
        tempim = img.data.cpu().numpy()
        im = np.zeros((tempim.shape[1], tempim.shape[2], 3))
        im[...,0] = tempim[r]
        im[...,1] = tempim[g]
        im[...,2] = tempim[b]
        if scale: im = norm(im, vmin, vmax)
        ax.imshow(im, **kwargs)
        ax.axis('off')
        if title is not None: ax.set_title(title)
        return ax

    @classmethod
    def create(cls, fn, chans=None,  **kwargs) ->None:
        return cls(open_features(fn=fn, chnls=chans))
        
    def __repr__(self): return f'{self.__class__.__name__} size={"x".join([str(d) for d in self.shape])}'
    
MultiChannelTensorImage.create = Transform(MultiChannelTensorImage.create)

def MultiChannelImageBlock(cls=MultiChannelTensorImage, chans=None):
    return TransformBlock(partial(cls.create, chans=chans))

## create image blocks

In [78]:
ImageBlock = MultiChannelImageBlock(chans=None)
MaskBlock = TransformBlock(type_tfms=[partial(open_mask, cls=TensorImage)])

## Get files for inference

In [79]:
def getFilesForStudy(path, items=lstmDF.File.values):
    return [path/('_'.join(item.split('_')[0:2]))/'MOD09A1.061'/item for item in items]

In [80]:
items = getFilesForStudy(dataPath)

## Create data blocks

In [81]:
db = DataBlock(blocks=(ImageBlock, MaskBlock),
               get_items = getFilesForStudy
              )

## Run Inference

In [82]:
model = CNNLSTM(nbTimeSteps = timeSteps)

In [83]:
dl = db.dataloaders(dataPath, num_workers=10, bs=180)

In [84]:
acc_metric = [mse, rmse, R2Score()]
loss_fn = MSELossFlat()
learn = Learner(dl, model, loss_func = loss_fn, metrics=acc_metric, opt_func=ranger)

### Get model weights

In [92]:
modelFolder = Path('./models')
modelFolder.mkdir(exist_ok=True, parents=True)

In [55]:
downloadLink = "https://github.com/GieziJo/cvpr23-earthvision-CNN-LSTM-Inundation/releases/download/v1.0.0/ModelWeights.zip"
zipFilePath = modelFolder/'ModelWeights.zip'
wget.download(downloadLink, out = (zipFilePath).as_posix())
shutil.unpack_archive(zipFilePath, modelFolder)

In [103]:
modelFolder = Path('./ModelWeights')

### Prepare output folders

In [104]:
outputFolder = Path('./InferredData')
outputFolder.mkdir(exist_ok=True, parents=True)

outputFolder_pickleData = outputFolder/'PickleData'
outputFolder_pickleData.mkdir(exist_ok=True, parents=True)

outputFolder_individualTifs = outputFolder/'IndividualTifs'
outputFolder_individualTifs.mkdir(exist_ok=True, parents=True)

outputFolder_fullTifs = outputFolder/'FullTifs'
outputFolder_fullTifs.mkdir(exist_ok=True, parents=True)

### Run inference and save results as raster

In [108]:
# define raster creation function
def processResultAsRaster(k, items, targetPath, preds):
    item = items[k]
    # new file target path
    fileTargetPath = targetPath/('_'.join(item.stem.split('_')[0:2]))
    fileTargetPath.mkdir(exist_ok=True, parents=True)
    
    # new file
    targetFile = fileTargetPath/item.name
    if targetFile.exists():
        return
    
    with rasterio.open(item) as r:
        profile = r.profile.copy()
        profile.update(count = 1)
        with rasterio.open(targetFile, 'w', **profile) as dst:
            dst.write(preds[k,::],1)

In [114]:
test_dl = learn.dls.test_dl(items)

# run inference for each leave-out year
for year in range(2017,2022):
    featurePicklePath = outputFolder_pickleData/(str(year) + '_Features')
    predictionPicklePath = outputFolder_pickleData/(str(year) + '_Infered')
                                                 
    outputFolder_individualTifs_year = outputFolder_individualTifs/(str(year))
    
    # only run inference if pickle files don't exist
    if not (featurePicklePath.exists() and predictionPicklePath.exists()):
    
        # load model for leave-out year
        modelNamePath = modelFolder/str(year)
        learn.load(modelNamePath)

        # infer for all items
        preds, _ = learn.get_preds(dl=test_dl)
        preds = preds.squeeze()

        # save pickel files in case something goes wrong
        with open(featurePicklePath,"wb") as f:
            pickle.dump(items, f)
        with open(predictionPicklePath,"wb") as f:
            pickle.dump(preds, f)
        
    else:
        with open(predictionPicklePath, "rb") as f:
            preds = np.array(pickle.load(f))
        
    # process results as tifs in parallel
    func_part = partial(processResultAsRaster, items=items, targetPath=outputFolder_individualTifs_year, preds=preds)
    _ = mp.Pool(10).map(func_part, range(len(items)))

## Results Post-processing

### create bangladesh full raster for each cross validated year and each time step

In [110]:
def createRasterForTime(time, files, targetPath):
    
    fileName = targetPath/(str(time) + '.tif')
    if fileName.exists():
        return
    
    filesForTime = Enumerable(files).where(lambda item: int(item.stem.split('_')[2]) == time)
    rasters = list(map(lambda item: riox.open_rasterio(item), filesForTime))
    
    gt = rasters[0].rio.transform()
    res = (gt[0], -gt[4])
    crs = str(rasters[0].rio.crs)
    
    merged_raster = merge_arrays(dataarrays = rasters, res = res, crs=crs, nodata = -9999)
    
    merged_raster.rio.to_raster(fileName)
    
    rasters = list(map(lambda item: item.close(), rasters))

In [112]:
times = np.sort(np.unique(list(map(lambda item: int(item.stem.split('_')[2]), items))))

for year in range(2017,2022):
    outputFolder_fullTifs_year = outputFolder_fullTifs/(str(year))
    outputFolder_fullTifs_year.mkdir(exist_ok=True, parents=True)

    outputFolder_individualTifs_year = outputFolder_individualTifs/(str(year))
    files = Enumerable(outputFolder_individualTifs_year.rglob("*")).where(lambda p: p.suffix == '.tif').to_list()
    
    mp.Pool(10).map(partial(createRasterForTime, files = files, targetPath=outputFolder_fullTifs_year), times)

### create bangladesh ensemble full raster for each time step from cross-validated models

In [113]:
outputFolder_fullTifs_ensemble = outputFolder_fullTifs/'Ensemble'
outputFolder_fullTifs_ensemble.mkdir(exist_ok=True, parents=True)

for time in times:
    fileName = outputFolder_fullTifs_ensemble/(str(time) + '.tif')
    
    if fileName.exists():
        continue
    
    vals = []
    
    for year in range(2017,2022):
        file = outputFolder_fullTifs/(str(year))/(str(time) + '.tif')
        with riox.open_rasterio(file) as r:
            vals.append(r.values)
    
            if year == 2021:
                out = np.median(np.array(vals), axis=0)
                out[out < 0] = -9999
                
                r.values = out
                r.rio.to_raster(fileName)