In [None]:
#New_Training - Validation - Testing
import os
import time
from glob import glob

from operator import add, sub
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold
import random
import csv

import torch
from torch.utils.data import Dataset, DataLoader,TensorDataset,random_split,SubsetRandomSampler, ConcatDataset
import torch.nn as nn

import math
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score,confusion_matrix,roc_auc_score

# This function lodes the traning dataset and divide it into two non-overlapping validation and traning datasets

def get_train_valid_loader(data_dir,
                           batch_size,
                           random_seed,
                           kfolde=10,
                           shuffle=True,
                           show_sample=False,
                           num_workers=1,
                           pin_memory=False,
                           shuffelthevaluditation=1):
    """

    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - kfolde: the number of foldes in cross validation
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot  sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    -shuffelthevaluditation: the folding index of validation set
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """


    # load the dataset
    train_x = sorted(glob(os.path.join(data_dir, "image", "*.png")))[:1536]
    train_y = sorted(glob(os.path.join(data_dir, "mask", "*.png")))[:1536]

    valid_x = sorted(glob(os.path.join(data_dir, "image", "*.png")))[1536:]
    valid_y = sorted(glob(os.path.join(data_dir, "mask", "*.png")))[1536:]


    train_dataset = DriveDataset(train_x, train_y,augmentation=None)
    valid_dataset = DriveDataset(valid_x, valid_y,augmentation=None)
    #get_validation_augmentation()

    print(len(train_x),len(train_y),len(valid_x),len(valid_y))

    dataset=ConcatDataset([train_dataset, valid_dataset])

    valid_size=(len(dataset)/kfolde)/(len(dataset))
    num_train = len(dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[:split* (int( shuffelthevaluditation) -1 )] + indices[split*int( shuffelthevaluditation):], indices[split * (int( shuffelthevaluditation) -1 ):split * int( shuffelthevaluditation)]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    data_str = f"Dataset Size:\nTrain sampler: {len(train_sampler)} - Validsampler: {len(valid_sampler)}\n"

    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    valid_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return (train_loader, valid_loader)

#best val accuracy
#written to CSV file
#K fold Cross Validation

def calculate_metrics(y_true, y_pred):
    """ Ground truth """
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    """ Prediction """
    y_pred = y_pred.cpu()
    y_pred = y_pred.detach().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    cm1 = confusion_matrix(y_true,y_pred)
    total1 = sum(sum(cm1))

    #TN (True Negative) = cm1[0,0]
    #FP (False Positive) = cm1[0,1]
    #FN (False Negative) = cm1[1,0]
    #TP (True Positive) = cm1[1,1]

    acc = (cm1[0,0]+cm1[1,1])/total1
    fdr = cm1[0,1]/(cm1[0,1]+cm1[1,1])   #False Detection Rate (FDR) = FP / (FP + TP)
    recall = cm1[1,1]/(cm1[1,1]+cm1[1,0])   # Recall = TP/ (TP + FN)
    precision = cm1[1,1] /(cm1[1,1]+cm1[0,1])  # Precision = TP/ (TP + FP)
    f1 = (2*(precision*recall))/(precision + recall)
    Specificity = cm1[0,0]/(cm1[0,0]+cm1[0,1]) # Specivity = TN / (TN + FP)
    Sensitivity =  cm1[1,1]/(cm1[1,1]+cm1[1,0]) # Sensitivity = TP / (TP + FN)
    auc = roc_auc_score(y_true, y_pred)

    intersection = np.logical_and(y_true==1, y_pred==1)
    union = np.logical_or(y_true==1, y_pred==1)

    iou = np.sum(intersection) / np.sum(union)
    dice_coeff = (2. * intersection.sum()) / (y_true.sum() + y_pred.sum())
    g_mean = math.sqrt(Specificity * Sensitivity)
    pe =((cm1[1,1] + cm1[1,0])*(cm1[1,1] + cm1[0,1]) + (cm1[0,0]  + cm1[0,1])*(cm1[0,0]  + cm1[1,0]))/((cm1[1,1] + cm1[0,0]  + cm1[0,1] + cm1[1,0])**2)
    kappa_score = (acc - pe)/(1 - pe)



    return [ acc, fdr, recall, precision,  f1, Specificity, Sensitivity, auc, iou, dice_coeff, g_mean, kappa_score]

def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0
    metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        score = calculate_metrics(y, y_pred)
        metrics_score = list(map(add, metrics_score, score))
    epoch_loss = epoch_loss/len(loader)
    metrics_score = [sc / len(loader) for sc in metrics_score]

    return epoch_loss, metrics_score

def evaluate(model, loader, loss_fn, device,mode):
    epoch_loss = 0.0
    metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()
            score = calculate_metrics(y, y_pred)
            metrics_score = list(map(add, metrics_score, score))
        epoch_loss = epoch_loss/len(loader)
        metrics_score = [sc / len(loader) for sc in metrics_score]
    return epoch_loss, metrics_score

def test(model, loader, loss_fn, device,mode):
    epoch_loss = 0.0
    metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()
            score = calculate_metrics(y, y_pred)
            metrics_score = list(map(add, metrics_score, score))
        epoch_loss = epoch_loss/len(loader)
        metrics_score = [sc / len(loader) for sc in metrics_score]
    return epoch_loss, metrics_score



if __name__ == "__main__":
    """ Seeding """
    seeding(42)

    create_dir("/content/drive/MyDrive/Farhana/RBVS_OCTA/chk_nbu/") #For checkpoint

    create_dir("/content/drive/MyDrive/Farhana/RBVS_OCTA/Result_nbu/") # For CSV files

    """ Load test dataset """
    data_dir="/content/drive/MyDrive/Farhana/RBVS_OCTA/Aug_OCTA/SVC/train"

    test_x = sorted(glob("/content/drive/MyDrive/Farhana/RBVS_OCTA/Aug_OCTA/SVC/test/image/*"))
    test_y = sorted(glob("/content/drive/MyDrive/Farhana/RBVS_OCTA/Aug_OCTA/SVC/test/mask/*"))

    """ Hyperparameters """
    num_epochs=50
    batch_size=4
    k=5
    lr = 0.0005

    checkpoint_path= "/content/drive/MyDrive/Farhana/RBVS_OCTA/chk_nbu/chk02.pth"
    checkpoint_path2 = "/content/drive/MyDrive/Farhana/RBVS_OCTA/chk_nbu/chk021.pth"
    checkpoint_path3 = "/content/drive/MyDrive/Farhana/RBVS_OCTA/chk_nbu/chk022.pth"


    """ Dataset and loader """

    train_loader, valid_loader=get_train_valid_loader(data_dir=data_dir,
                                                      batch_size=batch_size,
                                                      random_seed=False,
                                                      kfolde=k,
                                                      shuffle=True,
                                                      show_sample=False,
                                                      num_workers=1,
                                                      pin_memory=False,
                                                      shuffelthevaluditation=1)
    print("train data loader",len(train_loader))
    print("valid data loader",len(valid_loader))

    test_dataset=DriveDataset(test_x,test_y)
    test_loader = DataLoader(dataset= test_dataset,batch_size=batch_size,shuffle=False,num_workers=1)
    print("Test data loader",len(test_loader))

    loaded_checkpoint=torch.load(checkpoint_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = RIMNet(dropRate=0.00)
    model = model.to(device)
    print('Number of model parameters: {}'.format(
      sum([p.data.nelement() for p in model.parameters()])))

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=5, power=0.9, last_epoch=- 1, verbose=True)

    model.load_state_dict(loaded_checkpoint) #loading the trained model

    loss_fn = DiceLoss().to(device)

    """opening/creating required csv files for recording data"""

    with open('/content/drive/MyDrive/Farhana/RBVS_OCTA/Result_nbu/tr02.csv','a+') as td:
        training_details=csv.writer(td)
        with open('/content/drive/MyDrive/Farhana/RBVS_OCTA/Result_nbu/val02.csv','a+') as vd:
            val_details=csv.writer(vd)
            with open('/content/drive/MyDrive/Farhana/RBVS_OCTA/Result_nbu/te02.csv','a+') as ted:
                test_details=csv.writer(ted)

                training_details.writerow(['Epoch','TrLoss','Acc', 'FDR', 'Recall','Precision','F1', 'Specificity', 'Sensitivity',
                                           'AUC','IoU', 'Dice_coeff', 'G_Mean', 'Kappa'])

                val_details.writerow(['Epoch','ValLoss','Acc', 'FDR',  'Recall','Precision','F1', 'Specificity', 'Sensitivity',
                                      'AUC','IoU', 'Dice_coeff','G_Mean', 'Kappa'])

                test_details.writerow(['Epoch','TestLoss','Acc', 'FDR', 'Recall','Precision', 'F1', 'Specificity', 'Sensitivity',
                                       'AUC','IoU', 'Dice_coeff','G_Mean', 'Kappa'])

                """ Training the model """
                best_valid_loss = float("inf")

                best_tr_acc = 0.000
                best_val_acc = 0.000
                # acc, fdr, recall, precision,  f1, Specificity, Sensitivity, auc, iou, dice_coeff, g_mean, kappa_score
                # 0     1     2       3          4      5             6         7   8     9           10      11
                best_test_acc = 0.000
                best_fdr = 0.000
                best_recall = 0.000
                best_precision = 0.000
                best_f1 = 0.000
                best_sp = 0.000
                best_se = 0.000
                best_auc = 0.000
                best_iou = 0.000
                best_dice = 0.00
                best_g_mean = 0.00
                best_kappa_score = 0.00

                # Early stopping
                patience = 10
                trigger_times = 0


                for epoch in range(num_epochs):
                  start_time = time.time()
                  train_loss, train_score= train(model, train_loader, optimizer, loss_fn, device)
                  print("Training Done")
                  valid_loss, val_score = evaluate(model, valid_loader, loss_fn, device,'valid')
                  print("Validation Done")
                  test_loss, test_score = test(model, test_loader, loss_fn, device,'valid')
                  print("Test Done")

                  """ Saving the model """
                  if valid_loss <  best_valid_loss:
                      data_str = f"valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
                      print(data_str)
                      best_valid_loss = valid_loss
                      torch.save(model.state_dict(), checkpoint_path)

                  else:
                      trigger_times += 1
                      print('Trigger Times:', trigger_times)

                 ## '''
                  if best_tr_acc < train_score[0]:
                    best_tr_acc = max(train_score[0], best_tr_acc)

                  if best_val_acc < val_score[0]:
                    best_val_acc = max(val_score[0], best_val_acc)

                  if best_test_acc < test_score[0]:
                    best_test_acc = max(val_score[0], best_test_acc)

                  if best_fdr > test_score[1]:
                    best_fdr = min(test_score[1], best_fdr)

                  if best_recall < test_score[2]:
                    best_recall = max(test_score[2], best_recall)

                  if best_precision < test_score[3]:
                    best_precision = max(test_score[3], best_precision)

                  if best_f1 < test_score[4]:
                    best_f1 = max(test_score[4], best_f1)

                  if best_sp < test_score[5]:
                    best_sp = max(test_score[5], best_sp)

                  if best_se < test_score[6]:
                    best_se = max(test_score[6], best_se)

                  if best_auc < test_score[7]:
                    best_auc = max(test_score[7], best_auc)

                  if best_iou < test_score[8]:
                    best_iou = max(test_score[8], best_iou)

                  if best_dice < test_score[9]:
                    best_dice = max(test_score[9], best_dice)

                  if best_g_mean < test_score[10]:
                    best_g_mean = max(test_score[10], best_g_mean)

                  if best_kappa_score < test_score[11]:
                    best_kappa_score = max(test_score[11], best_kappa_score)

                 # '''
                  end_time = time.time()
                  epoch_mins, epoch_secs = epoch_time(start_time, end_time)

                  tacc = train_score[0]
                  tjac = train_score[1]
                  tf1 = train_score[2]
                  trecall = train_score[3]
                  tprecision = train_score[4]
                  tsp = train_score[5]
                  tse = train_score[6]
                  tauc = train_score[7]
                  tiou = train_score[8]
                  tdice = train_score[9]
                  tgmean = train_score[10]
                  tkappa = train_score[11]

                  vacc = val_score[0]
                  vjac = val_score[1]
                  vf1 = val_score[2]
                  vrecall = val_score[3]
                  vprecision = val_score[4]
                  vsp = val_score[5]
                  vse = val_score[6]
                  vauc = val_score[7]
                  viou = val_score[8]
                  vdice = val_score[9]
                  vgmean = val_score[10]
                  vkappa = val_score[11]

                  acc = test_score[0]
                  jac = test_score[1]
                  f1 = test_score[2]
                  recall = test_score[3]
                  precision = test_score[4]
                  sp = test_score[5]
                  se = test_score[6]
                  auc = test_score[7]
                  iou = test_score[8]
                  dice = test_score[9]
                  gmean = test_score[10]
                  kappa = test_score[11]

                  data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
                  data_str += f'\tTrain Loss: {train_loss:.3f}\n'
                  data_str += f'\tVal. Loss: {valid_loss:.3f}'
                  print(data_str)

                  print(f"\tTrain: Mean ACC: {best_tr_acc:1.4f} ")
                  print(f"\tVal: Mean Acc: {best_val_acc:1.4f} ")
                  print(f"\tTest: Mean Acc: {best_test_acc:1.4f} \t Mean IoU: {best_iou:1.4f} \t Mean dice: {best_dice:1.4f}")




                  training_details.writerow([str(epoch+1),float(train_loss),float(tacc),float(tjac),float(tf1),float(trecall),float(tprecision),
                                             float(tsp),float(tse),float(tauc),float(tiou),float(tdice),float(tgmean),float(tkappa)])
                  val_details.writerow([str(epoch+1),float(valid_loss),float(vacc),float(vjac),float(vf1),float(vrecall),float(vprecision),
                                        float(vsp),float(vse),float(vauc),float(viou),float(vdice),float(vgmean),float(vkappa)])
                  test_details.writerow([str(epoch+1),float(test_loss),float(acc),float(jac),float(f1),float(recall),float(precision),
                                         float(sp),float(se),float(auc),float(iou),float(dice),float(gmean),float(kappa)])

                  if trigger_times >= patience:
                    print('Early stopping!\nStart to test process.')
                    break

                  train_loader, valid_loader = get_train_valid_loader(kfolde=k,shuffelthevaluditation=(epoch%k)+1,
                                                                      random_seed=False,data_dir=data_dir,batch_size=batch_size, shuffle=True)



                  #gc.collect()

                  torch.cuda.empty_cache()
                  #break

    print("Completed Successfully")
    print(f"Train Accuracy:{best_tr_acc:1.4f}")
    print(f"Validation Accuracy:{best_val_acc:1.4f}\n")
    print("Result on Test dataset")
    print(f"Test Accuracy:{best_test_acc:1.4f}")
    print(f"FDR:{best_fdr:1.4f}")
    print(f"F1-Score:{best_f1:1.4f}")
    print(f"Recall:{best_recall:1.4f}")
    print(f"Precision:{best_precision:1.4f}")
    print(f"Sensitivity:{best_se:1.4f}")
    print(f"Specificity:{best_sp:1.4f}")
    print(f"AUC:{best_auc:1.4f}")
    print(f"IoU:{best_iou:1.4f}")
    print(f"DC:{best_dice:1.4f}")
    print(f"G Mean:{best_g_mean:1.4f}")
    print(f"Kappa:{best_kappa_score:1.4f}")
