# Training and Cross Validation

## Load all required packages

In [1]:
import os

In [2]:
import fastai
print(fastai.__version__)

2.7.9


In [3]:
from fastai.vision.all import *

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import rasterio
import rasterio.plot
import json
from py_linq import Enumerable
import pandas as pd

In [5]:
from rasterio import logging

In [6]:
log = logging.getLogger()
log.setLevel(logging.ERROR)

In [7]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [8]:
import importlib
import Helpers.MODIS8DaysHelper as mh
import Helpers.GEEHelpers as GEEHelpers
import Helpers.StaticFeaturesHelper as StaticFeaturesHelper

In [9]:
importlib.reload(mh)
importlib.reload(GEEHelpers)
importlib.reload(StaticFeaturesHelper)

<module 'Helpers.StaticFeaturesHelper' from '/home/jgiezendanner/UA/cvpr23-earthvision-CNN-LSTM-Inundation/Source/Helpers/StaticFeaturesHelper.py'>

## Define Data Path and load data references

In [10]:
dataPath = Path('../../Data/')

In [11]:
lstmDF = pd.read_json(dataPath/'lstmFiles.json') # dataframe for LSTM with corresponding data

## Define functions to access file path and open files

In [12]:
def getStaticFeaturesFromLabel(filePath):
    elevation = np.expand_dims(StaticFeaturesHelper.getScaledElevation(filePath.parent.parent/'Elevation'/('_'.join(filePath.stem.split('_')[0:2]) + '.tif')), 0)
    slopeFile = np.expand_dims(StaticFeaturesHelper.getScaledHAND(filePath.parent.parent/'Slope'/('_'.join(filePath.stem.split('_')[0:2]) + '.tif')), 0)
    hand = np.expand_dims(StaticFeaturesHelper.getSlope(filePath.parent.parent/'HAND'/('_'.join(filePath.stem.split('_')[0:2]) + '.tif')), 0)
    return np.concatenate((elevation, slopeFile, hand))
    

In [13]:
def getModisFileFromLabel(filePath):
    fileDir = filePath.parent.parent/"MOD09A1.061"
    return [fileDir/item for item in lstmDF[lstmDF.File == filePath.name].FeatureFiles.values[0]]

In [14]:
def readImage(file, bandsToUse):
    return mh.getScaledModisFileBands(file, bandsToUse)

In [15]:
timeSteps = 10
# Open MODIS files and indices
def open_features(fn, chnls=None):
    # Stack MODIS time steps
    bandsToUse = ['sur_refl_b03', 'sur_refl_b02', 'sur_refl_b01', 'sur_refl_b04', 'sur_refl_b05', 'sur_refl_b06', 'sur_refl_b07']
    files = getModisFileFromLabel(fn)[0:timeSteps]
    
    staticFeatures = getStaticFeaturesFromLabel(fn)
    
    img = np.empty((0,32,32))
    for file in files:
        try:
            newimg = readImage(file, bandsToUse)
        except:
            newimg = readImage(file, bandsToUse)
        
        newimg = np.concatenate((newimg, staticFeatures))
        img = np.concatenate((newimg, img))
    
    img = img.astype(np.float32)
    img = torch.from_numpy(img)
    return img

# open ground truth
def open_mask(fn, chnls=None, cls=torch.Tensor):
    img = np.expand_dims(rasterio.open(fn).read(1),0)
    img = img.astype(np.float32)
    npimg = torch.from_numpy(img)
    clsImg = cls(npimg)
    return clsImg

## define function wot work with multi-band data

In [16]:
class MultiChannelTensorImage(TensorImage):
    _show_args = ArrayImageBase._show_args
    def show(self, channels=[1], ctx=None, vmin=None, vmax=None, **kwargs):
        if len(channels) == 3: 
            return show_composite(self, channels=channels, ctx=ctx, vmin=vmin, vmax=vmax,
                                  **{**self._show_args, **kwargs})
    
            
    def norm(vals, vmin=None, vmax=None):
        vmin = ifnone(vmin, np.quantile(vals, 0.01))
        vmax = ifnone(vmax, np.quantile(vals, 0.99))
        return (vals - vmin)/(vmax-vmin)

    def show_composite(img, channels, ax=None, figsize=(3,3), title=None, scale=True,
                       ctx=None, vmin=None, vmax=None, **kwargs)->plt.Axes:
        
        ax = ifnone(ax, ctx)
        if ax is None: _, ax = plt.subplots()    
        r, g, b = channels
        tempim = img.data.cpu().numpy()
        im = np.zeros((tempim.shape[1], tempim.shape[2], 3))
        im[...,0] = tempim[r]
        im[...,1] = tempim[g]
        im[...,2] = tempim[b]
        if scale: im = norm(im, vmin, vmax)
        ax.imshow(im, **kwargs)
        ax.axis('off')
        if title is not None: ax.set_title(title)
        return ax

    @classmethod
    def create(cls, fn, chans=None,  **kwargs) ->None:
        return cls(open_features(fn=fn, chnls=chans))
        
    def __repr__(self): return f'{self.__class__.__name__} size={"x".join([str(d) for d in self.shape])}'
    
MultiChannelTensorImage.create = Transform(MultiChannelTensorImage.create)

def MultiChannelImageBlock(cls=MultiChannelTensorImage, chans=None):
    return TransformBlock(partial(cls.create, chans=chans))

## create image blocks

In [17]:
# create image blocks
ImageBlock = MultiChannelImageBlock(chans=None)
MaskBlock = TransformBlock(type_tfms=[partial(open_mask, cls=TensorImage)])

## Split dataset

In [18]:
def FileSplitter(leaveOutYear):
    def _func(x): return int(x.stem.split('_')[2]) >= int(GEEHelpers.GetGEETimeStampFromDate(leaveOutYear,1,1))\
                        and int(x.stem.split('_')[2]) <= int(GEEHelpers.GetGEETimeStampFromDate(leaveOutYear,12,31))
    def _inner(o, **kwargs): return FuncSplitter(_func)(o)
    return _inner

## define function to get dataset

In [19]:
def getFilesForStudy(path, items=lstmDF.File.values):
    return [path/('_'.join(item.split('_')[0:2]))/'Sen1FractionInundatedArea'/item for item in items]

In [20]:
# uncomment this to check if items are found
# items = getFilesForStudy(dataPath, lstmDF.File.values)

## Define model

In [21]:
class Net(nn.Module):
    def __init__(self, nbFeatures=10, initSize=32, nbLayers=1, nbTimeSteps=timeSteps, input_size=32*32, hidden_size=32*32):
        super().__init__()
        
        self.nbFeatures = nbFeatures
        self.input_size = input_size
        self.nbTimeSteps = nbTimeSteps
        
        # define helper functions for convolutions, either single convolution, or combined with res block
        def conv2_single(ni,nf): return nn.Conv2d(ni, nf, kernel_size=3, padding=1, padding_mode='reflect', stride=1)
        def conv2(ni,nf): return nn.Conv2d(ni, nf, groups=nbTimeSteps, kernel_size=3, padding=1, padding_mode='reflect', stride=1)
        def conv2_and_res(ni, nf): return nn.Sequential(conv2(ni,nf), ResBlock(2, ni, nf, groups=nbTimeSteps, stride=1))
        
        # Create CNN A
        # note that groups are defined by number of time steps, i.e. each time step is applied the same CNN separatly
        self.cnn = nn.Sequential(
            conv2(nbFeatures*nbTimeSteps,nbTimeSteps*initSize)
        )
        
        for k in range(nbLayers):
            self.cnn = self.cnn.append(conv2_and_res(nbTimeSteps*initSize * 4**k, nbTimeSteps*(initSize * 2) * 4**k))
        self.cnn = self.cnn.append(conv2(nbTimeSteps*(initSize * 2) * 4**(nbLayers-1)* 2,nbTimeSteps*1))
        
        # Create LSTM
        self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True, bidirectional=False)
        
        # Create transpose convolution
        self.convTrans = nn.ConvTranspose2d(1024,1,kernel_size=32)
        
        # Set the output to be a single convolution and a sigmoid
        self.outLayer = nn.Sequential(conv2_single(2,1), SigmoidRange(0,1))

    def forward(self, x):
        # get size of problem
        batchSize = x.shape[0]
        imgSize = x.shape[2:4]
        
        # pass all time steps through the CNN
        x = self.cnn(x)
        # extract time step 0
        x_now = x[:,0,::].view((batchSize,1,imgSize[0],imgSize[1]))
        # pass time step -1 to -9 through lstm
        x, (_,_) = self.LSTM(x[:,1:,::].view((batchSize,self.nbTimeSteps-1,-1)))
        # extract result and pass through transpose convolution
        x = x[:,-1,:].view((batchSize,-1,1,1))
        x = self.convTrans(x)
        # concatenate lstm output and time step t
        x = torch.cat((x_now,x),1)
        # pass output through output layer
        x = self.outLayer(x)
        return x

## Define model params, output dir and tfms

In [22]:
(Path('models')/'CNNLSTM').mkdir(parents=True, exist_ok=True)

In [23]:
batch_tfms = [Rotate(), Flip(), Dihedral()]

## Train Model

In [24]:
def train(leaveOutYear):
    
    # Define data loaders
    db = DataBlock(blocks=(ImageBlock, MaskBlock),
               get_items = getFilesForStudy,
               splitter=FileSplitter(leaveOutYear),
               batch_tfms = batch_tfms,
              )

    dl = db.dataloaders(dataPath, num_workers=20, bs=128)#os.cpu_count()-20, num_workers=20, num_workers=int((os.cpu_count()-20) / 3)

    # Set model metrics
    acc_metric = [mse, rmse, R2Score()]
    loss_fn = MSELossFlat()

    # create model
    model = Net()

    # create learner
    learn = Learner(dl, model, loss_func = loss_fn, metrics=acc_metric, opt_func=ranger, cbs=CSVLogger(append=True, fname='history_' + str(leaveOutYear) + '.csv'))

    # in case we want to load a previous iteration of learning (also modify the for loop below)
    # learn_iter = 0
    # learn.load("'CNNLSTM/' + str(leaveOutYear) + str(0), with_opt=False)

    # in case we want to find the learning rate valley
    # lr = learn.lr_find().valley
    # print(lr)

    lrs = [.001, .0001, .00001]
    # epochs = [20, 5, 5]
    epochs = [3, 1, 1]

    # train
    for k in range(3):
        lrslice = slice(lrs[k])
        learn.fit_flat_cos(epochs[k], lr=lrslice)
        learn.save('CNNLSTM/' + str(leaveOutYear) + "_" + str(k))
        print('done saving ' + str(k))

    print('Done with Leave-out year ' + str(leaveOutYear))

In [25]:
# Loop through cross validated years
for leaveOutYear in range(2017, 2022):
    print(' Starting Leave-out year ' + str(leaveOutYear))
    train(leaveOutYear)