<a href="https://colab.research.google.com/github/JaiswalFelipe/Dissertation-Project/blob/main/mydataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Requirements

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Not connected to a GPU


In [2]:
!pip install imagecodecs 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting imagecodecs
  Downloading imagecodecs-2021.11.20-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31.0 MB)
[K     |████████████████████████████████| 31.0 MB 83.4 MB/s 
Installing collected packages: imagecodecs
Successfully installed imagecodecs-2021.11.20


In [3]:
import os
import sys
import numpy as np
import imageio
import torch
import torchvision
import math
import scipy.stats as stats
import torch.nn.functional as F
import cv2

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
from skimage import transform
from skimage import img_as_float
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from torchvision import transforms
from skimage import transform
from torch import nn
from torchvision import models


#from data_utils import create_or_load_statistics, create_distrib, normalize_images, data_augmentation

scaler = MinMaxScaler()

In [8]:
# Connect to drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Model

### Utils

In [None]:
#from torch import nn


def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()


### fcnwideresnet

In [None]:
#import torch
#from torch import nn
#from torchvision import models
#import torch.nn.functional as F

#from networks.utils import initialize_weights


class FCNWideResNet50(nn.Module):
    def __init__(self, input_channels, num_classes, pretrained=True, skip=True):
        super(FCNWideResNet50, self).__init__()
        self.skip = skip

        # load pre-trained model
        wideresnet = models.wide_resnet50_2(pretrained=pretrained, progress=False)

        if input_channels == 3:
            self.init = nn.Sequential(
                wideresnet.conv1,
                wideresnet.bn1,
                wideresnet.relu,
                wideresnet.maxpool
            )
        else:
            self.init = nn.Sequential(
                nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
                wideresnet.bn1,
                wideresnet.relu,
                wideresnet.maxpool
            )
        self.layer1 = wideresnet.layer1  # output feat = 256
        self.layer2 = wideresnet.layer2  # output feat = 512
        self.layer3 = wideresnet.layer3  # output feat = 1024
        self.layer4 = wideresnet.layer4  # output feat = 2048

        if self.skip:
            self.classifier1 = nn.Sequential(
                nn.Conv2d(2048 + 512, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Dropout2d(0.5),
            )
        else:
            self.classifier1 = nn.Sequential(
                nn.Conv2d(2048, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Dropout2d(0.5),
            )
        self.final = nn.Conv2d(128, num_classes, kernel_size=3, padding=1)

        if not pretrained:
            initialize_weights(self)
        else:
            initialize_weights(self.classifier1)
            initialize_weights(self.final)

    def forward(self, x):
        fv_init = self.init(x)
        fv1 = self.layer1(fv_init)
        fv2 = self.layer2(fv1)
        fv3 = self.layer3(fv2)
        fv4 = self.layer4(fv3)

        if self.skip:
            # Forward on FCN with Skip Connections.
            fv_final = torch.cat([F.interpolate(fv2, x.size()[2:], mode='bilinear', align_corners=False),
                                  F.interpolate(fv4, x.size()[2:], mode='bilinear', align_corners=False)], 1)
        else:
            # Forward on FCN without Skip Connections.
            fv_final = F.interpolate(fv4, x.size()[2:], mode='bilinear', align_corners=False)

        classif1 = self.classifier1(fv_final)
        output = self.final(classif1)

        return output


# data_utils

In [5]:
def compute_image_mean(data):
    _mean = np.mean(np.mean(np.mean(data, axis=0), axis=0), axis=0)
    _std = np.std(np.std(np.std(data, axis=0, ddof=1), axis=0, ddof=1), axis=0, ddof=1)

    return _mean, _std
    

def normalize_images(data, _mean, _std):
    for i in range(len(_mean)):
        data[:, :, i] = np.subtract(data[:, :, i], _mean[i])
        data[:, :, i] = np.divide(data[:, :, i], _std[i])



def data_augmentation(img, label=None):
    rand_fliplr = np.random.random() > 0.50
    rand_flipud = np.random.random() > 0.50
    rand_rotate = np.random.random()

    if rand_fliplr:
        img = np.fliplr(img)
        label = np.fliplr(label)
    if rand_flipud:
        img = np.flipud(img)
        label = np.flipud(label)

    if rand_rotate < 0.25:
        img = transform.rotate(img, 270, order=1, preserve_range=True)
        label = transform.rotate(label, 270, order=0, preserve_range=True)
    elif rand_rotate < 0.50:
        img = transform.rotate(img, 180, order=1, preserve_range=True)
        label = transform.rotate(label, 180, order=0, preserve_range=True)
    elif rand_rotate < 0.75:
        img = transform.rotate(img, 90, order=1, preserve_range=True)
        label = transform.rotate(label, 90, order=0, preserve_range=True)

    img = img.astype(np.float32)
    label = label.astype(np.int64)

    return img, label


# utils

In [None]:
#import os
import random
#import numpy as np
#import imageio
import argparse
#import matplotlib.pyplot as plt
#import torch


def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)


def str2bool(v):
    """
    Function to transform strings into booleans.

    v: string variable
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def save_best_models(net, output_path, best_records, epoch, metric, num_saves=1,
                     patch_acc_loss=None, patch_occur=None, patch_chosen_values=None):
    if len(best_records) < num_saves:
        best_records.append({'epoch': epoch, 'kappa': metric})

        torch.save(net.state_dict(), os.path.join(output_path, 'model_' + str(epoch) + '.pth'))
        if patch_acc_loss is not None and patch_occur is not None and patch_chosen_values is not None:
            np.save(output_path + 'patch_acc_loss_step_' + str(epoch) + '.npy', patch_acc_loss)
            np.save(output_path + 'patch_occur_step_' + str(epoch) + '.npy', patch_occur)
            np.save(output_path + 'patch_chosen_values_step_' + str(epoch) + '.npy', patch_chosen_values)
    else:
        # find min saved acc
        min_index = 0
        for i, r in enumerate(best_records):
            if best_records[min_index]['kappa'] > best_records[i]['kappa']:
                min_index = i

        # check if currect acc is greater than min saved acc
        if metric > best_records[min_index]['kappa']:
            # if it is, delete previous files
            min_step = str(best_records[min_index]['epoch'])

            os.remove(os.path.join(output_path, 'model_' + min_step + '.pth'))
            # replace min value with current
            best_records[min_index] = {'epoch': epoch, 'kappa': metric}
            # save current model
            torch.save(net.state_dict(), os.path.join(output_path, 'model_' + str(epoch) + '.pth'))

            if patch_acc_loss is not None and patch_occur is not None and patch_chosen_values is not None:
                os.remove(os.path.join(output_path, 'patch_acc_loss_step_' + min_step + '.npy'))
                os.remove(os.path.join(output_path, 'patch_occur_step_' + min_step + '.npy'))
                os.remove(os.path.join(output_path, 'patch_chosen_values_step_' + min_step + '.npy'))

                np.save(output_path + 'patch_acc_loss_step_' + str(epoch) + '.npy', patch_acc_loss)
                np.save(output_path + 'patch_occur_step_' + str(epoch) + '.npy', patch_occur)
                np.save(output_path + 'patch_chosen_values_step_' + str(epoch) + '.npy', patch_chosen_values)

    np.save(os.path.join(output_path, 'best_records.npy'), best_records)


def define_multinomial_probs(values, dif_prob=2):
    interval_size = values[-1] - values[0] + 1

    general_prob = 1.0 / float(interval_size)
    max_prob = general_prob * dif_prob  # for values

    probs = np.full(interval_size, (1.0 - max_prob * len(values)) / float(interval_size - len(values)))
    for i in range(len(values)):
        probs[values[i] - values[0]] = max_prob

    return probs


def select_best_patch_size(distribution_type, values, patch_acc_loss, patch_occur, is_loss_or_acc='acc',
                           patch_chosen_values=None, debug=False):
        patch_occur[np.where(patch_occur == 0)] = 1
        patch_mean = patch_acc_loss / patch_occur
        # print is_loss_or_acc

        if is_loss_or_acc == 'acc':
            argmax_acc = np.argmax(patch_mean)
            if distribution_type == 'multi_fixed':
                cur_patch_val = int(values[argmax_acc])
            elif distribution_type == 'uniform' or distribution_type == 'multinomial':
                cur_patch_val = values[0] + argmax_acc

            if patch_chosen_values is not None:
                patch_chosen_values[int(argmax_acc)] += 1

            if debug is True:
                print('patch_acc_loss', patch_acc_loss)
                print('patch_occur', patch_occur)
                print('patch_mean', patch_mean)
                print('argmax_acc', argmax_acc)

                print('specific', argmax_acc, patch_acc_loss[argmax_acc], patch_occur[argmax_acc], patch_mean[argmax_acc])

        elif is_loss_or_acc == 'loss':
            arg_sort_out = np.argsort(patch_mean)

            if debug is True:
                print('patch_acc_loss', patch_acc_loss)
                print('patch_occur', patch_occur)
                print('patch_mean', patch_mean)
                print('arg_sort_out', arg_sort_out)
            if distribution_type == 'multi_fixed':
                for i in range(len(values)):
                    if patch_occur[arg_sort_out[i]] > 0:
                        cur_patch_val = int(values[arg_sort_out[i]])  # -1*(i+1)
                        if patch_chosen_values is not None:
                            patch_chosen_values[arg_sort_out[i]] += 1
                        if debug is True:
                            print('specific', arg_sort_out[i], patch_acc_loss[arg_sort_out[i]], patch_occur[
                                arg_sort_out[i]], patch_mean[arg_sort_out[i]])
                        break
            elif distribution_type == 'uniform' or distribution_type == 'multinomial':
                for i in range(values[-1] - values[0] + 1):
                    if patch_occur[arg_sort_out[i]] > 0:
                        cur_patch_val = values[0] + arg_sort_out[i]
                        if patch_chosen_values is not None:
                            patch_chosen_values[arg_sort_out[i]] += 1
                        if debug is True:
                            print('specific', arg_sort_out[i], patch_acc_loss[arg_sort_out[i]], patch_occur[
                                arg_sort_out[i]], patch_mean[arg_sort_out[i]])
                        break

        if debug is True:
            print('Current patch size ', cur_patch_val)
            if patch_chosen_values is not None:
                print('Distr of chosen sizes ', patch_chosen_values)

        return cur_patch_val


def create_prediction_map(img_name, prob_img, channels=False):
    if channels is True:
        for i in range(prob_img.shape[-1]):
            # imageio.imwrite(img_name + 'feat_' + str(i) + '.png', prob_img[:, :, i].astype(np.uint8))
            plt.imsave(img_name + 'feat_' + str(i) + '.png', prob_img[:, :, i], cmap=plt.cm.jet)
    else:
        imageio.imwrite(img_name + '.png', prob_img.astype(np.uint8) * 255)
        # img = Image.fromarray(prob_img.astype(np.uint8) * 255)
        # img.save(img_name + ".tif")


def calc_accuracy_by_crop(true_crop, pred_crop, num_classes, track_conf_matrix, masks=None):
    b, h, w = pred_crop.shape

    acc = 0
    local_conf_matrix = np.zeros((num_classes, num_classes), dtype=np.uint32)
    # count = 0
    for i in range(b):
        for j in range(h):
            for k in range(w):
                if masks is None or (masks is not None and masks[i, j, k]):
                    # count += 1
                    if true_crop[i, j, k] == pred_crop[i, j, k]:
                        acc = acc + 1
                    if track_conf_matrix is not None:
                        track_conf_matrix[true_crop[i, j, k]][pred_crop[i, j, k]] += 1
                    local_conf_matrix[true_crop[i, j, k]][pred_crop[i, j, k]] += 1

    # print count, b*h*w
    return acc, local_conf_matrix


def calc_accuracy_by_class(true_crop, pred_crop, num_classes, track_conf_matrix):
    acc = 0
    local_conf_matrix = np.zeros((num_classes, num_classes), dtype=np.uint32)
    # count = 0
    for i in range(len(true_crop)):
        if true_crop[i] == pred_crop[i]:
            acc = acc + 1
        track_conf_matrix[true_crop[i]][pred_crop[i]] += 1
        local_conf_matrix[true_crop[i]][pred_crop[i]] += 1

    return acc, local_conf_matrix


def create_cm(true, pred):
    conf_matrix = np.zeros((len(np.unique(true)), len(np.unique(true))), dtype=np.uint32)
    c, h, w = true.shape
    for i in range(c):
        for j in range(h):
            for k in range(w):
                conf_matrix[true[i, j, k]][pred[i, j, k]] += 1

    return conf_matrix


def kappa_with_cm(conf_matrix):
    acc = 0
    marginal = 0
    total = float(np.sum(conf_matrix))
    for i in range(len(conf_matrix)):
        acc += conf_matrix[i][i]
        marginal += np.sum(conf_matrix, 0)[i] * np.sum(conf_matrix, 1)[i]

    kappa = (total * acc - marginal) / (total * total - marginal)
    return kappa


def f1_with_cm(conf_matrix):
    precision = [0] * len(conf_matrix)
    recall = [0] * len(conf_matrix)
    f1 = [0] * len(conf_matrix)
    for i in range(len(conf_matrix)):
        precision[i] = conf_matrix[i][i] / float(np.sum(conf_matrix, 0)[i])
        recall[i] = conf_matrix[i][i] / float(np.sum(conf_matrix, 1)[i])
        f1[i] = 2 * ((precision[i]*recall[i])/(precision[i]+recall[i]))

    return np.mean(f1)


def jaccard_with_cm(conf_matrix):
    den = float(np.sum(conf_matrix[:, 1]) + np.sum(conf_matrix[1]) - conf_matrix[1][1])
    _sum_iou = conf_matrix[1][1] / den if den != 0 else 0

    return _sum_iou


# Dataset

### training_data and vali_data




In [6]:
class NGDataset(data.Dataset):
  def __init__(self, img_dir, mask_dir, output_path):

    #self.dataset_input_path = dataset_input_path
    #self.images = images
    self.img_dir = img_dir
    self.mask_dir = mask_dir
    self.images = os.listdir(img_dir)
    self.masks = os.listdir(mask_dir)

    self.output_path = output_path


    print("debug1")
    # data and label
    self.data, self.labels = self.load_images()
    print("debug2")
    #print(self.data.ndim, self.data.shape, self.data[0].shape, np.min(self.data), np.max(self.data),
    #      self.labels.shape, np.bincount(self.labels.astype(int).flatten()))

    if self.data.ndim == 4:  # if all images have the same shape
       self.num_channels = self.data.shape[-1]  # get the number of channels
    else:
       self.num_channels = self.data[0].shape[-1]  # get the number of channels

    self.num_classes = 2  # binary - two classes
    # negative classes will be converted into 2 so they can be ignored in the loss
    self.labels[np.where(self.labels < 0)] = 2
    
    print('num_channels and labels', self.num_channels, self.num_classes, np.bincount(self.labels.flatten()))

    #self.distrib, self.gen_classes = self.make_dataset()

    self.mean, self.std = compute_image_mean(self.data)

  print("debug3")


  def load_images(self):
        images = []
        masks = []
        for img in self.images:
            temp_image = imageio.imread(os.path.join(self.img_dir, img + '')).astype(np.float64)
            temp_image[np.where(temp_image < -1.0e+38)] = 0 # remove extreme negative values (probably NO_DATA values)
            
            images.append(temp_image)

        for msk in self.masks:
            temp_mask = imageio.imread(os.path.join(self.mask_dir, msk + '')).astype(int)
            temp_mask[np.where(temp_mask < -1.0e+38)] = 0

            masks.append(temp_mask)

        return np.asarray(images), np.asarray(masks)

  
  print("debug4")

  def __getitem__(self, index):
    
    #Reading items from list.
    img = self.data[index]
    label = self.labels[index]

    # Normalization.
    normalize_images(img, self.mean, self.std) # check data_utils.py
    
    # Data augmentation
    img, label = data_augmentation(img, label)
     
    img = np.transpose(img, (2, 0, 1))

    # Turning to tensors.
    img = torch.from_numpy(img.copy())
    label = torch.from_numpy(label.copy())

    # Returning to iterator.
    return img.float(), label

  def __len__(self):
    return len(self.data)

  print("debug5")


debug3
debug4
debug5


In [9]:
# loading train set

img_path = '/content/drive/MyDrive/training_data/image'
mask_path = '/content/drive/MyDrive/training_data/mask'
output_path = '/content/drive/MyDrive/outputs/'

#images = os.listdir(input_path)

train_set = NGDataset(img_path, mask_path, output_path)

debug1
debug2
num_channels and labels 23 2 [29044396 17830604]


In [10]:
# sanity check
img, msk = train_set[0][0], train_set[0][1] 
print(img.shape)
print(msk.shape)

torch.Size([23, 250, 250])
torch.Size([250, 250])


In [11]:
# loading vali set

vali_img_path = '/content/drive/MyDrive/vali_data/images'
vali_mask_path = '/content/drive/MyDrive/vali_data/masks'
output_path = '/content/drive/MyDrive/outputs/'

#images = os.listdir(input_path)

validation_set = NGDataset(vali_img_path, vali_mask_path, output_path)

# sanity check
i, m = validation_set[0][0], validation_set[0][1] 
print(i.shape)
print(m.shape)

debug1
debug2
num_channels and labels 23 2 [6525140 9099860]
torch.Size([23, 250, 250])
torch.Size([250, 250])


# training

In [None]:
import albumentations as A
import torch.optim as optim

from albumentations.pytorch import ToTensorsV2
from tqdm import tqdm 

#from utils import (
#    load_checkpoint,
#    save_checkpoint,
#    get_loaders,
#    check_accuracy
#    save_predictions_as_imgs
#)


def train(train_loader, model, criterion, optimizer, epoch):
    # Setting network for training mode.
    model.train()

    # Average Meter for batch loss.
    train_loss = list()

    # Iterating over batches.
    for i, data in enumerate(train_loader):
        # Obtaining images, labels and paths for batch.
        inps, labels = data[0], data[1]

        # if the current batch does not have samples from all classes
        # print('out i', i, len(np.unique(labels.flatten())))
        # if len(np.unique(labels.flatten())) < 10:
        #     print('in i', i, len(np.unique(labels.flatten())))
        #     continue

        # Casting tensors to cuda.
        inps = Variable(inps).cuda()
        labs = Variable(labels).cuda()

        # Clears the gradients of optimizer.
        optimizer.zero_grad()

        # Forwarding.
        outs = model(inps)

        # Computing loss.
        loss = criterion(outs, labs)

        if math.isnan(loss):
            print('-------------------------NaN-----------------------------------------------')
            print(inps.shape, labels.shape, outs.shape, np.bincount(labels.flatten()))
            print(np.min(inps.cpu().data.numpy()), np.max(inps.cpu().data.numpy()),
                  np.isnan(inps.cpu().data.numpy()).any())
            print(np.min(labels.cpu().data.numpy()), np.max(labels.cpu().data.numpy()),
                  np.isnan(labels.cpu().data.numpy()).any())
            print(np.min(outs.cpu().data.numpy()), np.max(outs.cpu().data.numpy()),
                  np.isnan(outs.cpu().data.numpy()).any())
            print('-------------------------NaN-----------------------------------------------')
            raise AssertionError

        # Computing backpropagation.
        loss.backward()
        optimizer.step()

        # Updating loss meter.
        train_loss.append(loss.data.item())

        # Printing.
        if (i + 1) % DISPLAY_STEP == 0:
            soft_outs = F.softmax(outs, dim=1)
            # Obtaining predictions.
            prds = soft_outs.cpu().data.numpy().argmax(axis=1).flatten()

            labels = labels.cpu().data.numpy().flatten()

            # filtering out pixels
            coord = np.where(labels != train_loader.dataset.num_classes)
            labels = labels[coord]
            prds = prds[coord]

            acc = accuracy_score(labels, prds)
            conf_m = confusion_matrix(labels, prds, labels=[0, 1])
            f1_s = f1_score(labels, prds, average='weighted')

            _sum = 0.0
            for k in range(len(conf_m)):
                _sum += (conf_m[k][k] / float(np.sum(conf_m[k])) if np.sum(conf_m[k]) != 0 else 0)

            print("Training -- Epoch " + str(epoch) + " -- Iter " + str(i + 1) + "/" + str(len(train_loader)) +
                  " -- Time " + str(datetime.datetime.now().time()) +
                  " -- Training Minibatch: Loss= " + "{:.6f}".format(train_loss[-1]) +
                  " Overall Accuracy= " + "{:.4f}".format(acc) +
                  " Normalized Accuracy= " + "{:.4f}".format(_sum / float(train_loader.dataset.num_classes)) +
                  " F1 Score= " + "{:.4f}".format(f1_s) +
                  " Confusion Matrix= " + np.array_str(conf_m).replace("\n", "")
                  )
            sys.stdout.flush()

    return sum(train_loss) / len(train_loss), _sum / float(train_loader.dataset.num_classes)

# CREATE OWN MAIN

def main():
    parser = argparse.ArgumentParser(description='main')
    # general options
    parser.add_argument('--operation', type=str, required=True, help='Operation [Options: Train | Test]')
    parser.add_argument('--output_path', type=str, required=True,
                        help='Path to to save outcomes (such as images and trained models) of the algorithm.')

    # dataset options
    parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path.')
    parser.add_argument('--training_images', type=str, nargs="+", required=True, help='Training image names.') # WHAT DO I DO IF I ONLY HAVE FOLDER?
    parser.add_argument('--testing_images', type=str, nargs="+", required=True, help='Testing image names.')
    parser.add_argument('--crop_size', type=int, required=True, help='Crop size.')
    parser.add_argument('--stride_crop', type=int, required=True, help='Stride size')

    # model options
    parser.add_argument('--model_name', type=str, required=True,
                        choices=['deeplab', 'fcnwideresnet'], help='Model to evaluate')
    parser.add_argument('--model_path', type=str, default=None, help='Model path.')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.005, help='Weight decay')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--epoch_num', type=int, default=500, help='Number of epochs')

    # handling imbalanced data
    parser.add_argument('--loss_weight', type=float, nargs='+', default=[1.0, 1.0], help='Weight Loss.')
    parser.add_argument('--weight_sampler', type=str2bool, default=False, help='Use weight sampler for loader?')
    args = parser.parse_args()
    print(args)

    # Making sure output directory is created.
    pathlib.Path(args.output_path).mkdir(parents=True, exist_ok=True)

    # writer for the tensorboard
    writer = SummaryWriter(os.path.join(args.output_path, 'logs'))

    if args.operation == 'Train':
        print('---- training data ----')
        train_set = DataLoader('Train', args.dataset_path, args.training_images, args.crop_size, args.stride_crop,
                               args.output_path)
        print('---- testing data ----')
        test_set = DataLoader('Test', args.dataset_path, args.testing_images, args.crop_size, args.stride_crop,
                              args.output_path)

        if args.weight_sampler is False:
            train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size,
                                                       shuffle=True, num_workers=NUM_WORKERS, drop_last=False)
        else:
            class_loader_weights = 1. / np.bincount(train_set.gen_classes)
            samples_weights = class_loader_weights[train_set.gen_classes]
            sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weights, len(samples_weights),
                                                                     replacement=True)
            train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size,
                                                       num_workers=NUM_WORKERS, drop_last=False, sampler=sampler)

        test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size,
                                                  shuffle=False, num_workers=NUM_WORKERS, drop_last=False)

        # Setting network architecture.
        model = model_factory(args.model_name, train_set.num_channels, train_set.num_classes).cuda()

        criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(args.loss_weight),
                                        ignore_index=train_set.num_classes).cuda()

        # Setting optimizer.
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay,
                               betas=(0.9, 0.99))
        # optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, momentum=0.9)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

        curr_epoch = 1
        best_records = []
        if args.model_path is not None:
            print('Loading model ' + args.model_path)
            best_records = np.load(os.path.join(args.output_path, 'best_records.npy'), allow_pickle=True)
            model.load_state_dict(torch.load(args.model_path))
            # optimizer.load_state_dict(torch.load(args.model_path.replace("model", "opt")))
            curr_epoch += int(os.path.basename(args.model_path)[:-4].split('_')[-1])
            for i in range(curr_epoch):
                scheduler.step()
        model.cuda()

        # Iterating over epochs.
        print('---- training ----')
        for epoch in range(curr_epoch, args.epoch_num + 1):
            # Training function.
            t_loss, t_nacc = train(train_loader, model, criterion, optimizer, epoch)
            writer.add_scalar('Train/loss', t_loss, epoch)
            writer.add_scalar('Train/acc', t_nacc, epoch)
            if epoch % VAL_INTERVAL == 0:
                # Computing test.
                acc, nacc, f1_s, kappa, track_cm = test_full_map(test_loader, model, epoch, args.output_path)
                writer.add_scalar('Test/acc', nacc, epoch)
                save_best_models(model, args.output_path, best_records, epoch, kappa)
                # patch_acc_loss=None, patch_occur=None, patch_chosen_values=None
            scheduler.step()
    elif args.operation == 'Test':
        print('---- testing data ----')
        test_set = DataLoader('Test', args.dataset_path, args.training_images, args.crop_size, args.stride_crop,
                              args.output_path)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size,
                                                  shuffle=False, num_workers=NUM_WORKERS, drop_last=False)

        # Setting network architecture.
        model = model_factory(args.model_name, test_set.num_channels, test_set.num_classes).cuda()

        best_records = np.load(os.path.join(args.output_path, 'best_records.npy'), allow_pickle=True)
        index = 0
        for i in range(len(best_records)):
            if best_records[index]['kappa'] < best_records[i]['kappa']:
                index = i
        epoch = int(best_records[index]['epoch'])
        print("loading model_" + str(epoch) + '.pth')
        model.load_state_dict(torch.load(os.path.join(args.output_path, 'model_' + str(epoch) + '.pth')))
        model.cuda()

        test_full_map(test_loader, model, epoch, args.output_path)
    else:
        raise NotImplementedError("Process " + args.operation + "not found!")


if __name__ == "__main__":
    main()
