In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
import json
import random

from unetModules.unet_model import UNet
from unetModules.Utils import Utils
from sklearn.model_selection import KFold
from tqdm import tqdm
from PIL import Image
from torchmetrics import JaccardIndex
from torchmetrics import Dice

In [None]:
cuda = True
batch_size = 2
epochs_per_fold = 100
learn_rate = 5e-4
n_classes = 2
n_channels = 3
weight_decay = 2e-4
k_folds = 5
eval_every = 10
description = "Diese Unet Evaluation wurde mit den River Blindess Datensatz trainiert. Erweiterere(Rotation + Sharpend + Contrast) Datenaugmentierung wurde angewendet. Validation set wurde separiert."

path_save_model = "./modelEvaluation/"
# path_images = "../../content/SchistosomaMansoni/img/"
# path_labels = "../../content/SchistosomaMansoni/labels/"
# path_images_val = "../../content/SchistosomaMansoni/val_img/"
# path_labels_val = "../../content/SchistosomaMansoni/val_labels/"
path_images = "../../content/RiverBlindness/img/"
path_labels = "../../content/RiverBlindness/labels/"
path_images_val = "../../content/RiverBlindness/val_img/"
path_labels_val = "../../content/RiverBlindness/val_labels/"

with open(path_save_model + "logs/settings.json", "w", encoding="utf-8") as f:
    settings = { 
        "cuda": cuda,
        "batch_size": batch_size,
        "epochs_per_fold": epochs_per_fold,
        "learn_rate": learn_rate,
        "n_classes": n_classes,
        "weight_decay": weight_decay,
        "k_folds": k_folds,
        "eval_every": eval_every,
        "description": description
    }
    
    json.dump(settings, f, ensure_ascii=False, indent=4)


In [None]:
img_filenames = np.array(os.listdir(path_images))
input_train = []

label_filenames = np.array(os.listdir(path_labels))
label_train = []

assert(len(img_filenames) == len(label_filenames))

random.shuffle(img_filenames)

print("[INFO]Loading trainings dataset.")
# Reading train images and labels                  
for file in tqdm(img_filenames):    
    img = Image.open(path_images + file)
    label = Image.open(path_labels + file)
    augmented_imgs, augmented_labels = Utils.imageAugmentation(img, label)
        
    for aug_img in augmented_imgs:
        img = cv2.resize(np.array(aug_img), (512, 512), cv2.INTER_NEAREST)
        input_train.append(img[:,:,0:3]) # cutting out potential alpha channel
      
    for aug_label in augmented_labels:
        label = cv2.resize(np.array(aug_label), (512, 512), cv2.INTER_NEAREST)
        label_train.append(label)

input_train = np.array(input_train)
label_train = np.array(label_train) 

val_img_filenames = np.array(os.listdir(path_images_val))
input_val = []

val_label_filenames = np.array(os.listdir(path_labels_val))
label_val = []

assert(len(val_img_filenames) == len(val_label_filenames))

print("[INFO]Loading validation dataset")
# Reading validation images and labels. Validation set should not be augmented
for file in tqdm(val_img_filenames):
    val_img = Image.open(path_images_val + file)
    val_img = cv2.resize(np.array(val_img), (512, 512), cv2.INTER_NEAREST)
    input_val.append(val_img[:,:,0:3])
    
    val_label = Image.open(path_labels_val + file)
    val_label = cv2.resize(np.array(val_label), (512, 512), cv2.INTER_NEAREST)
    label_val.append(val_label)
    
input_val = np.stack(input_val, axis=2)
input_val = torch.tensor(input_val).transpose(0,2).transpose(1,3)

label_val = np.array(label_val)
label_val = torch.tensor(label_val)

print("[INFO]Starting to define the class weights...")
class_weights = Utils.get_class_weights(label_train, n_classes)
print("[INFO]Fetched all class weights successfully!")

# Checking for cuda
if(torch.cuda.is_available() & cuda):
    print("[INFO]CUDA is available!")
    device = torch.device("cuda:0")
else:
    print("[INFO]CUDA isn't available!")
    device = torch.device("cpu")

In [12]:
print("[INFO]Starting Training...")

kf = KFold(n_splits=k_folds, shuffle=False)

# Important values and metrics for backup
train_losses_json = {}
val_losses_json = {}
jaccard_index_per_fold_json = {}
jaccard_index_means_json = {}
dice_score_per_fold_json = {}
dice_score_means_json = {}

for k, (kfsplit_train, kfsplit_test) in enumerate(kf.split(input_train)): 
    print("===========================")
    print(f"         Fold: {k + 1}")     
    print("===========================")
   
    unet = UNet(n_channels=n_channels, n_classes=n_classes)
    unet = unet.to(device)
    grad_scaler = torch.cuda.amp.GradScaler(enabled=True)
    criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
    optimizer = torch.optim.Adam(unet.parameters(),
                                lr=learn_rate,
                                weight_decay=weight_decay)
    print("[INFO]Model,optimizer and loss function instantiated!")
    
    kf_X_train, kf_X_test = input_train[kfsplit_train], input_train[kfsplit_test]
    kf_y_train, kf_y_test = label_train[kfsplit_train], label_train[kfsplit_test]

    kf_X_train = np.stack(kf_X_train, axis=2)
    kf_X_train = torch.tensor(kf_X_train).transpose(0, 2).transpose(1, 3)

    kf_X_test = np.stack(kf_X_test, axis=2)
    kf_X_test = torch.tensor(kf_X_test).transpose(0, 2).transpose(1, 3)

    kf_y_train = torch.tensor(kf_y_train)
    kf_y_test = torch.tensor(kf_y_test)

    jaccard_index_per_fold_json[f"fold{k+1}"] = []
    dice_score_per_fold_json[f"fold{k+1}"] = []
    jaccard_index_means_json[f"fold{k+1}"] = []
    dice_score_means_json[f"fold{k+1}"] = []

    train_losses = []
    val_losses = []
    
    batch_count_train = (len(kf_X_train) // batch_size)
    batch_count_val = (len(input_val) // batch_size)
      
    if((len(kf_X_train) / batch_size) % 1 != 0):
        batch_count_train += 1
        
    if((len(input_val) / batch_size) % 1 != 0):
        batch_count_val += 1
        
    for e in range(1, epochs_per_fold + 1):
        train_loss = 0
        print ("-"*15,"Epoch %d" % e , "-"*15) 

        unet.train()
             
        for _ in tqdm(range(batch_count_train)):
            X_train, y_train = kf_X_train[batch_size * _: batch_size * (_ + 1)], kf_y_train[batch_size * _: batch_size * (_ + 1)]          
            X_train, y_train = X_train.to(device), y_train.to(device)
                                     
            optimizer.zero_grad()

            with torch.autocast(device.type, enabled=True):
                out = unet(X_train.float())        
                loss = criterion(out, y_train.long())     

            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
                       
            train_loss += loss.item()
            
        print()
        train_losses.append(train_loss)

 
        print ('Epoch {}/{}...'.format(e, epochs_per_fold),
                'Loss {:6f}'.format(train_loss))
            
        with torch.no_grad():                
            print()
            print("Validation:")

            # Validates the model
            unet.eval()              
            val_loss = 0

            for _ in tqdm(range(batch_count_val)):           
                X_val, y_val = input_val[batch_size * _: batch_size * (_ + 1)], label_val[batch_size * _: batch_size * (_ + 1)]          
                X_val, y_val = X_val.to(device), y_val.to(device)

                out = unet(X_val.float())
                loss = criterion(out, y_val.long())

                val_loss += loss.item()

            print('Loss {:6f}'.format(val_loss))

            val_losses.append(val_loss)            
                                                                       
        print ('Epoch {}/{}...'.format(e, epochs_per_fold))

        if e % eval_every == 0:  
            print(f"[INFO]Starting Evaluation for Fold {k+1} Epoche {e}.")
    
            jaccard = JaccardIndex(num_classes=n_classes, average="none")
            dice = Dice(num_classes=n_classes, average="none")
    
            jaccard_indices = []
            dice_scores = []
          
            with torch.no_grad():
                for test_sample_index in tqdm(range(0, len(kf_X_test))):            
                    out_test = unet((kf_X_test[test_sample_index].unsqueeze(0).float()).to(device)).squeeze(0)
                    pred = out_test.data.max(0)[1].cpu().numpy()
                    pred = torch.tensor(pred.astype('uint8'))
                                                
                    jaccard_indices.append(jaccard(pred, kf_y_test[test_sample_index]).tolist())
                    dice_scores.append(dice(pred, kf_y_test[test_sample_index]).tolist())
        
            jaccard_index_means_per_class = [np.mean([jindex[0] for jindex in jaccard_indices]), np.mean([jindex[1] for jindex in jaccard_indices])]
            dice_score_means_per_class = [np.mean([dscore[0] for dscore in dice_scores]), np.mean([dscore[1] for dscore in dice_scores])]
                            
            jaccard_index_per_fold_json[f"fold{k+1}"].append(jaccard_indices)
            dice_score_per_fold_json[f"fold{k+1}"].append(dice_scores)
            
            jaccard_index_means_json[f"fold{k+1}"].append(jaccard_index_means_per_class)
            dice_score_means_json[f"fold{k+1}"].append(dice_score_means_per_class)   

            print("---------------------------------------------------------------------------------------------------------------------------------------")      
            print(f"Jaccard Index: \t meanClass[0] = {str(jaccard_index_means_per_class[0])} \t meanClass[1] = {str(jaccard_index_means_per_class[1])} \t mean = {np.mean(jaccard_index_means_per_class)}")                                  
            print(f"Dice Score: \t meanClass[0] = {str(dice_score_means_per_class[0])} \t meanClass[1] = {str(dice_score_means_per_class[1])} \t mean = {np.mean(dice_score_means_per_class)}")
            print("---------------------------------------------------------------------------------------------------------------------------------------")   
                
    train_losses_json[f"fold{k+1}"] = train_losses
    val_losses_json[f"fold{k+1}"] = val_losses
    
    print("[INFO]Saving metric data and model.")
       
    with open(path_save_model + "logs/trainLosses.json", "w", encoding="utf-8") as f:
        json.dump(train_losses_json, f, ensure_ascii=False, indent=4)
        
    with open(path_save_model + "logs/valLosses.json", "w", encoding="utf-8") as f:
        json.dump(val_losses_json, f, ensure_ascii=False, indent=4)    
        
    with open(path_save_model + "logs/jaccardIndicesPerFold.json", "w", encoding="utf-8") as f:
        json.dump(jaccard_index_per_fold_json, f, ensure_ascii=False, indent=4) 
        
    with open(path_save_model + "logs/diceScoresPerFold.json", "w", encoding="utf-8") as f:
        json.dump(dice_score_per_fold_json, f, ensure_ascii=False, indent=4) 
            
    with open(path_save_model + "logs/jaccardIndexMeans.json", "w", encoding="utf-8") as f:
        json.dump(jaccard_index_means_json, f, ensure_ascii=False, indent=4)
    
    with open(path_save_model + "logs/diceScoreMeans.json", "w", encoding="utf-8") as f:
        json.dump(dice_score_means_json, f, ensure_ascii=False, indent=4) 
        
    checkpoint = {
        'fold': k+1,
        'state_dict' : unet.state_dict()
    }
    
    torch.save(checkpoint, path_save_model + 'ckpt-unet-{}-{}.pth'.format(k+1, train_loss))

    print("[INFO]Saving done.")
            
print("[INFO]Evaluation process complete!") 

100%|██████████| 779/779 [02:34<00:00,  5.04it/s]



Epoch 81/100... Loss 70.626517

Validation:


100%|██████████| 28/28 [00:05<00:00,  5.60it/s]


Loss 12.138845
Epoch 81/100...
--------------- Epoch 82 ---------------


100%|██████████| 779/779 [02:33<00:00,  5.07it/s]



Epoch 82/100... Loss 68.403184

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 8.719670
Epoch 82/100...
--------------- Epoch 83 ---------------


100%|██████████| 779/779 [02:31<00:00,  5.15it/s]



Epoch 83/100... Loss 63.268886

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.28it/s]


Loss 10.564289
Epoch 83/100...
--------------- Epoch 84 ---------------


100%|██████████| 779/779 [02:32<00:00,  5.10it/s]



Epoch 84/100... Loss 64.970719

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 20.484768
Epoch 84/100...
--------------- Epoch 85 ---------------


100%|██████████| 779/779 [02:32<00:00,  5.11it/s]



Epoch 85/100... Loss 66.784975

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 9.818439
Epoch 85/100...
--------------- Epoch 86 ---------------


100%|██████████| 779/779 [02:31<00:00,  5.14it/s]



Epoch 86/100... Loss 66.448477

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 10.372926
Epoch 86/100...
--------------- Epoch 87 ---------------


100%|██████████| 779/779 [02:31<00:00,  5.13it/s]



Epoch 87/100... Loss 71.851050

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.15it/s]


Loss 8.599211
Epoch 87/100...
--------------- Epoch 88 ---------------


100%|██████████| 779/779 [02:36<00:00,  4.98it/s]



Epoch 88/100... Loss 65.500552

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 13.856434
Epoch 88/100...
--------------- Epoch 89 ---------------


100%|██████████| 779/779 [02:31<00:00,  5.15it/s]



Epoch 89/100... Loss 60.118224

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 10.353237
Epoch 89/100...
--------------- Epoch 90 ---------------


100%|██████████| 779/779 [02:33<00:00,  5.08it/s]



Epoch 90/100... Loss 60.582256

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 10.051537
Epoch 90/100...
[INFO]Starting Evaluation for Fold 5 Epoche 90.


100%|██████████| 389/389 [00:58<00:00,  6.60it/s]


---------------------------------------------------------------------------------------------------------------------------------------
Jaccard Index: 	 meanClass[0] = 0.7760506389916403 	 meanClass[1] = 0.7701624864209587 	 mean = 0.7731065627062995
Dice Score: 	 meanClass[0] = 0.8605805895169167 	 meanClass[1] = 0.8540999518951597 	 mean = 0.8573402707060382
---------------------------------------------------------------------------------------------------------------------------------------
--------------- Epoch 91 ---------------


100%|██████████| 779/779 [02:32<00:00,  5.11it/s]



Epoch 91/100... Loss 68.745540

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.25it/s]


Loss 16.132773
Epoch 91/100...
--------------- Epoch 92 ---------------


100%|██████████| 779/779 [02:31<00:00,  5.14it/s]



Epoch 92/100... Loss 66.634493

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.28it/s]


Loss 11.578780
Epoch 92/100...
--------------- Epoch 93 ---------------


100%|██████████| 779/779 [02:34<00:00,  5.05it/s]



Epoch 93/100... Loss 66.892615

Validation:


100%|██████████| 28/28 [00:04<00:00,  5.87it/s]


Loss 12.751736
Epoch 93/100...
--------------- Epoch 94 ---------------


100%|██████████| 779/779 [02:31<00:00,  5.14it/s]



Epoch 94/100... Loss 60.301668

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.28it/s]


Loss 8.641250
Epoch 94/100...
--------------- Epoch 95 ---------------


100%|██████████| 779/779 [02:33<00:00,  5.09it/s]



Epoch 95/100... Loss 60.742233

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.29it/s]


Loss 10.165675
Epoch 95/100...
--------------- Epoch 96 ---------------


100%|██████████| 779/779 [02:33<00:00,  5.06it/s]



Epoch 96/100... Loss 66.360128

Validation:


100%|██████████| 28/28 [00:04<00:00,  6.28it/s]


Loss 9.388008
Epoch 96/100...
--------------- Epoch 97 ---------------


100%|██████████| 779/779 [02:35<00:00,  5.01it/s]



Epoch 97/100... Loss 59.366041

Validation:


100%|██████████| 28/28 [00:04<00:00,  5.98it/s]


Loss 11.175318
Epoch 97/100...
--------------- Epoch 98 ---------------


100%|██████████| 779/779 [02:38<00:00,  4.92it/s]



Epoch 98/100... Loss 61.747508

Validation:


100%|██████████| 28/28 [00:04<00:00,  5.70it/s]


Loss 8.143127
Epoch 98/100...
--------------- Epoch 99 ---------------


100%|██████████| 779/779 [02:41<00:00,  4.82it/s]



Epoch 99/100... Loss 60.205612

Validation:


100%|██████████| 28/28 [00:04<00:00,  5.79it/s]


Loss 13.554310
Epoch 99/100...
--------------- Epoch 100 ---------------


100%|██████████| 779/779 [02:41<00:00,  4.82it/s]



Epoch 100/100... Loss 61.435039

Validation:


100%|██████████| 28/28 [00:04<00:00,  5.76it/s]


Loss 13.380938
Epoch 100/100...
[INFO]Starting Evaluation for Fold 5 Epoche 100.


100%|██████████| 389/389 [01:00<00:00,  6.38it/s]


---------------------------------------------------------------------------------------------------------------------------------------
Jaccard Index: 	 meanClass[0] = 0.6892083600201766 	 meanClass[1] = 0.6145416365068221 	 mean = 0.6518749982634994
Dice Score: 	 meanClass[0] = 0.799184959697846 	 meanClass[1] = 0.7334921225578618 	 mean = 0.7663385411278538
---------------------------------------------------------------------------------------------------------------------------------------
[INFO]Saving metric data and model.
[INFO]Saving done.
[INFO]Evaluation process complete!
