## Cropland mapping with Unet

This notebook trains a Unet convolutional neural network and outputs predictions on PlanetScope NICFI basemap data. It is adapted to run in Google's Colab.

You can either upload this notebook or clone the class repo into Colab.  

Code in this notebook was developed by Boka Luo and Sam Khallaghi for cropland mapping working conducted by the Agricultural Impacts Research Group.

## Set up

`rasterio` has to be installed first. 

In [1]:
!pip install rasterio

And change the runtime type in Colab to GPU

### Libraries

In [4]:
# Import the required libraries
# Numerical computation
import numpy as np
import numpy.ma as ma

# Structured data wrangling
import pandas as pd

# Computer vision libraries
from sklearn import metrics
from skimage import transform as trans
import matplotlib.pyplot as plt
import cv2
import gdal
import rasterio
from rasterio.windows import Window

# General libraries
import itertools
from itertools import product
import random
import math
import numbers
import copy
import os
import glob
import gc
import re
from datetime import datetime, timedelta
from pathlib import Path
import collections
from collections import OrderedDict
import urllib.parse as urlparse
import queue
import threading
import multiprocessing as mp

# DL package
import torch
import torch.nn.functional as F
from torch import optim
from torch import nn
from torch.nn import init
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torchvision import transforms
import torchvision.transforms.functional as TF
import torchvision.utils as vutils
import torch.cuda.comm as comm
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter

# Debugging library for jupyter notebook.
from IPython.core.debugger import set_trace

# Supress Warnings
import warnings
warnings.filterwarnings('ignore')

# Magic keywords for Ipython
%load_ext autoreload
%autoreload 2
%matplotlib inline

### Connect colab notebook to your Google Drive 
For the first connection you need to follow the pop-up window grab the authorization token and paste it in this cell.



In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

### Check pytorch, CUDA, GPUs

In [None]:
print("PyTorch version: {}".format(torch.__version__))
print("Cuda version : {}".format(torch.version.cuda))
print('CUDNN version:', torch.backends.cudnn.version())
print('Number of available GPU Devices:', torch.cuda.device_count())
print("current GPU Device: {}".format(torch.cuda.current_device()))

## Functions

### Manual Seeding

Manual seeding makes the model deterministic and keeps the experiments reproducible.


In [6]:
# "torch.backends.cudnn.benchmark = True" --> This will allow the cuda backend to optimize your graph during its first execution. 
# Be aware that if you change the network input/output tensor size the graph will be optimized each time a change occurs.
# This can lead to very slow runtime and out of memory errors. Only set this flag if your input and output have always the 
# same shape. Usually, this results in an improvement of about 20%.

def seed_everything(seed = 1234, cudnn = True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if cudnn:
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

seed_everything()

### Image Augmentation

In [None]:
### Random Rotation around the center ###
def centerRotate(img, label, mask, degree):
    '''
    Synthesize new image chips by rotating the input chip around its center.
    Args:
    img (narray): Concatenated variables or brightness value with a dimension of (H, W, C)
    label (narray): Ground truth with a dimension of (H,W)
    mask (narray): Binary mask represents valid pixels in images and label, in a dimension of (H,W)
    degree (tuple or list): Range of degree for rotation
    Returns:
    (narray, narray, narray) tuple of rotated image, label and mask
    '''

    if isinstance(degree, tuple) or isinstance(degree, list):
        degree = random.uniform(degree[0], degree[1])

    # Get the dimensions of the image (e.g. number of rows and columns).
    h, w,_ = img.shape

    # Determine the image center.
    center = (w // 2, h // 2)

    # Grab the rotation matrix
    rotMtrx = cv2.getRotationMatrix2D(center, degree, 1.0)

    # perform the actual rotation for both raw and labeled image.
    img = cv2.warpAffine(img, rotMtrx, (w, h))
    label = cv2.warpAffine(label, rotMtrx, (w, h))
    label = np.rint(label)
    mask = cv2.warpAffine(mask, rotMtrx, (w, h))
    mask = np.rint(mask)

    return img, label, mask

### Horizontal, Vertical and Diagonal Flip ###
def flip(img, label, mask, ftype):
    '''
    Synthesize new image chips by flipping the input chip around a user defined 
    axis.
    Args:
        img (narray): Concatenated variables or brightness value with a 
            dimension of (H, W, C)
        label (narray): Ground truth with a dimension of (H,W)
        mask (narray): Binary mask represents valid pixels in images and 
            label, in a dimension of (H,W)
        ftype (str): Flip type from ['vflip','hflip','dflip']
    Returns:
        (narray, narray, narray) tuple of flipped image, label and mask
    Note:
        Provided transformation are:
            1) 'vflip', vertical flip
            2) 'hflip', horizontal flip
            3) 'dflip', diagonal flip
    '''

    def diagonal_flip(img):
        flipped = np.flip(img, 1)
        flipped = np.flip(flipped, 0)
        return flipped


    # Horizontal flip
    if ftype == 'hflip':

        img = np.flip(img, 0)
        label = np.flip(label, 0)
        mask = np.flip(mask, 0)

    # Vertical flip
    elif ftype == 'vflip':

        img = np.flip(img, 1)
        label = np.flip(label, 1)
        mask = np.flip(mask, 1)

    # Diagonal flip
    elif ftype == 'dflip':

        img = diagonal_flip(img)
        label = diagonal_flip(label)
        mask = diagonal_flip(mask)

    else:

        raise ValueError("Bad flip type")

    return img.copy(), label.copy(), mask.copy()

### Random Rescaling of image chips ###
def reScale(img, label, mask, scale=(0.8, 1.2), randResizeCrop=False, 
            diff=False, cenLocate=True):
    '''
    Synthesize new image chips by rescaling the input chip.
    Params:
        img (narray): Concatenated variables or brightness value with a 
            dimension of (H, W, C)
        label (narray): Ground truth with a dimension of (H,W)
        mask (narray): Binary mask represents valid pixels in images and
            label, in a dimension of (H,W)
        scale (tuple or list): Range of scale ratio
        randResizeCrop (bool): Whether crop the rescaled image chip 
            randomly or at the center if the chip is larger than inpput ones
        diff (bool): Whether change the aspect ratio
        cenLocate (bool): Whether locate the rescaled image chip at the center
            or a random position if the chip is smaller than input
    Returns:
        (narray, narray, narray) tuple of rescaled image, label and mask
    '''

    h, w, _ = img.shape
    if isinstance(scale, tuple) or isinstance(scale, list):
        resizeH = round(random.uniform(scale[0], scale[1]) * h)
        if diff:
            resizeW = round(random.uniform(scale[0], scale[1]) * w)
        else:
            resizeW = resizeH
    else:
        raise Exception('Wrong scale type!')

    imgRe = trans.resize(img, (resizeH, resizeW), preserve_range=True)
    labelRe = trans.resize(label, (resizeH, resizeW), preserve_range=True)
    maskRe = trans.resize(mask, (resizeH, resizeW), preserve_range=True)


    # crop image when length of side is larger than input ones
    if randResizeCrop:
        x_off = random.randint(0, max(0, resizeH - h))
        y_off = random.randint(0, max(0, resizeW - w))
    else:
        x_off = max(0, (resizeH - h) // 2)
        y_off = max(0, (resizeW - w) // 2)

    imgRe = imgRe[x_off:x_off + min(h, resizeH), y_off:y_off + min(w, resizeW), :]
    labelRe = labelRe[x_off:x_off + min(h, resizeH), y_off:y_off + min(w, resizeW)]
    labelRe = np.rint(labelRe)
    maskRe = maskRe[x_off:x_off + min(h, resizeH), y_off:y_off + min(w, resizeW)]
    maskRe = np.rint(maskRe)

    # locate image when it is smaller than input
    if resizeH < h or resizeW < w:
        if cenLocate:
            tlX = max(0, (h - resizeH) // 2)
            tlY = max(0, (w - resizeW) // 2)
        else:
            tlX = random.randint(0, max(0, h - resizeH))
            tlY = random.randint(0, max(0, w - resizeW))

        # resized result
        imgRe, labelRe, maskRe = uniShape(imgRe, labelRe, maskRe, h, tlX, tlY)

    return imgRe, labelRe, maskRe

### Change pixel brightness to account for atmospheric and ### 
### illumination noise  ###
def shiftBrightness(img, gammaRange=(0.2, 2.0), shiftSubset=(4, 4), 
                    patchShift=True):
    '''
    Shift image brightness through gamma correction
    Params:
        img (narray): Concatenated variables or brightness value with a
            dimension of (H, W, C)
        gammaRange (tuple): Range of gamma values
        shiftSubset (tuple): Number of bands or channels for each shift
        patchShift (bool): Whether apply the shift on small patches
     Returns:
        narray, brightness shifted image
    '''


    c_start = 0

    if patchShift:
        for i in shiftSubset:
            gamma = random.triangular(gammaRange[0], gammaRange[1], 1)

            h, w, _ = img.shape
            rotMtrx = cv2.getRotationMatrix2D(
                center=(random.randint(0, h), random.randint(0, w)),
                angle=random.randint(0, 90), scale=random.uniform(1, 2)
            )
            mask = cv2.warpAffine(img[:, :, c_start:c_start + i], 
                                  rotMtrx, (w, h))
            mask = np.where(mask, 0, 1)
            # apply mask
            img_ma = ma.masked_array(img[:, :, c_start:c_start + i], mask=mask)
            img[:, :, c_start:c_start + i] = ma.power(img_ma, gamma)
            # default extra step -- shift on image
            gamma_full = random.triangular(0.5, 1.5, 1)
            img[:, :, c_start:c_start + i] = np.power(
                img[:, :, c_start:c_start + i], gamma_full
            )

            c_start += i
    else:
        # convert image dimension to (C, H, W) if len(img.shape)==3
        img = np.transpose(
            img, list(range(img.ndim)[-1:]) + list(range(img.ndim)[:-1])
        )
        for i in shiftSubset:
            gamma = random.triangular(gammaRange[0], gammaRange[1], 1)
            img[c_start:c_start + i, ] = np.power(
                img[c_start:c_start + i, ], gamma
            )

            c_start += i
        img = np.transpose(img, list(range(img.ndim)[-img.ndim + 1:]) + [0])

    return img

### Unify image and label chips dimensions ###
def uniShape(img, label, mask, dsize, tlX=0, tlY=0):

    '''
    Unify dimension of images and labels to specified data size
    Params:
    img (narray): Concatenated variables or brightness value with a dimension 
        of (H, W, C)
    label (narray): Ground truth with a dimension of (H,W)
    mask (narray): Binary mask represents valid pixels in images and label, 
        in a dimension of (H,W)
    dsize (int): Target data size
    tlX (int): Vertical offset by pixels
    tlY (int): Horizontal offset by pixels
    Returns:
    (narray, narray, narray) tuple of shape unified image, label and mask
    '''

    resizeH, resizeW, c = img.shape

    canvas_img = np.zeros((dsize, dsize, c), dtype=img.dtype)
    canvas_label = np.zeros((dsize, dsize), dtype=label.dtype)
    canvas_mask = np.zeros((dsize, dsize), dtype=label.dtype)

    canvas_img[tlX:tlX + resizeH, tlY:tlY + resizeW] = img
    canvas_label[tlX:tlX + resizeH, tlY:tlY + resizeW] = label
    canvas_mask[tlX:tlX + resizeH, tlY:tlY + resizeW] = mask

    return canvas_img, canvas_label, canvas_mask

### Utility Functions

In [None]:
### Load data ####
def load_data(dataPath, usage="train", window=None, isLabel=False):
    '''
    Read geographic data into numpy array
    Params:
        dataPath (str): Path of data to load
        usage (str): Usage of the data: "train", "validate", or "predict"
        window (tuple): The view onto a rectangular subset of the data, in the 
            format of (column offsets, row offsets, width in pixel, 
            height in pixel)
        isLabel (binary): Decide whether to saturate data with tested threshold
    Returns:
        narray
    '''

    with rasterio.open(dataPath, "r") as src:

        if isLabel:
            if src.count != 1:
                raise InputError("Label shape not applicable")
            img = src.read(1)
        else:
            nodata = src.nodata

            if usage in ['train', 'validate']:
                # Norm by tile
                img = mmNorm(src.read(), nodata=nodata)
                img = img[:, max(0, window[1]): window[1] + window[3], 
                          max(0, window[0]): window[0] + window[2]]

            else:
                # Norm by tile
                img = mmNorm(src.read(), nodata=nodata)

    return img

### stack images ###
def get_stacked_img(imgPaths, usage, window=None):
    '''
    Read geographic data into numpy array
    Params:
        gsPath (str): Path of growing season image
        osPath (str): Path of off season image
        imgPaths (list): List of paths for imgages
        usage (str): Usage of the image: "train", "validate", or "predict"
        window (tuple): The view onto a rectangular subset of the data, in 
            the format of (column offsets, row offsets, width in pixel, height in pixel)
    Returns:
        narray
    '''

    img_ls = [load_data(m, window=window, usage=usage) for m in imgPaths]
    img = np.concatenate(img_ls, axis=0).transpose(1, 2, 0)

    if usage in ["train", "validate"]:
        col_off, row_off, col_target, row_target = window
        row, col, c = img.shape

        if row < row_target or col < col_target:

            row_off = abs(row_off) if row_off < 0 else 0
            col_off = abs(col_off) if col_off < 0 else 0

            canvas = np.zeros((row_target, col_target, c))
            canvas[row_off: row_off + row, col_off : col_off + col, :] = img

            return canvas

        else:
            return img

    elif usage == "predict":
        return img

    else:
        raise ValueError

### Get buffered window ####

def get_buffered_window(srcPath, dstPath, buffer):
    '''
    Get bounding box representing subset of source image that overlaps with 
    bufferred destination image, in format of (column offsets, row offsets, 
    width, height)
    Params:
        srcPath (str): Path of source image to get subset bounding box
        dstPath (str): Path of destination image as a reference to define the
            bounding box. Size of the bounding box is (destination width + 
            buffer * 2, destination height + buffer * 2)
        buffer (int): Buffer distance of bounding box edges to destination image 
            measured by pixel numbers
    Returns:
        tuple in form of (column offsets, row offsets, width, height)
    '''

    with rasterio.open(srcPath, "r") as src:
        gt_src = src.transform

    with rasterio.open(dstPath, "r") as dst:
        gt_dst = dst.transform
        w_dst = dst.width
        h_dst = dst.height

    col_off = round((gt_dst[2] - gt_src[2]) / gt_src[0]) - buffer
    row_off = round((gt_dst[5] - gt_src[5]) / gt_src[4]) - buffer
    width = w_dst + buffer * 2
    height = h_dst + buffer * 2

    return col_off, row_off, width, height

### Get meta from bounds ###

def get_meta_from_bounds(file, buffer):
    '''
    Get metadata of unbuffered region in given file
    Params:
        file (str):  File name of a image chip
        buffer (int): Buffer distance measured by pixel numbers
    Returns:
        dictionary
    '''

    with rasterio.open(file, "r") as src:

        meta = src.meta
        dst_width = src.width - 2 * buffer
        dst_height = src.height - 2 * buffer

        window = Window(buffer, buffer, dst_width, dst_height)
        win_transform = src.window_transform(window)

    meta.update({
        'width': dst_width,
        'height': dst_height,
        'transform': win_transform,
        'count': 1,
        'nodata': -128,
        'dtype': 'int8'
    })

    return meta

### Image Normalization ###
def mmNorm(img, nodata):
    '''
    Data normalization with min/max method
    Params:
        img (narray): The targeted image for normalization
    Returns:
        narrray
    '''

    img_tmp = np.where(img == nodata, np.nan, img)
    img_max = np.nanmax(img_tmp)
    img_min = np.nanmin(img_tmp)
    normalized = (img - img_min)/(img_max - img_min)

    return normalized

### Get Chips ###
def get_chips(img, dsize, buffer):
    '''
    Generate small chips from input images and the corresponding index of each 
    chip The index marks the location of corresponding upper-left pixel of a 
    chip.
    Params:
        img (narray): Image in format of (H,W,C) to be crop, in this case it is 
            the concatenated image of growing season and off season
        dsize (int): Cropped chip size
        buffer (int):Number of overlapping pixels when extracting images chips
    Returns:
        list of cropped chips and corresponding coordinates
    '''

    h, w, _ = img.shape
    x_ls = range(0,h - 2 * buffer, dsize - 2 * buffer)
    y_ls = range(0, w - 2 * buffer, dsize - 2 * buffer)

    index = list(itertools.product(x_ls, y_ls))

    img_ls = []
    for i in range(len(index)):
        x, y = index[i]
        img_ls.append(img[x:x + dsize, y:y + dsize, :])

    return img_ls, index

### Input Error handling ###
class InputError(Exception):
    '''
    Exception raised for errors in the input
    '''

    def __init__(self, message):
        '''
        Params:
            message (str): explanation of the error
        '''

        self.message = message

    def __str__(self):
        '''
        Define message to return when error is raised
        '''

        if self.message:
            return 'InputError, {} '.format(self.message)
        else:
            return 'InputError'

### Polynomial Learning rate Decay Policy ###

class PolynomialLR(_LRScheduler):
    """Polynomial learning rate decay until step reach to max_decay_step
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_decay_steps: after this step, we stop decreasing learning rate
        min_learning_rate: scheduler stoping learning rate decay, value of 
            learning rate must be this value
        power: The power of the polynomial.
    """
    
    def __init__(self, optimizer, max_decay_steps, min_learning_rate=1e-5, 
                 power=1.0):
        
        if max_decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')
        
        self.max_decay_steps = max_decay_steps
        self.min_learning_rate = min_learning_rate
        self.power = power
        self.last_step = 0
        
        super().__init__(optimizer)
        
    def get_lr(self):
        if self.last_step > self.max_decay_steps:
            return [self.min_learning_rate for _ in self.base_lrs]

        return [(base_lr - self.min_learning_rate) * 
                ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                self.min_learning_rate for base_lr in self.base_lrs]
    
    def step(self, step=None):
        
        if step is None:
            step = self.last_step + 1
        self.last_step = step if step != 0 else 1
        
        if self.last_step <= self.max_decay_steps:
            decay_lrs = [
                (base_lr - self.min_learning_rate) * 
                ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                self.min_learning_rate for base_lr in self.base_lrs
            ]
            
            for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
                param_group['lr'] = lr

### Reading the dataframe input with parallel workers ###
def parallelize_df(df, func, n_cores=os.cpu_count(), **kwargs):
    '''
    Processes specified method on pandas dataframe using multiple cores
    Params:
        df (''pd.DataFrames''): Pandas dataframe to be processed
        func: Method to apply on provided dataframe
        n_cores (int): Number of processes that the mother process splits into
    Returns:
        ''pd.DataFrames''
    '''

    n_cores = min(n_cores, len(df))    
    other_args = [kwargs['{}'.format(m)] for m in func.__code__.co_varnames[1:]]
    df_split = np.array_split(df, n_cores)
    
    pool = mp.Pool(n_cores)
    df_map = pool.starmap(func, product(df_split, *[[m] for m in other_args]))

    df = pd.concat(df_map)
    pool.close()
    pool.join()

    return df

### Multi-core ###

def multicore(func, args, n_cores=os.cpu_count()):
    '''
    Processes specified method on a series of arguments in parallel. Number of 
    cores is determined by whichever is smaller of the computer cores or the 
    arguments nubmer.
    Params:
        fuc: function to apply on provided arguments
        args (list): a list of independent arguments
    '''

    n_cores = min(n_cores, len(args))
    pool = mp.Pool(processes=n_cores)
    pool.map(func, args)
    pool.close()
    pool.join()

### Loss functions

In [None]:
### Balanced Cross Entropy Loss ###
class BalancedCrossEntropyLoss(nn.Module):
    r'''
    Cross Entropy loss weighted based on inverse class ratio strategy.
    
    Arguments:
        ignore_index (int) -- Class index to ignore
        reduction (str) -- Reduction method to apply, return mean over batch if 'mean',
                         return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    '''

    def __init__(self, ignore_index=-100, reduction='mean'):
        super(BalancedCrossEntropyLoss, self).__init__()
        
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predict, target):
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        
        for i in range(len(unique)):
            lossWeight[unique[i]] = weight[i]
        
        loss = nn.CrossEntropyLoss(weight=lossWeight, ignore_index=self.ignore_index, reduction=self.reduction)

        return loss(predict, target)

### Dice Loss Family ###
class BinaryDiceLoss(nn.Module):
    r'''
    Dice loss of binary classes.
    Arguments:
        smooth (float): A float number to smooth loss, and avoid NaN error, default: 1
        p (int): Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict (torch.tensor): Predicted tensor of shape [N, *]
        target (torch.tensor): Target tensor of same shape with predict
    Returns:
        Loss tensor
    '''

    def __init__(self, smooth=1, p=1):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size do not match"
        predict = predict.contiguous().view(-1)
        target = target.contiguous().view(-1)

        num = 2 * (predict * target).sum() + self.smooth
        den = (predict.pow(self.p) + target.pow(self.p)).sum() + self.smooth
        loss = 1 - num / den

        return loss

###
class DiceLoss(nn.Module):
    r'''
    Dice loss
    
    Arguments:
        weight (torch.tensor): Weight array of shape [num_classes,]
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to BinaryDiceLoss
    Returns:
        same as BinaryDiceLoss
    '''

    def __init__(self, weight=None, ignore_index=-100, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        nclass = predict.shape[1]
        if predict.shape == target.shape:
            pass
        elif len(predict.shape) == 4:
            target = F.one_hot(target, num_classes=nclass).permute(0, 3, 1, 2).contiguous()
        else:
            assert 'Predict tensor shape of {} is not assceptable.'.format(predict.shape)

        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        weight = torch.Tensor([1. / nclass] * nclass).cuda() if self.weight is None else self.weight
        predict = F.softmax(predict, dim=1)

        for i in range(nclass):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])

                assert weight.shape[0] == nclass, \
                    'Expected weight tensor with shape [{}], but got[{}]'.format(nclass, weight.shape[0])
                dice_loss *= weight[i]
                total_loss += dice_loss

        return total_loss

###
class BalancedDiceLoss(nn.Module):
    r'''
    Dice Loss weighted by inverse of label frequency
    Arguments:
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to BinaryDiceLoss
    Returns:
        same as BinaryDiceLoss
    '''

    def __init__(self, ignore_index=-100, **kwargs):
        super(BalancedDiceLoss, self).__init__()
        self.kwargs = kwargs
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        for i in range(len(unique)):
            lossWeight[unique[i]] = weight[i]

        loss = DiceLoss(weight=lossWeight, ignore_index=self.ignore_index, **self.kwargs)

        return loss(predict, target)

###
class DiceCELoss(nn.Module):
    '''
    Combination of dice loss and cross entropy loss through summation
    
    Arguments:
        loss_weight (tensor): a manual rescaling weight given to each class. If given, has to be a Tensor of size C
        dice_weight (float): Weight on dice loss for the summation, while weight on cross entropy loss is
                             (1 - dice_weight)
        dice_smooth (float): A float number to smooth dice loss, and avoid NaN error, default: 1
        dice_p (int): Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        ignore_index (int): Class index to ignore
    Returns:
        Loss tensor
    '''

    def __init__(self, loss_weight = None, dice_weight=0.5 , dice_smooth=1, dice_p=1, ignore_index=-100):
        super(DiceCELoss, self).__init__()
        self.loss_weight = loss_weight
        self.dice_weight = dice_weight
        self.dice_smooth = dice_smooth
        self.dice_p = dice_p
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size do not match"

        dice = DiceLoss(weight=self.loss_weight, ignore_index=self.ignore_index, smooth=self.dice_smooth, p=self.dice_p)
        ce = nn.CrossEntropyLoss(weight=self.loss_weight, ignore_index=self.ignore_index)
        loss = self.dice_weight * dice(predict, target) + (1 - self.dice_weight) * ce(predict, target)

        return loss

###
class BalancedDiceCELoss(nn.Module):
    r'''
    Dice Cross Entropy weighted by inverse of label frequency
    Arguments:
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to DiceCELoss, excluding loss_weight
    Returns:
        Same as DiceCELoss
    '''

    def __init__(self, ignore_index=-100, **kwargs):
        super(BalancedDiceCELoss, self).__init__()
        self.ignore_index =  ignore_index
        self.kwargs = kwargs

    def forward(self, predict, target):
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        for i in range(len(unique)):
            lossWeight[unique[i]] = weight[i]

        loss = DiceCELoss(loss_weight=lossWeight, **self.kwargs)

        return loss(predict, target)

### Tversky-Focal Loss Family ###
class BinaryTverskyFocalLoss(nn.Module):
    r'''
    Pytorch versiono of tversky focal loss proposed in paper
    'A novel focal Tversky loss function and improved Attention U-Net for lesion segmentation'
    (https://arxiv.org/abs/1810.07842)
    
    Arguments:
        smooth (float): A float number to smooth loss, and avoid NaN error, default: 1
        alpha (float): Hyperparameters alpha, paired with (1 - alpha) to shift emphasis to improve recall
        gamma (float): Tversky index, default: 1.33
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
    
    Returns:
        Loss tensor
    '''

    def __init__(self, smooth=1, alpha=0.7, gamma=1.33):
        super(BinaryTverskyFocalLoss, self).__init__()
        self.smooth = smooth
        self.alpha = alpha
        self.beta = 1 - self.alpha
        self.gamma = gamma


    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size do not match"

        # no reduction, same as original paper
        predict = predict.contiguous().view(-1)
        target = target.contiguous().view( -1)

        num = (predict * target).sum() + self.smooth
        den = (predict * target).sum() + self.alpha * ((1 - predict) * target).sum() \
              + self.beta * (predict * (1 - target)).sum() + self.smooth
        loss = torch.pow(1 - num/den, 1 / self.gamma)

        return loss
            
###
class TverskyFocalLoss(nn.Module):
    r'''
    Tversky focal loss
    
    Arguments:
        weight (torch.tensor): Weight array of shape [num_classes,]
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to BinaryTverskyFocalLoss
    Returns:
        same as BinaryTverskyFocalLoss
    '''
    def __init__(self, weight=None, ignore_index=-100, **kwargs):
        super(TverskyFocalLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        nclass = predict.shape[1]
        if predict.shape == target.shape:
            pass
        elif len(predict.shape) == 4:
            target = F.one_hot(target, num_classes=nclass).permute(0, 3, 1, 2).contiguous()
        else:
            assert 'predict shape not applicable'

        tversky = BinaryTverskyFocalLoss(**self.kwargs)
        total_loss = 0
        weight = torch.Tensor([1./nclass] * nclass).cuda() if self.weight is None else self.weight
        predict = F.softmax(predict, dim=1)
        
        for i in range(nclass):
            if i != self.ignore_index:
                tversky_loss = tversky(predict[:, i], target[:, i])
                assert weight.shape[0] == nclass, \
                    'Expect weight shape [{}], get[{}]'.format(nclass, weight.shape[0])
                tversky_loss *= weight[i]
                total_loss += tversky_loss
            
        return total_loss

###
class BalancedTverskyFocalLoss(nn.Module):
    r'''  
    Tversky focal loss weighted by inverse of label frequency
    Arguments:
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to BinaryTverskyFocalLoss
    Returns:
        same as TverskyFocalLoss
    '''

    def __init__(self, ignore_index=-100, **kwargs):
        super(BalancedTverskyFocalLoss, self).__init__()
        self.kwargs = kwargs
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        for i in range(len(unique)):
                lossWeight[unique[i]] = weight[i]

        # loss
        loss = TverskyFocalLoss(weight=lossWeight, ignore_index=self.ignore_index, **self.kwargs)

        return loss(predict, target)

###
class TverskyFocalCELoss(nn.Module):
    '''
    Combination of tversky focal loss and cross entropy loss though summation
    Arguments:
        loss_weight (tensor): a manual rescaling weight given to each class. If given, has to be a Tensor of size C
        tversky_weight (float): Weight on tversky focal loss for the summation, while weight on cross entropy loss
                                is (1 - tversky_weight)
        tversky_smooth (float): A float number to smooth tversky focal loss, and avoid NaN error, default: 1
        tversky_alpha (float):
        tversky_gamma (float):
        ignore_index (int): Class index to ignore
    Returns:
        Loss tensor
    '''

    def __init__(self, loss_weight=None, tversky_weight=0.5, tversky_smooth=1, tversky_alpha=0.7, 
                 tversky_gamma=0.9, ignore_index=-100):
        super(TverskyFocalCELoss, self).__init__()
        self.loss_weight = loss_weight
        self.tversky_weight = tversky_weight
        self.tversky_smooth = tversky_smooth
        self.tversky_alpha = tversky_alpha
        self.tversky_gamma = tversky_gamma
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size do not match"

        tversky = TverskyFocalLoss(weight=self.loss_weight, ignore_index=self.ignore_index, smooth=self.tversky_smooth,
                                   alpha=self.tversky_alpha, gamma=self.tversky_gamma)
        ce = nn.CrossEntropyLoss(weight=self.loss_weight, ignore_index=self.ignore_index)
        loss = self.tversky_weight * tversky(predict, target) + (1 - self.tversky_weight) * ce(predict, target)

        return loss

###
class BalancedTverskyFocalCELoss(nn.Module):
    r'''
    Combination of tversky focal loss and cross entropy loss weighted by inverse of label frequency
    
    Arguments:
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to DiceCELoss, excluding loss_weight
    Returns:
        Same as TverskyFocalCELoss
    '''

    def __init__(self, ignore_index=-100, **kwargs):
        super(BalancedTverskyFocalCELoss, self).__init__()
        self.ignore_index =  ignore_index
        self.kwargs = kwargs

    def forward(self, predict, target):
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        for i in range(len(unique)):
            lossWeight[unique[i]] = weight[i]

        loss = TverskyFocalCELoss(loss_weight=lossWeight, **self.kwargs)

        return loss(predict, target)

### Accuracy Metrics + Evaluation 

Performend on the validation dataset, and producing a CSV report

In [10]:
class BinaryMetrics:
    '''
    Metrics measuring model performance.
    '''

    def __init__(self, refArray, scoreArray, predArray=None):
        '''
        Params:
            refArray (narray): Array of ground truth
            scoreArray (narray): Array of pixels scores of positive class
            predArray (narray): Boolean array of predictions telling whether a pixel belongs to a specific class
        '''

        self.eps = 10e-6
        self.observation = refArray.flatten()
        self.score = scoreArray.flatten()
        if predArray is not None:
            self.prediction = predArray.flatten()
        # take score over 0.5 as prediction if predArray not provided
        else:
            self.prediction = np.where(self.score > 0.5, 1, 0)
        self.confusion_matrix = self.confusion_matrix()

        if self.observation.shape != self.score.shape:
            raise InputError("Inconsistent input shape")

    def __add__(self, other):
        """
        Add two BinaryMetrics instances
        Params:
            other (''BinaryMetrics''): A BinaryMetrics instance
        Return:
            ''BinaryMetrics''
        """

        return BinaryMetrics(np.append(self.observation, other.observation),
                             np.append(self.score, other.score),
                            np.append(self.prediction, other.prediction))


    def __radd__(self, other):
        """
        Add a BinaryMetrics instance with reversed operands
        Params:
            other
        Returns:
            ''BinaryMetrics
        """

        if other == 0:
            return self
        else:
            return self.__add__(other)


    def confusion_matrix(self):
        """
        Calculate confusion matrix of given ground truth and predicted label
        Returns:
            ''pandas.dataframe'' of observation on the column and prediction on the row
        """

        refArray = self.observation
        predArray = self.prediction

        if refArray.max() > 1 or predArray.max() > 1:
            raise Exception("Invalid array")
        predArray = predArray * 2
        sub = refArray - predArray

        self.tp = np.sum(sub == -1)
        self.fp = np.sum(sub == -2)
        self.fn = np.sum(sub == 1)
        self.tn = np.sum(sub == 0)

        confusionMatrix = pd.DataFrame(data=np.array([[self.tn, self.fp], [self.fn, self.tp]]),
                                       index=['observation = 0', 'observation = 1'],
                                       columns = ['prediction = 0', 'prediction = 1'])
        return confusionMatrix


    def iou(self):
        """
        Calculate interception over union
        Returns:
            float
        """

        return metrics.jaccard_score(self.observation, self.prediction)


    def precision(self):
        """
        Calculate precision
        Returns:
            float
        """

        return metrics.precision_score(self.observation, self.prediction)


    def recall(self):
        """
        Calculate recall
        Returns:
            float
        """

        return metrics.recall_score(self.observation, self.prediction)


    def accuracy(self):
        """
        Calculate accuracy
        Returns:
            float
        """

        return metrics.accuracy_score(self.observation, self.prediction)


    def tss(self):
        """
        Calculate true scale statistic (TSS)
        Returns:
            float
        """

        return self.tp / (self.tp + self.fn) + self.tn / (self.tn + self.fp) - 1


    def false_positive_rate(self):
        """
        Calculate false positive rate
        Returns:
             float
        """

        return self.fp / (self.tn + self.fp)

    def F1_measure(self):
        """
        Calculate F1 score.
        Returns:
            float
        """

        try:
            precision = self.tp / (self.tp + self.fp)
            recall = self.tp / (self.tp + self.fn)
            f1 = (2 * precision * recall) / (precision + recall)

        except ZeroDivisionError:
            precision = self.tp / (self.tp + self.fp + self.eps)
            recall = self.tp / (self.tp + self.fn + self.eps)
            f1 = (2 * precision * recall) / (precision + recall + self.eps)

        return f1


    def area_under_roc(self):
        """
        Compute Area Under the Curve (AUC)
        Returns:
            float
        """

        return metrics.roc_auc_score(self.observation, self.score)
    
##################################################

def evaluate(evalData, model, buffer, gpu, csv_fn):
    """
    Evaluate model
    Params:
        evalData (''DataLoader''): Batch grouped data
        model: Trained model for validation
        buffer: Buffer added to the targeted grid when creating dataset. This allows metrics to calculate only
            at non-buffered region
        gpu (binary,optional): Decide whether to use GPU, default is True
        csv_fn (str): filename to save metrics

    """

    model.eval()

    metrics = []

    for img, label in evalData:
        img = Variable(img, requires_grad=False)
        label = Variable(label, requires_grad=False)

        # GPU setting
        if gpu:
            img = img.cuda()
            label = label.cuda()
        out = model(img)

        # Compute metrics
        out = F.softmax(out, 1)
        batch, nclass, height, width = out.size()

        for i in range(batch):
            label_batch = label[i, buffer:-buffer, buffer:-buffer].cpu().numpy()
            batch_predict = out.max(dim=1)[1][:, buffer:-buffer, buffer:-buffer].data[i].cpu().numpy()
            for n in range(1, nclass):
                class_out = out[:, n, buffer:-buffer, buffer:-buffer].data[i].cpu().numpy()
                class_predict = np.where(batch_predict == n, 1, 0)
                class_label = np.where(label_batch == n, 1, 0)
                metrics_chip = BinaryMetrics(class_label, class_out, class_predict)
                # append if exists
                try:
                    metrics[n - 1].append(metrics_chip)
                except:
                    metrics.append([metrics_chip])

    metrics = [sum(m) for m in metrics]
    report = pd.DataFrame({
        'tss': [m.tss() for m in metrics],
        'accuracy': [m.accuracy() for m in metrics],
        'precision': [m.precision() for m in metrics],
        'recall': [m.recall() for m in metrics],
        'fpr': [m.false_positive_rate() for m in metrics],
        'F1-score': [m.F1_measure() for m in metrics],
        'IoU': [m.iou() for m in metrics],
        'AUC': [m.area_under_roc() for m in metrics]
    }, index=["class_{}".format(m) for m in range(1, len(metrics) + 1)])
    
    print(report)
    with open(csv_fn, 'w', encoding = 'utf-8-sig') as f:
      report.to_csv(f)


### Custom Data Class

In [11]:
class planetData(Dataset):
    '''
    Dataset of planet scope image files for pytorch architecture
    '''

    def __init__(self, root_dir, catalog, dataSize, buffer, bufferComp, usage, imgPathCols, labelPathCol=None,
                 labelGroup = [0,1,2,3,4], catalogIndex=None, deRotate=(-90, 90), bShiftSubs=(4, 4), trans=None):

        '''
        Params:
            root_dir (str): Directory storing files of variables and labels
            catalog (Pandas.DataFrame): Pandas dataframe giving the list of data and their directories
            dataSize (int): Size of chips that is not buffered, i.e., the size of labels
            buffer (int): Distance to target chips' boundaries measured by number of pixels when extracting images
                (variables), i.e., variables size would be (dsize + buffer) x (dsize + buffer)
            bufferComp (int): Buffer used when creating composite. In the case of Ghana, it is 11.
            usage (str): Usage of the dataset : "train", "validate" or "predict"
            imgPathCols (list): Column names in the catalog referring to image paths
            labelPathCol(str): Column name in the catalog referring to label paths
            labelGroup (list): Group indices of labels to load, where each group corresponds to a specific level of label quality
            catalogIndex (int or None): Row index in catalog to load data for prediction. Only need to be specified when
                usage is "prediction"
            deRotate (tuple or None): Range of degrees for rotation
            bShiftSubs (tuple or list): Number of bands or channels on dataset for each brightness shift
            trans (list): Data augmentation methods: one or multiple elements from ['vflip','hflip','dflip', 'rotate',
                'resize']
        Note:
            Provided transformation are:
                1) 'vflip', vertical flip
                2) 'hflip', horizontal flip
                3) 'dflip', diagonal flip
                4) 'rotate', rotation
                5) 'resize', rescale image fitted into the specified data size
                6) 'shift_brightness', shift brightness of images
            Any value out of the range would cause an error
        Note:
            Catalog for train and validate contrains at least columns for image path, label path and "usage".
            Catalog for prediction contains at least columns for image path, "tile_col", and "tile_row", where the
            "tile_col" and "tile_row" is the relative tile location for naming predictions in Learner
        '''

        self.buffer = buffer
        self.composite_buffer = bufferComp
        self.data_size = dataSize
        self.chip_size = self.data_size+ self.buffer * 2

        self.usage = usage
        self.deRotate = deRotate
        self.bshift_subs = bShiftSubs
        self.trans = trans

        self.data_path = root_dir
        self.img_cols = imgPathCols if isinstance(imgPathCols, list) else [imgPathCols]
        self.label_col = labelPathCol

        if self.usage == "train":
            self.catalog = catalog.loc[(catalog['usage'] == self.usage) &
                                       (catalog['label_group'].isin(labelGroup))]
            self.img, self.label = self.get_train_validate_data()
            print('-------------{} samples loaded in training dataset-----------'.format(len(self.img)))

        elif self.usage == "validate":
            self.catalog = catalog.loc[(catalog['usage'] == self.usage) &
                                       (catalog['label_group'].isin(labelGroup))]
            self.img, self.label = self.get_train_validate_data()
            print('-------------{} samples loaded in validation dataset-----------'.format(len(self.img)))

        elif self.usage == "predict":
            self.catalog = catalog.iloc[catalogIndex]
            self.tile = (self.catalog['tile_col'], self.catalog['tile_row'])
            self.img, self.index, self.meta = self.get_predict_data()

        else:
            raise ValueError("Bad usage value")


    def get_train_validate_data(self):
        '''
        Get paris of image, label for train and validation
        Returns:
            tuple of list of images and label
        '''

        def load_label(row, data_path):

            buffer = self.buffer

            dir_label = row[self.label_col] if row[self.label_col].startswith("s3") \
                else os.path.join(data_path, row[self.label_col])
            label = load_data(dir_label, isLabel=True)
            label = np.pad(label, buffer, 'constant')

            return label

        def load_img(row, data_path):

            buffer = self.buffer

            dir_label = row['dir_label'] if row['dir_label'].startswith("s3") \
                else os.path.join(data_path, row['dir_label'])
            dir_imgs = [row[m] if row[m].startswith("s3") else os.path.join(data_path, row[m]) for m in self.img_cols]
            window = get_buffered_window(dir_imgs[0], dir_label, buffer)
            img = get_stacked_img(dir_imgs, self.usage, window=window)

            return img

        global list_data # Local function not applicable in parallelism
        def list_data(catalog, data_path):

            catalog["img"] = catalog.apply(lambda row: load_img(row, data_path), axis=1)
            catalog["label"] = catalog.apply(lambda row: load_label(row, data_path), axis=1)

            return catalog.filter(items=['label', 'img'])

        catalog = parallelize_df(self.catalog, list_data, data_path = self.data_path)

        img_ls = catalog['img'].tolist()
        label_ls = catalog['label'].tolist()

        return img_ls, label_ls



    def get_predict_data(self):
        '''
        Get data for prediction
        Returns:
            list of cropped chips
            list of index representing location of each chip in tile
            dictionary of metadata of score map reconstructed from chips
        '''

        dir_imgs = [self.catalog[m] if self.catalog[m].startswith("s3") \
            else os.path.join(self.data_path, self.catalog[m]) for m in self.img_cols]
        img = get_stacked_img(dir_imgs, self.usage)  # entire composite image in (H, W, C)
        buffer_diff = self.buffer - self.composite_buffer
        h,w,c = img.shape

        if buffer_diff > 0:
            canvas = np.zeros((h + buffer_diff * 2, w + buffer_diff * 2, c))

            for i in range(c):
                canvas[:,:,i] = np.pad(img[:,:,i], buffer_diff, mode='reflect')
            img = canvas

        else:
            img = img[buffer_diff:h-buffer_diff, buffer_diff:w-buffer_diff, :]

        meta = get_meta_from_bounds(dir_imgs[0], self.composite_buffer) # meta of composite buffer removed
        img_ls, index_ls = get_chips(img, self.chip_size, self.buffer)

        return img_ls, index_ls, meta


    def __getitem__(self, index):
        """
        Support dataset indexing and apply transformation
        Args:
            index -- Index of each small chips in the dataset
        Returns:
            tuple
        """

        if self.usage in ["train", "validate"]:
            img = self.img[index]
            label = self.label[index]


            if self.usage == "train":
                mask = np.pad(np.ones((self.data_size, self.data_size)), self.buffer, 'constant')
                trans = self.trans
                # trans = None
                deRotate = self.deRotate

                if trans:

                    # 0.5 possibility to flip
                    trans_flip_ls = [m for m in trans if 'flip' in m]
                    if random.randint(0, 1) and len(trans_flip_ls) > 1:
                        trans_flip = random.sample(trans_flip_ls, 1)
                        img, label, mask = flip(img, label, mask, trans_flip[0])

                    # 0.5 possibility to resize
                    if random.randint(0, 1) and 'resize' in trans:
                        img, label, mask = reScale(img, label.astype(np.uint8), mask.astype(np.uint8),
                                                   randResizeCrop=True, diff=True, cenLocate=False)

                    # 0.5 possibility to rotate
                    if random.randint(0, 1) and 'rotate' in trans:
                        img, label, mask = centerRotate(img, label, mask, deRotate)

                    # 0.5 possibility to shift brightness
                    if random.randint(0, 1) and 'shift_brightness' in trans:
                        img = shiftBrightness(img, gammaRange = (0.2, 2), shiftSubset = self.bshift_subs, patchShift=True)

                # numpy to torch
                label = torch.from_numpy(label).long()
                mask = torch.from_numpy(mask).long()
                img = torch.from_numpy(img.transpose((2, 0, 1))).float()

                # display(img[:, self.buffer:-self.buffer, self.buffer:-self.buffer], label[self.buffer:-self.buffer,self.buffer:-self.buffer], mask[self.buffer:-self.buffer,self.buffer:-self.buffer])
                # display(img, label, mask)

                return img, label, mask

            else:
                # numpy to torch
                label = torch.from_numpy(label).long()
                img = torch.from_numpy(img.transpose((2, 0, 1))).float()

                return img, label

        else:

            img = self.img[index]
            index = self.index[index]

            img = torch.from_numpy(img.transpose((2, 0, 1))).float()

            return img, index


    def __len__(self):
        '''
        Get size of the dataset
        '''

        return len(self.img)


### Model Architecture - Unet

In [12]:
class Conv3x3_bn_relu(nn.Module):
    def __init__(self, inch, outch, padding=0, stride=1, dilation=1, groups=1, relu=True):
        super(Conv3x3_bn_relu, self).__init__()
        self.applyRelu = relu

        self.conv = nn.Sequential(nn.Conv2d(inch, outch, 3, padding=padding, stride=stride, dilation=dilation, groups=groups),
                                  nn.BatchNorm2d(outch))
        if self.applyRelu:
            self.relu = nn.ReLU(True)

    def forward(self, x):
        out = self.conv(x)
        if self.applyRelu:
            out = self.relu(out)
        return out

#########################

class Conv1x1_bn_relu(nn.Module):
    def __init__(self, inch, outch, stride=1, padding=0, dilation=1, groups=1, relu=True):
        super(Conv1x1_bn_relu, self).__init__()
        self.applyRelu = relu
        self.conv = nn.Sequential(nn.Conv2d(inch, outch, 1, stride=stride, padding=padding, dilation=dilation, groups=groups),
                                  nn.BatchNorm2d(outch))

        if self.applyRelu:
            self.relu = nn.ReLU(True)
    def forward(self, x):
        x = self.conv(x.clone())
        if self.applyRelu:
            x = self.relu(x)
        return x

#########################
# Consecutive 2 convolution with batch normalization and ReLU activation
class doubleConv(nn.Module):
    def __init__(self, inch, outch):
        super(doubleConv, self).__init__()
        self.conv1 = Conv3x3_bn_relu(inch, outch, padding=1)
        self.conv2 = Conv3x3_bn_relu(outch, outch, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

##################################################

# unet construction
class unet(nn.Module):
    def __init__(self, inch, classNum):
        super(unet, self).__init__()
        # downsample
        self.dlyr1 = doubleConv(inch, 64)
        self.ds = nn.MaxPool2d(2, stride=2)
        self.dlyr2 = doubleConv(64, 128)
        self.dlyr3 = doubleConv(128, 256)
        self.dlyr4 = doubleConv(256, 512)
        self.dlyr5 = doubleConv(512, 1024)
        self.dlyr6 = doubleConv(1024, 2048)

        # upsample
        self.us_init = nn.ConvTranspose2d(2048, 1024, 4, stride=2, padding=1)
        self.ulyr_init = doubleConv(2048, 1024)
        self.us6 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)
        self.ulyr6 = doubleConv(1024, 512)  # 512x32x32
        self.us7 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.ulyr7 = doubleConv(512, 256)
        self.us8 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.ulyr8 = doubleConv(256, 128)
        self.us9 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.ulyr9 = doubleConv(128, 64)
        self.dimTrans = nn.Conv2d(64, classNum, 1)

    def forward(self, x):
        # downsample
        dlyr1 = self.dlyr1(x)
        ds1 = self.ds(dlyr1) #

        dlyr2 = self.dlyr2(ds1)
        ds2 = self.ds(dlyr2)
        dlyr3 = self.dlyr3(ds2)
        ds3 = self.ds(dlyr3)
        dlyr4 = self.dlyr4(ds3)
        ds4 = self.ds(dlyr4)
        dlyr5 = self.dlyr5(ds4)
        ds_last = self.ds(dlyr5)
        dlyr_last = self.dlyr6(ds_last)
        # upsample

        us_init = self.us_init(dlyr_last)
        ulyr_init = self.ulyr_init(torch.cat([us_init, dlyr5], 1))
        us6 = self.us6(ulyr_init)
        merge6 = torch.cat([us6, dlyr4], 1)  # channel is the second dimension after batch operation
        ulyr6 = self.ulyr6(merge6)
        us7 = self.us7(ulyr6)
        merge7 = torch.cat([us7, dlyr3], 1)
        ulyr7 = self.ulyr7(merge7)
        us8 = self.us8(ulyr7)
        merge8 = torch.cat([us8, dlyr2], 1)
        ulyr8=self.ulyr8(merge8)
        us9 = self.us9(ulyr8)
        merge9 = torch.cat([us9, dlyr1], 1)
        ulyr9 = self.ulyr9(merge9)
        dimTrans = self.dimTrans(ulyr9)
        
        return dimTrans

### Training, Validatation and Prediction

In [54]:
def train(trainData, model, criterion, optimizer, scheduler, trainLoss=[], gpu=True):
    """
    Train model
    Params:
        trainData (''DataLoader''): Batch grouped data
        model: Model to train
        classNum (int): Number of categories to classify
        criterion: Function to caculate loss
        oprimizer: Funtion for optimzation
        scheduler: Update policy for learning rate decay.
        trainLoss: (empty list) To record average loss for each epoch
        gpu: (binary,optional) Decide whether to use GPU, default is True
    """

    model.train()

    # mini batch iteration
    epoch_loss = 0
    i = 0

    for img, label, mask in trainData:

        # forward
        img = Variable(img)
        label = Variable(label)
        if gpu:
            img = img.cuda()
            label = label.cuda()

        out = model(img)
        label = label * mask.cuda()
        mask = torch.stack([mask]*out.size()[1], dim=1)
        out = out * mask.cuda()

        loss = criterion()(out, label)
        epoch_loss += loss.item()
        i += 1

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # avoid calling to config.yaml
        isCyclicLR = False
        if type(scheduler) == torch.optim.lr_scheduler.CyclicLR:
            scheduler.step()
            isCyclicLR = True

    print(f'train loss:{epoch_loss / i}')
    if isCyclicLR:
        print(f"LR: {scheduler.get_last_lr()}")

    if trainLoss != None:
        trainLoss.append(float(epoch_loss / i))

##################################################

def validate(valData, model, criterion, buffer, valLoss, gpu):
    """
        Validate model
        Params:
            valData (''DataLoader''): Batch grouped data
            model: Trained model for validation
            criterion: Function to calculate loss
            buffer: Buffer added to the targeted grid when creating dataset. 
                This allows loss to calculate at non-buffered region
            valLoss (empty list): To record average loss for each epoch
            gpu (binary,optional): Decide whether to use GPU, default is True
    """

    model.eval()

    # mini batch iteration
    epoch_loss = 0
    i = 0

    for img, label in valData:

        img = Variable(img, requires_grad=False)
        label = Variable(label, requires_grad=False)

        # GPU setting
        if gpu:
            img = img.cuda()
            label = label.cuda()

        out = model(img)

        loss = criterion()(out[:, :, buffer:-buffer, buffer:-buffer],
                            label[:, buffer:-buffer, buffer:-buffer])
        epoch_loss += loss.item()
        i += 1

    print('validation loss: {}'.format(epoch_loss / i))

    if valLoss != None:
        valLoss.append(float(epoch_loss / i))

##################################################

def predict(predData, model, buffer, gpu, shrinkPixel):
    """
    Predict by tile
    Params:
        predData (''DataLoader''): Batch grouped data
        model: Trained model for prediction
        buffer (int): Buffer to cut out when writing chips
        gpu (binary,optional): Decide whether to use GPU, default is True
        shrinkPixel (int, optional): pixel numbers to cut out on each side 
            before averging neighbors.
    """
    predData, meta, tile = predData
    meta.update({
        'dtype': 'int8'
    })

    model.eval()

    # create dummy tile
    canvas_score_ls = []

    for img, index_batch in predData:

        img = Variable(img, requires_grad=False)

        # GPU setting
        if gpu:
            img = img.cuda()

        out = F.softmax(model(img), 1)
        batch, nclass, height, width = out.size()
        chip_height = height - buffer * 2
        chip_width = width - buffer * 2
        max_index_0 = meta['height'] - chip_height
        max_index_1 = meta['width'] - chip_width

        # new by taking average
        for i in range(batch):
            index = (index_batch[0][i], index_batch[1][i])
            # only score here
            for n in range(nclass - 1):
                out_score = out[
                    :, n + 1, 
                    (index[0] != 0) * buffer : (index[0] != 0) * buffer + \
                    chip_height + (index[0]==0 or index[0] == max_index_0) * \
                    buffer,
                    (index[1] != 0) * buffer: (index[1] != 0) * buffer + \
                    chip_height + (index[1] == 0 or index[1] == max_index_1) * \
                    buffer
                ].data[i].cpu().numpy() * 100
                out_score = out_score.astype(meta['dtype'])
                score_height, score_width = out_score.shape

                try:
                    # if canvas_score_ls[n] exists
                    canvas_score_ls[n][
                        index[0] + buffer * (index[0] != 0): index[0] + \
                        buffer * (index[0] != 0)+ score_height,
                        index[1]+ buffer * (index[1] != 0): index[1] + \
                        buffer * (index[1] != 0)+ score_width
                    ] = out_score

                except:
                    # create masked canvas_score_ls[n]
                    canvas_score = np.zeros(
                        (meta['height'] + buffer * 2, 
                         meta['width'] + buffer * 2), dtype=meta['dtype']
                    )

                    canvas_score[
                        index[0] + buffer * (index[0] != 0): index[0] + \
                        buffer * (index[0] != 0)+ score_height,
                        index[1]+ buffer * (index[1] != 0): index[1] + \
                        buffer * (index[1] != 0)+ score_width
                    ] = out_score
                    canvas_score_ls.append(canvas_score)


    for j in range(len(canvas_score_ls)):
        canvas_score_ls[j] = canvas_score_ls[j][
          shrinkPixel:meta['height'] + buffer * 2 -shrinkPixel, 
          shrinkPixel:meta['width'] + buffer * 2 - shrinkPixel
        ]

    return canvas_score_ls

### Model Compiler

In [3]:
def get_optimizer(optimizer, params, lr, momentum):

    optimizer = optimizer.lower()
    if optimizer == 'sgd':
        return torch.optim.SGD(params, lr, momentum=momentum)
    elif optimizer == 'nesterov':
        return torch.optim.SGD(params, lr, momentum=momentum, nesterov=True)
    elif optimizer == 'adam':
        return torch.optim.Adam(params, lr)
    elif optimizer == 'amsgrad':
        return torch.optim.Adam(params, lr, amsgrad=True)
    else:
        raise ValueError(
            "{} currently not supported," \
            "please customize your optimizer in compiler.py".format(optimizer)
        )


def weighted_average_overlay(predDict, overlayPixels):

    if isinstance(predDict, dict):
        key_ls = ["top", "center", "left", "right", "bottom"]
        key_miss_ls = [m for m in predDict.keys() if m not in key_ls]
        if len(key_miss_ls) == 0:
            pass
        else:
            assert "Input must be dictionary containing data for centered image and its 4 neighbors."\
            "Missed {}".format(", ".join(key_miss_ls))
    else:
        assert "Input must be dictionary containing data for centered image and its 4 neighbors, " \
               "including including 'top', 'left', 'right', and  'bottom'"

    target = predDict['center']
    h, w = target.shape
    # top
    if predDict['top'] is not None:
        target_weight = np.array([1. / overlayPixels * np.arange(1, overlayPixels + 1)] * w).transpose(1, 0)
        comp_weight = 1. - target_weight
        # comp = scores_dict["up"][- overlay_pixs : , : ]
        target[:overlayPixels, :] = comp_weight * predDict['top'][- overlayPixels:, :] + \
                                   target_weight * target[:overlayPixels, :]
    else:
        pass
    # bottom
    if predDict['bottom'] is not None:
        target_weight = np.array([1. / overlayPixels * np.flip(np.arange(1, overlayPixels + 1))] * w).transpose(1, 0)
        comp_weight = 1. - target_weight
        target[-overlayPixels:, :] = comp_weight * predDict['bottom'][:overlayPixels, :] + \
                                    target_weight * target[-overlayPixels:, :]
    else:
        pass
    # left
    if predDict['left'] is not None:
        target_weight = np.array([1. / overlayPixels * np.arange(1, overlayPixels + 1)] * h)
        comp_weight = 1 - target_weight
        target[:, :overlayPixels] = comp_weight * predDict['left'][:, -overlayPixels:] + \
                                   target_weight * target[:, :overlayPixels]
    else:
        pass
    # right
    if predDict['right'] is not None:
        target_weight = np.array([1. / overlayPixels * np.flip(np.arange(1, overlayPixels + 1))] * h)
        comp_weight = 1 - target_weight
        target[:, -overlayPixels:] = comp_weight * predDict['right'][:, :overlayPixels] + \
                                    target_weight * target[:, -overlayPixels:]
    else:
        pass

    return target


class ModelCompiler:
    """
    Compiler of specified model
    Args:
        model (''nn.Module''): pytorch model for segmentation
        buffer (int): distance to sample edges not considered in optimization
        gpuDevices (list): indices of gpu devices to use
        params_init (dict object): initial model parameters
        freeze_params (list): list of indices for parameters to keep frozen
    """

    def __init__(self, model, buffer, gpuDevices=[0], params_init=None, freeze_params=None):
        
        self.working_dir = config["working_dir"]
        self.out_dir = config["out_dir"]

        # model
        self.gpuDevices = gpuDevices
        self.model = model

        self.model_name = self.model.__class__.__name__

        if params_init:
            self.load_params(params_init, freeze_params)

        self.buffer = buffer

        # gpu
        self.gpu = torch.cuda.is_available()

        if self.gpu:
            print('----------GPU available----------')
            # GPU setting
            if gpuDevices:
                torch.cuda.set_device(gpuDevices[0])
                self.model = torch.nn.DataParallel(self.model, device_ids=gpuDevices)
            self.model = self.model.cuda()

        num_params = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        print("total number of trainable parameters: {:2.1f}M".format(num_params / 1000000))

        if params_init:
            print("---------- Pre-trained model compiled successfully ----------")
        else:
            print("---------- Vanilla Model compiled successfully ----------")

    def load_params(self, dir_params, freeze_params):

        params_init = urlparse.urlparse(dir_params)

        inparams = torch.load(params_init.path)

        ## overwrite model entries with new parameters
        model_dict = self.model.state_dict()

        if "module" in list(inparams.keys())[0]:
            inparams_filter = {k[7:]: v.cpu() for k, v in inparams.items() if k[7:] in model_dict}

        else:
            inparams_filter = {k: v.cpu() for k, v in inparams.items() if k in model_dict}
        model_dict.update(inparams_filter)
        # load new state dict
        self.model.load_state_dict(model_dict)

        # free some layers
        if freeze_params != None:
            for i, p in enumerate(self.model.parameters()):
                if i in freeze_params:
                    p.requires_grad = False


    def fit(self, trainDataset, valDataset, epochs, optimizer_name, lr_init, lr_policy,
            criterion, momentum = None, resume=False, resume_epoch=None, **kwargs):
        
        self.model_dir = "{}/{}/{}_ep{}".format(self.working_dir, self.out_dir, config["model_name"], config["epochs"])
        
        if not os.path.exists(Path(self.working_dir) / self.out_dir / self.model_dir):
            os.makedirs(Path(self.working_dir) / self.out_dir / self.model_dir)
        
        self.checkpoint_dir = Path(self.working_dir) / self.out_dir / self.model_dir / "chkpt"
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        
        os.chdir(Path(self.working_dir) / self.out_dir / self.model_dir)
        
        print("-------------------------- Start training --------------------------")
        start = datetime.now()
        
        writer = SummaryWriter('./')
        lr = lr_init
        train_loss = []
        val_loss = []
        
        optimizer = get_optimizer(optimizer_name, 
                                  filter(lambda p: p.requires_grad, self.model.parameters()),
                                  lr, 
                                  momentum)

        # initialize different learning rate scheduler
        lr_policy = lr_policy.lower()
        if lr_policy == "StepLR".lower():
            step_size = kwargs.get("step_size", 3)
            gamma = kwargs.get("gamma", 0.98)
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=step_size, gamma=gamma)

        elif lr_policy == "MultiStepLR".lower():
            milestones = kwargs.get("milestones", [15, 25, 35, 50, 70, 90, 120, 150, 200])
            gamma = kwargs.get("gamma", 0.5)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=gamma,
            )

        elif lr_policy == "ReduceLROnPlateau".lower():
            mode = kwargs.get('mode', 'min')
            factor = kwargs.get('factor', 0.8)
            patience = kwargs.get('patience', 3)
            threshold = kwargs.get('threshold', 0.0001)
            threshold_mode = kwargs.get('threshold_mode', 'rel')
            min_lr = kwargs.get('min_lr', 3e-6)
            verbose = kwargs.get('verbose', True)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode=mode, factor=factor, patience=patience, threshold=threshold,
                threshold_mode=threshold_mode, min_lr=min_lr, verbose=verbose
            )

        elif lr_policy == "PolynomialLR".lower():
            max_decay_steps = kwargs.get('max_decay_steps', 100)
            min_learning_rate = kwargs.get('min_learning_rate', 1e-5)
            power = kwargs.get('power', 0.8)
            scheduler = PolynomialLR(
                optimizer, max_decay_steps=max_decay_steps, min_learning_rate=min_learning_rate,
                power=power
            )

        elif lr_policy == "CyclicLR".lower():
            base_lr = kwargs.get('base_lr', 3e-5)
            max_lr = kwargs.get('max_lr', 0.01)
            step_size_up = kwargs.get('step_size_up', 1100)
            mode =  kwargs.get('mode', 'triangular')
            scheduler = torch.optim.lr_scheduler.CyclicLR(
                optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=step_size_up,
                mode=mode
            )

        else:
            scheduler = None

        if resume:
            model_state_file = os.path.join(
                self.checkpoint_dir,
                f"{resume_epoch}_checkpoint.pth.tar"
            )

            # Resume the model from the specified checkpoint in the config file.
            if os.path.exists(model_state_file):
                print(f"Model is resumed from checkpoint: {model_state_file}")
                checkpoint = torch.load(model_state_file)
                resume_epoch = checkpoint["epoch"]
                scheduler.load_state_dict(checkpoint["scheduler"])
                self.model.load_state_dict(checkpoint["state_dict"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                train_loss = checkpoint["train_loss"]
                val_loss = checkpoint["val_loss"]
            else:
                raise ValueError(f"{model_state_file} does not exist")

        if resume:
            iterable = range(resume_epoch, epochs)
        else:
            iterable = range(epochs)

        for t in iterable:

            print(f"[{t+1}/{epochs}]")

            # start fitting
            start_epoch = datetime.now()
            train(trainDataset, self.model, criterion, optimizer, scheduler, 
                  gpu=self.gpu, trainLoss=train_loss)
            validate(valDataset, self.model, criterion, self.buffer, 
                     gpu=self.gpu, valLoss=val_loss)

            # Update the scheduler
            if lr_policy in ["StepLR".lower(), "MultiStepLR".lower()]:
                scheduler.step()
                print(f"LR: {scheduler.get_last_lr()}")

            if lr_policy == "ReduceLROnPlateau".lower():
                scheduler.step(val_loss[t])

            if lr_policy == "PolynomialLR".lower():
                scheduler.step(t)
                print(f"LR: {optimizer.param_groups[0]['lr']}")

            # time spent on single iteration
            print('time:', (datetime.now() - start_epoch).seconds)

            writer.add_scalars(
                "Loss",
                {"train_loss": train_loss[t],
                 "val_loss": val_loss[t]},
                 t + 1)

            checkpoint_interval = 2 # e.g. save every 10 epochs
            if (t+1) % checkpoint_interval == 0:
                torch.save(
                    {
                        "epoch": t+1,
                        "state_dict": self.model.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "train_loss": train_loss,
                        "val_loss": val_loss
                    }, os.path.join(
                        self.checkpoint_dir,
                        f"{t+1}_checkpoint.pth.tar")
                )

        writer.close()

        print(f"-------------------------- Training finished in {(datetime.now() - start).seconds}s --------------------------")


    def evaluate(self, evalDataset, csv_fn):

        if not os.path.exists(Path(self.working_dir) / self.out_dir):
            os.makedirs(Path(self.working_dir) / self.out_dir)
        
        print('-------------------------- Start evaluation --------------------------')
        start = datetime.now()

        evaluate(evalDataset, self.model, self.buffer, csv_fn, self.gpu)

        print(f"-------------------------- Evaluation finished in {(datetime.now() - start).seconds}s --------------------------")


    def predict(self, predDataset, out_prefix, predBuffer=None, 
                averageNeighbors=False, shrinkBuffer=0):
        
        # predDataset must be dictionary containing target and all 4 neighbors if averageNeighbors
        if averageNeighbors == True:
            if isinstance(predDataset, dict):
                key_ls = ["top", "center", "left", "right", "bottom"]
                key_miss_ls = [m for m in predDataset.keys() if m not in key_ls]
                if len(key_miss_ls) == 0:
                    pass
                else:
                    assert "predDataset must be dictionary containing data for centered image and its 4 neighbors when " \
                           "averageNeighbors set to be True. Missed {}".format(", ".join(key_miss_ls))
            else:
                assert "predDataset must be dictionary containing data for centered image and its 4 neighbors when " \
                       "averageNeighbors set to be True, including 'top', 'left', 'right', 'bottom'"
        else:
            pass

        print('-------------------------- Start prediction --------------------------')
        start = datetime.now()
        
        if out_prefix is None:
            out_prefix = Path(self.working_dir) / self.out_dir / "Inference_output"
            Path(out_prefix).mkdir(parents=True, exist_ok=True)

        _, meta, tile = predDataset["center"] if isinstance(predDataset, dict) else predDataset
        #set_trace()
        name_score = 'score_c{}_r{}.tif'.format(tile[0], tile[1])
        meta.update({
            'dtype': 'int8'
        })

        new_buffer = predBuffer - shrinkBuffer
        
        if averageNeighbors:
            scores_dict = {k: predict(predDataset[k], self.model, predBuffer, gpu=self.gpu, shrinkPixel=shrinkBuffer) if predDataset[k]
                           else None for k in predDataset.keys()}

            nclass = len(list(scores_dict['center']))
            overlay_pixs = new_buffer * 2

            for n in range(nclass):
                score_dict = {k: scores_dict[k][n] if scores_dict[k] else None for k in scores_dict.keys()}
                score = weighted_average_overlay(score_dict, overlay_pixs)
                # write to Drive
                score = score[new_buffer: meta['height'] + new_buffer, new_buffer:meta['height'] + new_buffer]
                score = np.expand_dims(score, axis=0).astype(meta['dtype'])
                
                updated_name_score = "class_{}_".format(n) + name_score
                with rasterio.open(Path(out_prefix) / updated_name_score, "w", **meta) as dst:
                    dst.write(score)
                    

        # when not averageNeighbors
        else:
            scores = predict(predDataset, self.model, predBuffer, gpu=self.gpu, shrinkPixel=shrinkBuffer)
            # write score of each non-background classes into s3
            nclass = len(scores)
            for n in range(nclass):
                canvas = scores[n][new_buffer: meta['height'] + new_buffer, new_buffer: meta['width'] + new_buffer]
                canvas = np.expand_dims(canvas, axis=0).astype(meta['dtype'])

                updated_name_score = "class_{}_".format(n) + name_score
                with rasterio.open(Path(out_prefix) / updated_name_score, "w", **meta) as dst:
                    dst.write(canvas)

        print('-------------------------- Prediction finished in {}s --------------------------' \
              .format((datetime.now() - start).seconds))
    
    def save(self, object = "params"):
        
        if object == "params":
            torch.save(self.model.state_dict(), 
                       os.path.join(self.checkpoint_dir, "{}_final_state.pth".format(config["model_name"])))
            
            print("--------------------- Model parameters is saved to disk ---------------------")
        
        elif object == "model":
            torch.save(self.model, 
                       os.path.join(self.checkpoint_dir, "{}_final_state.pth".format(config["model_name"])))
        
        else:
            raise ValueError("Improper object type.")

## Train the Model

### Define parameters

We first set up a configuration dictionary with all the parameters the model needs to run

In [4]:
#!ls /content/gdrive/MyDrive/teaching/geog287387/data/fieldmapping/

In [None]:
config = {
    
    #I/O setup
    "working_dir" : "/content/gdrive/MyDrive/working_folder",
    "out_dir" : "testing_11",
    
    "root_dir" : "/content/gdrive/MyDrive/field_mapping/",
    "catalog_train_fn" : "catalog_ghana_ecaas_ejura_tain.csv",
    "catalog_pred_fn" : "catalog_predict_nicfi_retiled_ejura_tain_2020-11.csv",
    
    # Train Dataset and Loader
    "patch_size" : 200,
    "buffer" : 12,
    "composite_buffer" : 11,
    "img_path_cols": ['dir_gs', 'dir_os'],
    "label_path_col": "dir_label",
    "label_group_train" : [2, 3, 4],
    "transformation" : ['vflip', 'hflip', 'rotate', 'resize', 'shift_brightness'],
    "rotate_degree" : [-90, 90],
    "brightness_shift_subsets": [4, 4],
    "train_batch" : 12,
    
    # Validation Dataset and Loader
    "label_grou_val": [3, 4],
    "val_batch": 1,
    
    # Model
    "model_name" : "Unet",
    "img_bands" : 8,
    "class_numbers" : 3,
    
    # Compiler
    "gpus" : [0],
    "init_params" : "unet_params.pth",
    "freeze_params": list(range(58)),
    #"init_params" : 
    #    "/content/gdrive/MyDrive/working_folder/params/Unet_ep5/"\
    #    "chkpt/Unet_final_state.pth",    
    
    # Model fitting
    "epochs" : 5,
    "optimizer" : 'nesterov',
    "momentum" : 0.95,
    "lr_init" : 0.01,
    "LR_policy" : "StepLR",
    "criterion" :  BalancedTverskyFocalLoss(gamma = 0.9),
    "resume" : False,
    "resume_epoch" : None,
    
    # Evaluation report on validation dataset
    "val_metric_fname" : "validate_metrics.csv",
    
    # Prediction (Inference)
    "patch_size_pred" : 250,
    "buffer_pred" : 179,
    "composite_buffer_pred" : 179,
    "batch_pred" : 2,
    "average_neighbors": False,
    "shrink_pixels": 54,

    "out_prefix": None  
}

if not os.path.exists(config["working_dir"]):
    os.makedirs(config["working_dir"])
os.chdir(config["working_dir"])

### Read in training data 

Load the catalog

In [25]:
# Reading the train csv
train_catalog = pd.read_csv(os.path.join(config["root_dir"], config["catalog_train_fn"]))
train_catalog

Run the dataloader, which takes a few minutes. This compiles the training dataset consisting of PlanetScope basemap imagery and labels. It chips up the imagery, pairs the chips with labels, puts the sample pairs into mini-batches, performs augmentations, and load them onto the GPU to be fed into the model.


In [None]:
train_data = planetData(root_dir=config["root_dir"], 
                        catalog=train_catalog, 
                        dataSize=config["patch_size"], 
                        buffer=config["buffer"], 
                        bufferComp=config["composite_buffer"], 
                        usage="train",
                        imgPathCols=config["img_path_cols"] ,
                        labelPathCol=config["label_path_col"],
                        labelGroup=config["label_group_train"], 
                        deRotate=config["rotate_degree"], 
                        bShiftSubs=config["brightness_shift_subsets"],
                        trans=config["transformation"])

train_dataloader = DataLoader(train_data, 
                              batch_size=config["train_batch"], 
                              shuffle=True)

And then do the same for the validation sample...

In [None]:
validate_data = planetData(
    root_dir=config["root_dir"], 
    catalog=train_catalog, 
    dataSize=config["patch_size"], 
    buffer=config["buffer"], 
    bufferComp=config["composite_buffer"], 
    usage="validate",
    imgPathCols=config["img_path_cols"],
    labelPathCol=config["label_path_col"],
    labelGroup=config["label_grou_val"]
)

validate_dataloader = DataLoader(
    validate_data, 
    batch_size=config["val_batch"], 
    shuffle=False
)

### Initialize and compile the model

Initialize

In [29]:
model = eval(config["model_name"].lower())(config["img_bands"], config["class_numbers"])

Compile

In [None]:

model = ModelCompiler(
    model=model, 
    buffer=config["buffer"], 
    gpuDevices=config["gpus"], 
    params_init=f'{config["root_dir"]}{config["init_params"]}',
    freeze_params=config["freeze_params"]
)

### Fit the model

In [None]:
model.fit(
    trainDataset=train_dataloader,
    valDataset=validate_dataloader, 
    epochs=config["epochs"],
    optimizer_name=config["optimizer"], 
    lr_init=config["lr_init"], 
    lr_policy=config["LR_policy"],
    criterion=config["criterion"], 
    momentum = config["momentum"], 
    resume=config["resume"], 
    resume_epoch=config["resume_epoch"]
)

Save the trained model parameters for further use either fine-tunning or making predictions.

In [None]:
model.save(object="params")

### Evaluate model performance

Against the validation dataset, saving the report as a csv file.


In [None]:
os.chdir(Path(config["working_dir"])/config["out_dir"])
model.evaluate(validate_dataloader, csv_fn = config["val_metric_fname"])

In [None]:
!cp validate_metrics.csv "/content/gdrive/MyDrive/working_folder/testing_11/Unet_ep5"

## Create Prediction Maps

Using the trained model, we will predict cropland locations on several PlanetScope image tiles.

### Prediction function

In [49]:
# Tile reader: goes through each row of the prediction csv (e.g. each tile) complete the prediction, 
# write the output (both probability and harden) and move to the next tile.
def load_pred_data(dir_data, pred_patch_size, pred_buffer, pred_composite_buffer, 
                   pred_batch, catalog, catalog_row, img_path_cols, average_neighbors=False):
    def load_single_tile(catalog_ind = catalog_row):
        dataset = planetData(dir_data, catalog, pred_patch_size, pred_buffer, 
                             pred_composite_buffer, "predict", 
                             catalogIndex=catalog_ind, imgPathCols=img_path_cols)
        data_loader = DataLoader(dataset, batch_size=pred_batch, shuffle=False)
        meta = dataset.meta
        tile = dataset.tile
        return data_loader, meta, tile

    if average_neighbors == True:
        catalog["tile_col_row"] = catalog.apply(lambda x: "{}_{}".format(x['tile_col'], x['tile_row']), axis=1)
        tile_col = catalog.iloc[catalog_row].tile_col
        tile_row = catalog.iloc[catalog_row].tile_row
        row_dict = {
            "center": catalog_row,
            "top": catalog.query('tile_col=={} & tile_row=={}'.format(tile_col, tile_row - 1)).iloc[0].name \
                if "{}_{}".format(tile_col, tile_row - 1) in list(catalog.tile_col_row) else None,
            "left" : catalog.query('tile_col=={} & tile_row=={}'.format(tile_col - 1, tile_row)).iloc[0].name \
                if "{}_{}".format(tile_col - 1, tile_row) in list(catalog.tile_col_row) else None,
            "right" : catalog.query('tile_col=={} & tile_row=={}'.format(tile_col + 1, tile_row)).iloc[0].name \
                if "{}_{}".format(tile_col + 1, tile_row) in list(catalog.tile_col_row) else None,
            "bottom": catalog.query('tile_col=={} & tile_row=={}'.format(tile_col, tile_row + 1)).iloc[0].name \
                if "{}_{}".format(tile_col, tile_row + 1) in list(catalog.tile_col_row) else None,
            }
        dataset_dict = {k:load_single_tile(catalog_ind = row_dict[k]) if row_dict[k] is not None else None 
                        for k in row_dict.keys()}
        return dataset_dict
    # direct crop edge pixels
    else:
        return load_single_tile()

### Read prediction catalog

In [None]:
pred_catalog = pd.read_csv(
    os.path.join(config["root_dir"], config["catalog_pred_fn"])
)

# make list of tile indices to query
inds = pred_catalog.query("type == 'center'").index.values

### Run predictions

for i in inds:
    print("Predicting on index %s" % (i))
    pred_dataloader = load_pred_data(
        config['root_dir'], config['patch_size_pred'], config['buffer_pred'], 
        config["composite_buffer_pred"], config['batch_pred'], 
        pred_catalog, i, config['img_path_cols'], 
        average_neighbors = config['average_neighbors']
    )
    p = model.predict(
        pred_dataloader, config["out_prefix"], config['buffer_pred'], 
        averageNeighbors=config['average_neighbors'], 
        shrinkBuffer=config['shrink_pixels']
    )

### Plot predictions

Load in predictions

In [None]:
pred_path = f'{config["working_dir"]}/{config["out_dir"]}/Inference_output'
preds = [f'{pred_path}/{file}' for file in os.listdir(pred_path)]
score_maps = [rasterio.open(p) for p in preds]