In [None]:
import os
import sys
import random
from random import shuffle
import gc
import math
from pathlib import Path
import time
from datetime import datetime
from collections import Counter, defaultdict, OrderedDict
import urllib.parse as urlparse
import boto3
import shutil
import tqdm
import itertools
from datetime import datetime, timedelta

import numpy as np
from numpy import inf
import pandas as pd
from sklearn import metrics
import rasterio
import pickle
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data
from torch.utils.data import Dataset, DataLoader, Sampler
import torch.nn.utils.rnn as rnn_util
from tensorboardX import SummaryWriter

from IPython.core.debugger import set_trace

In [None]:
print("PyTorch version: {}".format(torch.__version__))
print("Cuda version : {}".format(torch.version.cuda))
print('CUDNN version:', torch.backends.cudnn.version())
print('Number of available GPU Devices:', torch.cuda.device_count())
print("current GPU Device: {}".format(torch.cuda.current_device()))

In [None]:
def make_reproducible(seed = 42, cudnn = True):
    """Make all the randomization processes start from a shared seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if cudnn:
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

make_reproducible()

## Accuracy Evaluation and Metrics

In [None]:
"""
import numpy as np
import pandas as pd
from sklearn import metrics
"""

class BinaryMetrics:
    
    '''Metrics measuring model performance.'''

    def __init__(self, refArray, scoreArray, predArray=None):
        '''
        Params:
            refArray (narray): Array of ground truth
            scoreArray (narray): Array of pixels scores of positive class
        '''

        self.observation = refArray.flatten()
        self.score = scoreArray.flatten()
        
        if self.observation.shape != self.score.shape:
            raise Exception("Inconsistent size between label and prediction arrays.")
        
        if predArray is not None:
            self.prediction = predArray.flatten()
        else:
            self.prediction = np.where(self.score > 0.5, 1, 0)

        self.confusion_matrix = self.confusion_matrix()

        
    def __add__(self, other):
        """
        Add two BinaryMetrics instances
        Params:
            other (''BinaryMetrics''): A BinaryMetrics instance
        Return:
            ''BinaryMetrics''
        """

        return BinaryMetrics(np.append(self.observation, other.observation),
                             np.append(self.score, other.score),
                            np.append(self.prediction, other.prediction))


    def __radd__(self, other):
        """
        Add a BinaryMetrics instance with reversed operands.
        Params:
            other
        Returns:
            ''BinaryMetrics
        """

        if other == 0:
            return self
        else:
            return self.__add__(other)


    def confusion_matrix(self):
        """
        Calculate confusion matrix of given ground truth and predicted label
        Returns:
            ''pandas.dataframe'' of observation on the column and prediction on the row
        """

        #set_trace()
        refArray = self.observation
        predArray = self.prediction

        if refArray.max() > 1 or predArray.max() > 1:
            raise Exception("Invalid array")
        
        predArray = predArray * 2
        sub = refArray - predArray

        self.tp = np.sum(sub == -1)
        self.fp = np.sum(sub == -2)
        self.fn = np.sum(sub == 1)
        self.tn = np.sum(sub == 0)
        
        confusionMatrix = pd.DataFrame(data = np.array([[self.tn, self.fp],[self.fn, self.tp]]),
                                       index = ['observation = 0', 'observation = 1'],
                                       columns = ['prediction = 0', 'prediction = 1'])

        return confusionMatrix


    def ir(self):
        """
        Imbalance Ratio (IR) is defined as the proportion between positive and negative instances of the label. 
        This value lies within the [0, ∞] range, having a value IR = 1 in the balanced case.
        Returns:
             float
        """
        try:
            ir = (self.tp + self.fn) / (self.fp + self.tn)
        
        except ZeroDivisionError:
            ir = np.nan_to_num(float("NaN"))

        return ir
    
    
    def oa(self):
        """
        Calculate Overal Accuracy.
        Returns:
            float
        """

        oa = metrics.accuracy_score(self.observation, self.prediction)
        
        return oa
    
    
    def producers_accuracy(self):
        """
        Calculate Producer's Accuracy (True Positive Rate |Sensitivity |hit rate | recall).
        Returns:
            float
        """

        return metrics.recall_score(self.observation, self.prediction)

    
    def users_accuracy(self):
        """
        Calculate User’s Accuracy (Positive Prediction Value (PPV) | Precision).
        Returns:
            float
        """

        ua = metrics.precision_score(self.observation, self.prediction)
        
        return ua
    
    
    def npv(self):
        """
        Calculate Negative Predictive Value or true negative accuracy.
        Returns:
             float
        """
        
        try:
            npv = self.tn / (self.tn + self.fn)
        
        except ZeroDivisionError:
            npv = np.nan_to_num(float("NaN"))
        
        return npv


    def specificity(self):
        """
        Calculate Specificity aka. True negative rate (TNR), or inverse recall.
        Returns:
             float
        """
        try:
            spc = self.tn / (self.tn + self.fp)
        
        except ZeroDivisionError:
            spc = np.nan_to_num(float("NaN"))

        return spc

      
    def f1_measure(self):
        """
        Calculate F1 score.
        Returns:
            float
        """

        f1 = metrics.f1_score(self.observation, self.prediction)

        return f1
    
    
    def iou(self):
        """
        Calculate interception over union for the positive class.
        Returns:
            float
        """

        return metrics.jaccard_score(self.observation, self.prediction)
    
    
    def miou(self):
        """
        Calculate mean interception over union considering both positive and negative classes.
        Returns:
            float
        """
        try:
            miou = np.nanmean([self.tn / (self.tn + self.fn + self.fp), self.tp / (self.tp + self.fn + self.fp)])
        
        except ZeroDivisionError:
            miou = np.nan_to_num(float("NaN"))

        return miou
    
    
    def tss(self):
        """
        Calculates true scale statistic (TSS). Also called Bookmaker Informedness (BM). 
        Scale of the metric:[-1,1].
        Returns:
            float
        """  
        tss = self.tp / (self.tp + self.fn) + self.tn / (self.tn + self.fp) - 1
        
        return tss

##################################################    
    
def accuracy_evaluation(eval_data, model, gpu, out_prefix, weights, bucket=None):
    """
    Evaluate model
    Params:
        eval_data (''DataLoader'') -- Batch grouped data
        model -- Trained model for validation
        buffer: Buffer added to the targeted grid when creating dataset. This allows metrics to calculate only
            at non-buffered region
        gpu (binary,optional): Decide whether to use GPU, default is True
        bucket (str): name of s3 bucket to save metrics
        outPrefix (str): s3 prefix to save metrics
    """
    
    model.eval()
    metrics = []
    
    for s1_img, s2_img, label in eval_data:
        s1_img = Variable(s1_img, requires_grad=False)    #shape=(B,T,C)
        s1_img[s1_img != s1_img] = 0
        s2_img = Variable(s2_img, requires_grad=False)
        s2_img[s2_img != s2_img] = 0
        label = Variable(label, requires_grad=False)      #shape=1
    
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
            label = label.cuda()
        
        #model_out = model(s1_img, s2_img) #shape=(B, Class_num)
        #model_out_prob = F.softmax(model_out, 1)
        #model_out_prob = F.softmax(out_logits, 1)
        
        s1_model_out,  s2_model_out, fused_model_out = model(s1_img, s2_img)
        out_logits = s1_model_out * weights[0] + s2_model_out * weights[1] + fused_model_out * weights[2]
        model_out_prob = F.softmax(out_logits, 1)
        
        batch, nclass = model_out_prob.size()
        
        for i in range(batch):
            label_batch = label[i].cpu().numpy()
            batch_pred = model_out_prob.max(dim=1)[1].data[i].cpu().numpy()
            
            for n in range(1, nclass):
                class_out = model_out_prob[:, n].data[i].cpu().numpy()
                class_pred = np.where(batch_pred == n, 1, 0)
                class_label = np.where(label_batch == n, 1, 0)
                pixel_metrics = BinaryMetrics(class_label, class_out, class_pred)
                
                try:
                    metrics[n - 1].append(pixel_metrics)
                except:
                    metrics.append([pixel_metrics])
    #set_trace()
    metrics = [sum(m) for m in metrics]
    
    report = pd.DataFrame({
        "Overal Accuracy" : [m.oa() for m in metrics],
        "Producer's Accuracy (recall)" : [m.producers_accuracy() for m in metrics],
        "User's Accuracy (precision)" : [m.users_accuracy() for m in metrics],
        "Negative Predictive Value" : [m.npv() for m in metrics],
        "Specificity (TNR)" : [m.specificity() for m in metrics],
        "F1 score" : [m.f1_measure() for m in metrics],
        "IoU" : [m.iou() for m in metrics],
        "mIoU" : [m.miou() for m in metrics],
        "TSS" : [m.tss() for m in metrics]
    }, index=["class_{}".format(m) for m in range(1, len(metrics) + 1)])
    
    if bucket:
        metrics_path = f"s3://{bucket}/{out_prefix}/Metrics.csv"
    else:
        metrics_path = Path(out_prefix).joinpath("Metrics.csv")
        Path(out_prefix).mkdir(parents=True, exist_ok=True)
        
    report.to_csv(metrics_path)


## Loading input data

In [None]:
################################### Helper functions for custom Dataset ######################################

def load_data(dataPath, isLabel = False):
    """Load the dataset.
    Args:
        dataPath (str) -- Path to either the image or label raster.
        isLabel (binary) -- decide wether the input dataset is label. Default is False.
    
    Returns:
        loaded data as numpy ndarray. 
    """
    
    if isLabel:
        
        with rasterio.open(dataPath, "r") as src:
            
            if src.count != 1:
                raise ValueError("Label must have only 1 band but {} bands were detected.".format(src.count))
            img = src.read(1)
    
    else:
        img = np.load(dataPath)
    
    return img

############################################################

def pickle_dataset(dataset, filePath):
    with open(filePath, "wb") as fp:
        pickle.dump(dataset, fp)
#####

def load_dataset(filePath):
    return pd.read_pickle(filePath)

############################################################

def get_test_pixel_coord(img_cube):
    
    x_ls = range(img_cube.shape[1])
    y_ls = range(img_cube.shape[2])
    index = list(itertools.product(x_ls, y_ls))
    
    return index

############################################################

class CropTypeBatchSampler(Sampler):
    """
    This sampler is designed to divide samples into batches for mini-batch training in a way that samples in each batch
    are closest in sequence length to each other which is helpful as the samples in a batch require the minimum amount of
    zero padding to become equal length.
    
    Args:
            dataset (Pytorch dataset): list of tuples in the form of [(s1_img, s2_img, label),...,(s1_img, s2_img, label)]
            batch_size (int): Number of samples in a mini-batch training strategy.
            sort_src (str) -- image dataset used for sorting.
            drop_last (bool) -- Decide whether keep or drop the last batch if its length is shorter than batch size.
    Returns:
            list of batches where each batch is a list of sample indices.
            
    Note 1: Batches are designed so that samples in a batch are closest in sequence length only for the chosen image source.
    Note 2: The last batch might be shorter that the other batch size if drop_last is False.
    Note 3: Separate padding might be required for both sources using 'collate_fn'.
    """
    
    def __init__(self, dataset, batch_size, sort_src, drop_last=False):
        super(CropTypeBatchSampler, self).__init__(dataset)
        
        assert sort_src in ["s1", "s2"]
        self.batch_size = batch_size
        self.batches = []
        batch = []
        indices_n_lengths = []
        
        for i in range(len(dataset)):
            if sort_src == "s1":
                indices_n_lengths.append((i, dataset[i][0].shape[0]))
            else:
                indices_n_lengths.append((i, dataset[i][1].shape[0]))
        
        shuffle(indices_n_lengths)
        indices_n_lengths.sort(key = lambda x:x[1])
        
        for i in range(len(indices_n_lengths)):
            sample_idx = indices_n_lengths[i][0]
            batch.append(sample_idx)
            
            if len(batch) == self.batch_size:
                self.batches.append(batch)
                batch = []
        
        if len(dataset) % self.batch_size != 0:       
            if (len(batch) > 0) and (not drop_last):
                self.batches.append(batch)
    
    def __len__(self):
        return len(self.batches)
    
    def __iter__(self):
        for b in self.batches:
            yield b

############################################################

def collate_var_length(batch):
    
    batch_size = len(batch)
    
    s1_max_len = 50
    s2_max_len = 50
    
    labels = [batch[i][2] for i in range(batch_size)]
    label = torch.stack(labels)
    
    s1_grids = [batch[i][0] for i in range(batch_size)]
    s2_grids = [batch[i][1] for i in range(batch_size)]
    
    s1_img = rnn_util.pad_sequence(s1_grids, batch_first=True)
    s2_img = rnn_util.pad_sequence(s2_grids, batch_first=True)
    
    return s1_img, s2_img, label

In [None]:
class pixelDataset(Dataset):
    """
    Args:
            root_dir (str): path to the main folder of the dataset, formatted as indicated in the readme
            usage (str): decide whether we are making a "train", "validation" or "test" dataset.
            num_samples (int) -- Number of samples for each crop type.
            sampling_strategy (str) -- If ranked samples are only taken from crop pixels with lowest number of cloudy days.
                                       Otherwise a samples can be chosen randomly from the all the avilable samples for each crop type.
            sources (list of str): Sensors of image acquisition. At the moment two sensors 
                                   are used ["Sentinel-1", "Sentinel-2"]
            inference_index (iterable) : Only gets used at prediction time as a mechanism to go through prediction tiles one at a time.
            verbose (bool): Decide to print extra information on-screen.
    """
    
    def __init__(self, root_dir, usage, num_samples=None, sampling_strategy="ranked", sources=("Sentinel-1", "Sentinel-2"), 
                 inference_index=None, verbose=False):
        
        self.usage = usage
        self.sources = sources
        self.num_samples = num_samples
        self.sampling_strategy = sampling_strategy
        
        assert self.usage in ["train", "validation", "test"], "Usage can only be one of 'train', 'validation' and 'test'."
        assert self.sampling_strategy in ["ranked", "random"], "Sampling strategy is invalid."
        
        if self.usage in ["train", "validation"]:
            
            assert num_samples is not None
            
            s1_dir = Path(root_dir).joinpath("Ghana", self.sources[0], self.usage, "categories")
            s2_dir = Path(root_dir).joinpath("Ghana", self.sources[1], self.usage, "categories")
            categories = [name for name in os.listdir(s1_dir) if os.path.isdir(os.path.join(s1_dir, name))]
            
            s1_samples_ls = []
            s2_samples_ls = []
            
            for cat in categories:
                s1_src_path = Path(s1_dir).joinpath(cat)
                s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_src_path) for \
                             f in filenames if f.endswith(".npy")]
                s1_fnames.sort()
                
                s2_src_path = Path(s2_dir).joinpath(cat)
                s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_src_path) for \
                             f in filenames if f.endswith(".npy")]
                s2_fnames.sort()
                assert len(s1_fnames) == len(s2_fnames)
                
                if len(s1_fnames) < self.num_samples:
                    s1_samples_ls.extend(s1_fnames)
                    s2_samples_ls.extend(s2_fnames)
                    print(f"===> Only {len(s1_fnames)} samples are available for {cat} class. All taken.")
                
                else:
                    if sampling_strategy == "ranked":
                    
                        grid_numbers = [str(f).split("_")[-5] for f in s1_fnames]
                        num_unique_tiles = len(sorted(set(grid_numbers)))
                        min_num_samp_from_each_id = self.num_samples // num_unique_tiles
                        
                        if verbose:
                            print(f"Category: {cat}, Total number of tiles: {num_unique_tiles} Num samples per ite: {min_num_samp_from_each_id}")
                            print("#####")
                        i = 1
                        counter = 0
                        while counter < self.num_samples:
                            for grid in sorted(set(grid_numbers)):
                                s1_samples_in_grid = [str(f) for f in s1_fnames if "_" + grid + "_" in str(f)]
                                s2_samples_in_grid = [str(f) for f in s2_fnames if "_" + grid + "_" in str(f)]
                                assert len(s1_samples_in_grid) == len(s2_samples_in_grid)
                                if len(s1_samples_in_grid) > (i+min_num_samp_from_each_id):
                                    diff = abs(self.num_samples - counter)
                                    if diff >= min_num_samp_from_each_id:
                                        s1_samples = [fn for fn in s1_samples_in_grid if int(str(fn).split("_")[-3]) in range(i, i+min_num_samp_from_each_id)]
                                        s2_samples = [fn for fn in s2_samples_in_grid if int(str(fn).split("_")[-3]) in range(i, i+min_num_samp_from_each_id)]
                                    else:
                                        s1_samples = [fn for fn in s1_samples_in_grid if int(str(fn).split("_")[-3]) in range(i, i+diff)]
                                        s2_samples = [fn for fn in s2_samples_in_grid if int(str(fn).split("_")[-3]) in range(i, i+diff)]
                                else:
                                    else_diff = len(s1_samples_in_grid) - i
                                    s1_samples = [fn for fn in s1_samples_in_grid if int(str(fn).split("_")[-3]) in range(i, i+else_diff)]
                                    s2_samples = [fn for fn in s2_samples_in_grid if int(str(fn).split("_")[-3]) in range(i, i+else_diff)]
                                    
                                if verbose:
                                    print(f"grid: {grid}, counter: {counter}, i: {i}")
                                    print(f"S1 samples: {s1_samples}")
                                    print("")
                                    print(f"S2 samples: {s2_samples}")
                                    print("-----")
                                s1_samples_ls.extend(s1_samples)
                                s2_samples_ls.extend(s2_samples)
                                counter+=len(s1_samples)
                                if counter >= self.num_samples:
                                    break
                        i+=min_num_samp_from_each_id
  
                    else:
                        random_indices = random.sample(range(len(s1_fnames)), self.num_samples)
                        for idx in random_indices:
                            s1_samples_ls.append(s1_fnames[idx])
                            s2_samples_ls.append(s2_fnames[idx])          
            
            self.lbl = []
            self.s1 = []
            self.s2 = []
            
            assert len(s1_samples_ls) == len(s2_samples_ls)
            
            for s1_fn, s2_fn in tqdm.tqdm(zip(s1_samples_ls, s2_samples_ls), total = len(s1_samples_ls)):
                
                s1_lbl = str(s1_fn).split("_")[-1].replace(".npy", "")
                s2_lbl = str(s2_fn).split("_")[-1].replace(".npy", "")
                assert s1_lbl == s2_lbl
                
                lbl_val = int(s1_lbl)
                self.lbl.append(lbl_val)
                
                s1_array = np.load(s1_fn)
                self.s1.append(s1_array)
                
                s2_array = np.load(s2_fn)
                self.s2.append(s2_array)
            
            print("------{} tuple samples of form (s1, s2, lbl) are loaded from the {} dataset------".format(len(self.s1), self.usage))
        
        if self.usage == "test":
            
            self.s1 = []
            self.s2 = []
            self.img_coor = []

            s1_dir = Path(root_dir).joinpath("prediction_tiles", self.sources[0])
            s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_dir) for f in filenames if f.endswith(".npy")]
            s1_meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s1_dir) for f in filenames if f.endswith(".pickle")]
            s1_fnames.sort()
            s1_meta_fnames.sort()

            s2_dir = Path(root_dir).joinpath("prediction_tiles", self.sources[1])
            s2_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_dir) for f in filenames if f.endswith(".npy") if "source" in f]
            s2_meta_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(s2_dir) for f in filenames if f.endswith(".pickle")]
            s2_fnames.sort()
            s2_meta_fnames.sort()

            s1_grid_id = str(s1_fnames[inference_index]).split("_")[-1].replace(".npy", "")
            s2_grid_id = str(s2_fnames[inference_index]).split("_")[-1].replace(".npy", "")
            s1_grid_meta_id = str(s1_meta_fnames[inference_index]).split("_")[-2]
            s2_grid_meta_id = str(s2_meta_fnames[inference_index]).split("_")[-2]
            assert s1_grid_id == s2_grid_id == s1_grid_meta_id == s2_grid_meta_id
            
            self.tile_id = s1_grid_id
            self.meta = pd.read_pickle(s1_meta_fnames[0])
            
            s1_array = np.load(s1_fnames[inference_index])
            s1_array = s1_array * 1e-7
            
            s2_array = np.load(s2_fnames[inference_index])
            s2_array = s2_array * 1e-7
            
            assert s1_array.shape[1] == s2_array.shape[1]
            assert s1_array.shape[2] == s2_array.shape[2]
            
            pixel_indices = get_test_pixel_coord(s1_array)
            
            for coord in pixel_indices:
                s1_val = s1_array[:,coord[0], coord[1],:]
                self.s1.append(s1_val.copy())
                
                s2_val = s2_array[:,coord[0], coord[1],:]
                self.s2.append(s2_val.copy())
                
                self.img_coor.append(coord)
            
            del s1_array, s2_array
            gc.collect()
            
    def __getitem__(self, index):
        
        if self.usage in ["train", "validation"]:
            s1_img = self.s1[index]
            s2_img = self.s2[index]
            label = self.lbl[index]
            
            # numpy to torch
            # tensor shape: (N x C x T)
            s1_img = torch.from_numpy(s1_img.transpose((1, 0))).float()
            s2_img = torch.from_numpy(s2_img.transpose((1, 0))).float()
            label = torch.from_numpy(np.asarray(label)).long()
                
            return s1_img, s2_img, label
        
        if self.usage == "test":
            s1_img = self.s1[index]
            s2_img = self.s2[index]
            coord = self.img_coor[index]
            
            s1_img = torch.from_numpy(s1_img.transpose((1, 0))).float()
            s2_img = torch.from_numpy(s2_img.transpose((1, 0))).float()
            
            return s1_img, s2_img, coord
        
    def __len__(self):
        return len(self.s1)

## Custom Loss function

In [None]:
"""
import torch
from torch import nn
"""

class BalancedCrossEntropyLoss(nn.Module):
    '''
    Balanced cross entropy loss by weighting of inverse class ratio
    Params:
        ignore_index (int): Class index to ignore
        reduction (str): Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    '''

    def __init__(self, ignore_index=-100, reduction='mean'):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predict, target):
        #set_trace()
        # get class weights
        unique, unique_counts = torch.unique(target, return_counts=True)
        # calculate weight for only valid indices
        unique_counts = unique_counts[unique != self.ignore_index]
        unique = unique[unique != self.ignore_index]
        ratio = unique_counts.float() / torch.numel(target)
        weight = (1. / ratio) / torch.sum(1. / ratio)

        lossWeight = torch.ones(predict.shape[1]).cuda() * 0.00001
        for i in range(len(unique)):
            lossWeight[unique[i]] = weight[i]
        loss = nn.CrossEntropyLoss(weight=lossWeight, ignore_index=self.ignore_index, reduction=self.reduction)

        return loss(predict, target)


## Model architecture

### LSTM 

In [None]:
"""
import torch
import torch.nn as nn
import torch.utils.data
"""

# Dot-product attention between Bi-LSTM last states and its output.
class attention(nn.Module):
    def __init__(self, attn_dropout=0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v):
        #set_trace()
        query = q.unsqueeze(1)
        
        key = k.transpose(2,1).contiguous()
        weight_score = torch.bmm(query, key)
        
        attn = self.softmax(weight_score)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        
        return output, attn

###################################################################################################

def get_inp_branch(x):
    if isinstance(x, tuple) or isinstance(x, list):
        br1 = x[0]
        br2 = x[1]
    elif isinstance(x, int) or isinstance(x, float):
        br1 = x
        br2 = x
           
    return br1, br2
    

###########################################################################################

class Double_branch_stacked_biLSTM(torch.nn.Module):
    def __init__(self, input_dims = (4, 11), hidden_dims = (64, 64), n_classes = 4, n_layers = (2, 2), 
                 dropout_rate = (0.35, 0.45), s1_weight = 0.6, bidirectional = True, use_layernorm = True, 
                 use_batchnorm = False, use_attention = False):
        super(Double_branch_stacked_biLSTM, self).__init__()
        
        # Define object properties
        self.n_classes = n_classes
        self.s1_weight = s1_weight
        self.bidirectional = bidirectional
        self.use_layernorm = use_layernorm
        self.use_batchnorm = use_batchnorm
        self.use_attention = use_attention

        
        s1_in_dim, s2_in_dim = get_inp_branch(input_dims)
        s1_hidden_dim, s2_hidden_dim = get_inp_branch(hidden_dims)
        s1_n_layers, s2_n_layers = get_inp_branch(n_layers)
        s1_dropout_rate, s2_dropout_rate = get_inp_branch(dropout_rate)
        
        # Layer normalization for s1, s2 inputs and current_states of LSTM
        if self.use_layernorm:
            self.s1_inlayernorm = nn.LayerNorm(s1_in_dim)
            self.s1_clayernorm = nn.LayerNorm((s1_hidden_dim + s1_hidden_dim * self.bidirectional) * s1_n_layers)
            
            self.s2_inlayernorm = nn.LayerNorm(s2_in_dim)
            self.s2_clayernorm = nn.LayerNorm((s2_hidden_dim + s2_hidden_dim * self.bidirectional) * s2_n_layers)
        
        # LSTM layers for s1 and s2
        self.s1_lstm = nn.LSTM(input_size = s1_in_dim, hidden_size = s1_hidden_dim, 
                               num_layers = s1_n_layers, bias = False, batch_first = True, dropout = s1_dropout_rate, 
                               bidirectional = self.bidirectional)
        
        self.s2_lstm = nn.LSTM(input_size = s2_in_dim, hidden_size = s2_hidden_dim, 
                               num_layers = s2_n_layers, bias = False, batch_first = True, dropout = s2_dropout_rate, 
                               bidirectional = self.bidirectional)
        
        if self.bidirectional:
            s1_hidden_dim = s1_hidden_dim * 2
            s2_hidden_dim = s2_hidden_dim * 2
        
        if self.use_attention:
            self.attention = attention()
        
        # MLP layer on top of LSTM
        s1_linear_input_dim = s1_hidden_dim if self.use_attention else s1_hidden_dim * s1_n_layers
        self.s1_linear_class = nn.Linear(s1_linear_input_dim, self.n_classes, bias = True)
        
        s2_linear_input_dim = s2_hidden_dim if self.use_attention else s2_hidden_dim * s2_n_layers
        self.s2_linear_class = nn.Linear(s2_linear_input_dim, self.n_classes, bias = True)
 

    def _logits(self, s1, s2):
        #set_trace()
        if self.use_layernorm:
            s1 = self.s1_inlayernorm(s1)
            s2 = self.s2_inlayernorm(s2)
        
        # Get outputs and the last current state and hidden state for each branch.
        #s1_outputs & s2_outputs: [B, Seq_length, 2 x hidden_dim]
        s1_outputs, s1_last_state_list = self.s1_lstm.forward(s1)
        s2_outputs, s2_last_state_list = self.s2_lstm.forward(s2)
        
        #s1_h & s1_c & s2_h & s2_c: [2 x num_layers, B, hidden_dim] 
        s1_h, s1_c = s1_last_state_list
        s2_h, s2_c = s2_last_state_list
        
        # Get the query layer to calculate self attention for each branch
        if self.use_attention:
            if self.bidirectional:
                # Get the last state of each branch. size:[B, hidden_dim]
                s1_query_forward = s1_c[-1]
                s1_query_backward = s1_c[-2]
                # size:[B, 2 x hidden_dim]
                s1_query = torch.cat([s1_query_forward, s1_query_backward], 1)
                
                s2_query_forward = s2_c[-1]
                s2_query_backward = s2_c[-2]
                s2_query = torch.cat([s2_query_forward, s2_query_backward], 1)
            else:
                s1_query = s1_c[-1]
                s2_query = s2_c[-1]
            
            # Get attention weights and hidden state
            s1_h, s1_weights = self.attention(s1_query, s1_outputs, s1_outputs)
            s2_h, s2_weights = self.attention(s2_query, s2_outputs, s2_outputs)
            s1_h = s1_h.squeeze(1)
            s2_h = s2_h.squeeze(1)
        else:
            s1_nlayers, s1_batchsize, s1_n_hidden = s1_c.shape
            s2_nlayers, s2_batchsize, s2_n_hidden = s2_c.shape
            s1_h = self.s1_clayernorm(s1_c.transpose(0,1).contiguous().view(s1_batchsize, s1_nlayers * s1_n_hidden))
            s2_h = self.s2_clayernorm(s2_c.transpose(0,1).contiguous().view(s2_batchsize, s2_nlayers * s2_n_hidden))
        
        # Calculate logits for each branch. Shape:[B, num_classes]
        s1_logits = self.s1_linear_class.forward(s1_h)
        s2_logits = self.s2_linear_class.forward(s2_h)
        
        if self.use_attention:
            s1_pts = s1_weights
            s2_pts = s2_weights
        else:
            s1_pts = None
            s2_pts = None
        
        return s1_logits, s2_logits, s1_pts, s2_pts
    
    def forward(self, s1, s2):
        
        s1_logits, s2_logits, s1_pts, s2_pts = self._logits(s1, s2)
        fused_logits = (s1_logits * self.s1_weight) + (s2_logits * (1 - self.s1_weight))
        
        return s1_logits, s2_logits, fused_logits

## Training and Inference procedure

In [None]:
"""
from datetime import datetime
from tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
"""

def get_optimizer(optimizer, params, lr, momentum):

    optimizer = optimizer.lower()
    if optimizer == 'sgd':
        return torch.optim.SGD(params, lr, momentum=momentum)
    elif optimizer == 'nesterov':
        return torch.optim.SGD(params, lr, momentum=momentum, nesterov=True)
    elif optimizer == 'adam':
        return torch.optim.Adam(params, lr)
    elif optimizer == 'amsgrad':
        return torch.optim.Adam(params, lr, amsgrad=True)
    else:
        raise ValueError("{} currently not supported, please customize your optimizer in compiler.py".format(optimizer))

###############################################################################################

class PolynomialLR(_LRScheduler):
    """Polynomial learning rate decay until step reach to max_decay_step
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_decay_steps: after this step, we stop decreasing learning rate
        min_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value
        power: The power of the polynomial.
    """
    
    def __init__(self, optimizer, max_decay_steps, min_learning_rate=1e-5, power=1.0):
        if max_decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')
        self.max_decay_steps = max_decay_steps
        self.min_learning_rate = min_learning_rate
        self.power = power
        self.last_step = 0
        super().__init__(optimizer)
        
    def get_lr(self):
        if self.last_step > self.max_decay_steps:
            return [self.min_learning_rate for _ in self.base_lrs]

        return [(base_lr - self.min_learning_rate) * 
                ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                self.min_learning_rate for base_lr in self.base_lrs]
    
    def step(self, step=None):
        if step is None:
            step = self.last_step + 1
        self.last_step = step if step != 0 else 1
        if self.last_step <= self.max_decay_steps:
            decay_lrs = [(base_lr - self.min_learning_rate) * 
                         ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 
                         self.min_learning_rate for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
                param_group['lr'] = lr

##########################################################################

class ModelCompiler:
    '''
    Compiler of specified model
    Args:
        model (''nn.Module'') -- pytorch model for segmentation.
        working_dir (sys.path or str) -- path to the working directory.
        out_dir (sys.path or str) -- Path to the directory to store output prediction and associated files.
        gpuDevices (tuple) -- indices of gpu devices to use.
        br_weights (tuple) -- weights to decide the influence of each triple branchs in the LSTM model (e.g. s1, s2, fused). 
        params_init (sys.path or str) -- Path to the saved model parameters to load.
        freeze_params (list of int) -- list of indices of the trainable layers in the network to freeze the gradients.
                                       Useful in finetunning the model.
    '''

    def __init__(self, model, working_dir, out_dir, gpuDevices=(0), 
                 br_weights=(0.3, 0.3, 0.4), params_init=None, freeze_params=None):

        self.s3_client = boto3.client("s3")
        self.working_dir = working_dir
        self.out_dir = out_dir
        self.gpuDevices = gpuDevices
        self.br_weights = br_weights
        self.model = model
        
        self.model_name = self.model.__class__.__name__

        if params_init:
            self.load_params(params_init, freeze_params)

        # gpu
        self.gpu = torch.cuda.is_available()
        if self.gpu:
            print("----------GPU available----------")
            # GPU setting
            if gpuDevices:
                torch.cuda.set_device(gpuDevices[0])
                self.model = torch.nn.DataParallel(self.model, device_ids=gpuDevices)
            self.model = self.model.cuda()
        
        num_params = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        print("total number of trainable parameters: {:2.1f}M".format(num_params / 1000000))
        
        if params_init:
            print("---------- Pre-trained model compiled successfully ----------")
        else:
            print("---------- Vanilla Model compiled successfully ----------")


    def load_params(self, dir_params, freeze_params):

        params_init = urlparse.urlparse(dir_params)
        # load from s3
        if params_init.scheme == "s3":
            
            bucket = params_init.netloc
            params_key = params_init.path
            params_key = params_key[1:] if params_key.startswith('/') else params_key
            _, fn_params = os.path.split(params_key)

            self.s3_client.download_file(Bucket=bucket,
                                         Key=params_key,
                                         Filename=fn_params)
            inparams = torch.load(fn_params, map_location="cuda:{}".format(self.gpuDevices[0]))

            os.remove(fn_params)  # remove after loaded

        ## or load from local
        else:
            inparams = torch.load(dir_params)

        ## overwrite model entries with new parameters
        model_dict = self.model.state_dict()

        if "module" in list(inparams.keys())[0]:
            inparams_filter = {k[7:]: v.cpu() for k, v in inparams.items() if k[7:] in model_dict}

        else:
            inparams_filter = {k: v.cpu() for k, v in inparams.items() if k in model_dict}
        
        model_dict.update(inparams_filter)
        self.model.load_state_dict(model_dict)
        
        if freeze_params != None:
            for i, p in enumerate(self.model.parameters()):
                if i in freeze_params:
                    p.requires_grad = False


    def fit(self, trainDataset, valDataset, epochs, optimizer_name, lr_init, LR_policy, criterion, momentum=None):
        

        # Set the folder to save results.
        working_dir = self.working_dir
        out_dir = self.out_dir
        model_name = self.model_name
        self.model_dir = "{}/{}/{}_ep{}".format(working_dir, self.out_dir, model_name, epochs)
        
        if not os.path.exists(Path(working_dir) / out_dir / self.model_dir):
            os.makedirs(Path(working_dir) / out_dir / self.model_dir)
        
        os.chdir(Path(working_dir) / out_dir / self.model_dir)
        
        print("--------------- Start training ---------------")
        start = datetime.now()

        # Tensorboard writer setting
        writer = SummaryWriter('./')

        train_loss = []
        val_loss = []
        lr = lr_init
        
        optimizer = get_optimizer(optimizer_name, self.model.parameters(), lr, momentum)
        
        # Initialize the learning rate scheduler
        if LR_policy == "StepLR":
            scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                                  step_size = 10, 
                                                  gamma = 0.85,)
        elif LR_policy == "Exponential":
            scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                         gamma = 0.85,)
        
        elif LR_policy == "PolynomialLR":
            scheduler = PolynomialLR(optimizer, 
                                     max_decay_steps=100, 
                                     min_learning_rate=1e-5, 
                                     power=0.9)
        else:
            scheduler = None  
        
        if isinstance(criterion, tuple) or isinstance(criterion, list):
            train_criterion = criterion[0]
            val_criterion = criterion[1]
        else:
            train_criterion = criterion
            val_criterion = criterion
        
        for t in range(epochs):

            print("[{}/{}]".format(t + 1, epochs))
            # start fitting
            start_epoch = datetime.now()
            train(trainDataset, self.model, train_criterion, optimizer, self.br_weights, gpu=self.gpu, train_loss=train_loss)
            validate(valDataset, self.model, val_criterion, self.br_weights, gpu=self.gpu, val_loss=val_loss)

            # Update the scheduler
            if LR_policy in ["StepLR", "Exponential"]:
                scheduler.step()
                print("LR: {}".format(scheduler.get_last_lr()))

            if LR_policy == "PolynomialLR":
                scheduler.step(t)
                print("LR: {}".format(optimizer.param_groups[0]['lr']))
            
            # time spent on single iteration
            print("time:", (datetime.now() - start_epoch).seconds)

            #if t > 1 and t % lr_decay[1] == 0:
                #lr *= lr_decay[0]

            writer.add_scalars("Loss", {"train_loss": train_loss[t], "validation_loss": val_loss[t]}, t + 1)
            
            writer.close()
        
        print("--------------- Training finished in {}s ---------------".format((datetime.now() - start).seconds))
    
    def accuracy_evaluation(self, evalDataset, outPrefix, bucket=None):
        
        if not os.path.exists(Path(self.working_dir) / self.out_dir):
            os.makedirs(Path(self.working_dir) / self.out_dir)
        
        os.chdir(Path(self.working_dir) / self.out_dir)
        
        print("--------------- Start evaluation ---------------")
        start = datetime.now()
        
        accuracy_evaluation(evalDataset, self.model, self.gpu, outPrefix, self.br_weights, bucket)
        
        print("--------------- Evaluation finished in {}s ---------------".format((datetime.now() - start).seconds))
        
    def inference(self, predDataset, out_prefix=None):
        
        print("-------------------------- Start Inference(Test) --------------------------")
        
        start = datetime.now()
        if out_prefix is None:
            out_prefix = Path(self.working_dir) / self.out_dir / "Inference_output"
        
        prefix_hard = Path(out_prefix) / "HardScore"
        prefix_soft = Path(out_prefix) / "SoftProb"
        
        if not os.path.exists(prefix_hard):
            os.makedirs(prefix_hard)
        if not os.path.exists(prefix_soft):
            os.makedirs(prefix_soft)
        
        os.chdir(Path(out_prefix))
        
        inference(predDataset, self.model, prefix_soft, prefix_hard, gpu=self.gpu, weights=self.br_weights)
        
        duration_in_sec = (datetime.now() - start).seconds
        duration_format = str(timedelta(seconds = duration_in_sec))
        print("-------------------------- Inference finished in {}s --------------------------".format(duration_format))
    
    def save(self, save_fldr, bucket=None, save_object="params"):
        
        outPrefix = Path(self.working_dir) / self.out_dir / save_fldr
        
        if save_object == "params":
            
            fn_params = "{}_params.pth".format(self.model_name)
            
            if bucket:
                torch.save(self.model.state_dict(), fn_params )

                self.s3_client.upload_file(Filename=fn_params, 
                                           Bucket=bucket, 
                                           Key=os.path.join(outPrefix, fn_params))
                print("model parameters uploaded to s3!, at ", outPrefix)
                
                os.remove(Path(outPrefix) / fn_params)
                
            else:
                
                if not os.path.exists(Path(outPrefix)):
                    os.makedirs(Path(outPrefix))
                
                torch.save(self.model.state_dict(), Path(outPrefix) / fn_params)
                print("model parameters is saved locally, at ", outPrefix)
            
        elif save_object == "model":
            
            fn_model = "{}.pth".format(self.model_name)
            
            if bucket:
                torch.save(self.model, fn_model)

                self.s3_client.upload_file(Filename=fn_model,
                                           Bucket=bucket, 
                                           Key=os.path.join(outPrefix, fn_model))
                print("model uploaded to s3!, at ", outPrefix)
                
                os.remove(Path(outPrefix) / fn_model)
            
            else:
                
                if not os.path.exists(Path(outPrefix)):
                    os.makedirs(Path(outPrefix))
                
                torch.save(self.model, Path(outPrefix) / fn_model)
                print("model saved locally, at ", outPrefix)

        else:
            raise ValueError("Object type is not acceptable.")

################################################################################################################
################################### Train, Evaluate, Validate and Predict ######################################
################################################################################################################

def train(trainData, model, criterion, optimizer, weights, gpu=True, train_loss=[]):
    
    model.train()
    epoch_loss = 0
    i = 0
    
    for s1_img, s2_img, label in trainData:
        s1_img = Variable(s1_img)
        s2_img = Variable(s2_img)
        label = Variable(label)
        
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
            label = label.cuda()
        
        #model_out = model(s1_img, s2_img)
        #loss = criterion()(model_out, label)
        #epoch_loss += loss.item()
        
        #s1_model_out,  s2_model_out= model(s1_img, s2_img)
        #s1_loss = criterion()(s1_model_out, label)
        #s2_loss = criterion()(s2_model_out, label)
        #s1_weight = 0.5
        #total_loss = s1_loss * s1_weight + s2_loss * (1 - s1_weight)
        #epoch_loss += total_loss.item()
        
        s1_model_out,  s2_model_out, fused_model_out = model(s1_img, s2_img)
        s1_loss = criterion()(s1_model_out, label)
        s2_loss = criterion()(s2_model_out, label)
        fused_loss = criterion()(fused_model_out, label)
        total_loss = (s1_loss * weights[0] + s2_loss * weights[1] + fused_loss * weights[2])
        epoch_loss += total_loss.item()
        
        #print("train: ", i, epoch_loss)
        i += 1
        
        optimizer.zero_grad()
        #loss.backward()
        total_loss.backward()
        optimizer.step()
        
    print("train loss: {}".format(epoch_loss / i))
    if train_loss != None:
        train_loss.append(float(epoch_loss / i))

##################################################

def validate(evalData, model, criterion, weights, gpu=True, val_loss=[]):
    
    model.eval()
    epoch_loss = 0
    i = 0
    #set_trace()
    for s1_img, s2_img, label in evalData:
        s1_img = Variable(s1_img, requires_grad=False)
        #s1_img[s1_img != s1_img] = -100
        s2_img = Variable(s2_img, requires_grad=False)
        #s2_img[s2_img != s2_img] = -100
        label = Variable(label, requires_grad=False)
        
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
            label = label.cuda()
        
        #model_out = model(s1_img, s2_img)
        #loss = nn.CrossEntropyLoss()(model_out, label)
        #epoch_loss += loss.item()
        
        #s1_model_out,  s2_model_out= model(s1_img, s2_img)
        #s1_loss = criterion()(s1_model_out, label)
        #s2_loss = criterion()(s2_model_out, label)
        #s1_weight = 0.5
        #total_loss = s1_loss.item() * s1_weight + s2_loss.item() * (1 - s1_weight)
        #epoch_loss += total_loss
        
        s1_model_out,  s2_model_out, fused_model_out = model(s1_img, s2_img)
        s1_loss = criterion(ignore_index = 0)(s1_model_out, label)
        s2_loss = criterion(ignore_index = 0)(s2_model_out, label)
        fused_loss = criterion(ignore_index = 0)(fused_model_out, label)
        total_loss = (s1_loss * weights[0] + s2_loss * weights[1] + fused_loss * weights[2])
        epoch_loss += total_loss.item()
        
        #print("val: ", i, epoch_loss)
        i += 1
    
    print("validation loss: {}".format(epoch_loss / i))
    if val_loss != None:
        val_loss.append(float(epoch_loss / i))

##################################################

def inference(testData, model, score_path, pred_path, gpu, weights):

    testData, meta, tile_id = testData
    
    meta_hard = meta.copy()
    
    meta_hard.update({'dtype': 'uint8',
                      'nodata': None,
                      'count': 1,
                     })
    
    meta_soft = meta_hard.copy()
    meta_soft.update({
        "dtype": "float32"
    })
    
    name_prob = "prob_{}.tif".format(tile_id)
    name_crisp = "crisp_{}.tif".format(tile_id)
    
    model.eval()
    
    h_canvas = np.zeros((1, meta_hard['height'], meta_hard['width']), dtype = meta_hard["dtype"])
    canvas_softScore_ls = []
    metrics = []
    
    for s1_img, s2_img, coor in testData:
        s1_img = Variable(s1_img, requires_grad=False)
        s2_img = Variable(s2_img, requires_grad=False)
        
        if gpu:
            s1_img = s1_img.cuda()
            s2_img = s2_img.cuda()
        
        s1_model_out, s2_model_out, fused_model_out = model(s1_img, s2_img)
        pred_logits = s1_model_out * weights[0] + s2_model_out * weights[1] + fused_model_out * weights[2] 
        pred_prob = F.softmax(pred_logits, 1)
        
        batch, nclass = pred_prob.size()
        
        for i in range(batch):
            index = (int(coor[0][i]), int(coor[1][i]))
            out_predict = pred_prob.max(dim=1)[1].cpu().numpy()
            out_predict = np.expand_dims(out_predict, axis=0).astype(np.int8)
            h_canvas[:, index[0], index[1]] = out_predict
            
            for n in range(nclass - 1):
                out_softScore = pred_prob[:, n+1].data[i].cpu().numpy() * 100
                out_softScore = np.expand_dims(out_softScore, axis=0).astype(np.float32)
                try:
                    canvas_softScore_ls[n][:, index[0], index[1]] = out_softScore
                except:
                    canvas_softScore_single = np.zeros((1, meta_soft['height'], meta_soft['width']), dtype= meta_soft["dtype"])
                    canvas_softScore_single[:, index[0], index[1]] = out_softScore
                    canvas_softScore_ls.append(canvas_softScore_single)
    
    with rasterio.open(Path(pred_path) / name_crisp, "w", **meta_hard) as dst:
        dst.write(h_canvas)
    
    for n in range(1, len(canvas_softScore_ls)):
        with rasterio.open(Path(score_path) / name_prob, "w", **meta_soft) as dst:
            dst.write(canvas_softScore_ls[n])
            

## Function Calls

In [None]:
config = {
    
    "working_dir" : "C:/My_documents/CropTypeData_Rustowicz/working_folder",
    "out_dir": "new_try_01",
    # Dataset & Loader
    "root_dir" : "C:/My_documents/CropTypeData_Rustowicz/CropType",
    "sampling_strategy" : "ranked",
    "lbl_fldrname" : "Labels",
    "sources" : ["Sentinel-1", "Sentinel-2"],
    "num_train_pixels" : 15000,
    "num_validation_pixels" : 10000,
    "test_label" : False,
    "batch_train" : 128,
    "batch_val" : 1,
    
    # Model Compiler
    "init_params" : None,
    "gpus" : [0],
    "input_dims" : (4, 11),
    "LSTM_hidden_dim" : (48, 64),
    "CNN_hidden_dim" : (48, 64),
    "CNN_kernel_size" : (5, 5),
    "CNN_sequence_length" : (57, 67),
    "n_classes": 4,
    "n_LSTM_layers" : (2, 4),
    "LSTM_lyr_dropout_rate" : (0.4, 0.5),
    "CNN_lyr_dropout_rate" : (0.25, 0.45),
    "s1_weight" : 0.5,
    
    # Model fitting
    "epoch" : 75,
    "optimizer" : "amsgrad",
    "momentum" : 0.95,
    "criterion" : nn.CrossEntropyLoss,
    "branch_weights" : (0.3, 0,3, 0.4),
    "lr_init" : 0.01,
    "LR_policy" : "",
    
    "bucket" : None,
    "save_fldr": "model_path",
    "prefix_out" : None
}

### For training and validation

In [None]:
# Making training dataset from .npy files
train_dataset = pixelDataset(root_dir=config["root_dir"],
                             usage="train",
                             num_samples=config["num_train_pixels"],
                             sampling_strategy=config["sampling_strategy"],  
                             sources=config["sources"],
                             verbose=False
                             )

#Make a Pickle from the training dataset
#filePath = Path(config["working_dir"]) / config["out_dir"] / "train_dataset.pickle"
#pickle_dataset(train_dataset, filePath)

In [None]:
# Load the training dataset from pickle file
filePath = Path(config["working_dir"]) / config["out_dir"] / "train_dataset.pickle"
train_dataset = load_dataset(filePath)

In [None]:
# Batching and loading the training dataset on GPU
sampler = CropTypeBatchSampler(train_dataset, batch_size=config["batch_train"], sort_src="s1", drop_last=False)
train_loader = DataLoader(train_dataset, batch_sampler=sampler, collate_fn=collate_var_length)

# Loading without dedicated sampler
#train_loader = DataLoader(train_dataset, batch_size = config["batch_train"], shuffle=True, collate_fn=collate_var_length)

In [None]:
# Sanity check to check how the data looks like.
"""
for s1, s2, lbl in DataLoader(train_dataset, batch_sampler=sampler, collate_fn=collate_var_length):
    print(s1.shape)
    print(s1[1,:,:])
    print("---")
    print(s2.shape)
    print(s2[1,:,:])
    print("---")
    print(lbl)
"""

In [None]:
# Make validation dataset from .npy file
validation_dataset = pixelDataset(root_dir = config["root_dir"],
                                  usage = "validation",
                                  num_samples = config["num_validation_pixels"],
                                  sampling_strategy=config["sampling_strategy"], 
                                  sources = config["sources"],
                                  verbose=False
                                  )

# Make a Pickle from the validation dataset
#filePath = Path(config["working_dir"]) / config["out_dir"] / "validation_dataset.pickle"
#pickle_dataset(train_dataset, filePath)

In [None]:
# Load the validation dataset from pickle file
filePath = Path(config["working_dir"]) / config["out_dir"] / "validation_dataset.pickle"
validation_dataset = load_dataset(filePath)

In [None]:
# Batching and loading the validation dataset on GPU
validation_loader = DataLoader(validation_dataset, 
                               batch_size = config["batch_val"], 
                               shuffle = True,
                              collate_fn= collate_var_length)

In [None]:
# Initialize the model
lstm_model = Double_branch_stacked_biLSTM(input_dims = config["input_dims"],
                                          hidden_dims = config["LSTM_hidden_dim"], 
                                          n_classes = config["n_classes"], 
                                          n_layers = config["n_LSTM_layers"], 
                                          dropout_rate = config["LSTM_lyr_dropout_rate"], 
                                          s1_weight = config["s1_weight"], 
                                          bidirectional = True, 
                                          use_layernorm = True, 
                                          use_batchnorm = False, 
                                          use_attention = False)

In [None]:
# Compile the model
model = ModelCompiler(model=lstm_model,
                      working_dir=config["working_dir"], 
                      out_dir=config["out_dir"],
                      gpuDevices=config["gpus"],
                      br_weights = config["branch_weights"],
                      params_init=config["init_params"],
                      freeze_params=None)

In [None]:
model.fit(train_loader, 
          validation_loader, 
          config["epoch"], 
          config["optimizer"], 
          config["lr_init"],
          config["LR_policy"], 
          config["criterion"],
          config["momentum"])

In [None]:
model.accuracy_evaluation(validation_loader, 
                          outPrefix=config["prefix_out"],
                          weights = config["branch_weights"],
                          bucket=config["bucket"])

In [None]:
model.save(save_fldr=config["save_fldr"], 
           bucket=config["bucket"], 
           save_object = "params")

### For Prediction

In [None]:
# Initialize the model
lstm_model = Double_branch_stacked_biLSTM(input_dims = config["input_dims"],
                                          hidden_dims = config["LSTM_hidden_dim"], 
                                          n_classes = config["n_classes"], 
                                          n_layers = config["n_LSTM_layers"], 
                                          dropout_rate = config["LSTM_lyr_dropout_rate"], 
                                          s1_weight = config["s1_weight"], 
                                          bidirectional = True, 
                                          use_layernorm = True, 
                                          use_batchnorm = False, 
                                          use_attention = False)

In [None]:
config["init_params"] = "C:/My_documents/CropTypeData_Rustowicz/working_folder/new_try_01/model_path/Double_branch_stacked_biLSTM_params.pth"

In [None]:
# Compile the model
model = ModelCompiler(model=lstm_model,
                      working_dir=config["working_dir"], 
                      out_dir=config["out_dir"],
                      gpuDevices=config["gpus"],
                      br_weights = config["branch_weights"],
                      params_init=config["init_params"],
                      freeze_params=None)

In [None]:
def load_data_pred(usage, item):
    dataset = pixelDataset(root_dir = config["root_dir"],
                           usage = "test",
                           sources = config["sources"],
                           inference_index = item)
    tile = dataset.tile_id
    meta = dataset.meta
    data_loader = DataLoader(dataset, 
                          batch_size=config["batch_val"], 
                          shuffle = False)
    
    return data_loader, meta, tile

In [None]:
prediction_dir = Path(config[root_dir]).joinpath("prediction_tiles", config["sources"][0])
s1_fnames = [Path(dirpath) / f for (dirpath, dirnames, filenames) in os.walk(prediction_dir) for f in filenames if f.endswith(".npy")]
tile_count = len(s1_fnames)

for i in range(tile_count):
    pred_data = load_data_pred("test", i)
    model.inference(pred_data, config["prefix_out"])