In [None]:
import os
import sys
import random
from random import shuffle
import gc
import math
from pathlib import Path
from datetime import datetime
from collections import Counter, defaultdict, OrderedDict
import urllib.parse as urlparse
import boto3
import shutil
import tqdm

import numpy as np
from numpy import inf
import pandas as pd
from sklearn import metrics
import rasterio
import pickle
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data
from torch.utils.data import Dataset, DataLoader, Sampler
import torch.nn.utils.rnn as rnn_util
from tensorboardX import SummaryWriter

from IPython.core.debugger import set_trace

In [None]:
print("PyTorch version: {}".format(torch.__version__))
print("Cuda version : {}".format(torch.version.cuda))
print('CUDNN version:', torch.backends.cudnn.version())
print('Number of available GPU Devices:', torch.cuda.device_count())
print("current GPU Device: {}".format(torch.cuda.current_device()))

In [None]:
def make_reproducable(seed = 42, cudnn = True):
    """Make all the randomization processes start from a shared seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if cudnn:
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

make_reproducable()

## Making the raw dataset, analysis-ready

### Rename the data such that the gridIDs in the filenames have leading zeros

In [None]:
import os
from pathlib import Path

def rename_w_leading_0s(root_dir, tif_content, num_digits, ftype, country, source, verbose = False):
    """
    Renames grid IDs with leading 0's so that sorting of filenames is in ascending order
    
    root_dir (str) -- path to the main directory which data resides. Example: "C:/My_documents/Data".
    tif_content (str) -- Seperate between Labels (mask) and img data (data).
    num_digits (int) -- Decide on the number of digits to represent a Grid-ID. Default id 6.
    ftype (str) -- identifies format of the file to be renamed.
    country (str) -- This is based on the organization of dataset that images and 
                     labels reside inside a country folder.
    source (str) -- folder name of the resource. It can be either the remote sensing sensor used to
                    acquire the image dataset or the label dataset.
    verbose (Binary) -- If set to True, prints a report of old and new names on screen. Default is False.
    """
    assert tif_content in ["mask", "data"]
    assert country in ["SouthSudan", "Ghana"]
    assert source in ["Sentinel-1", "Sentinel-2", "Labels"]
    
    path_to_src = Path(root_dir) / country / source
    old_fname = []
    new_fname = []
    
    if tif_content == "mask":
        for dirname in os.listdir(path_to_src):
            gridID = str(dirname).split("_")[-1]
            
            for filename in os.listdir(path_to_src / dirname):
                
                if ftype == "tif":
                    
                    if filename.endswith(".tif"):
                        old_fname += [path_to_src / dirname / filename]
                        new_name = filename.replace(".tif", "_" + gridID).zfill(num_digits) + ".tif"
                        new_fname += [path_to_src / dirname / new_name]
                
                elif ftype == "npy":
                    
                    if filename.endswith(".npy"):
                        old_fname += [path_to_src / dirname / filename]
                        new_name = filename.replace(".npy", "_" + gridID).zfill(num_digits) + ".npy"
                        new_fname += [path_to_src / dirname / new_name]
                    
                    elif filename.endswith(".json"):
                        old_fname += [path_to_src / dirname / filename]
                        new_name = filename.replace(".json", "_" + gridID).zfill(num_digits) + ".json"
                        new_fname += [path_to_src / dirname / new_name]
    
    elif tif_content == 'data':
        
        for dirname in os.listdir(path_to_src):
            string_list = str(dirname).split("_")[-5:]
            string_list[1] = string_list[1].zfill(num_digits)
            replace_string = '_'.join(string_list)
            
            for filename in os.listdir(path_to_src / dirname):
                
                if filename.endswith(".tif"):
                    old_fname += [path_to_src / dirname / filename]
                    new_name = filename.replace(".tif", "_" + replace_string) + ".tif"
                    new_fname += [path_to_src / dirname / new_name]
                
                elif filename.endswith(".json"):
                    old_fname += [path_to_src / dirname / filename]
                    new_name = filename.replace(".json", "_" + replace_string) + ".json"
                    new_fname += [path_to_src / dirname / new_name]
    
    for i,j in zip(old_fname, new_fname):
        os.rename(i, j)
        if verbose:
            print('Renaming {} to {}'.format(i, j))

In [None]:
root_dir = "E:/CropType"
tif_content = "data"
num_digits = 6
ftype = "tif"
country = "Ghana"
source = "Sentinel-2"

In [None]:
rename_w_leading_0s(root_dir, tif_content, num_digits, ftype, country, source, verbose = False)

### Remove those Grid-IDs where there is no actual label recorded

In [None]:
"""
import os
from pathlib import Path
import numpy as np
import pickle
import rasterio
from collections import Counter
import json
"""

def get_grid_nums(root_dir, country, source, ftype, verbose = False):
    
    src_path = Path(root_dir) / country / source
    
    if ftype == "tif":
        files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for \
                 f in filenames if f.endswith(".tif")]
        files.sort()
        
        if country == "Ghana":
            grid_numbers = [str(f).split("_")[-4] for f in files]
        
        elif country == "SouthSudan":
            grid_numbers = [str(f).split("_")[-3] for f in files]
    
    elif ftype == "npy":
        files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for \
                 f in filenames if f.endswith(".npy") or f.endswith(".json")]
        files.sort()
        grid_numbers = [str(f).split("_")[-4] for f in files]

    grid_numbers.sort()
    
    if verbose:
        for i,j in zip(grid_numbers, files):
            print('Grid-ID: {}\nAssociated file with same ID {}\n'.format(i, j))
    
    return grid_numbers, files

##################################################

def get_empty_grids(root_dir, country, source, lbl_fldrname, verbose = True):
    """
    Provides data from input .tif files depending on function input parameters. 
    
    Args:
      directory - (str) the base directory of data
      countries - (list of str) list of strings that point to the directory names
                  of the different countries (i.e. ['ghana', 'tanzania', 'southsudan'])
      sources - (list of str) list of directory of satellite sources (i.e. 's1_64x64', 's2') 
      verbose - (boolean) prints outputs from function
      ext - (str) file type that you are working with (i.e. 'tif', 'npy') 
   
      lbl_dir - (str) the directory name that the raster labels are stored in 
                      (i.e. 'raster', 'raster_64x64')
    """

    valid_pixels_list = []
    empty_masks = []
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname

    mask_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                   f in filenames if f.endswith('.tif')]
    mask_ids = [str(f).split('_')[-1].replace('.tif', '') for f in mask_fnames]

    mask_fnames.sort()
    mask_ids.sort()
    
    assert len(mask_fnames) == len(mask_ids)

    for mask_fname, mask_id in zip(mask_fnames, mask_ids):
        with rasterio.open(mask_fname) as src:
            cur_mask = src.read()
            valid_pixels = np.sum(cur_mask > 0) 
            valid_pixels_list.append((mask_id, valid_pixels))
            if valid_pixels == 0:
                empty_masks.append(mask_id)

    delete_me = []
    
    grid_numbers, source_files = get_grid_nums(root_dir, country, source, ftype, verbose = False)

    all_ids = set(empty_masks + grid_numbers)
    for el in all_ids:
        if el in empty_masks and el in grid_numbers:
            delete_me.append(el)

    if verbose:
        print("valid pixels list: ", len(valid_pixels_list))
        print('empty masks: ', len(empty_masks))
        print('delete me length: ', len(delete_me))
        print('delete me: ', delete_me)
        
    return set(delete_me)

##################################################

def remove_irrelevant_files(root_dir, country, source, delete_list, ftype, verbose = True):
    
    if len(delete_list) == 0:
        print("There is no empty grid to remove.")
    
    else:
        grid_nums, source_files = get_grid_nums(root_dir, country, source, ftype, verbose = False)
    
        for grid_to_rm in delete_list:
            files_to_rm = [str(f) for f in source_files if ''.join(['_', grid_to_rm]) in str(f)]
        
        if verbose:
            print("grid to remove: {}\n".format(grid_to_rm))
            print("files to remove: {}\n".format(files_to_rm))
        
        #Remove files        
        [os.remove(f) for f in files_to_rm]

In [None]:
root_dir = "E:/CropType"
lbl_fldrname = "Labels"
ftype = "tif"
country = "Ghana"
source = "Sentinel-2"

In [None]:
grids_to_delete = get_empty_grids(root_dir, country, source, lbl_fldrname)

remove_irrelevant_files(root_dir, country, source, grids_to_delete, ftype, verbose = True)

### Normalize tiles, add doy band, add spectral indices to sentinel-2 and create temporal stacks for each grid and save them as .npy files

In [None]:
"""
import numpy as np
"""

def normalize(grid, source, country, norm_type = "z-value"):
    r""" Normalization based on the chosen normalization type.
    Args: 
      grid (numpy array) -- grid to be normalized.
      norm_type (str) -- decide on the type of normalization. either z-value (standardization) or min/max.
                         default is z-value.
      source (str) -- Satellite sensor.
      country (str) -- Geographic region where the dataset is taken.
    Returns:
      grid (numpy array) -- a normalized version of the input grid.
      
    NOTE: z-value normalization is based on the statistics of the whole temporal extent for each band.
          min/max normalization is based on the statistics of each indivisual grid.
    """
    
    MEANS = {"Sentinel-1": {"Ghana": np.array([-10.50, -17.24, 1.17])}, 
             "Sentinel-2": {"Ghana": np.array([2620.00, 2519.89, 2630.31, 2739.81, 3225.22, 
                                               3562.64, 3356.57, 3788.05, 2915.40, 2102.65])}
            }


    STDS = {"Sentinel-1": {"Ghana": np.array([3.57, 4.86, 5.60])}, 
            "Sentinel-2": {"Ghana": np.array([2171.62, 2085.69, 2174.37, 2084.56, 2058.97, 
                                              2117.31, 1988.70, 2099.78, 1209.48, 918.19])}
           }
    
    num_bands = grid.shape[0]
    
    if norm_type == "z-value":
        means = MEANS[source][country]
        stds = STDS[source][country]
        
        for i in range(num_bands):
            grid[i,:,:] = (grid[i,:,:] - means[i]) / stds[i]
    
    elif norm_type == "min/max":
        nan_Corr_grid = np.where(grid == 0, np.nan, grid)
        grid_min = np.nanmin(nan_Corr_grid)
        grid_max = np.nanmax(nan_Corr_grid)
        grid = (grid  - grid_min) / (grid_max - grid_min)
    
    else:
        raise ValueError("Normlaization type is not recognized.")
    
    return grid

##################################################

def date2doy(date, in_shp):
    r"""
    Convert string dates to equivalent day of the year and convert it to a Z-norm.
    
    Parameters:
        date_list (list of string) -- list of dates read from a .json file.
        in_shape (tuple or None) -- If 'none' just make a vector of day of the year with the length equal to the sequence length.
                                    Otherwise day of the year will be broadcasted to specified shape.
    Output: Z-norm day of the year band in specified shape
    """
    #set_trace()
    date = datetime.strptime(date, '%Y_%m_%d').date()
    doy = date.timetuple().tm_yday
    doy = np.array([doy])
    
    # normalize
    norm_doy = (doy - 177.5) / 177.5

    C, W, H = in_shp
        
    stack = norm_doy[np.newaxis,:]
    stack = np.broadcast_to(stack,(W,H))
    stack = stack[np.newaxis,:]

    return stack

##################################################

def get_spectral_indices(img, Channel_first = True):

    #set_trace()
    ch = img.shape[0] if Channel_first else img.shape[-1]
    assert ch == 10, "Either Number of bands are incorrect or its not located in first or last dimensions."

    # B1, B2,
    S2_BANDS = {"BLUE": 0, "GREEN": 1, "RED": 2, "RDED1": 3, "RDED2": 4, "RDED3": 5, "NIR": 6, "RDED4": 7, "SWIR1": 8, "SWIR2": 9}
    
    G = 2.5 
    C1 = 6
    C2 = 7.5
    L= 0.5

    if Channel_first:
        blue = img[S2_BANDS["BLUE"], :, :]
        green = img[S2_BANDS["GREEN"], :, :]
        red = img[S2_BANDS["RED"], :, :]
        nir = img[S2_BANDS["NIR"], :, :]
        rded1 = img[S2_BANDS["RDED1"], :, :]
        rded2 = img[S2_BANDS["RDED2"], :, :]
        rded3 = img[S2_BANDS["RDED3"], :, :]
        swir1 = img[S2_BANDS["SWIR1"], :, :]
        rded4 = img[S2_BANDS["RDED4"], :, :]
        swir2 = img[S2_BANDS["SWIR2"], :, :]
        
    else:
        blue = img[:, :, S2_BANDS["BLUE"]]
        green = img[:, :, S2_BANDS["GREEN"]]
        red = img[:, :, S2_BANDS["RED"]]
        nir = img[:, :, S2_BANDS["NIR"]]
        rded1 = img[:, :, S2_BANDS["RDED1"]]
        rded2 = img[:, :, S2_BANDS["RDED2"]]
        rded3 = img[:, :, S2_BANDS["RDED3"]]
        swir1 = img[:, :, S2_BANDS["SWIR1"]]
        rded3 = img[:, :, S2_BANDS["RDED4"]]
        swir2 = img[:, :, S2_BANDS["SWIR2"]]
        
    # Normalized Difference Vegetation Index
    ndvi = (nir - red) / (nir + red)
    
    # Enhanced Vegetation Index
    evi = G * (nir - red) / (nir + C1 * red - C2 * blue + 1)
    
    # Normalized Difference Water Index
    ndwi = (nir - swir2) / (nir + swir2)
    
    #Absorption properties of the middle infrared band cause a low reflectance of rice plants in this channel (Lilliesand & Kiefer 1994). 
    #In irrigated rice fields, especially in early transplanting periods, water environment plays an important role in rice spectral.
    # Nuarsa, I. W., Nishio, F., & Hongo, C. (2011). Spectral characteristics and mapping of rice plants using multi-temporal Landsat data. Journal of Agricultural Science.
    # Rice Growth Vegetation Index
    rgvi = 1 - (blue + red) / (nir + swir1 + swir2)
    
    bi = np.sqrt(((red * red) / (green * green)) / 2)

    stack = np.dstack([ndvi, evi, ndwi, rgvi, bi]).transpose(2,0,1)

    return stack

##################################################

"""
import os
from pathlib import Path
import numpy as np
import pickle
import rasterio
from collections import Counter
import json
"""

def get_img_cube(root_dir, country, source, out_format, lbl_fldrname, verbose):
    
    #set_trace()
    # Path to "Label" folder which contains label tiles. 
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    if verbose:
        print("Number of grids for country {}: {}".format(country, len(lbl_ids)))
    
    # Path to "source" folder which contains img tiles.
    src_path = Path(root_dir) / country / source
    
    out_path = src_path / out_format
    Path(out_path).mkdir(parents=True, exist_ok=True)
    
    # List of img tiles for the current source (RS sensor) and get the Grid-ID of each img tile..
    files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".tif") if "source" in f]
    grid_numbers = [str(f).split("_")[-4] for f in files]
    
    if source == "Sentinel-2":
        # cloud_mask categories --> {"clear":0, "cloud":1, "haze":2, "shadow":3}
        cloud_masks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".tif") if "cloudmask" in f]
        cloud_masks.sort()
    
    files.sort()
    grid_numbers.sort()
    
    # read one image from list to get dimensions
    with rasterio.open(files[0]) as src:
        img = src.read()
    
    if out_format == "pickle":
        # dimensions: grids x bands x rows x columns x timestamps
        data_array = np.zeros((len(set(grid_numbers)), img.shape[0], img.shape[1], 
                               img.shape[2], Counter(grid_numbers).most_common(1)[0][1]))
        g, b, r, c, t = data_array.shape
        
    if verbose:
        print("-----------------------------")
        print("Image dimensions: {}".format(img.shape))
        print("Current data source: {}".format(source))
        print("Number of grids in set: {}".format(len(set(grid_numbers))))
        print("Set of grid numbers: {}".format(sorted(set(grid_numbers))))
        print("Maximum timestamps from this data source: {}".format(Counter(grid_numbers).most_common(1)[0][1]))
        if out_format == "pickle":
            print("Final array shape: {}".format(data_array.shape))
            
    for grid_idx, grid in enumerate(sorted(set(grid_numbers))):
        if verbose:
            print("Grid: {}".format(grid))
        cur_grid_files = [str(f) for f in files if "_" + grid + "_" in str(f)]
        cur_grid_files.sort()
        
        if source == "Sentinel-2":
            cur_mask_files = [str(f) for f in cloud_masks if "_" + grid + "_" in str(f)]
            cur_mask_files.sort()
    
            if out_format == "npy":
                # dimensions: bands x rows x columns x timestamps
                data_array = np.zeros((img.shape[0]+6, img.shape[1], img.shape[2], len(cur_grid_files)))
                mask_array = np.zeros((img.shape[0]+6, img.shape[1], img.shape[2], len(cur_mask_files)))
        
        if source == "Sentinel-1" and out_format == "npy":
            data_array = np.zeros((img.shape[0]+1, img.shape[1], img.shape[2], len(cur_grid_files)))
        
        dates = []

        if source == "Sentinel-2":
            for idx, (fname, mname) in enumerate(zip(cur_grid_files, cur_mask_files)):
                if verbose:
                    print("idx: ", idx)
                    print("fname: ", fname)
                    print("mname: ", mname)
        
                with rasterio.open(fname) as src:
                    tile = src.read()
                    tile = tile.astype(float)
                    if out_format == "pickle":
                        data_array[grid_idx, :, :, :, idx] = normalize(tile, source, country, norm_type = "z-value")
            
                    elif out_format == "npy":
                        tmp = Path(fname).name.replace(".tif", "").split("_")
                        date_parts = tmp[-3:]
                        date = "_".join(date_parts)
                        dates.append(date)
                        
                        si_bands = get_spectral_indices(tile, Channel_first = True)
                        normal_tile = normalize(tile, source, country, norm_type = "z-value")
                        doy_band = date2doy(date, tile.shape)
                        aug_array = np.concatenate([normal_tile, si_bands, doy_band], axis = 0)
                        
                        data_array[:, :, :, idx] = aug_array
                        #data_array[:, :, :, idx] = normalize(tile, source, country, norm_type = "z-value")
                        
                        
                        with rasterio.open(mname) as msrc:
                            mask_array[:, :, :, idx] = msrc.read()
    
            if out_format == "npy":
                tmp_fn = Path(fname).name.replace(".tif", "").split("_")
                fn = "_".join(tmp_fn[0:3])
                out_fname = out_path / fn
                
                tmp_mn = Path(mname).name.replace(".tif", "").split("_")
                mn = "_".join(tmp_mn[0:3])
                out_mname = out_path / mn
    
                # store and save metadata
                meta = {}
                meta["dates"] = dates
    
                with open(str(out_fname) + ".json", "w") as fp:
                    json.dump(meta, fp)
    
                # save image stack as .npy
                np.save(out_fname, data_array)
                np.save(out_mname, mask_array)
        
        else:
            
            for idx, fname in enumerate(cur_grid_files):
                
                if verbose:
                    print("idx: ", idx)
                    print("fname: ", fname)
        
                with rasterio.open(fname) as src:
                    s1_tile = src.read()
                    s1_tile = s1_tile.astype(float)
                    
                    if out_format == "pickle":
                        data_array[grid_idx, :, :, :, idx] = normalize(s1_tile, source, country, norm_type = "z-value")
            
                    elif out_format == "npy":
                        
                        tmp = Path(fname).name.replace(".tif", "").split("_")
                        date_parts = tmp[-3:]
                        date = "_".join(date_parts)
                        dates.append(date)
                        
                        normal_tile = normalize(s1_tile, source, country, norm_type = "z-value")
                        doy_band = date2doy(date, s1_tile.shape)
                        
                        aug_array = np.concatenate([normal_tile, doy_band], axis = 0)
                        
                        data_array[:, :, :, idx] = aug_array
                        #data_array[:, :, :, idx] = normalize(tile, source, country, norm_type = "z-value")
                        
            
            if out_format == "npy":
                tmp_fn = Path(fname).name.replace(".tif", "").split("_")
                fn = "_".join(tmp_fn[0:3])
                out_fname = out_path / fn
    
                # store and save metadata
                meta = {}
                meta["dates"] = dates
    
                with open(str(out_fname) + ".json", "w") as fp:
                    json.dump(meta, fp)
    
                # save image stack as .npy
                np.save(out_fname, data_array)

        
    if out_format == "pickle":
        out_fname = "_".join([country, source, "shape", "g" + str(g), "b" + str(b), "r" + str(r), "c" + str(c), "t" + str(t) + ".pickle"])
        
        with open(str(out_path / out_fname), "wb") as f:
            pickle.dump((sorted(set(grid_numbers)), data_array), f)


In [None]:
root_dir = "E:/CropType"
#root_dir = "C:/My_documents/Data"
lbl_fldrname = "Labels"
out_format = "npy"
country = "Ghana"
source = "Sentinel-2"
verbose = False

In [None]:
get_img_cube(root_dir, country, source, out_format, lbl_fldrname, verbose)

### Find tiles with NaN value on SI and zero on image bands

In [None]:
def findTilesWithNan(root_dir, country, source, lbl_fldrname):
    #set_trace()
    lbl_dir = Path(root_dir) / country / lbl_fldrname 
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]

    src_path = Path(root_dir) / country / source / "npy"
    fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".npy") if "source" in f]
    meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".json")]
    
    lbl_fnames.sort()
    fnames.sort()
    meta_fnames.sort()

    
    if source == "Sentinel-2":
        cloud_masks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
        cloud_masks.sort()
    
    src_out_path = src_path / "chips_contain_nan"
    lbl_out_path = lbl_dir / "chips_contain_nan"
    
    dirs = [src_out_path, lbl_out_path]
    
    MEANS = {"Sentinel-1": {"Ghana": np.array([-10.50, -17.24, 1.17])}, 
             "Sentinel-2": {"Ghana": np.array([2620.00, 2519.89, 2630.31, 2739.81, 3225.22, 
                                               3562.64, 3356.57, 3788.05, 2915.40, 2102.65])}
            }


    STDS = {"Sentinel-1": {"Ghana": np.array([3.57, 4.86, 5.60])}, 
            "Sentinel-2": {"Ghana": np.array([2171.62, 2085.69, 2174.37, 2084.56, 2058.97, 
                                              2117.31, 1988.70, 2099.78, 1209.48, 918.19])}
           }
    
    means = MEANS[source][country]
    stds = STDS[source][country]
    
    
    comparison_array = (-1 * means) / stds
    
    for p in dirs:
        if not os.path.exists(p):
            os.makedirs(p)
    
    nan_grids_list = []
    if source == "Sentinel-1":
        for fn, meta_fn, lbl_fn in zip(fnames, meta_fnames, lbl_fnames):
            grid_id = str(fn).split("_")[-1].replace(".npy", "")
            src_array = np.load(fn)
            num_nans = np.count_nonzero(np.isnan(src_array))
            num_zeros = np.count_nonzero(np.isin(src_array, comparison_array))
        
            if num_nans > 0:
                print("Grid ID: {} with {} number of NaN in indices and {} zero values in image band is moved.".format(grid_id, num_nans, num_zeros))
                nan_grids_list.append(grid_id)
                #shutil.move(str(fn), str(src_out_path))
                #shutil.move(str(meta_fn), str(src_out_path))
                #shutil.move(str(lbl_fn), str(lbl_out_path))
    else:
        for fn, meta_fn, cmask_fn, lbl_fn in zip(fnames, meta_fnames, cloud_masks, lbl_fnames):
            grid_id = str(fn).split("_")[-1].replace(".npy", "")
            src_array = np.load(fn)
            num_nans = np.count_nonzero(np.isnan(src_array))
            num_zeros = np.count_nonzero(np.isin(src_array, comparison_array))
        
            if num_nans > 0:
                print("Grid ID: {} with {} number of NaN in indices and {} zero values in image band is moved.".format(grid_id, num_nans, num_zeros))
                nan_grids_list.append(grid_id)
                #shutil.move(str(fn), str(src_out_path))
                #shutil.move(str(cmask_fn), str(src_out_path))
                #shutil.move(str(meta_fn), str(src_out_path))
                #shutil.move(str(lbl_fn), str(lbl_out_path))
    return nan_grids_list

In [None]:
def MoveTilesWithNan(root_dir, country, lbl_fldrname, verbose = True):
    #set_trace()
    lbl_dir = Path(root_dir) / country / lbl_fldrname / "reclass" 
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]

    s1_path = Path(root_dir) / country / "Sentinel-1" / "npy"
    s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_path) for f in filenames if f.endswith(".npy") if "source" in f]
    s1_meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_path) for f in filenames if f.endswith(".json")]
    s2_path = Path(root_dir) / country / "Sentinel-2" / "npy"
    s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_path) for f in filenames if f.endswith(".npy") if "source" in f]
    s2_meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_path) for f in filenames if f.endswith(".json")]
    cloud_masks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    
    assert len(s1_fnames) == len(s2_fnames) == len(lbl_fnames) == len(s1_meta_fnames) == len(s1_meta_fnames) == len(cloud_masks)
    lbl_fnames.sort()
    s1_fnames.sort()
    s1_meta_fnames.sort()
    s2_fnames.sort()
    s2_meta_fnames.sort()
    cloud_masks.sort()

    s1_out_path = s1_path / "chips_contain_nan"
    s2_out_path = s2_path / "chips_contain_nan"
    lbl_out_path = lbl_dir / "chips_contain_nan"
    ext_file_path = Path(root_dir) / country / "detailed_report"
    
    dirs = [s1_out_path, s2_out_path, lbl_out_path, ext_file_path]
    
    for p in dirs:
        if not os.path.exists(p):
            os.makedirs(p)

    nan_grids_list = []
    for s1_fn, s1_meta_fn, s2_fn, s2_meta_fn, cmask_fn, lbl_fn in zip(s1_fnames, s1_meta_fnames, s2_fnames, s2_meta_fnames, cloud_masks, lbl_fnames):
        
        s1_grid_id = str(s1_fn).split("_")[-1].replace(".npy", "")
        s2_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        assert s1_grid_id == s2_grid_id == lbl_grid_id, "Grid Id mis-match between Sentinel-1 & 2 chips."
        
        s1_array = np.load(s1_fn)
        s2_array = np.load(s2_fn)
        
        s1_num_nans = np.count_nonzero(np.isnan(s1_array))
        s2_num_nans = np.count_nonzero(np.isnan(s2_array))
        
        if (s1_num_nans > 0) or (s2_num_nans > 0):
            
            
            with open(os.path.join(str(ext_file_path), "NaN_contaminated_tiles.txt"), "a") as external_file:
                print("Grid ID: {} with {} NaN indices in S1 and {} NaNs for S2 is moved.".format(s1_grid_id, s1_num_nans, s2_num_nans), file=external_file)
            if verbose:
                print("Grid ID: {} with {} NaN indices in S1 and {} NaNs for S2 is moved.".format(s1_grid_id, s1_num_nans, s2_num_nans))
            
            nan_grids_list.append(s1_grid_id)
            shutil.move(str(s1_fn), str(s1_out_path))
            shutil.move(str(s1_meta_fn), str(s1_out_path))
            shutil.move(str(s2_fn), str(s2_out_path))
            shutil.move(str(s2_meta_fn), str(s2_out_path))
            shutil.move(str(cmask_fn), str(s2_out_path))
            shutil.move(str(lbl_fn), str(lbl_out_path))
    
    external_file.close()
    return nan_grids_list

In [None]:
def MoveLabels(root_dir, country, lbl_fldrname, verbose = True):
    
    #set_trace()
    s1_path = Path(root_dir) / country / "Sentinel-1" / "npy" / "chips_contain_nan"
    s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_path) for f in filenames if f.endswith(".npy") if "source" in f]
    s1_fnames.sort()
    s1_ids = [str(f).split("_")[-1].replace(".npy", "") for f in s1_fnames]
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname / "reclass" 
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_fnames.sort()
    

    lbl_out_path = lbl_dir / "chips_contain_nan"
    if not os.path.exists(lbl_out_path):
        os.makedirs(lbl_out_path)

    for lbl_fn in lbl_fnames:
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        if lbl_grid_id in s1_ids:
            shutil.move(str(lbl_fn), str(lbl_out_path))

            if verbose:
                print("Label tilw with Grid ID: {} is moved.".format(lbl_grid_id))

In [None]:
root_dir = "C:/My_documents/CropTypeData_Rustowicz"
country = "Ghana"
source = "Sentinel-1"
lbl_fldrname = "Labels3"

In [None]:
S1_grids_with_nan = findTilesWithNan(root_dir, country, source, lbl_fldrname)

In [None]:
S2_grids_with_nan = findTilesWithNan(root_dir, country, source, lbl_fldrname)

In [None]:
MoveTilesWithNan(root_dir, country, lbl_fldrname)

In [None]:
MoveLabels(root_dir, country, lbl_fldrname)

### Reclassify Labels

In [None]:
"""
import os
from pathlib import Path
import random
import math
import numpy as np
import rasterio
"""

def reclassify_lbl(root_dir, country, lbl_fldrname, categories = None):
    #set_trace()
    lbl_dir = Path(root_dir) / country / lbl_fldrname
    
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    complete_categories = {"unknown": 0, "ground nut": 1, "maize": 2, "rice": 3, "soya bean": 4, "yam": 5, 
                           "intercrop": 6, "sorghum": 7, "okra": 8, "cassava": 9, "millet": 10, "tomato": 11, 
                           "cowpea": 12, "sweet potato": 13, "babala beans": 14, "salad vegetables": 15, 
                           "bra and ayoyo": 16, "watermelon": 17, "zabla": 18, "nili": 19, "kpalika": 20, 
                           "cotton": 21, "akata": 22, "nyenabe": 23, "pepper": 24}
    
    for fn in lbl_fnames:
        with rasterio.open(fn) as src:
            lbl_array = src.read()
            
            if categories is not None:
                
                # List of categories that need to be aggregated into the other class.
                categories_to_other = list(np.setdiff1d(list(complete_categories.keys()),list(categories.keys())))
                # name of the aggregation category usually 'other'.
                aggregator_cat = list(np.setdiff1d(list(categories.keys()), list(complete_categories.keys())))
                
                for category in list(complete_categories.keys()):
                    if category in list(categories.keys()):
                        lbl_array[lbl_array == complete_categories[category]] = categories[category]
                    else:
                        if ((category in categories_to_other) and 
                            (complete_categories[category] != complete_categories["unknown"])):
                            lbl_array[lbl_array == complete_categories[category]] = categories[aggregator_cat[0]]
                
                reclass_lbl_out_path = lbl_dir / "reclass"
                Path(reclass_lbl_out_path).mkdir(parents=True, exist_ok=True)
                
                profile = {
                    "driver": "GTiff", 
                    "count": lbl_array.shape[0],
                    "height": lbl_array.shape[1],
                    "width": lbl_array.shape[2],
                    "dtype": "float64",
                    "transform": rasterio.Affine(1, 0, 0, 0, 1, 0),
                }
                
                with rasterio.open(reclass_lbl_out_path / fn.name, "w", **profile) as dst:
                    dst.write(lbl_array)

In [None]:
root_dir = "C:/My_documents/CropTypeData_Rustowicz"
lbl_fldrname = "Labels3"
country = "Ghana"
#categories = {"unknown": 0, "ground nut": 1, "maize": 2, "rice": 3, "soya bean": 4, "other": 5}
categories = {"unknown": 0, "maize": 1, "other": 2, "rice": 3}

In [None]:
reclassify_lbl(root_dir, country, lbl_fldrname, categories)

### Summarize categories

In [None]:
"""
import os
from pathlib import Path
import random
import math
import numpy as np
import pandas as pd
import rasterio
"""

def summarize_lbl(root_dir, country, lbl_fldrname, out_filename, category = None):
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname / "validation"
    
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    if category:
        category_dict = category
    else:
        category_dict = {"unknown": 0, "ground nut": 1, "maize": 2, "rice": 3, "soya bean": 4, "yam": 5, 
                         "intercrop": 6, "sorghum": 7, "okra": 8, "cassava": 9, "millet": 10, "tomato": 11, 
                         "cowpea": 12, "sweet potato": 13, "babala beans": 14, "salad vegetables": 15, 
                         "bra and ayoyo": 16, "watermelon": 17, "zabla": 18, "nili": 19, "kpalika": 20, 
                         "cotton": 21, "akata": 22, "nyenabe": 23, "pepper": 24}
    
    key_list = list(category_dict.keys())
    val_list = list(category_dict.values())
    
    df = pd.DataFrame(columns = key_list, index = lbl_ids)
    
    for fn in lbl_fnames:
        lbl_id = fn.name.split("_")[-1].replace(".tif", "")
        
        with rasterio.open(fn) as src:
            lbl_array = src.read()
            categories, counts = np.unique(lbl_array, return_counts=True)
            for a,b in zip(list(categories), list(counts)):
                if lbl_id in list(df.index):
                    if key_list[val_list.index(a)] in list(df.columns):
                        df.loc[lbl_id, key_list[val_list.index(a)]] = b
    df = df.fillna(0)
    df.to_csv(lbl_dir / out_filename, index_label='Grid-ID')
    return df

In [None]:
root_dir = "C:/My_documents/CropTypeData_Rustowicz"
lbl_fldrname = "Labels"
country = "Ghana"
out_filename = "report.csv"
category = {"unknown": 0, "maize": 1, "other": 2, "rice": 3}

In [None]:
report = summarize_lbl(root_dir, country, lbl_fldrname, out_filename, category)
report.sum(axis = 0, skipna = True)

### Split the dataset into train and validation

In [None]:
"""
import os
from pathlib import Path
import random
import math
import shutil
import numpy as np
import pandas as pd
"""

def create_grid_splits(root_dir, country, sources, lbl_fldrname, csv_fn, split_threshold = 0.8):
    """
    Splitting the dataset into train and test datasets.
    
    root_dir (str) -- path to the main directory which data resides. Example: "C:/My_documents/Data".
    country (str) -- This is based on the organization of dataset that images and 
                     labels reside inside a country folder.
    sources (list) -- folder name of the image resource. example: ["Sentinel-1", "Sentinel-2"]
    lbl_fldrname (str) -- Name of the folder containing annotated grids.
    csv_fn (str) -- Name of the csv file summerizing the content of each grid.
    split_threshold (float) -- scalar value as a threshold to decide how many of the grids will be
                               in the training folder. Default is 0.8.
    """
    # Directory to labels
    lbl_dir = Path(root_dir) / country / lbl_fldrname / "reclass"
    
    # filename of the label grids
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for \
                  f in filenames if f.endswith(".tif")]
    
    # List of grid-ids
    lbl_ids = [str(f).split("_")[-1].replace(".tif", "") for f in lbl_fnames]
    
    # Read the csv file into a dataframe
    report = pd.read_csv(lbl_dir / csv_fn, index_col="Grid-ID")
    
    assert(report.shape[0] == len(lbl_ids))
    
    num_train_girds = math.ceil(split_threshold * len(lbl_ids))
    categories = list(report.columns)
    
    # make a dictionary with the keys as category names and values as list of grid-ids containing
    # those categories.
    cat_grid = {}
    for category in categories:
        if category != "unknown":
            cat_grid[category] = list(report[report[category] > 0].index)
    
    # Choose the category with the least amount of tiles as the initial category and add the sampled ids
    # in the list of training grids.
    initial_category = min(cat_grid, key=lambda cat: len(cat_grid[cat]))
    num_initial_tiles = math.ceil(len(cat_grid[initial_category]) * split_threshold)
    training_grids = random.sample(cat_grid[initial_category], num_initial_tiles)
    
    # for each category find the similar grid-ids that are already in the training list. Recalculate the number of
    # samples that need to be taken. Sample unique grid-ids for the category and add it to the list of training grids.
    for category in categories:
        if category not in ["unknown", initial_category]:
            similar_grids = set(cat_grid[category]).intersection(training_grids)
            num_samples_to_take = math.ceil(len(cat_grid[category]) * split_threshold) - len(similar_grids)
            allowable_grids = list(np.setdiff1d(cat_grid[category], training_grids))
            grid_ids = random.sample(allowable_grids, num_samples_to_take)
            training_grids.extend(grid_ids)
    
    # make sure that training folder contains correct number of grids as decided by the split threshold.
    #To do that we add or drop grids from the category with the max number of grids.
    if len(training_grids) < num_train_girds:
        num_extra_samples = num_train_girds - len(training_grids)
        biggest_category = max(cat_grid, key=lambda cat: len(cat_grid[cat]))
        allowable_other_grids = list(np.setdiff1d(cat_grid[biggest_category], training_grids))
        extra_other_grid_ids = random.sample(allowable_other_grids, num_extra_samples)
        training_grids.extend(extra_other_grid_ids)
    else:
        num_samples_to_drop = len(training_grids) - num_train_girds
        biggest_category = max(cat_grid, key=lambda cat: len(cat_grid[cat]))
        allowable_other_grids = set(cat_grid[biggest_category]).intersection(training_grids)
        droppable_other_grid_ids = random.sample(allowable_other_grids, num_samples_to_drop)
        for item in training_grids:
            if item in droppable_other_grid_ids:
                training_grids.remove(item)
    
    # add preceding zeros to the list of training grids
    training_grids = [str(item).zfill(6) for item in training_grids]
    val_grids = list(np.setdiff1d(lbl_ids, training_grids))
    training_grids.sort()
    val_grids.sort()
    
    # Create proper folders for the splitted dataset
    lbl_train_out_path = lbl_dir / "train"
    lbl_val_out_path = lbl_dir / "validation"
    Path(lbl_train_out_path).mkdir(parents=True, exist_ok=True)
    Path(lbl_val_out_path).mkdir(parents=True, exist_ok=True)
    
    # Copy labels into train, validate subfolders.
    for id, fn in zip(lbl_ids, lbl_fnames):
        
        if (id in training_grids) and (id in fn.name):
            #shutil.copy(fn, lbl_train_out_path)
            shutil.move(str(fn), str(lbl_train_out_path))
        
        if (id in val_grids) and (id in fn.name):
            #shutil.copy(fn, lbl_val_out_path)
            shutil.move(str(fn), str(lbl_val_out_path))
    
    # Copy the img dataset based on the grid-ID to equivalent subfolders.
    for source in sources:        
        
        src_dir = Path(root_dir) / country / source / "npy"
        src_train_out_path = src_dir / "train"
        src_val_out_path = src_dir / "validation"
        
        Path(src_train_out_path).mkdir(parents=True, exist_ok=True)
        Path(src_val_out_path).mkdir(parents=True, exist_ok=True)
        
        src_files = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(src_dir) for \
                     f in filenames]
        
        for id in lbl_ids:
            for fn in src_files:
                if (id in training_grids) and (id in fn.name):
                    #shutil.copy(fn, src_train_out_path)
                    shutil.move(str(fn), str(src_train_out_path))
            
                if (id in val_grids) and (id in fn.name):
                    #shutil.copy(fn, src_val_out_path)
                    shutil.move(str(fn), str(src_val_out_path))

In [None]:
root_dir = "C:/My_documents/CropTypeData_Rustowicz"
country = "Ghana"
sources = ["Sentinel-1", "Sentinel-2"]
lbl_fldrname = "Labels"
csv_fn = "report.csv"
split_threshold = 0.80

In [None]:
create_grid_splits(root_dir, country, sources, lbl_fldrname, csv_fn, split_threshold)

### Make pixel dataset

In [None]:
def Make_pixel_dataset(root_dir, country, sources, lbl_fldrname, usage):
    
    lbl_dir = Path(root_dir) / country / lbl_fldrname / usage
    lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(lbl_dir) for f in filenames if f.endswith(".tif")]
    lbl_fnames.sort()
    
    s1_src_path = Path(root_dir) / country / "Sentinel-1" / usage
    s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_src_path) for f in filenames if f.endswith(".npy")]
    s1_fnames.sort()
            
    s2_src_path = Path(root_dir) / country / "Sentinel-2" / usage
    s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for f in filenames if f.endswith(".npy") if "source" in f]
    cmasks = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for f in filenames if f.endswith(".npy") if "cloudmask" in f]
    s2_fnames.sort()
    cmasks.sort()
    
    for lbl_fn, s1_fn, s2_fn, cmask_fn in zip(lbl_fnames, s1_fnames, s2_fnames, cmasks):
        
        lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
        s1_grid_id = str(s1_fn).split("_")[-1].replace(".npy", "")
        s2_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        cmask_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
        assert lbl_grid_id == s1_grid_id == s2_grid_id == cmask_grid_id
        
        lbl_array = load_data(lbl_fn, isLabel = True)
        s1_array = load_data(s1_fn, isLabel = False)
        s2_array = load_data(s2_fn, isLabel = False)
        cmask_array = load_data(cmask_fn, isLabel = False)
        
        unique_vals, unique_counts = np.unique(lbl_tile, return_counts=True)
        for val, count in zip(unique_vals, unique_counts):
            crop_indices = np.where(lbl_tile == [val])

In [None]:
source_s1_grid_id,crop_type,coord

## Augmentation

In [None]:
def shiftBrightness(img, gammaRange=(0.2, 2.0), shiftSubset=(4, 4), patchShift=True):
    '''
    Shift image brightness through gamma correction
    Params:
        img (narray): Concatenated variables or brightness value with a dimension of (H, W, C)
        gammaRange (tuple): Range of gamma values
        shiftSubset (tuple): Number of bands or channels for each shift
        patchShift (bool): Whether apply the shift on small patches
     Returns:
        narray, brightness shifted image
    '''


    c_start = 0

    if patchShift:
        for i in shiftSubset:
            gamma = random.triangular(gammaRange[0], gammaRange[1], 1)

            h, w, _ = img.shape
            rotMtrx = cv2.getRotationMatrix2D(center=(random.randint(0, h), random.randint(0, w)),
                                              angle=random.randint(0, 90),
                                              scale=random.uniform(1, 2))
            mask = cv2.warpAffine(img[:, :, c_start:c_start + i], rotMtrx, (w, h))
            mask = np.where(mask, 0, 1)
            # apply mask
            img_ma = ma.masked_array(img[:, :, c_start:c_start + i], mask=mask)
            img[:, :, c_start:c_start + i] = ma.power(img_ma, gamma)
            # default extra step -- shift on image
            gamma_full = random.triangular(0.5, 1.5, 1)
            img[:, :, c_start:c_start + i] = np.power(img[:, :, c_start:c_start + i], gamma_full)

            c_start += i
    else:
        # convert image dimension to (C, H, W) if len(img.shape)==3
        img = np.transpose(img, list(range(img.ndim)[-1:]) + list(range(img.ndim)[:-1]))
        for i in shiftSubset:
            gamma = random.triangular(gammaRange[0], gammaRange[1], 1)
            img[c_start:c_start + i, ] = np.power(img[c_start:c_start + i, ], gamma)

            c_start += i
        img = np.transpose(img, list(range(img.ndim)[-img.ndim + 1:]) + [0])

    return img

#########################################################################################

def jitter(x, sigma=0.03):
    # https://arxiv.org/pdf/1706.00527.pdf
    jittered = x + np.random.normal(loc=0., scale=sigma, size=x.shape)
    return jittered

#########################################################################################

def shift_tstack_brightness(img, gamma_range=(0.2, 2.0), sample_number=5, patch_shift=False, mode = 'random'):
    '''
    Shift image brightness on sampled timestamps
    Params:
        img (narray): Concatenated time series cube in a dimension of (N, T), where N could be either nothing or  C, H, W
        gamma_range (tuple): Range of gamma values
        sample_number (int): Nubmer of timestamps to sample
        patch_shift (bool): Whether apply the shift on small patches; will be ignored if input is point stack
        mode {int): 'random' or 'uni'
     Returns:
        narray, brightness shifted time series cube
    '''

    # parameters
    assert img.ndim == 4 or img.ndim == 1
    if img.ndim == 4:
        c, h, w, t = img.shape
        if h < 2 or w < 2:
            patch_shift = False
    else:
        t = img.shape[-1]
    # reset sample number 
    sample_number = min(sample_number, t)

    # sample timestamos
    if mode == 'random':
        # convert t dimension to the first dimension
        img = np.transpose(img, list(range(img.ndim)[-1:]) + list(range(img.ndim)[:-1])) # T, C, H, W
        tsamples = random.sample(range(t), sample_number)
        for tsample in tsamples:
            img_t = img[tsample, ]
            img_t = img_t.transpose(list(range(img_t.ndim)[1:]) + [0]) # H, W, C
            shifted = shiftBrightness(img_t, gamma_range, [c],  patch_shift) # H, W C
            img[tsample,] = shifted.transpose(list(range(img_t.ndim)[-1:]) + list(range(img_t.ndim)[:-1])) # C, H, W
        # transpose back to C, H, W, T
        img = np.transpose(img, list(range(img.ndim)[-img.ndim + 1:]) + [0]) # C, H, W, T
    else:
        # convert c dimension to the last
        img = np.transpose(img, list(range(img.ndim)[1:]) + [0]) # H, W, T, C
        shifted = shiftBrightness(img, gamma_range, [c], patchShift = False)
        # transpose back to  C, H, W, T
        img = np.transpose(shifted, list(range(shifted.ndim)[-1:]) + list(range(shifted.ndim)[:-1]))

    return img

###############################################################################################

def jitter_tstack(img, sample_number=5):
    '''
    Apply jitter augmentation on sampled timestamps
    Params:
        img(narray): Concatenated time series cube in a dimension of (N, T), where N could be either nothing or  C, H, W
        sample_number (int): Nubmer of timestamps to sample
    Returns:
        narray, jittered time series cube
    '''
    assert img.ndim == 4 or img.ndim == 1
   
    # convert t dimension to the first dimension
    t = img.shape[-1]
    # reset sample number 
    sample_number = min(sample_number, t)
    img = np.transpose(img, list(range(img.ndim)[-1:]) + list(range(img.ndim)[:-1]))

    # jitter
    tsamples = random.sample(range(t), sample_number)
    for tsample in tsamples:
        img[tsample, ] = jitter(img[tsample, ])

    img = np.transpose(img, list(range(img.ndim)[-img.ndim + 1:]) + [0])

    return img

## Accuracy Evaluation and Metrics

In [None]:
"""
import numpy as np
import pandas as pd
from sklearn import metrics
"""

class BinaryMetrics:
    
    '''Metrics measuring model performance.'''

    def __init__(self, refArray, scoreArray, predArray = None):
        '''
        Params:
            refArray (narray): Array of ground truth
            scoreArray (narray): Array of pixels scores of positive class
        '''

        self.observation = refArray.flatten()
        self.score = scoreArray.flatten()
        
        if self.observation.shape != self.score.shape:
            raise Exception("Inconsistent size between label and prediction arrays.")
        
        if predArray is not None:
            self.prediction = predArray.flatten()
        else:
            self.prediction = np.where(self.score > 0.5, 1, 0)

        self.confusion_matrix = self.confusion_matrix()

        
    def __add__(self, other):
        """
        Add two BinaryMetrics instances
        Params:
            other (''BinaryMetrics''): A BinaryMetrics instance
        Return:
            ''BinaryMetrics''
        """

        return BinaryMetrics(np.append(self.observation, other.observation),
                             np.append(self.score, other.score),
                            np.append(self.prediction, other.prediction))


    def __radd__(self, other):
        """
        Add a BinaryMetrics instance with reversed operands.
        Params:
            other
        Returns:
            ''BinaryMetrics
        """

        if other == 0:
            return self
        else:
            return self.__add__(other)


    def confusion_matrix(self):
        """
        Calculate confusion matrix of given ground truth and predicted label
        Returns:
            ''pandas.dataframe'' of observation on the column and prediction on the row
        """

        #set_trace()
        refArray = self.observation
        predArray = self.prediction

        if refArray.max() > 1 or predArray.max() > 1:
            raise Exception("Invalid array")
        
        predArray = predArray * 2
        sub = refArray - predArray

        self.tp = np.sum(sub == -1)
        self.fp = np.sum(sub == -2)
        self.fn = np.sum(sub == 1)
        self.tn = np.sum(sub == 0)
        
        confusionMatrix = pd.DataFrame(data = np.array([[self.tn, self.fp],[self.fn, self.tp]]),
                                       index = ['observation = 0', 'observation = 1'],
                                       columns = ['prediction = 0', 'prediction = 1'])

        return confusionMatrix


    def ir(self):
        """
        Imbalance Ratio (IR) is defined as the proportion between positive and negative instances of the label. 
        This value lies within the [0, ∞] range, having a value IR = 1 in the balanced case.
        Returns:
             float
        """
        try:
            ir = (self.tp + self.fn) / (self.fp + self.tn)
        
        except ZeroDivisionError:
            ir = np.nan_to_num(float("NaN"))

        return ir
    
    
    def oa(self):
        """
        Calculate Overal Accuracy.
        Returns:
            float
        """

        oa = metrics.accuracy_score(self.observation, self.prediction)
        
        return oa
    
    
    def producers_accuracy(self):
        """
        Calculate Producer's Accuracy (True Positive Rate |Sensitivity |hit rate | recall).
        Returns:
            float
        """

        return metrics.recall_score(self.observation, self.prediction, average='binary')

    
    def users_accuracy(self):
        """
        Calculate User’s Accuracy (Positive Prediction Value (PPV) | Precision).
        Returns:
            float
        """

        ua = metrics.precision_score(self.observation, self.prediction, average='binary')
        
        return ua
    
    
    def npv(self):
        """
        Calculate Negative Predictive Value or true negative accuracy.
        Returns:
             float
        """
        
        try:
            npv = self.tn / (self.tn + self.fn)
        
        except ZeroDivisionError:
            npv = np.nan_to_num(float("NaN"))
        
        return npv


    def specificity(self):
        """
        Calculate Specificity aka. True negative rate (TNR), or inverse recall.
        Returns:
             float
        """
        try:
            spc = self.tn / (self.tn + self.fp)
        
        except ZeroDivisionError:
            spc = np.nan_to_num(float("NaN"))

        return spc

      
    def F1_measure(self):
        """
        Calculate F1 score.
        Returns:
            float
        """

        f1 = metrics.f1_score(self.observation, self.prediction)

        return f1
    
    
    def iou(self):
        """
        Calculate interception over union for the positive class.
        Returns:
            float
        """

        return metrics.jaccard_score(self.observation, self.prediction)
    
    
    def miou(self):
        """
        Calculate mean interception over union considering both positive and negative classes.
        Returns:
            float
        """
        try:
            miou = np.nanmean([self.tn / (self.tn + self.fn + self.fp), self.tp / (self.tp + self.fn + self.fp)])
        
        except ZeroDivisionError:
            miou = np.nan_to_num(float("NaN"))

        return miou
    
    
    def MCCn(self):
        """
        Calculate Matthews correlation coefficient (MCC). Rescale the range from [-1,1] to [o,1].
.
        Returns:
            float
        """
        
        try:
            mmcn = 0.5 * ((self.lambda_pp + self.lambda_nn - 1) / math.sqrt((self.lambda_pp + (1 - self.lambda_nn)) * 
                                                                               (self.lambda_nn + (1 - self.lambda_pp))) + 1)
        except ZeroDivisionError:
            mmcn = np.nan_to_num(float("NaN"))

        return mmcn


    def tss(self):
        """
        Calculates true scale statistic (TSS). Also called Bookmaker Informedness (BM). 
        Scale of the metric:[-1,1].
        Returns:
            float
        """  
        tss = self.tp / (self.tp + self.fn) + self.tn / (self.tn + self.fp) - 1
        
        return tss

##################################################    
    
def accuracy_evaluation(evalData, model, gpu, outPrefix, bucket = None):
    """
    Evaluate model
    Params:
        evalData (''DataLoader''): Batch grouped data
        model: Trained model for validation
        buffer: Buffer added to the targeted grid when creating dataset. This allows metrics to calculate only
            at non-buffered region
        gpu (binary,optional): Decide whether to use GPU, default is True
        bucket (str): name of s3 bucket to save metrics
        outPrefix (str): s3 prefix to save metrics
    """
    
    model.eval()
    metrics = []
    
    for s1_img, s2_img, label in evalData:
        s1_img = Variable(s1_img, requires_grad=False)    #shape=(B,T,C)
        s1_img[s1_img != s1_img] = 0
        s2_img = Variable(s2_img, requires_grad=False)
        s2_img[s2_img != s2_img] = 0
        label = Variable(label, requires_grad=False)      #shape=1
    
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
            label = label.cuda()
        
        model_out = model(s1_img, s2_img) #shape=(B, Class_num)
        model_out_prob = F.softmax(model_out, 1)
        
        batch, nclass = model_out_prob.size()
        
        for i in range(batch):
            label_batch = label[i].cpu().numpy()
            batch_pred = model_out_prob.max(dim=1)[1].data[i].cpu().numpy()
            
            for n in range(1, nclass):
                class_out = model_out_prob[:, n].data[i].cpu().numpy()
                class_pred = np.where(batch_pred == n, 1, 0)
                class_label = np.where(label_batch == n, 1, 0)
                pixel_metrics = BinaryMetrics(class_label, class_out, class_pred)
                
                try:
                    metrics[n - 1].append(pixel_metrics)
                except:
                    metrics.append([pixel_metrics])
    #set_trace()
    metrics = [sum(m) for m in metrics]
    
    report = pd.DataFrame({
        "Overal Accuracy" : [m.oa() for m in metrics],
        "Producer's Accuracy (recall)" : [m.producers_accuracy() for m in metrics],
        "User's Accuracy (precision)" : [m.users_accuracy() for m in metrics],
        "Negative Predictive Value" : [m.npv() for m in metrics],
        "Specificity (TNR)" : [m.specificity() for m in metrics],
        "F1 score" : [m.F1_measure() for m in metrics],
        "IoU" : [m.iou() for m in metrics],
        "mIoU" : [m.miou() for m in metrics],
        "TSS" : [m.tss() for m in metrics]
    }, index=["class_{}".format(m) for m in range(1, len(metrics) + 1)])
    
    if bucket:
        dir_metrics = "s3://{}/{}/Metrics.csv".format(bucket, outPrefix)
    else:
        dir_metrics = Path(outPrefix)/ "Metrics.csv"
        
        if not os.path.exists(Path(outPrefix)):
            os.makedirs(Path(outPrefix))
        
    report.to_csv(dir_metrics)


## Loading input data

### Rustowicz African crop custom dataset

In [None]:
################################### Helper functions for custom Dataset ######################################

def load_data(dataPath, isLabel = False):
    """Load the dataset.
    Args:
        dataPath (str) -- Path to either the image or label raster.
        isLabel (binary) -- decide wether the input dataset is label. Default is False.
    
    Returns:
        loaded data as numpy ndarray. 
    """
    
    if isLabel:
        
        with rasterio.open(dataPath, "r") as src:
            
            if src.count != 1:
                raise ValueError("Label must have only 1 band but {} bands were detected.".format(src.count))
            img = src.read(1)
    
    else:
        img = np.load(dataPath)
    
    return img

############################################################

def get_pixel_coord(lbl_tile, lbl_grid_id, percnt_pixels = 1, sampling_strategy = "natural frequency", 
                    get_neg_samp = False, verbose = False):
    """
    get coordinates of pixels to be sampled in image coordinated (row|col).
    Args:
        lbl_tile (ndarray) -- 64x64 array of label values.
        lbl_grid_id (int) -- 6 digit index to identify the chips.
        percnt_pixels (float) -- value between (0, 1] indicating the percentage of crop pixels to be sampled
                               from each grid. By default all the crop pixels and only 10% of negative pixels
                               are sampled.
        get_neg_samp (Binary) -- Decision whether to use Negative samples from the Unknown class.
        sampling_strategy (str) -- choice of sampling strategy. It can be either 'natural frequency' or 
                                   'fixed size' with the former as default choice.
    Returns:
        A list of coordinate tuples in the form of [(sample1_row, sample1_col),...,(sampleN_row, sampleN_col)]
        
        
    Note 1: If percnt_pixels is other than 1, then:
            'natural frequency': Actually is a stratified sampling that follows the natural frequency of 
                                 categories during sampling to ensure the samples represent the underlying 
                                 distribution of crop categories.
            'fixed size': it's a fixed size uniform stratified sampling that choose the minimum of the smallest 
                          category size and a fixed value as the number of samples from each crop category.
    Note 2: number of negative samples is fixed to maximum 30 pixels based on availability in each chip.
    """
    
    # fixed sampling strategy for negative samples from each image chip.
    if get_neg_samp:
        negative_indices = np.where(lbl_tile == [0])
        negative_coordinates = list(zip(negative_indices[0], negative_indices[1]))
        total_neg_pixels = len(negative_coordinates)
        total_pos_pixels = 4096 - total_neg_pixels
    
        if sampling_strategy == "natural frequency":
            #(min(total_pos_pixels, total_neg_pixels) / max(total_pos_pixels, total_neg_pixels))
            num_negative_samples = math.ceil((total_neg_pixels * percnt_pixels) * 0.1) 
        elif sampling_strategy == "fixed size":    
            num_negative_samples = min(total_neg_pixels, 3)
    
        neg_samples = random.sample(negative_coordinates, num_negative_samples)
    
    sampled_coordinates = []
    unique_vals, unique_counts = np.unique(lbl_tile, return_counts=True)
    smallest_category_count = min(unique_counts)
    
    if sampling_strategy == "natural frequency":
        if verbose:
            print("Chip ID: {}".format(lbl_grid_id)) 
        for val, count in zip(unique_vals, unique_counts):
            if val != 0:
                num_samples_per_cat = math.ceil(np.count_nonzero(lbl_tile == val) * percnt_pixels)
                crop_indices = np.where(lbl_tile == [val])
                crop_coordinates = list(zip(crop_indices[0], crop_indices[1]))
                crop_samples = random.sample(crop_coordinates, num_samples_per_cat)
                if verbose:
                    print("Number of sampled pixels of crop type {}: {}".format(val, len(crop_samples)))
                sampled_coordinates.extend(crop_samples)
        if get_neg_samp:
            sampled_coordinates.extend(neg_samples)
        return sampled_coordinates
    
    elif sampling_strategy == "fixed size":
        for val, count in zip(unique_vals, unique_counts):
            if val != 0:
                num_samples_per_cat = min(count, 8)
                crop_indices = np.where(lbl_tile == [val])
                crop_coordinates = list(zip(crop_indices[0], crop_indices[1]))
                crop_samples = random.sample(crop_coordinates, num_samples_per_cat)
                if verbose:
                    print("Chip ID: {}".format(lbl_grid_id))
                    print("Number of sampled pixels of crop type {}: {}".format(val, len(crop_samples)))
                sampled_coordinates.extend(crop_samples)
        if get_neg_samp:
            sampled_coordinates.extend(neg_samples)
        return sampled_coordinates
    
    else:
        raise ValueError("Sampling strategy is not recognized.")

############################################################

class CropTypeBatchSampler(Sampler):
    """
    Args:
            dataset (Pytorch dataset): list of tuples in the form of [(s1_img, s2_img, label),...,(s1_img, s2_img, label)]
            batch_size (int): Number of samples in a mini-batch training strategy.
    Returns:
            list of batches of list of sample indices.
    
    Note 1: Batches are designed so that samples in a batch are exactly the same in sequence length.
    Note 2: Batches might be of varied length. The number of batches that vary from batch size are maximum
            equal to the number of unique sequence lengths in the image source (s1).
    Note 3: No seperate padding is required for S1 using 'collate_fn'.
    """
    
    def __init__(self, dataset, batch_size):
        super(CropTypeBatchSampler, self).__init__(dataset)
        
        self.batch_size = batch_size
        self.indices_n_lengths = []
        
        for i in range(len(dataset)):
            self.indices_n_lengths.append((i, dataset[i][0].shape[0]))
        
        shuffle(self.indices_n_lengths)
        
        # dictionary with unique temporal length as keys and sample index as values.
        batch_map = OrderedDict()
        
        for idx, length in self.indices_n_lengths:
            if length not in batch_map:
                batch_map[length] = [idx]
            else:
                batch_map[length].append(idx)
        
        self.batch_list = []
        for length, indices in batch_map.items():
            for bucket in [indices[i:(i + self.batch_size)] for i in range(0, len(indices), self.batch_size)]:
                self.batch_list.append(bucket)
   
    def __len__(self):
        return len(self.batch_list)
    
    def __iter__(self):
        shuffle(self.batch_list)
        for i in self.batch_list:
            yield i

############################################################

class CropTypeBatchSampler2(Sampler):
    """
    Args:
            dataset (Pytorch dataset): list of tuples in the form of [(s1_img, s2_img, label),...,(s1_img, s2_img, label)]
            batch_size (int): Number of samples in a mini-batch training strategy.
    Returns:
            list of batches of list of sample indices
            
    Note 1: Batches are designed so that samples in a batch are closest in sequence length for S1.
    Note 2: The last batch might be shorter that the batch size.
    Note 3: Seperate padding might be required for S1 using 'collate_fn'.
    """
    
    def __init__(self, dataset, batch_size):
        super(CropTypeBatchSampler2, self).__init__(dataset)
        
        self.batch_size = batch_size
        
        self.batches = []
        batch = []
        indices_n_lengths = []
        
        for i in range(len(dataset)):
            indices_n_lengths.append((i, train_dataset[i][0].shape[0]))
        
        shuffle(indices_n_lengths)
        indices_n_lengths.sort(key = lambda x:x[1])
        
        for i in range(len(indices_n_lengths)):
            sample_idx = indices_n_lengths[i][0]
            batch.append(sample_idx)
            
            if len(batch) == self.batch_size:
                self.batches.append(batch)
                batch = []
                
        if len(batch) > 0:
            self.batches.append(batch)
    
    def __len__(self):
        return len(self.batches)
    
    def __iter__(self):
        for b in self.batches:
            yield(b)

############################################################

def collate_var_length(batch):
    
    batch_size = len(batch)
    
    labels = [batch[i][2] for i in range(batch_size)]
    label = torch.stack(labels)
    
    s1_grids = [batch[i][0] for i in range(batch_size)]
    s2_grids = [batch[i][1] for i in range(batch_size)]
    
    #s1_lengths = [batch[i][0].shape[0] for i in range(batch_size)]
    #s2_lengths = [batch[i][1].shape[0] for i in range(batch_size)]
    
    s1_img = rnn_util.pad_sequence(s1_grids, batch_first=True)
    s2_img = rnn_util.pad_sequence(s2_grids, batch_first=True)
    
    return s1_img, s2_img, label
    #return s1_img, s2_img, label, s1_lengths, s2_lengths 

######################################## Custom Dataset ######################################
"""
import os
from pathlib import Path
import random
import math
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import rasterio
"""

class pixelDataset(Dataset):
    """
    Args:
            root_dir (str): path to the main folder of the dataset, formatted as indicated in the readme
            country (str): name of the country. Necessary based on the organization of dataset into folders.
            lbl_fldrname (str): Name of the folder containing the labels.
            usage (str): decide whether we are making a "train", "validation" or "test" dataset.
            sources (list of str): Sensors of image acquisition. At the moment two sensors 
                                   are used ["Sentinel-1", "Sentinel-2"]
            percnt_pixels (float): Defines the number of pixels to be randomly sampled from each Grid-ID as
                                   a percentage of the total number of field pixels in that Grid-ID. If set
                                   to "None" then all the crop pixels are sampled. Default value is "None".
            useCloudMask (Binary) : Decides whether to apply cloud mask on Sentinel-2 images. Default is False.
            transform (str): apply the temporal jittering augmentation on temporal pixel samples. Default is None.
    """
    
    def __init__(self, root_dir, country, lbl_fldrname, usage, sources = ["Sentinel-1", "Sentinel-2"], 
                 percnt_pixels = None, sampling_strategy = "natural frequency", transform = None):
        
        self.usage = usage
        self.sources = sources
        self.percnt_pixels = percnt_pixels
        self.sampling_strategy = sampling_strategy
        self.transform = transform
        
        if self.usage in ["train", "validation"]:
     
            self.lbl_dir = Path(root_dir) / country / lbl_fldrname / self.usage
        
            lbl_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(self.lbl_dir) for \
                          f in filenames if f.endswith(".tif")]
            lbl_fnames.sort()
            
            s1_src_path = Path(root_dir) / country / "Sentinel-1" / self.usage
            s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_src_path) for \
                                         f in filenames if f.endswith(".npy")]
            s1_fnames.sort()
            
            s2_src_path = Path(root_dir) / country / "Sentinel-2" / self.usage
            s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for \
                                         f in filenames if f.endswith(".npy") if "source" in f]
            s2_fnames.sort()

            assert len(lbl_fnames) == len(s1_fnames) == len(s2_fnames)
            
            self.lbl = []
            self.lbl_grid_ids = []
            self.s1 = []
            self.s2 = []
            
            for lbl_fn, s1_fn, s2_fn in tqdm.tqdm(zip(lbl_fnames, s1_fnames, s2_fnames), total = len(lbl_fnames)):

                lbl_grid_id = str(lbl_fn).split("_")[-1].replace(".tif", "")
                s1_grid_id = str(s1_fn).split("_")[-1].replace(".npy", "")
                s2_grid_id = str(s2_fn).split("_")[-1].replace(".npy", "")
                
                lbl_array = load_data(lbl_fn, isLabel = True)
                s1_array = load_data(s1_fn, isLabel = False)
                s2_array = load_data(s2_fn, isLabel = False)
                s2_array[s2_array == +inf] = 0
                s2_array[s2_array == -inf] = 0
                #print(lbl_grid_id)
                sample_coordinates = get_pixel_coord(lbl_array, lbl_grid_id, self.percnt_pixels, self.sampling_strategy)
                
                assert lbl_grid_id == s1_grid_id == s2_grid_id
                assert lbl_array.shape[0] == s1_array.shape[1] == s2_array.shape[1]
                assert lbl_array.shape[1] == s1_array.shape[2] == s2_array.shape[2]
                
                
                
                for coord in sample_coordinates:
                    lbl_val = lbl_array[coord[0], coord[1]]
                    self.lbl.append(lbl_val.copy())
                    
                    s1_val = s1_array[:,coord[0], coord[1],:]
                    self.s1.append(s1_val.copy())
                    
                    s2_val = s2_array[:,coord[0], coord[1],:]
                    self.s2.append(s2_val.copy())
                    
                    self.lbl_grid_ids.append(lbl_grid_id)
                
                del lbl_array, s1_array, s2_array
                gc.collect()

        #print("Size of lbl: " + str(sys.getsizeof(self.lbl)) + "bytes")
        #print("Size of s1: " + str(sys.getsizeof(self.s1)) + "bytes")
        #print("Size of s2: " + str(sys.getsizeof(self.s2)) + "bytes")
        
        assert len(self.s1) == len(self.s2)
        print("------{} samples from each of the Sentinel sources are loaded in the {} dataset------".format(len(self.s1),
                                                                                                             self.usage))
        
        if self.usage == "test":
            pass

    
    def __getitem__(self, index):
        
        if self.usage in ["train", "validation"]:
            s1_img = self.s1[index]
            s2_img = self.s2[index]
            label = self.lbl[index]
            grid_id = self.lbl_grid_ids[index]
            
            if (self.usage == "train") and self.transform:
                pass
                """
                # apply the transformation.
                if random.randint(0, 1) and "shift brightness" in self.transform:
                    s1_img = shift_tstack_brightness(s1_img, gamma_range=(0.2, 2.0), sample_number=5, 
                                                     patch_shift=False, mode = "random")
                    
                if random.randint(0, 1) and "jitter" in self.transform:
                    s1_img = jitter_tstack(s1_img, sample_number=5)
                """    
                
            # numpy to torch
            # tensor shape: (N x C x T)
            s1_img = torch.from_numpy(s1_img.transpose((1, 0))).float()
            s2_img = torch.from_numpy(s2_img.transpose((1, 0))).float()
            label = torch.from_numpy(np.asarray(label)).long()
            grid_id = torch.from_numpy(np.asarray(int(grid_id))).long()
                
            return s1_img, s2_img, label
        
        else:
            s1_img = self.s1[index]
            s2_img = self.s2[index]
            label = self.lbl[index]
            
            s1_img = torch.from_numpy(s1_img.transpose((1, 0))).float()
            s2_img = torch.from_numpy(s2_img.transpose((1, 0))).float()
            label = torch.from_numpy(np.asarray(label)).long()
            grid_id = torch.from_numpy(np.asarray(int(grid_id))).long()
            
            return s1_img, s2_img, label
    
    def __len__(self):
        return len(self.s1)

### Rußwurm Bavarian Crop custom dataset

In [None]:
"""
import torch
import torch.utils.data
import pandas as pd
import os
import numpy as np
from numpy import genfromtxt
import tqdm
"""

BANDS = ['B1', 'B10', 'B11', 'B12', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9']
NORMALIZING_FACTOR = 1e-4
PADDING_VALUE = -1

class BavarianCropsDataset(Dataset):
    
    def __init__(self, root, partition, classmapping, mode=None, scheme="random", region=None, samplet=70, 
                 cache=True, seed=0, validfraction=0.1):
        
        assert (mode in ["trainvalid", "traintest"] and scheme=="random") or (mode is None and scheme=="blocks")
        assert scheme in ["random","blocks"]
        assert partition in ["train","test","trainvalid","valid"]
        
        self.seed = seed
        self.validfraction = validfraction
        self.scheme = scheme
        
        seed += sum([ord(ch) for ch in partition])
        np.random.seed(seed)
        torch.random.manual_seed(seed)
        
        self.mode = mode
        self.root = root
        
        if scheme=="random":
            if mode == "traintest":
                self.trainids = os.path.join(self.root, "ids", "random", region+"_train.txt")
                self.testids = os.path.join(self.root, "ids", "random", region+"_test.txt")
            elif mode == "trainvalid":
                self.trainids = os.path.join(self.root, "ids", "random", region+"_train.txt")
                self.testids = None
            
            self.read_ids = self.read_ids_random
        
        elif scheme=="blocks":
            self.trainids = os.path.join(self.root, "ids", "blocks", region+"_train.txt")
            self.testids = os.path.join(self.root, "ids", "blocks", region+"_test.txt")
            self.validids = os.path.join(self.root, "ids", "blocks", region + "_valid.txt")
            
            self.read_ids = self.read_ids_blocks
            
        self.mapping = pd.read_csv(classmapping, index_col=0).sort_values(by="id")
        self.mapping = self.mapping.set_index("nutzcode")
        self.classes = self.mapping["id"].unique()
        self.classname = self.mapping.groupby("id").first().classname.values
        self.klassenname = self.mapping.groupby("id").first().klassenname.values
        self.nclasses = len(self.classes)
        
        self.region = region
        self.partition = partition
        self.data_folder = "{root}/csv/{region}".format(root=self.root, region=self.region)
        self.samplet = samplet
        
        print("Initializing BavarianCropsDataset {} partition in {}".format(self.partition, self.region))
        
        self.cache = os.path.join(self.root,"npy",os.path.basename(classmapping), scheme,region, partition)
        print("read {} classes".format(self.nclasses))
        
        if cache and self.cache_exists() and not self.mapping_consistent_with_cache():
            self.clean_cache()
        
        if cache and self.cache_exists() and self.mapping_consistent_with_cache():
            print("precached dataset files found at " + self.cache)
            self.load_cached_dataset()
        else:
            print("no cached dataset found. iterating through csv folders in " + str(self.data_folder))
            self.cache_dataset()
        
        self.hist, _ = np.histogram(self.y, bins=self.nclasses)
        print("loaded {} samples".format(len(self.ids)))
        print(self)
    
    
    def __str__(self):
        return "Dataset {}. region {}. partition {}. X:{}, y:{} with {} classes".format(self.root, self.region, 
                                                                                        self.partition,str(len(self.X)) +"x"+ 
                                                                                        str(self.X[0].shape), self.y.shape, 
                                                                                        self.nclasses)
    
    def read_ids_random(self):
        assert isinstance(self.seed, int)
        assert isinstance(self.validfraction, float)
        assert self.partition in ["train", "valid", "test"]
        assert self.trainids is not None
        assert os.path.exists(self.trainids)

        np.random.seed(self.seed)

        """if trainids file provided and no testids file <- sample holdback set from trainids"""
        if self.testids is None:
            assert self.partition in ["train", "valid"]

            print("partition {} and no test ids file provided.\ 
                  Splitting trainids file in train and valid partitions".format(self.partition))

            with open(self.trainids,"r") as f:
                ids = [int(id) for id in f.readlines()]
            print("Found {} ids in {}".format(len(ids), self.trainids))

            np.random.shuffle(ids)

            validsize = int(len(ids) * self.validfraction)
            validids = ids[:validsize]
            trainids = ids[validsize:]

            print("splitting {} ids in {} for training and {} for validation".format(len(ids), 
                                                                                     len(trainids), len(validids)))

            assert len(validids) + len(trainids) == len(ids)

            if self.partition == "train":
                return trainids
            if self.partition == "valid":
                return validids

        elif self.testids is not None:
            assert self.partition in ["train", "test"]

            if self.partition=="test":
                with open(self.testids,"r") as f:
                    test_ids = [int(id) for id in f.readlines()]
                print("Found {} ids in {}".format(len(test_ids), self.testids))
                return test_ids

            if self.partition == "train":
                with open(self.trainids, "r") as f:
                    train_ids = [int(id) for id in f.readlines()]
                return train_ids
    
    
    def read_ids_blocks(self):
        assert self.partition in ["train", "valid", "test", "trainvalid"]
        assert os.path.exists(self.validids)
        assert os.path.exists(self.testids)
        assert os.path.exists(self.trainids)
        assert self.scheme == "blocks"
        assert self.mode is None

        def read(filename):
            with open(filename, "r") as f:
                ids = [int(id) for id in f.readlines()]
            return ids

        if self.partition == "train":
            ids = read(self.trainids)
        elif self.partition == "valid":
            ids = read(self.validids)
        elif self.partition == "test":
            ids = read(self.testids)
        elif self.partition == "trainvalid":
            ids = read(self.trainids) + read(self.validids)
        return ids
    
    
    def cache_dataset(self):
        """
        Iterates though the data folders and stores y, ids, classweights, and sequencelengths
        X is loaded at with getitem
        """
        #ids = self.split(self.partition)

        ids = self.read_ids()
        assert len(ids) > 0

        self.X = list()
        self.nutzcodes = list()
        self.stats = dict(not_found=list())
        self.ids = list()
        self.samples = list()

        for id in tqdm.tqdm(ids):

            id_file = self.data_folder + "/{id}.csv".format(id=id)
            if os.path.exists(id_file):
                self.samples.append(id_file)

                X,nutzcode = self.load(id_file)

                if len(nutzcode) > 0:
                    nutzcode = nutzcode[0]
                    if nutzcode in self.mapping.index:
                        self.X.append(X)
                        self.nutzcodes.append(nutzcode)
                        self.ids.append(id)
            else:
                self.stats["not_found"].append(id_file)

        self.y = self.applyclassmapping(self.nutzcodes)

        self.sequencelengths = np.array([np.array(X).shape[0] for X in self.X])
        assert len(self.sequencelengths) > 0
        self.sequencelength = self.sequencelengths.max()
        self.ndims = np.array(X).shape[1]

        self.hist,_ = np.histogram(self.y, bins=self.nclasses)
        self.classweights = 1 / self.hist
        self.cache_variables(self.y, self.sequencelengths, self.ids, self.ndims, self.X, self.classweights)
    
    
    def mapping_consistent_with_cache(self):
        # cached y must have the same number of classes than the mapping
        return True
        #return len(np.unique(np.load(os.path.join(self.cache, "y.npy")))) == self.nclasses
    
    
    def cache_variables(self, y, sequencelengths, ids, ndims, X, classweights):
        os.makedirs(self.cache, exist_ok=True)
        # cache
        np.save(os.path.join(self.cache, "classweights.npy"), classweights)
        np.save(os.path.join(self.cache, "y.npy"), y)
        np.save(os.path.join(self.cache, "ndims.npy"), ndims)
        np.save(os.path.join(self.cache, "sequencelengths.npy"), sequencelengths)
        np.save(os.path.join(self.cache, "ids.npy"), ids)
        #np.save(os.path.join(self.cache, "dataweights.npy"), dataweights)
        np.save(os.path.join(self.cache, "X.npy"), X)
    
    def load_cached_dataset(self):
        # load
        self.classweights = np.load(os.path.join(self.cache, "classweights.npy"))
        self.y = np.load(os.path.join(self.cache, "y.npy"))
        self.ndims = int(np.load(os.path.join(self.cache, "ndims.npy")))
        self.sequencelengths = np.load(os.path.join(self.cache, "sequencelengths.npy"))
        self.sequencelength = self.sequencelengths.max()
        self.ids = np.load(os.path.join(self.cache, "ids.npy"))
        self.X = np.load(os.path.join(self.cache, "X.npy"), allow_pickle=True)
    
    
    def cache_exists(self):
        weightsexist = os.path.exists(os.path.join(self.cache, "classweights.npy"))
        yexist = os.path.exists(os.path.join(self.cache, "y.npy"))
        ndimsexist = os.path.exists(os.path.join(self.cache, "ndims.npy"))
        sequencelengthsexist = os.path.exists(os.path.join(self.cache, "sequencelengths.npy"))
        idsexist = os.path.exists(os.path.join(self.cache, "ids.npy"))
        Xexists = os.path.exists(os.path.join(self.cache, "X.npy"))
        return yexist and sequencelengthsexist and idsexist and ndimsexist and Xexists and weightsexist
    
    
    def clean_cache(self):
        os.remove(os.path.join(self.cache, "classweights.npy"))
        os.remove(os.path.join(self.cache, "y.npy"))
        os.remove(os.path.join(self.cache, "ndims.npy"))
        os.remove(os.path.join(self.cache, "sequencelengths.npy"))
        os.remove(os.path.join(self.cache, "ids.npy"))
        #os.remove(os.path.join(self.cache, "dataweights.npy"))
        os.remove(os.path.join(self.cache, "X.npy"))
        os.removedirs(self.cache)
    
    
    def load(self, csv_file, load_pandas = False):
        """['B1', 'B10', 'B11', 'B12', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8',
       'B8A', 'B9', 'QA10', 'QA20', 'QA60', 'doa', 'label', 'id']"""

        if load_pandas:
            sample = pd.read_csv(csv_file, index_col=0)
            X = np.array((sample[BANDS] * NORMALIZING_FACTOR).values)
            nutzcodes = sample["label"].values
            # nutzcode to classids (451,411) -> (0,1)

        else: # load with numpy
            data = genfromtxt(csv_file, delimiter=',', skip_header=1)
            X = data[:, 1:14] * NORMALIZING_FACTOR
            nutzcodes = data[:, 18]

        # drop times that contain nans
        if np.isnan(X).any():
            t_without_nans = np.isnan(X).sum(1) > 0

            X = X[~t_without_nans]
            nutzcodes = nutzcodes[~t_without_nans]

        return X, nutzcodes
    
    
    def applyclassmapping(self, nutzcodes):
        """uses a mapping table to replace nutzcodes (e.g. 451, 411) with class ids"""
        return np.array([self.mapping.loc[nutzcode]["id"] for nutzcode in nutzcodes])
    
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):

        load_file = False
        if load_file:
            id = self.ids[idx]
            csvfile = os.path.join(self.data_folder, "{}.csv".format(id))
            X,nutzcodes = self.load(csvfile)
            y = self.applyclassmapping(nutzcodes=nutzcodes)
        else:

            X = self.X[idx]
            y = np.array([self.y[idx]] * X.shape[0]) # repeat y for each entry in x

        # pad up to maximum sequence length
        t = X.shape[0]

        if self.samplet is None:
            npad = self.sequencelengths.max() - t
            X = np.pad(X,[(0,npad), (0,0)],'constant', constant_values=PADDING_VALUE)
            y = np.pad(y, (0, npad), 'constant', constant_values=PADDING_VALUE)
        else:
            idxs = np.random.choice(t, self.samplet, replace=False)
            idxs.sort()
            X = X[idxs]
            y = y[idxs]


        X = torch.from_numpy(X).type(torch.FloatTensor)
        y = torch.from_numpy(y).type(torch.LongTensor)

        return X, y, self.ids[idx]


In [None]:
import torch
import numpy as np
import bisect
import warnings

class ConcatDataset(torch.utils.data.Dataset):
    """
    Dataset to concatenate multiple datasets.
    Purpose: useful to assemble different existing datasets, possibly
    large-scale datasets as the concatenation operation is done in an
    on-the-fly manner.
    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        self.nclasses = datasets[0].nclasses
        self.mapping = datasets[0].mapping
        self.classes = datasets[0].classes
        self.sequencelength = datasets[0].sequencelength
        self.sequencelengths = datasets[0].sequencelengths
        self.ndims = datasets[0].ndims
        self.classweights = datasets[0].classweights
        self.classname = datasets[0].classname
        self.klassenname = datasets[0].klassenname
        self.hist = np.array([d.hist for d in self.datasets]).sum(0)
        self.partition = self.datasets[0].partition

        self.y = np.concatenate([d.y for d in self.datasets], axis=0)
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

In [None]:
def prepare_dataset(args):

    if args.dataset == "BavarianCrops":
        root = os.path.join(args.dataroot,"BavarianCrops")

        #ImbalancedDatasetSampler
        test_dataset_list = list()
        for region in args.testregions:
            test_dataset_list.append(
                BavarianCropsDataset(root=root, region=region, partition=args.test_on,
                                            classmapping=args.classmapping, samplet=args.samplet,
                                     scheme=args.scheme,mode=args.mode, seed=args.seed)
            )

        train_dataset_list = list()
        for region in args.trainregions:
            train_dataset_list.append(
                BavarianCropsDataset(root=root, region=region, partition=args.train_on,
                                            classmapping=args.classmapping, samplet=args.samplet,
                                     scheme=args.scheme,mode=args.mode, seed=args.seed)
            )

    if args.dataset == "VNRice":
        train_dataset_list=[VNRiceDataset(root=args.root, partition=args.train_on, samplet=args.samplet,
                                          mode=args.mode, seed=args.seed)]

        test_dataset_list=[VNRiceDataset(root=args.root, partition=args.test_on, samplet=args.samplet,
                                         mode=args.mode, seed=args.seed)]

    if args.dataset == "BreizhCrops":
        root = "/home/marc/projects/BreizhCrops/data"

        train_dataset_list = list()
        for region in args.trainregions:
            train_dataset_list.append(
                CropsDataset(root=root, region=region, samplet=args.samplet)
            )

        #ImbalancedDatasetSampler
        test_dataset_list = list()
        for region in args.testregions:
            test_dataset_list.append(
                CropsDataset(root=root, region=region, samplet=args.samplet)
            )

    elif args.dataset == "GAFv2":
        root = os.path.join(args.dataroot,"GAFdataset")

        #ImbalancedDatasetSampler
        test_dataset_list = list()
        for region in args.testregions:
            test_dataset_list.append(
                GAFDataset(root, region=region, partition="test", scheme=args.scheme, classmapping=args.classmapping, features=args.features)
            )

        train_dataset_list = list()
        for region in args.trainregions:
            train_dataset_list.append(
                GAFDataset(root, region=region, partition="train", scheme=args.scheme, classmapping=args.classmapping, features=args.features)
            )

    print("setting random seed to "+str(args.seed))
    np.random.seed(args.seed)
    if args.seed is not None:
        torch.random.manual_seed(args.seed)

    traindataset = ConcatDataset(train_dataset_list)
    traindataloader = torch.utils.data.DataLoader(dataset=traindataset, sampler=RandomSampler(traindataset),
                                                  batch_size=args.batchsize, num_workers=args.workers)

    testdataset = ConcatDataset(test_dataset_list)

    testdataloader = torch.utils.data.DataLoader(dataset=testdataset, sampler=SequentialSampler(testdataset),
                                                 batch_size=args.batchsize, num_workers=args.workers)

    return traindataloader, testdataloader

## Custom Loss functions, Accuracy Metric and Evaluation

In [None]:
"""
import torch
from torch import nn
"""

class BalancedCrossEntropyLoss(nn.Module):
    '''
    Balanced cross entropy loss by weighting of inverse class ratio
    Params:
        ignore_index (int): Class index to ignore
        reduction (str): Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    '''

    def __init__(self, ignore_index=-100, reduction='mean'):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predict, target):
        #set_trace()
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        for i in range(len(unique)):
            lossWeight[unique[i]] = weight[i]
        loss = nn.CrossEntropyLoss(weight=lossWeight, ignore_index=self.ignore_index, reduction=self.reduction)

        return loss(predict, target)


## Model architecture

### LSTM 

In [None]:
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import os
"""
# Conv_based Expansive Prametric Spectral Index
class CEPSI(torch.nn.Module):
    def __init__(self, input_dim, expanded_dim):
        super(CEPSI, self).__init__()

        layers = [nn.Conv1d(input_dim, input_dim, kernel_size = 1, stride = 1, padding = 0, bias=False),
                  nn.BatchNorm1d(input_dim),
                  nn.ReLU(inplace = True),]
        
        layers += [nn.Conv1d(input_dim, expanded_dim, kernel_size = 1, stride = 1, padding = 0, bias=False),
                  nn.BatchNorm1d(expanded_dim),
                  nn.ReLU(inplace = True),]
        
        layers += [nn.Conv1d(expanded_dim, expanded_dim, kernel_size = 1, stride = 1, padding = 0, bias=False),
                  nn.BatchNorm1d(expanded_dim),
                  nn.ReLU(inplace = True),]
        
        self.block = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.block(inputs)

###################################################################################################

# Dot-product attention between Bi-LSTM last states and its output.
class attention(nn.Module):
    def __init__(self, attn_dropout=0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v):
        #set_trace()
        query = q.unsqueeze(1)
        
        key = k.transpose(2,1).contiguous()
        weight_score = torch.bmm(query, key)
        
        attn = self.softmax(weight_score)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        
        return output, attn

###########################################################################################

class Double_branch_stacked_biLSTM(torch.nn.Module):
    def __init__(self, input_dims = (4, 16), hidden_dim = 128, n_classes = 6, n_layers = 4, 
                 dropout_rate = 0.57, s1_weight = 0.8, bidirectional = True, use_layernorm = True, 
                 use_batchnorm = False, use_attention = False, use_cepsi=False):
        super(Double_branch_stacked_biLSTM, self).__init__()
        
        # Define object properties
        self.n_classes = n_classes
        self.s1_weight = s1_weight
        self.bidirectional = bidirectional
        self.use_layernorm = use_layernorm
        self.use_batchnorm = use_batchnorm
        self.use_attention = use_attention
        self.use_cepsi = use_cepsi
        self.model_depth = n_layers * hidden_dim
        
        # Get the input dimensions for Sentinel-1 and 2 datasets
        if isinstance(input_dims, tuple) or isinstance(input_dims, list):
            s1_in_dim = input_dims[0]
            s2_in_dim = input_dims[1]
            
        if self.use_cepsi:
            s1_expanded_dim = input_dims[0] * 2
            s2_expanded_dim = input_dims[1] * 3
            self.s1_cepsi = CEPSI(s1_in_dim, s1_expanded_dim)
            self.s2_cepsi = CEPSI(s2_in_dim, s2_expanded_dim)
        
        
        # Layer normalization for s1, s2 inputs and current_states of LSTM
        if self.use_layernorm:
            self.s1_inlayernorm = nn.LayerNorm(s1_expanded_dim if self.use_cepsi else s1_in_dim)
            self.s2_inlayernorm = nn.LayerNorm(s2_expanded_dim if self.use_cepsi else s2_in_dim)
            self.clayernorm = nn.LayerNorm((hidden_dim + hidden_dim * self.bidirectional) * n_layers)
        
        # LSTM layers for s1 and s2
        self.s1_lstm = nn.LSTM(input_size = s1_expanded_dim if self.use_cepsi else s1_in_dim, hidden_size = hidden_dim, 
                               num_layers = n_layers, bias = False, batch_first = True, dropout = dropout_rate, 
                               bidirectional = self.bidirectional)
        self.s2_lstm = nn.LSTM(input_size = s2_expanded_dim if self.use_cepsi else s2_in_dim, hidden_size = hidden_dim, 
                               num_layers = n_layers, bias = False, batch_first = True, dropout = dropout_rate, 
                               bidirectional = self.bidirectional)
        
        if self.bidirectional:
            hidden_dim = hidden_dim * 2
        
        if self.use_attention:
            self.attention = attention()
        
        # MLP layer on top of LSTM
        linear_input_dim = hidden_dim if self.use_attention else hidden_dim * n_layers
        self.linear_class = nn.Linear(linear_input_dim, self.n_classes, bias = True)
 

    def _logits(self, s1, s2):
        #set_trace()
        if self.use_cepsi:
            s1 = self.s1_cepsi(s1.transpose(2,1).contiguous())
            s2 = self.s2_cepsi(s2.transpose(2,1).contiguous())
        
        if self.use_layernorm:
            s1 = self.s1_inlayernorm(s1.transpose(2,1).contiguous() if self.use_cepsi else s1)
            s2 = self.s2_inlayernorm(s2.transpose(2,1).contiguous() if self.use_cepsi else s2)
        
        # Get outputs and the last current state and hidden state for each branch.
        #s1_outputs & s2_outputs: [B, Seq_length, 2 x hidden_dim]
        s1_outputs, s1_last_state_list = self.s1_lstm.forward(s1)
        s2_outputs, s2_last_state_list = self.s2_lstm.forward(s2)
        
        #s1_h & s1_c & s2_h & s2_c: [2 x num_layers, B, hidden_dim] 
        s1_h, s1_c = s1_last_state_list
        s2_h, s2_c = s2_last_state_list
        
        # Get the query layer to calculate self attention for each branch
        if self.use_attention:
            if self.bidirectional:
                # Get the last state of each branch. size:[B, hidden_dim]
                s1_query_forward = s1_c[-1]
                s1_query_backward = s1_c[-2]
                # size:[B, 2 x hidden_dim]
                s1_query = torch.cat([s1_query_forward, s1_query_backward], 1)
                
                s2_query_forward = s2_c[-1]
                s2_query_backward = s2_c[-2]
                s2_query = torch.cat([s2_query_forward, s2_query_backward], 1)
            else:
                s1_query = s1_c[-1]
                s2_query = s2_c[-1]
            
            # Get attention weights and hidden state
            s1_h, s1_weights = self.attention(s1_query, s1_outputs, s1_outputs)
            s2_h, s2_weights = self.attention(s2_query, s2_outputs, s2_outputs)
            s1_h = s1_h.squeeze(1)
            s2_h = s2_h.squeeze(1)
        else:
            s1_nlayers, s1_batchsize, s1_n_hidden = s1_c.shape
            s2_nlayers, s2_batchsize, s2_n_hidden = s2_c.shape
            s1_h = self.clayernorm(s1_c.transpose(0,1).contiguous().view(s1_batchsize, s1_nlayers * s1_n_hidden))
            s2_h = self.clayernorm(s2_c.transpose(0,1).contiguous().view(s2_batchsize, s2_nlayers * s2_n_hidden))
        
        # Calculate logits for each branch. Shape:[B, num_classes]
        s1_logits = self.linear_class.forward(s1_h)
        s2_logits = self.linear_class.forward(s2_h)
        
        if self.use_attention:
            s1_pts = s1_weights
            s2_pts = s2_weights
        else:
            s1_pts = None
            s2_pts = None
        
        return s1_logits, s2_logits, s1_pts, s2_pts
    
    def forward(self, s1, s2):
        s1_logits, s2_logits, s1_pts, s2_pts = self._logits(s1, s2)
        out_logits = (s1_logits * self.s1_weight) + (s2_logits * (1 - self.s1_weight))
        #s1_logprob = F.log_softmax(s1_logits, dim=-1)
        #s2_logprob = F.log_softmax(s2_logits, dim=-1)
        
        return out_logits
        #return s1_logprob, s2_logprob
        

## Training and Inference procedure (In Progress )

In [None]:
"""
from datetime import datetime
from tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
"""

def get_optimizer(optimizer, params, lr, momentum):

    optimizer = optimizer.lower()
    if optimizer == 'sgd':
        return torch.optim.SGD(params, lr, momentum=momentum)
    elif optimizer == 'nesterov':
        return torch.optim.SGD(params, lr, momentum=momentum, nesterov=True)
    elif optimizer == 'adam':
        return torch.optim.Adam(params, lr)
    elif optimizer == 'amsgrad':
        return torch.optim.Adam(params, lr, amsgrad=True)
    else:
        raise ValueError("{} currently not supported, please customize your optimizer in compiler.py".format(optimizer))

###############################################################################################

class PolynomialLR(_LRScheduler):
    """Polynomial learning rate decay until step reach to max_decay_step
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_decay_steps: after this step, we stop decreasing learning rate
        min_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value
        power: The power of the polynomial.
    """
    
    def __init__(self, optimizer, max_decay_steps, min_learning_rate=1e-5, power=1.0):
        if max_decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')
        self.max_decay_steps = max_decay_steps
        self.min_learning_rate = min_learning_rate
        self.power = power
        self.last_step = 0
        super().__init__(optimizer)
        
    def get_lr(self):
        if self.last_step > self.max_decay_steps:
            return [self.min_learning_rate for _ in self.base_lrs]

        return [(base_lr - self.min_learning_rate) * 
                ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                self.min_learning_rate for base_lr in self.base_lrs]
    
    def step(self, step=None):
        if step is None:
            step = self.last_step + 1
        self.last_step = step if step != 0 else 1
        if self.last_step <= self.max_decay_steps:
            decay_lrs = [(base_lr - self.min_learning_rate) * 
                         ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                         self.min_learning_rate for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
                param_group['lr'] = lr

##########################################################################

class ModelCompiler:
    '''
    Compiler of specified model
    Attributes:
        model (''nn.Module''): pytorch model for segmentation
        classNum (int): output class number of given model
        buffer (int): distance to sample edges not considered in optimization
        gpuDevices (list): indices of gpu devices to use
        params_init (dict): initial model parameters
    '''

    def __init__(self, model, gpuDevices=[0], params_init=None, freeze_params = None):

        self.s3_client = boto3.client("s3")
        self.working_dir = config["working_dir"]
        self.out_dir = config["out_dir"]
        self.gpuDevices = gpuDevices
        self.model = model
        
        self.model_name = self.model.__class__.__name__

        if params_init:
            self.load_params(params_init, freeze_params)

        # gpu
        self.gpu = torch.cuda.is_available()
        if self.gpu:
            print("----------GPU available----------")
            # GPU setting
            if gpuDevices:
                torch.cuda.set_device(gpuDevices[0])
                self.model = torch.nn.DataParallel(self.model, device_ids=gpuDevices)
            self.model = self.model.cuda()
        
        num_params = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        print("total number of trainable parameters: {:2.1f}M".format(num_params / 1000000))
        
        if params_init:
            print("---------- Pre-trained model compiled successfully ----------")
        else:
            print("---------- Vanilla Model compiled successfully ----------")


    def load_params(self, dir_params, freeze_params):

        params_init = urlparse.urlparse(dir_params)
        # load from s3
        if params_init.scheme == "s3":
            
            bucket = params_init.netloc
            params_key = params_init.path
            params_key = params_key[1:] if params_key.startswith('/') else params_key
            _, fn_params = os.path.split(params_key)

            self.s3_client.download_file(Bucket=bucket,
                                         Key=params_key,
                                         Filename=fn_params)
            inparams = torch.load(fn_params, map_location="cuda:{}".format(self.gpuDevices[0]))

            os.remove(fn_params)  # remove after loaded

        ## or load from local
        else:
            inparams = torch.load(dir_params)

        ## overwrite model entries with new parameters
        model_dict = self.model.state_dict()

        if "module" in list(inparams.keys())[0]:
            inparams_filter = {k[7:]: v.cpu() for k, v in inparams.items() if k[7:] in model_dict}

        else:
            inparams_filter = {k: v.cpu() for k, v in inparams.items() if k in model_dict}
        
        model_dict.update(inparams_filter)
        self.model.load_state_dict(model_dict)
        
        if freeze_params != None:
            for i, p in enumerate(self.model.parameters()):
                if i in freeze_params:
                    p.requires_grad = False


    def fit(self, trainDataset, valDataset, epochs, optimizer_name, lr_init, LR_policy, criterion, momentum = None):

        # Set the folder to save results.
        working_dir = self.working_dir
        out_dir = self.out_dir
        model_name = self.model_name
        self.model_dir = "{}/{}/{}_ep{}".format(working_dir, self.out_dir, model_name, epochs)
        
        if not os.path.exists(Path(working_dir) / out_dir / self.model_dir):
            os.makedirs(Path(working_dir) / out_dir / self.model_dir)
        
        os.chdir(Path(working_dir) / out_dir / self.model_dir)
        
        print("--------------- Start training ---------------")
        start = datetime.now()

        # Tensorboard writer setting
        writer = SummaryWriter('./')

        train_loss = []
        val_loss = []
        lr = lr_init
        
        optimizer = get_optimizer(optimizer_name, self.model.parameters(), lr, momentum)
        
        # Initialize the learning rate scheduler
        if LR_policy == "StepLR":
            scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                                  step_size = 10, 
                                                  gamma = 0.25,)
        
        elif LR_policy == "PolynomialLR":
            scheduler = PolynomialLR(optimizer, 
                                     max_decay_steps=75, 
                                     min_learning_rate=1e-5, 
                                     power=0.85)
        else:
            scheduler = None  
        
        
        for t in range(epochs):

            print("[{}/{}]".format(t + 1, epochs))
            # start fitting
            start_epoch = datetime.now()
            train(trainDataset, self.model, criterion, optimizer, gpu=self.gpu, train_loss=train_loss)
            validate(valDataset, self.model, criterion, gpu=self.gpu, val_loss=val_loss)

            # Update the scheduler
            if LR_policy == "StepLR":
                scheduler.step()
                print("LR: {}".format(scheduler.get_last_lr()))

            if LR_policy == "PolynomialLR":
                scheduler.step(t)
                print("LR: {}".format(optimizer.param_groups[0]['lr']))
            
            # time spent on single iteration
            print("time:", (datetime.now() - start_epoch).seconds)

            #if t > 1 and t % lr_decay[1] == 0:
                #lr *= lr_decay[0]

            writer.add_scalars("Loss", {"train_loss": train_loss[t], "validation_loss": val_loss[t]}, t + 1)
            
            writer.close()
        
        print("--------------- Training finished in {}s ---------------".format((datetime.now() - start).seconds))
    
    def accuracy_evaluation(self, evalDataset, outPrefix, bucket = None):
        
        if not os.path.exists(Path(self.working_dir) / self.out_dir):
            os.makedirs(Path(self.working_dir) / self.out_dir)
        
        os.chdir(Path(self.working_dir) / self.out_dir)
        
        print("--------------- Start evaluation ---------------")
        start = datetime.now()
        
        accuracy_evaluation(evalDataset, self.model, self.gpu, outPrefix, bucket)
        
        print("--------------- Evaluation finished in {}s ---------------".format((datetime.now() - start).seconds))
        
    def save(self, save_fldr, bucket = None, object = "params"):
        
        outPrefix = Path(self.working_dir) / self.out_dir / save_fldr
        
        if object == "params":
            
            fn_params = "{}_params.pth".format(self.model_name)
            
            if bucket:
                torch.save(self.model.state_dict(), fn_params )

                self.s3_client.upload_file(Filename=fn_params, 
                                           Bucket=bucket, 
                                           Key=os.path.join(outPrefix, fn_params))
                print("model parameters uploaded to s3!, at ", outPrefix)
                
                os.remove(Path(outPrefix) / fn_params)
                
            else:
                
                if not os.path.exists(Path(outPrefix)):
                    os.makedirs(Path(outPrefix))
                
                torch.save(self.model.state_dict(), Path(outPrefix) / fn_params)
                print("model parameters is saved locally, at ", outPrefix)
            
        elif object == "model":
            
            fn_model = "{}.pth".format(self.model_name)
            
            if bucket:
                torch.save(self.model, fn_model)

                self.s3_client.upload_file(Filename=fn_model,
                                           Bucket=bucket, 
                                           Key=os.path.join(outPrefix, fn_model))
                print("model uploaded to s3!, at ", outPrefix)
                
                os.remove(Path(outPrefix) / fn_params)
            
            else:
                
                if not os.path.exists(Path(outPrefix)):
                    os.makedirs(Path(outPrefix))
                
                torch.save(self.model, Path(outPrefix) / fn_params)
                print("model saved locally, at ", outPrefix)

        else:
            raise ValueError("Object type is not acceptable.")

################################################################################################################
################################### Train, Evaluate, Validate and Predict ######################################
################################################################################################################

def train(trainData, model, criterion, optimizer, gpu=True, train_loss=[]):
    
    model.train()
    epoch_loss = 0
    i = 0
    
    for s1_img, s2_img, label in trainData:
        s1_img = Variable(s1_img)
        #s1_img[s1_img != s1_img] = -100
        s2_img = Variable(s2_img)
        #s2_img[s2_img != s2_img] = -100
        label = Variable(label)
        
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
            label = label.cuda()
        
        model_out = model(s1_img, s2_img)
        loss = criterion()(model_out, label)
        epoch_loss += loss.item()
        
        #s1_model_out,  s2_model_out= model(s1_img, s2_img)
        #s1_loss = criterion()(s1_model_out, label)
        #s2_loss = criterion()(s2_model_out, label)
        #s1_weight = 0.8
        #total_loss = s1_loss * s1_weight + s2_loss * (1 - s1_weight)
        
        
        #epoch_loss += total_loss.item()
        #print("train: ", i, epoch_loss)
        i += 1
        
        optimizer.zero_grad()
        loss.backward()
        #total_loss.backward()
        optimizer.step()
        
    print("train loss: {}".format(epoch_loss / i))
    if train_loss != None:
        train_loss.append(float(epoch_loss / i))

##################################################

def validate(evalData, model, criterion, gpu=True, val_loss=[]):
    
    model.eval()
    epoch_loss = 0
    i = 0
    #set_trace()
    for s1_img, s2_img, label in evalData:
        s1_img = Variable(s1_img, requires_grad=False)
        s1_img[s1_img != s1_img] = -100
        s2_img = Variable(s2_img, requires_grad=False)
        s2_img[s2_img != s2_img] = -100
        label = Variable(label, requires_grad=False)
        
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
            label = label.cuda()
        
        model_out = model(s1_img, s2_img)
        loss = nn.CrossEntropyLoss()(model_out, label)
        epoch_loss += loss.item()
        
        #s1_model_out,  s2_model_out= model(s1_img, s2_img)
        #s1_loss = criterion()(s1_model_out, label)
        #s2_loss = criterion()(s2_model_out, label)
        #s1_weight = 0.75
        #total_loss = s1_loss.item() * s1_weight + s2_loss.item() * (1 - s1_weight)
        
        
        #epoch_loss += total_loss
        #print("val: ", i, epoch_loss)
        i += 1
    
    print("validation loss: {}".format(epoch_loss / i))
    if val_loss != None:
        val_loss.append(float(epoch_loss / i))
        

## Function Calls

In [None]:
#["shift brightness", "jitter"]

config = {
    
    "working_dir" : "C:/My_documents/CropTypeData_Rustowicz/working_folder",
    "out_dir": "try_10",
    # Dataset & Loader
    "root_dir" : "C:/My_documents/CropTypeData_Rustowicz",
    "country" : "Ghana",
    "lbl_fldrname" : "Labels",
    "sources" : ["Sentinel-1", "Sentinel-2"],
    "percnt_pixels" : 0.3,
    "sampling_strategy" : "fixed size",
    "transform" : None,
    "batch_train" : 128,
    "batch_val" : 1,
    
    # Model Compiler
    "init_params" : None,
    "gpus" : [0],
    
    "LSTM_input_dims" : (4, 16),
    "LSTM_hidden_dim" : 128,
    "n_classes": 4,
    "n_LSTM_layers" : 4,
    "LSTM_lyr_dropout_rate" : 0.2,
    "s1_weight" : 0.6,
    
    # Model fitting
    "epoch" : 150,
    "optimizer" : "nesterov",
    "momentum" : 0.95,
    "criterion" : BalancedCrossEntropyLoss,
    "lr_init" : 0.01,
    "LR_policy" : "PolynomialLR",
    
    "bucket" : None,
    "save_fldr": "testing_results",
    "prefix_out" : "C:/My_documents/CropTypeData_Rustowicz/working_folder/try_10"
}

In [None]:
train_dataset = pixelDataset(root_dir = config["root_dir"], 
                             country = config["country"], 
                             lbl_fldrname = config["lbl_fldrname"], 
                             usage = "train", 
                             sources = config["sources"], 
                             percnt_pixels = config["percnt_pixels"],
                             sampling_strategy = config["sampling_strategy"], 
                             transform = config["transform"])

In [None]:
#sampler = CropTypeBatchSampler(train_dataset, batch_size = config["batch_train"])
#train_loader = DataLoader(train_dataset, batch_size = config["batch_train"], collate_fn=collate_var_length)

sampler2 = CropTypeBatchSampler2(train_dataset, 
                                 batch_size = config["batch_train"])

train_loader = DataLoader(train_dataset, 
                          batch_sampler = sampler2, 
                          collate_fn=collate_var_length)

In [None]:
validation_dataset = pixelDataset(root_dir = config["root_dir"], 
                                  country = config["country"], 
                                  lbl_fldrname = config["lbl_fldrname"], 
                                  usage = "validation", 
                                  sources = config["sources"], 
                                  percnt_pixels = config["percnt_pixels"],
                                  sampling_strategy = config["sampling_strategy"], 
                                  transform = config["transform"])

In [None]:
validation_loader = DataLoader(validation_dataset, 
                               batch_size = config["batch_val"], 
                               shuffle = True)

In [None]:
lstm_model = Double_branch_stacked_biLSTM(input_dims = config["LSTM_input_dims"],
                                          hidden_dim = config["LSTM_hidden_dim"], 
                                          n_classes = config["n_classes"], 
                                          n_layers = config["n_LSTM_layers"], 
                                          dropout_rate = config["LSTM_lyr_dropout_rate"], 
                                          s1_weight = config["s1_weight"], 
                                          bidirectional = True, 
                                          use_layernorm = True, 
                                          use_batchnorm = False, 
                                          use_attention = False,
                                          use_cepsi=False)

In [None]:
model = ModelCompiler(model = lstm_model,
                      gpuDevices = config["gpus"], 
                      params_init = config["init_params"],
                      freeze_params = None)

In [None]:
model.fit(train_loader, 
          validation_loader, 
          config["epoch"], 
          config["optimizer"], 
          config["lr_init"],
          config["LR_policy"], 
          config["criterion"], 
          config["momentum"])

In [None]:
model.accuracy_evaluation(validation_loader, 
                          outPrefix=config["prefix_out"], 
                          bucket=config["bucket"])

In [None]:
model.save(save_fldr=config["save_fldr"], 
           bucket=config["bucket"], 
           object = "params")

## debugging

In [None]:
tensor1 = torch.randn(64, 256, 10)
tensor2 = torch.randn(64, 10, 1)

In [None]:
tensor1 = torch.randn(64, 1, 10)
tensor2 = torch.randn(64, 10, 256)

In [None]:
#tensor2 = tensor2.transpose(2,1).contiguous()
#tensor3 = torch.matmul(tensor1, tensor2)
tensor3 = torch.bmm(tensor1, tensor2)
tensor3.shape

In [None]:
percnt_pixels = 0.1
total_neg_pixels = 30
total_pos_pixels = 4096 - total_neg_pixels
num_negative_samples2 = math.ceil(total_neg_pixels * percnt_pixels * 
                                 abs((total_pos_pixels - total_neg_pixels) / (total_pos_pixels + total_neg_pixels)))
num_negative_samples = math.ceil((total_neg_pixels * percnt_pixels) * (total_pos_pixels / 4096))
num_negative_samples3 = math.ceil((total_neg_pixels * percnt_pixels) * (min(total_pos_pixels, total_neg_pixels) / max(total_pos_pixels, total_neg_pixels)))
num_negative_samples4 = math.ceil((total_neg_pixels * percnt_pixels) * 0.1)
print(total_pos_pixels)
print(num_negative_samples)
print(num_negative_samples2)
print(num_negative_samples3)
print(num_negative_samples4)

In [None]:
def Chip_stats(img_tile, percnt_pixels = 1, sampling_strategy = "natural frequency"):
    
    # fixed sampling strategy for negative samples from each image chip.
    negative_indices = np.where(img_tile == [0])
    negative_coordinates = list(zip(negative_indices[0], negative_indices[1]))
    total_neg_samples = len(negative_coordinates)
    num_negative_samples = 3 if total_neg_samples >= 50 else total_neg_samples
    neg_samples = random.sample(negative_coordinates, num_negative_samples)
    
    sampled_coordinates = []
    if (0 < percnt_pixels < 1):
        unique_vals, unique_counts = np.unique(img_tile, return_counts=True)
        smallest_category_count = min(unique_counts)

        for val, count in zip(unique_vals, unique_counts):
            
            if val != 0:
                
                if sampling_strategy == "natural frequency":
                    num_samples_per_cat = math.ceil(np.count_nonzero(img_tile == val) * percnt_pixels)
                
                elif sampling_strategy == "balanced":
                    if len(unique_vals) > 2:
                        num_samples_per_cat = smallest_category_count
                    else:
                        if count > 100:
                            num_samples_per_cat = math.ceil(count * percnt_pixels)
                        else:
                            num_samples_per_cat = count
                
                elif sampling_strategy == "fixed size":
                    num_samples_per_cat = min(count, 10)
                
                else:
                    raise ValueError("Sampling strategy is not recognized.")
                
                #print("number of samples for category {} is {}".format(val, num_samples_per_cat))
                crop_indices = np.where(img_tile == [val])
                crop_coordinates = list(zip(crop_indices[0], crop_indices[1]))
                crop_samples = random.sample(crop_coordinates, num_samples_per_cat)
                print("Number of sampled pixels of crop type {} from chip: {}".format(val, len(crop_coordinates))
                sampled_coordinates.extend(crop_samples)
        
        sampled_coordinates.extend(neg_samples)        
        return sampled_coordinates
    
    elif percnt_pixels == 1:
        crop_indices = np.where(img_tile != [0])
        crop_coordinates = list(zip(crop_indices[0], crop_indices[1]))
        print("total number of crop pixels from chip: ", len(crop_coordinates))
        crop_coordinates.extend(neg_samples)
        
        return crop_coordinates
    
    else:
        raise ValueError("'percnt_pixels' argument is out of range of (0, 1].")