In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
!pip install rasterio

In [None]:
import os, sys, copy, time, math, random, numbers, itertools, tqdm, importlib, re
import numpy as np
import numpy.ma as ma
import pandas as pd
import matplotlib.pyplot as plt
#import cv2
import rasterio
import torch

from sklearn import metrics
from skimage import transform as trans
from pathlib import Path
from collections.abc import Sequence
from datetime import datetime, timedelta
from scipy.ndimage import rotate
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch import optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

from IPython.core.debugger import set_trace

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#@title Utilities

def load_data(data_path, is_label=False, apply_normalization=False, dtype=np.float32, verbose=False):
    r"""
    Open data using gdal, read it as an array and normalize it.

    Arguments:
            data_path (string) -- Full path including filename of the data source we wish to load.
            is_label (binary) -- If True then the layer is a ground truth (category index) and if
                                set to False the layer is a reflectance band.
            apply_normalization (binary) -- If true min/max normalization will be applied on each band.
            dtype (np.dtype) -- Data type of the output image chips.
            verbose (binary) -- if set to true, print a screen statement on the loaded band.

    Returns:
            image -- Returns the loaded image as a 32-bit float numpy ndarray.
    """

    # Inform user of the file names being loaded from the Dataset.
    if verbose:
        print('loading file:{}'.format(data_path))

    # open dataset using rasterio library.
    with rasterio.open(data_path, "r") as src:

        if is_label:
            if src.count != 1:
                raise ValueError("Expected Label to have exactly one channel.")
            img = src.read(1)

        else:
            if apply_normalization:
                img = do_normalization(src.read(), bounds=(0, 1), clip_val=1)
                img = img.astype(dtype)
            else:
                img = src.read()
                img = img.astype(dtype)

    return img


def make_deterministic(seed=None, cudnn=True):
    """
    Sets the random seed for Python, NumPy, and PyTorch to a fixed value to ensure 
    reproducibility of results. Optionally, sets the seed for the CuDNN backend to 
    ensure reproducibility when training on a GPU.

    Args:
        seed (int): The seed value to use for setting the random seed (default: 1960).
        cudnn (bool): If True, sets the seed for the CuDNN backend to ensure 
            reproducibility when training on a GPU (default: True).
    """
    if seed is None:
        seed = int(time.time()) + int(os.getpid())
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if cudnn:
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

In [None]:
#@title Input Normalization

def do_normalization(img, normal_strategy="min_max", bounds=(0, 1), clip_val=None):
    """
    Normalize the input image pixels to a user-defined range based on the
    minimum and maximum statistics of each band and optional clip value.

    Args:
        img (np.ndarray): Stacked image bands with a dimension of (C, H, W).
        normal_strategy (str): Strategy for normalization. Either 'min_max'
                               or 'z_value'.
        bounds (tuple): Lower and upper bound of rescaled values applied to all
                        the bands in the image. Default is (0, 1).
        clip_val (float): Defines how much of the distribution tails to be cut off.
                          Default is None.

    Returns:
        np.ndarray: Normalized image stack of size (C, H, W).

    Notes:
        - Most common bounds for satellite image processing would be (0, 1)
          or (0, 256).
        - Normalization statistics are calculated per band and for each image
          tile separately.
    """

    if normal_strategy not in ["min_max", "z_value"]:
        raise ValueError("Normalization strategy is not recognized.")

    if not isinstance(bounds, (tuple, list)) or len(bounds) != 2:
        raise ValueError("Normalization bounds should be a tuple or list of length 2.")

    lower_bound, upper_bound = map(float, bounds)

    img_mins = np.nanmin(img, axis=(1, 2))
    img_maxs = np.nanmax(img, axis=(1, 2))

    if normal_strategy == "min_max":
        if clip_val is not None:
            img = np.clip(img, np.nanpercentile(img, clip_val),
                          np.nanpercentile(img, 100 - clip_val))

        normal_img = (upper_bound - lower_bound) * (img - img_mins[:, None, None]) / (
                img_maxs[:, None, None] - img_mins[:, None, None])

    elif normal_strategy == "z_value":
        img_means = np.nanmean(img, axis=(1, 2))
        img_stds = np.nanstd(img, axis=(1, 2))
        normal_img = (img - img_means[:, None, None]) / img_stds[:, None, None]

    return normal_img

In [None]:
#@title Image Augmentation

def center_rotate(img, label, degree):
    r"""
    Synthesize a new pair of image, label chips by rotating the input chip around its center.
    Arguments:
            img (ndarray) -- Stacked image bands with a dimension of (H,W,C).
            label (ndarray) -- Ground truth layer with a dimension of (H,W).
            degree (tuple or list) -- If the  passed argument has exactly two elements then they
                                      act as a bound on the possible range of values to be used for rotation.
                                      If number of elements is more than two then one element is chosen
                                      randomly as the rotation degree.
    Returns:
        img -- A numpy array of rotated variables or brightness value.
        label -- A numpy array of rotated ground truth.
    """

    # Validate input parameters
    if not isinstance(img, np.ndarray) or not isinstance(label, np.ndarray):
        raise ValueError("img and label must be numpy arrays.")
    if img.ndim != 3:
        raise ValueError("img must have dimensions (H, W, C).")
    if label.ndim != 2:
        raise ValueError("label must have dimensions (H, W).")
    if not any(isinstance(degree, t) for t in (tuple, list)):
        raise ValueError("Degree must be a tuple or a list.")

    # And draw a random degree between the bounds from a uniform distribution.
    if len(degree) == 2:
        rotation_degree = random.uniform(degree[0], degree[1])
    elif len(degree) > 2:
        rotation_degree = random.choice(degree)
    else:
        raise ValueError("Parameter degree needs at least two elements.")

    # Get the spatial dimensions of the image.
    h, w = label.shape

    # Determine the image center.
    center = (w // 2, h // 2)

    # Grab the rotation matrix.
    # Third arg --> scale: Isotropic scale factor.
    rot_matrix = cv2.getRotationMatrix2D(center, rotation_degree, 1.0)

    # perform the actual rotation on image and label.
    img = cv2.warpAffine(img, rot_matrix, (w, h))
    label = cv2.warpAffine(label, rot_matrix, (w, h))

    # Round all pixel values greater than 0.5 to 1 and assign zero to the rest.
    label = np.rint(label)

    return img, label


# use scipy package if there is issue installing opencv
def rotate_image_and_label(image, label, angle):
    """
    Applies rotation augmentation to an image patch and label.

    Args:
        image (numpy array) : The input image patch as a numpy array.
        label (numpy array) : The corresponding label as a numpy array.
        angle (list of floats) : If the list has exactly two elements they will
            be considered the lower and upper bounds for the rotation angle
            (in degrees) respectively. If number of elements are bigger than 2,
            then one value is chosen randomly as the rotation angle.

    Returns:
        A tuple containing the rotated image patch and label as numpy arrays.
    """
    if isinstance(angle, tuple) or isinstance(angle, list):
        if len(angle) == 2:
            rotation_degree = random.uniform(angle[0], angle[1])
        elif len(angle) > 2:
            rotation_degree = random.choice(angle)
        else:
            raise ValueError("Parameter angle needs at least two elements.")
    else:
        raise ValueError(
            "Rotation bound param for augmentation must be a tuple or list."
        )

    # Apply rotation augmentation to the image patch
    rotated_image = rotate(input=image, angle=rotation_degree, axes=(1,0),
                           reshape=False, mode='reflect')

    # Apply rotation augmentation to the label
    rotated_label = rotate(input=label, angle=rotation_degree, axes=(1,0),
                           reshape=False, mode='nearest')

    # Return the rotated image patch and label as a tuple
    return rotated_image.copy(), rotated_label.copy()


def flip(img, label, flip_type):
    r"""
    Synthesize a new pair of image, label chips by flipping the input chips around a user defined axis.

    Arguments:
            img (ndarray) -- Concatenated variables or brightness value with a dimension of (H,W,C)
            label (ndarray) -- Ground truth with a dimension of (H,W)
            flip_type (list) -- A flip type based on the choice of axis.
                                Provided transformation are:
                                    1) 'v_flip', vertical flip
                                    2) 'h_flip', horizontal flip
                                    3) 'd_flip', diagonal flip
    Returns:
            img -- A numpy array of flipped variables or brightness value.
            label --A numpy array of flipped labeled reference (ground truth).
    """

    if not isinstance(img, np.ndarray) or not isinstance(label, np.ndarray):
        raise ValueError("img and label must be numpy arrays.")
    if img.ndim != 3:
        raise ValueError("img must have dimensions (H, W, C).")
    if label.ndim != 2:
        raise ValueError("label must have dimensions (H, W).")
    if not isinstance(flip_type, str):
        raise ValueError("Flip type must be a string.")

    # Horizontal flip
    if flip_type == "h_flip":
        img = np.flip(img, 0)
        label = np.flip(label, 0)

    # Vertical flip
    elif flip_type == "v_flip":
        img = np.flip(img, 1)
        label = np.flip(label, 1)

    # Diagonal flip
    elif flip_type == "d_flip":
        img = np.transpose(img, axes=(1, 0))
        label = np.transpose(label, axes=(1, 0))

    else:
        raise ValueError("Unsupported flip type. Valid options are: 'h_flip', 'v_flip', 'd_flip'.")

    return img.copy(), label.copy()


def re_scale(img, label, scale=(0.75, 1.5), crop_strategy="center"):
    r"""
    Synthesize a new pair of image, label chips by rescaling the input chips.

    Arguments:
            img (ndarray) -- Image chip with a dimension of (H,W,C).
            label (ndarray) -- Reference annotation layer with a dimension of (H,W).
            scale (tuple or list) -- A range of scale ratio.
            crop_strategy (str) -- decides whether to crop the rescaled image chip randomly
                                   or at the center.
    Returns:
           Tuple[np.ndarray, np.ndarray] including:
            resampled_img -- A numpy array of rescaled variables or brightness values in the
                             same size as the input chip.
            resampled_label --A numpy array of flipped ground truth in the same size as input.
    """

    if not isinstance(img, np.ndarray) or not isinstance(label, np.ndarray):
        raise ValueError("img and label must be numpy arrays.")
    if img.ndim != 3:
        raise ValueError("img must have dimensions (H, W, C).")
    if label.ndim != 2:
        raise ValueError("label must have dimensions (H, W).")

    h, w, c = img.shape

    if isinstance(scale, Sequence):
        resize_h = round(random.uniform(scale[0], scale[1]) * h)
        resize_w = resize_h
    else:
        raise Exception('Wrong scale type!')

    assert crop_strategy in ["center", "random"], "'crop_strategy' is not recognized."

    # We are using a bi-linear interpolation by default for resampling.
    # When output image size is zero then the output size is calculated based on fx and fy.
    resampled_img = trans.resize(img, (resize_h, resize_w), preserve_range=True)
    resampled_label = trans.resize(label, (resize_h, resize_w), preserve_range=True)

    if crop_strategy == "center":
        x_off = max(0, abs(resize_h - h) // 2)
        y_off = max(0, abs(resize_w - w) // 2)
    elif crop_strategy == "random":
        x_off = random.randint(0, max(0, abs(resize_h - h)))
        y_off = random.randint(0, max(0, abs(resize_w - w)))

    canvas_img = np.zeros((h, w, c), dtype=img.dtype)
    canvas_label = np.zeros((h, w), dtype=label.dtype)

    if resize_h > h and resize_w > w:
        canvas_img = resampled_img[x_off: x_off + min(h, resize_h), y_off: y_off + min(w, resize_w), :]
        canvas_label = resampled_label[x_off: x_off + min(h, resize_h), y_off: y_off + min(w, resize_w)]
        canvas_label = np.rint(canvas_label)

    elif resize_h < h and resize_w < w:
        canvas_img[x_off: x_off + resize_h, y_off: y_off + resize_w] = resampled_img
        canvas_label[x_off: x_off + resize_h, y_off: y_off + resize_w] = resampled_label

    return canvas_img, canvas_label


def shift_brightness(img, gamma_range=(0.2, 2.0), shift_subset=(4, 4, 4), patch_shift=False):
    """
    Shift image brightness through gamma correction

    Params:

        img (ndarray): Concatenated variables or brightness value with a dimension of (H, W, C)
        gamma_range (tuple): Range of gamma values
        shift_subset (tuple): Number of bands or channels for each shift
        patch_shift (bool): Whether apply the shift on small patches

     Returns:

        ndarray, brightness shifted image

    """
    c_start = 0
    for i in shift_subset:
        gamma = random.triangular(gamma_range[0], gamma_range[1], 1)
        if patch_shift:
            # shift on patch
            # generate mask - random rotate or/and rescale

            h, w, _ = img.shape
            rotMtrx = cv2.getRotationMatrix2D(center=(random.randint(0, h), random.randint(0, w)),
                                              angle=random.randint(0, 90),
                                              scale=random.uniform(1, 2))
            mask = cv2.warpAffine(img[:, :, c_start:c_start + i], rotMtrx, (w, h))
            mask = np.where(mask, 0, 1)
            # apply mask
            img_ma = ma.masked_array(img[:, :, c_start:c_start + i], mask=mask)
            img[:, :, c_start:c_start + i] = ma.power(img_ma, gamma)
            # default extra step -- shift on image
            gamma_full = random.triangular(0.5, 1.5, 1)
            img[:, :, c_start:c_start + i] = np.power(img[:, :, c_start:c_start + i], gamma_full)

        else:
            img[:, :, c_start:c_start + i] = np.power(img[:, :, c_start:c_start + i], gamma)
        c_start += i

    return img


def gaussian_blur(img, kernel_size):
    """
        Apply Gaussian blur to the input image.

        Args:
            img (np.ndarray): Input image as a NumPy array.
            kernel_size (int): Size of the Gaussian kernel.

        Returns:
            np.ndarray: Blurred image as a NumPy array.

        Note:
            The sigmaX parameter specifies the standard deviation of the Gaussian
            kernel along the x-axis, and if set to 0, OpenCV automatically computes
            it based on the kernel size using the formula:
            sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8.
        """
    # When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`
    aug_img = cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX=0)

    return aug_img


def adjust_brightness(img, value=-0.2):
    """
    Adjust the brightness of the input image by adding a value to each pixel.

    Args:
        img (np.ndarray): Input image as a NumPy array.
        value (float): Value to be added to each pixel to adjust brightness. Default is -0.2.

    Returns:
        np.ndarray: Image with adjusted brightness as a NumPy array.

    Notes:
        - If the input image has a floating-point "dtype" (np.float, np.float32, np.float64),
            the pixel values are assumed to be in the range [0, 1].
        - If the input image has an integer "dtype", the pixel values are assumed to be in the
            range specified by the "dtype" (e.g., [0, 255] for np.uint8).
        - The adjusted pixel values are clipped to the minimum and maximum values allowed
            by the "dtype" of the input image.
    """
    if img.dtype in [np.float, np.float32, np.float64]:
        dtype_min, dtype_max = 0, 1
        dtype = np.float32
    else:
        dtype_min = np.iinfo(img.dtype).min
        dtype_max = np.iinfo(img.dtype).max
        dtype = np.iinfo(img.dtype)

    aug_img = np.clip(img.astype(np.float) + value, dtype_min, dtype_max).astype(dtype)

    return aug_img


def adjust_contrast(img, factor=1):
    """
    Adjust the contrast of the input image by multiplying it with a contrast factor.

    Args:
        img (np.ndarray): Input image as a NumPy array.
        factor (float): Contrast factor to adjust the contrast. Default is 1.0.

    Returns:
        np.ndarray: Image with adjusted contrast as a NumPy array.
    """
    if img.dtype in [np.float, np.float32, np.float64]:
        dtype_min, dtype_max = 0, 1
        dtype = np.float32
    else:
        dtype_min = np.iinfo(img.dtype).min
        dtype_max = np.iinfo(img.dtype).max
        dtype = np.iinfo(img.dtype)

    aug_img = np.clip(img.astype(np.float) * factor, dtype_min, dtype_max).astype(dtype)

    return aug_img

In [None]:
#@title Custom loss functions

class BalancedCrossEntropyLoss(nn.Module):
    """
    Balanced cross entropy loss by weighting of inverse class ratio.

    Args:
        ignore_index (int): Class index to ignore.
        reduction (str): Reduction method to apply to loss.
                         Options: 'mean', 'sum', 'none'.
        weight_scheme (str): Strategy to weight samples. Options:
                      "icr" -- inverse class ratio
                      "mcf" -- median class frequency

    Returns:
        Loss tensor according to the specified reduction.
    """

    def __init__(self, ignore_index=-100, reduction="mean", weight_scheme="icr"):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction

        assert weight_scheme in ["icr", "mcf"], "'weight_scheme' is not recognized."
        self.weight_scheme = weight_scheme

    def forward(self, predict, target):
        """
        Args:
            predict (torch.Tensor): Predicted output tensor.
            target (torch.Tensor): Target tensor.
        """

        class_counts = torch.bincount(target.view(-1), minlength=predict.shape[1])
        # get class weights
        if self.weight_scheme == "icr":
            class_weights = 1.0 / torch.sqrt(class_counts.float())
        else:
            median_frequency = torch.median(class_counts.float())
            class_weights = median_frequency / class_counts.float()

        # set weight of ignore index to 0
        if self.ignore_index >= 0 and self.ignore_index < len(class_weights):
            class_weights[self.ignore_index] = 0

        # normalize weights
        class_weights /= torch.sum(class_weights)

        # apply class weights to loss function
        loss_fn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=self.ignore_index,
                                      reduction=self.reduction)

        return loss_fn(predict, target)

class OhemCrossEntropyLoss(nn.Module):
    """
    Online Hard Example Mining (OHEM) Cross Entropy Loss for Semantic Segmentation
    Params:
        ignore_index (int): Class index to ignore
        reduction (str): Reduction method to apply to loss, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
        ohem_ratio (float): Ratio of hard examples to use in the loss function

    Returns:
        Loss tensor according to arg reduction
    """
    def __init__(self, ignore_index=-100, reduction='mean', ohem_ratio=0.25):
        super(OhemCrossEntropyLoss, self).__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.ohem_ratio = ohem_ratio

    def forward(self, predict, target):
        # calculate pixel-wise cross entropy loss
        loss_fn = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='none')
        pixel_losses = loss_fn(predict, target)

        # apply online hard example mining
        num_hard = int(self.ohem_ratio * pixel_losses.numel())
        _, top_indices = pixel_losses.flatten().topk(num_hard)
        ohem_losses = pixel_losses.flatten()[top_indices]

        # apply reduction to ohem losses
        if self.reduction == 'mean':
            loss = ohem_losses.mean()
        elif self.reduction == 'sum':
            loss = ohem_losses.sum()
        else:
            loss = ohem_losses

        return loss


class BinaryDiceLoss(nn.Module):
    '''
        Dice loss of binary class
        Params:
            smooth (float): A float number to smooth loss, and avoid NaN error, default: 1
            p (int): Denominator value: \sum{x^p} + \sum{y^p}, default: 2. Used
                     to control the sensitivity of the loss.
            predict (torch.tensor): Predicted tensor of shape [N, *]
            target (torch.tensor): Target tensor of same shape with predict
        Returns:
            Loss tensor
    '''
    def __init__(self, smooth=1, p=1):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p

    def forward(self, predict, target):

        assert predict.shape == target.shape, "predict & target shape do not match"
        assert predict.shape == target.shape, "predict & target shapes do not match"
        assert predict.dtype == target.dtype, "predict & target data types do not match"

        predict = predict.contiguous().view(-1)
        target = target.contiguous().view(-1)

        num = 2 * (predict * target).sum() + self.smooth
        den = (predict.pow(self.p) + target.pow(self.p)).sum() + self.smooth
        loss = 1 - num / den

        return loss


class DiceLoss(nn.Module):
    r"""
    Dice loss

    Arguments:
        weight (torch.tensor): Weight array of shape [num_classes,]
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to BinaryDiceLoss
    Returns:
        same as BinaryDiceLoss
    """

    def __init__(self, weight=None, ignore_index=-100, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        nclass = predict.shape[1]
        if predict.shape == target.shape:
            pass
        elif len(predict.shape) == 4:
            target = F.one_hot(target, num_classes=nclass).permute(0, 3, 1, 2).contiguous()
        else:
            assert 'Predict tensor shape of {} is not assceptable.'.format(predict.shape)

        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        weight = torch.Tensor([1. / nclass] * nclass).cuda() if self.weight is None else self.weight
        predict = F.softmax(predict, dim=1)

        for i in range(nclass):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])

                assert weight.shape[0] == nclass, \
                    'Expected weight tensor with shape [{}], but got[{}]'.format(nclass, weight.shape[0])
                dice_loss *= weight[i]
                total_loss += dice_loss

        return total_loss


class BalancedDiceLoss(nn.Module):
    """
    Dice Loss weighted by inverse of label frequency

    Arguments:
        ignore_index (int): Class index to ignore
        **kwargs: Additional arguments passed to BinaryDiceLoss

    Returns:
        Loss tensor
    """

    def __init__(self, ignore_index=-100, **kwargs):
        super(BalancedDiceLoss, self).__init__()
        self.kwargs = kwargs
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        class_weights = self.calculate_class_weights(target)

        loss_weight = torch.ones(predict.shape[1], device=predict.device) * 0.00001
        for i, weight in enumerate(class_weights):
            loss_weight[i] = weight

        loss = DiceLoss(weight=loss_weight, ignore_index=self.ignore_index, **self.kwargs)

        return loss(predict, target)

    def calculate_class_weights(self, target):
        unique, unique_counts = torch.unique(target[target != self.ignore_index], return_counts=True)
        class_ratios = unique_counts.float() / torch.numel(target)
        class_weights = 1.0 / class_ratios
        class_weights /= torch.sum(1. / class_weights)

        return class_weights


class DiceCELoss(nn.Module):
    """
    Combination of dice loss and cross entropy loss through summation

    Arguments:
        loss_weight (Tensor, optional): A manual rescaling weight given to each class.
                                        If provided, should be a Tensor of size C
        dice_weight (float): Weight on dice loss for the summation, while the weight
                             on cross-entropy loss is (1 - dice_weight)
        dice_smooth (float, optional): A float number to smooth dice loss and avoid NaN error.
                                       Default: 1
        dice_p (int, optional): Denominator value: \sum{x^p} + \sum{y^p}. Default: 1
        ignore_index (int, optional): Class index to ignore. Default: None

    Returns:
        Loss tensor
    """

    def __init__(self, loss_weight=None, dice_weight=0.5, dice_smooth=1,
                 dice_p=1, ignore_index=-100):

        super(DiceCELoss, self).__init__()
        self.loss_weight = loss_weight
        self.dice_weight = dice_weight
        self.dice_smooth = dice_smooth
        self.dice_p = dice_p
        self.ignore_index = ignore_index

        self.dice_loss = DiceLoss(weight=self.loss_weight, ignore_index=self.ignore_index,
                                  smooth=self.dice_smooth, p=self.dice_p)
        self.ce_loss = nn.CrossEntropyLoss(weight=self.loss_weight, ignore_index=self.ignore_index)

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size do not match"

        loss = self.dice_weight * self.dice_loss(predict, target) + (1 - self.dice_weight) * self.ce_loss(predict, target)

        return loss


class BalancedDiceCELoss(nn.Module):
    r"""
    Dice Cross Entropy weighted by inverse of label frequency
    Arguments:
        ignore_index (int): Class index to ignore
        predict (torch.tensor): Predicted tensor of shape [N, C, *]
        target (torch.tensor): Target tensor either in shape [N,*] or of same shape with predict
        other args pass to DiceCELoss, excluding loss_weight
    Returns:
        Same as DiceCELoss
    """

    def __init__(self, ignore_index=-100, **kwargs):
        super(BalancedDiceCELoss, self).__init__()
        self.ignore_index = ignore_index
        self.kwargs = kwargs

    def forward(self, predict, target):
        # get class weights
        class_weights = self.calculate_class_weights(target)
        loss_weight = torch.ones(predict.shape[1], device=predict.device) * 0.00001

        for i, weight in enumerate(class_weights):
            loss_weight[i] = weight

        loss = DiceCELoss(loss_weight=loss_weight, **self.kwargs)

        return loss(predict, target)

    def calculate_class_weights(self, target):
        unique, unique_counts = torch.unique(target[target != self.ignore_index], return_counts=True)
        class_ratios = unique_counts.float() / torch.numel(target)
        class_weights = 1.0 / class_ratios
        class_weights /= torch.sum(1. / class_weights)

        return class_weights

In [None]:
#@title Custom Dataset

from pathlib import Path
import tqdm, re
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler

class CropData(Dataset):
    r"""
    Create an iterable dataset of image chips
    Arguments:
        src_dir (str): Path to the folder contains data folders and files.
        usage (str): can be either train, validate, or test.
        dataset_name (str): Name of the training/validation dataset containing 
                            structured folders for image, label, and mask.
        apply_normalization (binary): decides if normalization should be applied.
        trans (list of str): Transformation or data augmentation methods; list 
                             elements could be chosen from:
                             ['v_flip','h_flip','d_flip','rotate','resize']
        split_ratio (float): Number in the range (0,1) that decides on the portion 
                             of samples that should be used for training. 
                             The remaining portion of samples will be assigned to 
                             the 'validation' dataset. Default is 0.8.
        make_deterministic (Binary): If set to True, we seed the numpy randomization 
                                     in splitting the dataset into train and validation 
                                     subfolders.
    Returns:
        A tuple of (image, label) for training and validation but only the image iterable 
        if in the inference phase.
    """

    def __init__(self, src_dir, usage, dataset_name, split_ratio=0.8, 
                 apply_normalization=False, trans=None, **kwargs):

        self.usage = usage
        self.dataset_name = dataset_name
        self.split_ratio = split_ratio
        self.apply_normalization = apply_normalization
        self.trans = trans

        assert self.usage in ["train", "validation", "inference"], "Usage is not recognized."

        img_fnames = [Path(dirpath) / f 
                      for (dirpath, dirnames, filenames) in os.walk(Path(src_dir) / self.dataset_name) 
                      for f in filenames 
                      if f.endswith(".tif") and "merged" in f and "_".join(f.split("_")[1:3]) in train_ids]
        img_fnames.sort()

        lbl_fnames = [Path(dirpath) / f 
                      for (dirpath, dirnames, filenames) in os.walk(Path(src_dir) / self.dataset_name) 
                      for f in filenames 
                      if f.endswith(".tif") and "merged" in f and re.search(r"_\d{3}_\d{3}", f)]
        lbl_fnames.sort()

        if self.usage in ["train", "validation"]:

            self.img_chips = []
            self.lbl_chips = []

            total_samples = len(img_fnames)
            indices = np.arange(total_samples)
            split_index = int(total_samples * self.split_ratio)

            np.random.seed(0)
            np.random.shuffle(indices)

            train_indices = indices[:split_index]
            val_indices = indices[split_index:]

            train_img_fnames = [img_fnames[i] for i in train_indices]
            train_lbl_fnames = [lbl_fnames[i] for i in train_indices]

            val_img_fnames = [img_fnames[i] for i in val_indices]
            val_lbl_fnames = [lbl_fnames[i] for i in val_indices]

            if self.usage == "train":
                img_fnames = train_img_fnames
                lbl_fnames = train_lbl_fnames
            else:
                img_fnames = val_img_fnames
                lbl_fnames = val_lbl_fnames

            for img_fname, lbl_fname in tqdm.tqdm(zip(img_fnames, lbl_fnames), 
                                                  total=len(img_fnames)):
                    
                img_chip = load_data(Path(src_dir) / self.dataset_name / img_fname,
                                     apply_normalization=self.apply_normalization, 
                                     is_label=False)
                img_chip = img_chip.transpose((1, 2, 0))

                lbl_chip = load_data(Path(src_dir) / self.dataset_name / lbl_fname, 
                                     is_label=True)

                self.img_chips.append(img_chip)
                self.lbl_chips.append(lbl_chip)
        else:
            pass
        
        print(f"------ {self.usage} dataset with {len(self.img_chips)} patches created ------")


    def __getitem__(self, index):
        """
        Support indexing such that dataset[index] can be used to get 
        the (index)th sample.
        """
        if self.usage in ["train", "validation"]:
            img_chip = self.img_chips[index]
            lbl_chip = self.lbl_chips[index]

            if self.trans and self.usage == "train":
                trans_flip_ls = [m for m in self.trans if "flip" in m]
                if random.randint(0, 1) and len(trans_flip_ls) > 1:
                    trans_flip = random.sample(trans_flip_ls, 1)
                    img_chip, lbl_chip = flip(img_chip, lbl_chip, trans_flip[0])
                    
                if random.randint(0, 1) and "resize" in self.trans:
                    scale_factor = kwargs.get("scale_factor", (0.75, 1.5))
                    img_chip, lbl_chip = re_scale(img_chip, lbl_chip.astype(np.uint8),
                                                  scale=scale_factor, crop_strategy="center")
                    
                if random.randint(0, 1) and "rotate" in self.trans:
                    deRotate = kwargs.get("rotation_degree", (-90, 90))
                    img_chip, lbl_chip = center_rotate(img_chip, lbl_chip, deRotate)
                    
                if random.randint(0, 1) and 'shift_brightness' in self.trans:
                    bshift_subs = kwargs.get("bshift_subs", (3, 3))
                    bshift_gamma_range = kwargs.get("bshift_gamma_range", (0.2, 2.0))
                    patch_shift = kwargs.get("patch_shift", True)
                    img_chip = shift_brightness(img_chip, gamma_range=bshift_gamma_range,
                                                shift_subset=bshift_subs, patch_shift=patch_shift)
            
            label = torch.from_numpy(np.ascontiguousarray(lbl_chip)).long()
            # shape from (H,W,C) --> (C,H,W)
            img_chip = torch.from_numpy(img_chip.transpose((2, 0, 1))).float()

            return img_chip, label
        
        else:
            img_chip = self.img_chips[index]
            img_chip = torch.from_numpy(img_chip.transpose((2, 0, 1))).float()

            return img_chip

    def __len__(self):
        return len(self.img_chips)


In [None]:
#@title training and validation functions

def train_one_epoch(trainData, model, criterion, optimizer, scheduler, lr_policy, device, train_loss=[]):
    r"""
    Train the model.
    Arguments:
            trainData (DataLoader object) -- Batches of image chips from PyTorch custom dataset(AquacultureData)
            model (initialized model) -- Choice of segmentation Model to train.
            criterion -- Chosen function to calculate loss over training samples.
            optimizer -- Chosen function for optimization.
            scheduler -- Update policy for learning rate decay.
            lr_policy (str) -- Learning rate decade policy.
            device --(str) Either 'cuda' or 'cpu'.
            train_loss -- (empty list) To record average training loss for each epoch.
            
    """
    model.train()

    epoch_loss = 0
    num_train_batches = len(trainData)

    for img_chips, labels in trainData:

        img = img_chips.to(device)
        label = labels.to(device)

        model_out = model(img)

        loss = eval(criterion)(model_out, label)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if lr_policy == "CyclicLR":
            scheduler.step()

    print('train loss:{}'.format(epoch_loss / num_train_batches))

    if lr_policy == "CyclicLR":
        print("LR: {}".format(scheduler.get_last_lr()))

    if train_loss is not None:
        train_loss.append(float(epoch_loss / num_train_batches))


def validate_one_epoch(valData, model, criterion, device, val_loss=[]):
    """
        Evaluate the model on separate Landsat scenes.
        Params:
            valData (DataLoader object) -- Batches of image chips from PyTorch custom dataset(AquacultureData)
            model -- Choice of segmentation Model.
            criterion -- Chosen function to calculate loss over validation samples.
            device (str): Either 'cuda' or 'cpu'.
            val_loss (empty list): To record average loss for each epoch
    """

    model.eval()

    # mini batch iteration
    eval_epoch_loss = 0
    num_val_batches = len(valData)

    with torch.no_grad():
        
        for img_chips, labels in valData:
            
            img = img_chips.to(device)
            label = labels.to(device)

            pred = model(img)

            loss = eval(criterion)(pred, label)
            eval_epoch_loss += loss.item()

    print('validation loss: {:.4f}'.format(eval_epoch_loss / num_val_batches))

    if val_loss is not None:
        val_loss.append(float(eval_epoch_loss / num_val_batches))

In [None]:
#@title Accuracy Metrics Assessment

import numpy as np
import pandas as pd
from torch.autograd import Variable
import torch.nn.functional as F


class BinaryMetrics:
    """
    Metrics measuring model performance.
    """

    def __init__(self, ref_array, score_array, pred_array=None):
        """
        Params:
            ref_array (ndarray): Array of ground truth
            score_array (ndarray): Array of pixels scores of positive class
            pred_array (ndarray): Boolean array of predictions telling whether
                                 a pixel belongs to a specific class.
        """

        self.tp = None
        self.fp = None
        self.fn = None
        self.tn = None
        self.eps = 10e-6
        self.observation = ref_array.flatten()
        self.score = score_array.flatten()
        if pred_array is not None:
            self.prediction = pred_array.flatten()
        # take score over 0.5 as prediction if predArray not provided
        else:
            self.prediction = np.where(self.score > 0.5, 1, 0)
        self.confusion_matrix = self.confusion_matrix()

        assert self.observation.shape == self.score.shape, "Inconsistent input shapes"

    def __add__(self, other):
        """
        Add two BinaryMetrics instances
        Params:
            other (''BinaryMetrics''): A BinaryMetrics instance
        Return:
            ''BinaryMetrics''
        """

        return BinaryMetrics(np.append(self.observation, other.observation),
                             np.append(self.score, other.score),
                             np.append(self.prediction, other.prediction))

    def __radd__(self, other):
        """
        Add a BinaryMetrics instance with reversed operands
        Params:
            other
        Returns:
            ''BinaryMetrics
        """

        if other == 0:
            return self
        else:
            return self.__add__(other)

    def confusion_matrix(self):
        """
        Calculate confusion matrix of given ground truth and predicted label
        Returns:
            "pandas.dataframe" of observation on the column and prediction on the row
        """

        ref_array = self.observation
        pred_array = self.prediction

        if ref_array.max() > 1 or pred_array.max() > 1:
            raise Exception("Invalid array")
        predArray = pred_array * 2
        sub = ref_array - predArray

        self.tp = np.sum(sub == -1)
        self.fp = np.sum(sub == -2)
        self.fn = np.sum(sub == 1)
        self.tn = np.sum(sub == 0)

        confusionMatrix = pd.DataFrame(data=np.array([[self.tn, self.fp], [self.fn, self.tp]]),
                                       index=['observation = 0', 'observation = 1'],
                                       columns=['prediction = 0', 'prediction = 1'])
        return confusionMatrix

    def ir(self):
        """
        Imbalance Ratio (IR) is defined as the proportion between positive and negative
        instances of the label. This value lies within the [0, ∞] range, having a value
        IR = 1 in the balanced case.
        Returns:
                float
        """
        try:
            ir = (self.tp + self.fn) / (self.fp + self.tn)

        except ZeroDivisionError:
            ir = (self.tp + self.fn) / (self.fp + self.tn + self.eps)

        return ir

    def accuracy(self):
        """
        Calculate Overall (Global) Accuracy.
        Returns:
            float scalar
        """
        try:
            oa = (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn)

        except ZeroDivisionError:
            oa = (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn + self.eps)

        return oa

    def precision(self):
        """
        Calculate User’s Accuracy (Positive Prediction Value (PPV) | UA).
        Returns:
            float
        """
        try:
            ua = self.tp / (self.tp + self.fp)

        except ZeroDivisionError:
            ua = self.tp / (self.tp + self.fp + self.eps)

        return ua

    def recall(self):
        """
        Calculate Producer's Accuracy (True Positive Rate |Sensitivity |hit rate | recall).
        Returns:
            float
        """
        try:
            pa = self.tp / (self.tp + self.fn)

        except ZeroDivisionError:
            pa = self.tp / (self.tp + self.fn + self.eps)

        return pa

    def false_positive_rate(self):
        """
        Calculate False Positive Rate(FPR) aka. False Alarm Rate (FAR), or Fallout.
        Returns:
             float
        """
        try:
            fpr = self.fp / (self.tn + self.fp)

        except ZeroDivisionError:
            fpr = self.fp / (self.tn + self.fp + self.eps)

        return fpr

    def iou(self):
        """
        Calculate interception over union for the positive class.
        Returns:
            float
        """

        try:
            iou = self.tp / (self.tp + self.fp + self.fn)
        except ZeroDivisionError:
            iou = self.tp / (self.tp + self.fp + self.fn + self.eps)

        return iou

    def f1_measure(self):
        """
        Calculate F1 score.
        Returns:
            float
        """

        try:
            precision = self.tp / (self.tp + self.fp)
            recall = self.tp / (self.tp + self.fn)
            f1 = (2 * precision * recall) / (precision + recall)

        except ZeroDivisionError:
            precision = self.tp / (self.tp + self.fp + self.eps)
            recall = self.tp / (self.tp + self.fn + self.eps)
            f1 = (2 * precision * recall) / (precision + recall + self.eps)

        return f1

    def tss(self):
        """
        Calculate true skill statistic (TSS)
        Returns:
            float
        """

        return self.tp / (self.tp + self.fn) + self.tn / (self.tn + self.fp) - 1


def do_accuracy_evaluation(eval_data, model, filename, gpu=True):
    r"""
    Evaluate the model on a separate Landsat scene.

    Arguments:
    eval_data -- Batches of image chips from PyTorch custom dataset(AquacultureData)
    model -- Choice of segmentation Model to train.
    filename -- (str) Name of the csv file to report metrics.
    gpu --(binary) If False the model will run on CPU instead of GPU. Default is True.

    Note: to harden the class prediction around a higher probability, drop 'class_pred' argument
          and increase the threshold of 'predArray' in the 'BinaryMetrics' class '__init__' function.

    """

    model.eval()

    metrics_ls = []

    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    for img_chips, label in eval_data:

        img = Variable(img_chips, requires_grad=False)  # size: batch size X channels X W X H
        label = Variable(label, requires_grad=False)  # size: batch size X W X H

        if gpu:
            img = img.cuda()
            label = label.cuda()

        pred = model(img)  # size: batch size x number of categories X W x H
        pred_prob = F.softmax(pred, 1)
        batch, n_class, height, width = pred_prob.size()

        for i in range(batch):
            label_batch = label[i, :, :].cpu().numpy()
            batch_pred = pred_prob.max(dim=1)[1][:, :, :].data[i].cpu().numpy()

            for n in range(1, n_class):
                class_prob = pred_prob[:, n, :, :].data[i].cpu().numpy()
                class_pred = np.where(batch_pred == n, 1, 0)
                class_label = np.where(label_batch == n, 1, 0)
                chip_metrics = BinaryMetrics(class_label, class_prob, class_pred)

                try:
                    metrics_ls[n - 1].append(chip_metrics)
                except:
                    metrics_ls.append([chip_metrics])

    metrics = [sum(m) for m in metrics_ls]

    report = pd.DataFrame({
        "Imbalance Ratio": [m.ir() for m in metrics],
        "Overall Accuracy": [m.accuracy() for m in metrics],
        "Precision (UA or PPV)": [m.precision() for m in metrics],
        "Recall (PA or TPR or Sensitivity)": [m.recall() for m in metrics],
        "False Positive Rate": [m.false_positive_rate() for m in metrics],
        "IoU": [m.iou() for m in metrics],
        "F1-score": [m.f1_measure() for m in metrics],
        "TSS": [m.tss() for m in metrics]
    }, ["class_{}".format(m) for m in range(1, len(metrics) + 1)])

    report.to_csv(filename, index=False)

In [None]:
#@title Model Compiler

import os
from pathlib import Path
from datetime import datetime, timedelta
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import _LRScheduler
from torch.nn import init


def get_optimizer(optimizer, params, lr, momentum):
    """
    Get an instance of the specified optimizer with the given parameters.

    Parameters:
        optimizer (str): The name of the optimizer. Options: 
                              "sgd", "nesterov", "adam", "amsgrad".
        params (iterable): The parameters to optimize.
        lr (float): The learning rate.
        momentum (float): The momentum factor for optimizers that support it.

    Returns:
        torch.optim.Optimizer: An instance of the specified optimizer with the 
        given parameters.
    """
    optimizer = optimizer.lower()

    if optimizer == "sgd":
        return torch.optim.SGD(params, lr, momentum=momentum)
    elif optimizer == "nesterov":
        return torch.optim.SGD(params, lr, momentum=momentum, nesterov=True)
    elif optimizer == "adam":
        return torch.optim.Adam(params, lr)
    elif optimizer == 'amsgrad':
        return torch.optim.Adam(params, lr, amsgrad=True)
    else:
        raise ValueError(f"{optimizer} currently not supported, please choose a valid optimizer")


def init_weights(model, init_type="normal", gain=0.02):
    """Initialize the network weights using various initialization methods.

    Args:
        model (torch.nn.Module): The initialized model.
        init_type (str): The initialization type. Supported initialization methods: 
                         "normal", "xavier", "kaiming", "orthogonal"
                         Default is "normal" for random initialization
                         using a normal distribution.
        gain (float): The scaling factor for the initialized weights.
    """
    class_name = model.__class__.__name__
    if hasattr(model, "weight") and (class_name.find("Conv") != -1 or 
                                     class_name.find("Linear") != -1):
        if init_type == "normal":
            init.normal_(model.weight.data, 0.0, gain)
        elif init_type == "xavier":
            init.xavier_normal_(model.weight.data, gain=gain)
        elif init_type == "kaiming":
            init.kaiming_normal_(model.weight.data, a=0, mode="fan_out")
        elif init_type == "orthogonal":
            init.orthogonal_(model.weight.data, gain=gain)
        else:
            raise NotImplementedError(f"initialization method {init_type} is not implemented.")

    if hasattr(model, "bias") and model.bias is not None:
        init.constant_(model.bias.data, 0.0)

    if class_name.find("BatchNorm2d") != -1:
        init.normal_(model.weight.data, 1.0, gain)
        init.constant_(model.bias.data, 0.0)

    print(f"initialize network with {init_type}.")


class PolynomialLR(_LRScheduler):
    """Polynomial learning rate decay until the step reaches the max_decay_steps.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_decay_steps (int): The maximum number of steps after which the learning 
                               rate stops decreasing.
        min_learning_rate (float): The minimum value of the learning rate. 
                                   Learning rate decay stops at this value.
        power (float): The power of the polynomial.
    """

    def __init__(self, optimizer, max_decay_steps, min_learning_rate=1e-5, power=1.0):

        if max_decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')

        self.max_decay_steps = max_decay_steps
        self.min_learning_rate = min_learning_rate
        self.power = power
        self.last_step = 0

        super().__init__(optimizer)

    def get_lr(self):
        if self.last_step > self.max_decay_steps:
            return [self.min_learning_rate for _ in self.base_lrs]

        return [(base_lr - self.min_learning_rate) *
                ((1 - self.last_step / self.max_decay_steps) ** self.power) +
                self.min_learning_rate for base_lr in self.base_lrs]

    def step(self, step=None):

        if step is None:
            step = self.last_step + 1
        self.last_step = step if step != 0 else 1

        if self.last_step <= self.max_decay_steps:
            decay_lrs = [(base_lr - self.min_learning_rate) *
                         ((1 - self.last_step / self.max_decay_steps) ** self.power) +
                         self.min_learning_rate for base_lr in self.base_lrs]

            for param_group, lr in zip(self.optimizer.param_groups, decay_lrs):
                param_group['lr'] = lr


class ModelCompiler:

    def __init__(self, model, working_dir, out_dir, num_classes, inch, gpu_devices=[0],
                 model_init_type="kaiming", params_init=None, freeze_params=None):
        r"""
        Train the model.

        Arguments:
            model (ordered Dict) -- initialized model either vanilla or pre-trained depending on
                                    the argument 'params_init'.
            working_dir (str) -- General Directory to store output from any experiment.
            out_dir (str) -- specific output directory for the current experiment.
            num_classes (int) -- number of output classes based on the classification scheme.
            inch (int) -- number of input channels.
            gpu_devices (list) -- list of GPU indices to use for parallelism if multiple GPUs are available.
                                  Default is set to index 0 for a single GPU.
            model_init_type -- (str) model initialization choice if it's not pre-trained.
            params_init --(str or None) Path to the saved model parameters. If set to 'None', a vanilla model will
                          be initialized.
            freeze_params (list) -- list of integers that show the index of layers in a pre-trained
                                    model (on the source domain) that we want to freeze for fine-tuning
                                    the model on the target domain used in the model-based transfer learning.
        """

        self.working_dir = working_dir
        self.out_dir = out_dir

        self.num_classes = num_classes
        self.inch = inch
        self.gpu_devices = gpu_devices
        self.use_sync_bn = use_sync_bn
        self.model_init_type = model_init_type
        self.params_init = params_init
        self.checkpoint_dirpath = None

        self.model = model
        self.model_name = self.model.__class__.__name__

        if self.params_init:
            self.load_params(self.params_init, freeze_params)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if self.device.type == "cuda":
            print("----------GPU available----------")
            if self.gpu_devices:
                torch.cuda.set_device(self.gpu_devices[0])
                self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_devices)
        else:
            print('----------No GPU available, using CPU instead----------')
            self.model = self.model.to(device)


        if params_init is None:
            init_weights(self.model, self.model_init_type, gain=0.01)

        num_params = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        print("total number of trainable parameters: {:2.1f}M".format(num_params / 1000000))

        if self.params_init:
            print("---------- Pre-trained model compiled successfully ----------")
        else:
            print("---------- Vanilla Model compiled successfully ----------")

    def load_params(self, dir_params, freeze_params):
        """
        Load parameters from a file and update the model's state dictionary.

        Args:
            dir_params (str): Directory path to the parameters file.
            freeze_params (list): List of indices corresponding to the model's parameters that should be frozen.

        Returns:
            None
        """

        # inparams = torch.load(self.params_init, map_location='cuda:0')
        inparams = torch.load(self.params_init)

        model_dict = self.model.state_dict()

        if "module" in list(inparams.keys())[0]:
            inparams_filter = {k[7:]: v.cpu() for k, v in inparams.items() if k[7:] in model_dict}
        else:
            inparams_filter = {k: v.cpu() for k, v in inparams.items() if k in model_dict}

        model_dict.update(inparams_filter)

        # load new state dict
        self.model.load_state_dict(model_dict)

        if freeze_params:
            for i, p in enumerate(self.model.parameters()):
                if i in freeze_params:
                    p.requires_grad = False

    def fit(self, trainDataset, valDataset, epochs, optimizer_name, lr_init, 
            lr_policy, criterion, momentum=None, resume=False, resume_epoch=None, **kwargs):
        """
        Train the model on the provided datasets.

        Args:
            trainDataset: The loaded training dataset.
            valDataset: The loaded validation dataset.
            epochs (int): The number of epochs to train.
            optimizer_name (str): The name of the optimizer to use.
            lr_init (float): The initial learning rate.
            lr_policy (str): The learning rate policy.
            criterion: The loss criterion.
            momentum (float, optional): The momentum factor for the optimizer (default: None).
            resume (bool, optional): Whether to resume training from a checkpoint (default: False).
            resume_epoch (int, optional): The epoch from which to resume training (default: None).
            **kwargs: Additional arguments specific to certain learning rate policies.

        Returns:
            None
        """

        # Set the folder to save results.
        working_dir = self.working_dir
        out_dir = self.out_dir
        model_dir = "{}/{}/{}_ep{}".format(working_dir, out_dir, self.model_name, epochs)

        if not os.path.exists(Path(working_dir) / out_dir / model_dir):
            os.makedirs(Path(working_dir) / out_dir / model_dir)

        self.checkpoint_dirpath = Path(working_dir) / out_dir / model_dir / "chkpt"
        if not os.path.exists(self.checkpoint_dirpath):
            os.makedirs(self.checkpoint_dirpath)

        os.chdir(Path(working_dir) / out_dir / model_dir)

        print("-------------------------- Start training --------------------------")
        start = datetime.now()

        writer = SummaryWriter('../')
        lr = lr_init

        optimizer = get_optimizer(optimizer_name,
                                  filter(lambda p: p.requires_grad, self.model.parameters()),
                                  lr,
                                  momentum)

        # Initialize the learning rate scheduler
        if lr_policy == "StepLR":
            step_size = kwargs.get("step_size", 3)
            gamma = kwargs.get("gamma", 0.98)
            scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                  step_size=step_size,
                                                  gamma=gamma)

        elif lr_policy == "MultiStepLR":
            milestones = kwargs.get("milestones", [5, 10, 20, 35, 50, 70, 90])
            gamma = kwargs.get("gamma", 0.5)
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                       milestones=milestones,
                                                       gamma=gamma)

        elif lr_policy == "ReduceLROnPlateau":
            mode = kwargs.get("mode", "min")
            factor = kwargs.get("factor", 0.8)
            patience = kwargs.get("patience", 3)
            threshold = kwargs.get("threshold", 0.0001)
            threshold_mode = kwargs.get("threshold_mode", "rel")
            min_lr = kwargs.get("min_lr", 3e-6)
            verbose = kwargs.get("verbose", True)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                             mode=mode,
                                                             factor=factor,
                                                             patience=patience,
                                                             threshold=threshold,
                                                             threshold_mode=threshold_mode,
                                                             min_lr=min_lr,
                                                             verbose=verbose)

        elif lr_policy == "PolynomialLR":
            max_decay_steps = kwargs.get("max_decay_steps", 75)
            min_learning_rate = kwargs.get("min_learning_rate", 1e-5)
            power = kwargs.get("power", 0.8)
            scheduler = PolynomialLR(optimizer,
                                     max_decay_steps=max_decay_steps,
                                     min_learning_rate=min_learning_rate,
                                     power=power)

        elif lr_policy == "CyclicLR":
            base_lr = kwargs.get("base_lr", 3e-5)
            max_lr = kwargs.get("max_lr", 0.01)
            step_size_up = kwargs.get("step_size_up", 1100)
            mode = kwargs.get("mode", "triangular")
            scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                                    base_lr=base_lr,
                                                    max_lr=max_lr,
                                                    step_size_up=step_size_up,
                                                    mode=mode)

        else:
            scheduler = None

        # Resume the model from the specified checkpoint in the config file.
        train_loss = []
        val_loss = []

        if resume:
            model_state_file = os.path.join(self.checkpoint_dirpath, "{}_checkpoint.pth.tar".format(resume_epoch))
            if os.path.isfile(model_state_file):
                checkpoint = torch.load(model_state_file)
                resume_epoch = checkpoint["epoch"]
                scheduler.load_state_dict(checkpoint["scheduler"])
                self.model.load_state_dict(checkpoint["state_dict"])
                optimizer.load_state_dict(checkpoint["optimizer"])
                train_loss = checkpoint["train loss"]
                val_loss = checkpoint["Evaluation loss"]

        # epoch iteration
        if resume:
            iterable = range(resume_epoch, epochs)
        else:
            iterable = range(epochs)

        for t in iterable:

            print("Epoch [{}/{}]".format(t + 1, epochs))

            start_epoch = datetime.now()

            train_one_epoch(trainDataset, self.model, criterion, optimizer, 
                            scheduler, lr_policy, device=self.device, 
                            train_loss=train_loss)
            validate_one_epoch(valDataset, self.model, criterion, device=self.device, 
                               val_loss=val_loss)

            # Update the scheduler
            if lr_policy in ["StepLR", "MultiStepLR"]:
                scheduler.step()
                print("LR: {}".format(scheduler.get_last_lr()))

            if lr_policy == "ReduceLROnPlateau":
                scheduler.step(val_loss[t])

            if lr_policy == "PolynomialLR":
                scheduler.step(t)
                print("LR: {}".format(optimizer.param_groups[0]['lr']))

            print('time:', (datetime.now() - start_epoch).seconds)

            # Adjust logger to resume status and save checkpoints in defined intervals.
            checkpoint_interval = 20

            writer.add_scalars("Loss",
                               {"train loss": train_loss[t],
                                "Evaluation loss": val_loss[t]},
                               t + 1)

            if (t + 1) % checkpoint_interval == 0:
                torch.save({"epoch": t + 1,
                            "state_dict": self.model.state_dict() if len(self.gpu_devices) > 1 else \
                                self.model.module.state_dict(),
                            "scheduler": scheduler.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "train loss": train_loss,
                            "Evaluation loss": val_loss},
                           os.path.join(self.checkpoint_dirpath, f"{t + 1}_checkpoint.pth.tar"))

        writer.close()

        duration_in_sec = (datetime.now() - start).seconds
        duration_format = str(timedelta(seconds=duration_in_sec))
        print(f"----------- Training finished in {duration_format} -----------")

    def accuracy_evaluation(self, evalDataset, filename):
        """
        Evaluate the accuracy of the model on the provided evaluation dataset.

        Args:
            evalDataset (DataLoader): The evaluation dataset to evaluate the model on.
            filename (str): The filename to save the evaluation results in the output CSV.
    """

        if not os.path.exists(Path(self.working_dir) / self.out_dir):
            os.makedirs(Path(self.working_dir) / self.out_dir)

        os.chdir(Path(self.working_dir) / self.out_dir)

        print("---------------- Start evaluation ----------------")

        start = datetime.now()

        do_accuracy_evaluation(evalDataset, self.model, filename, self.gpu)

        duration_in_sec = (datetime.now() - start).seconds
        print(
            f"---------------- Evaluation finished in {duration_in_sec}s ----------------")


    def save(self, save_object="params"):
        """
        Save model parameters or the entire model to disk.

        Args:
            save_object (str): Specifies whether to save "params" or "model". 
            Defaults to "params".
        """

        if save_object == "params":
            if len(self.gpu_devices) > 1:
                torch.save(self.model.module.state_dict(),
                           os.path.join(self.checkpoint_dirpath, "{}_final_state.pth".format(self.model_name)))
            else:
                torch.save(self.model.state_dict(),
                           os.path.join(self.checkpoint_dirpath, "{}_final_state.pth".format(self.model_name)))

            print("--------------------- Model parameters is saved to disk ---------------------")

        elif save_object == "model":
            torch.save(self.model,
                       os.path.join(self.checkpoint_dirpath, "{}_final_state.pth".format(self.model_name)))

        else:
            raise ValueError("Improper object type.")


**Function Calls**

In [None]:
config = {
    
    # Custom dataset params
    "src_dir": "/content/gdrive/MyDrive/PondDataset",
    "train_dataset_name": "chips_binary_filtered",
    "split_ratio": 0.8,
    "apply_normalization": True,
    "transformations": ["v_flip","h_flip","d_flip", "rotate", "resize"],
    "aug_prams": {
        "scale_factor": (0.75, 1.3), 
        "rotation_degree": (-90, 90), 
        "bshift_gamma_range": (0.2, 2), 
        "bshift_subs": (3, 3), 
        "patch_shift": True
    },

    # DataLoader
    "train_BatchSize": 16,
    "val_test_BatchSize": 1,

    # model initialization params
    "n_classes": 2,
    "input_channels": 4,

    # Model compiler params
    "working_dir": "",
    "n_classes": 2,
    "gpuDevices": [0],
    "init_type": "kaiming",
    "params_init": None,
    "freeze_params": None,
    
    # Model fitting
    "epochs": 100,
    "optimizer": "SGD",
    "LR": 0.003,
    "LR_policy": "PolynomialLR",
    "criterion": BalancedTverskyFocalCELoss,
    "momentum": 0.95,
    "resume": False,
    "resume_epoch": None,
    "lr_prams": {
        # StepLR & MultiStepLR
        "step_size" : 3,
        "milestones": [5, 10, 20, 35, 50, 70, 90], 
        "gamma": 0.98, 
        # ReduceLROnPlateau
        "mode": "min", 
        "factor": 0.8, 
        "patience": 3, 
        "threshold": 0.0001,
        "threshold_mode": "rel",
        "min_lr": 3e-6,
        # PolynomialLR
        "max_decay_steps": 75,
        "min_learning_rate": 1e-5,
        "power": 0.8,
        # CyclicLR
        "base_lr": 3e-5,
        "max_lr": 0.01,
        "step_size_up": 1100,
        "mode": "triangular",
    },
    
    # Model accuracy evaluation
    "val_metric_fname" : "validate_metrics.csv",   
}

if not os.path.exists(config["working_dir"]):
    os.makedirs(config["working_dir"])
os.chdir(config["working_dir"])

In [None]:
train_dataset = CropData(src_dir = config["src_dir"],
                         usage = "train",
                         dataset_name = config["train_dataset_name"],
                         split_ratio = config["split_ratio"],
                         apply_normalization = config["apply_normalization"],
                         trans = config["transformations"], 
                         aug_params = config["aug_params"])

In [None]:
train_loader = DataLoader(train_dataset,
                          batch_size = config["train_BatchSize"], 
                          shuffle = True)

In [None]:
val_dataset = CropData(src_dir = config["src_dir"],
                       usage = "validation",
                       dataset_name = config["train_dataset_name"],
                       split_ratio = config["split_ratio"],
                       apply_normalization = config["apply_normalization"])

In [None]:
val_loader = DataLoader(val_dataset, 
                        batch_size = config["val_test_BatchSize"], 
                        shuffle = False)

In [None]:
model = Unet(n_classes=config["n_classes"], in_channels=config["input_channels"], use_skipAtt=False)

In [None]:
model = ModelCompiler(model,
                      working_dir = config["working_dir"],
                      out_dir = config["out_dir"],
                      num_classes = config["n_classes"],
                      inch = config["input_channels"],
                      gpu_devices = config["gpuDevices"],
                      model_init_type = config["init_type"], 
                      params_init = config["params_init"],
                      freeze_params = config["freeze_params"])

In [None]:
model.fit(train_loader,
          val_loader, 
          epochs = config["epochs"], 
          optimizer_name = config["optimizer"], 
          lr_init = config["LR"],
          lr_policy = config["LR_policy"], 
          criterion = config["criterion"], 
          momentum = config["momentum"],
          resume = config["resume"],
         resume_epoch = config["resume_epoch"])

In [None]:
model.save(save_object="params")

In [None]:
model.accuracy_evaluation(val_loader, filename=config["val_metric_fname"])