In [None]:
import dataloader
from UModel import UNet
import Config
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import torch
import time
import os
from torchmetrics.functional import dice
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold
import torch.optim.lr_scheduler as lr_scheduler
import cv2
from focal_loss import FocalLoss
import matplotlib.pyplot as plt
from class_weights import compute_class_weight
import Early_stopping as E
import torch.nn as nn

## load the image and mask filepaths in a sorted manner ##

imagePaths = np.array(sorted(list(paths.list_images(Config.Image_dataset_dir))))
maskPaths = np.array(sorted(list(paths.list_images(Config.Mask_dataset_dir))))

# Taking out the test set and leaving the rest of the data for cross validation
X_rest, X_test, y_rest, y_test = train_test_split(imagePaths, maskPaths, train_size=0.9, random_state=42, shuffle = True)

print("Saving testing image paths...")
f = open(Config.TEST_PATHS, "w")
f.write("\n".join(X_test))
f.close()

# Setting up the Kfold cross validation with 5 folds
kf = KFold(n_splits=5, shuffle=True)

for train_index, val_index in kf.split(X_rest, y_rest):
    # Splitting the data into training and validation sets
    X_train, X_val = X_rest[train_index], X_rest[val_index]
    y_train, y_val = y_rest[train_index], y_rest[val_index]
    
    # Define transformations
    train_transform = A.Compose([A.Resize(Config.Input_Height, Config.Input_Width),
                                 A.Normalize(mean=(0.0), std=(1.0)),
                                 A.HorizontalFlip(p=0.5), 
                                 A.RandomRotate90(p=0.5),
                                 ToTensorV2()])
    
    val_transform = A.Compose([A.Resize(Config.Input_Height, Config.Input_Width),
                               A.Normalize(mean=(0.0), std=(1.0)),
                               A.VerticalFlip(p=0.5), 
                               A.RandomRotate90(p=0.5),
                               ToTensorV2()])
    

    
    # Create the train and validation datasets
    trainDS = dataloader.MyDataset(imagePaths=X_train, maskPaths=y_train, transform=train_transform)
    valDS = dataloader.MyDataset(imagePaths=X_val, maskPaths=y_val, transform=val_transform)
    print(f"There are {len(trainDS)} samples in the training set")
    print(f"There are {len(valDS)} samples in the validation set")
    print('************************************************')

    # Create the train and validation dataloaders
    trainLoader = DataLoader(trainDS, shuffle=True, batch_size=Config.Batch_size,
                             pin_memory=Config.PIN_MEMORY)

    valLoader = DataLoader(valDS, shuffle=True, batch_size=Config.Batch_size,
                           pin_memory=Config.PIN_MEMORY)
    
    class_weight = compute_class_weight(trainLoader, Config.No_classes).to(Config.DEVICE)
    print(class_weight)
    
    # initialize our UNet model
    unet = UNet(n_channels=Config.No_channels, n_classes=Config.No_classes, bilinear=False).to(Config.DEVICE)

    # initialize Binary cross entropy with logit loss function and Adam optimizer
    lossFunc = FocalLoss(weight = class_weight)
    opt = Adam(unet.parameters(), lr=Config.Init_LR)
    scheduler = lr_scheduler.ReduceLROnPlateau(opt, 'max', patience=2)
    early_stopper = E.EarlyStopper(patience=2, min_delta=0.01)

    # calculate steps per epoch for training, validation and test set
    trainSteps = len(trainDS) // Config.Batch_size
    valSteps = len(valDS) // Config.Batch_size

    # initialize a dictionary to store training history
    CE_loss = {"train_loss": [], "val_loss": []}
    dsc_loss = {"Dice_train": [], "Dice_val": []}

    # loop over epochs
    print("Training the network...")
    startTime = time.time()

    for e in tqdm(range(Config.Num_epochs)):
        # set the model in training mode
        unet.train()

        # initialize the total training and validation loss0
        totalTrainLoss, totalValLoss, dice_score_train, dice_score_val = 0, 0, 0, 0

        # loop over the training set
        for (i, (x, y)) in enumerate(trainLoader):
            # send the input to the device
            (x, y) = (x.to(Config.DEVICE), y.to(Config.DEVICE))

            # perform a forward pss and calculate the training loss
            pred = unet(x)
            CEloss = lossFunc(pred, y)
            softmax = nn.Softmax(dim=1)
            dice_loss = dice(softmax(pred), y)
            total_loss = (1 - dice_loss) + CEloss

            # first, zero out any previously accumulated gradients, then
            # perform backpropagation, and then update model parameters
            opt.zero_grad()
            total_loss.backward()
            opt.step()

            # add the loss to the total training loss so far
            totalTrainLoss += CEloss
            dice_score_train += dice_loss

        # switch off autograd
        with torch.no_grad():
            # set the model in evaluation mode
            unet.eval()

            # loop over the validation set
            for (x, y) in valLoader:
                # send the input to the device
                (x, y) = (x.to(Config.DEVICE), y.to(Config.DEVICE))

                # make the predictions and calculate the validation loss
                pred = unet(x)
                totalValLoss += lossFunc(pred, y)
                softmax = nn.Softmax(dim=1)
                dice_score_val += dice(softmax(pred), y)

        # calculate the average training and validation loss
        avgTrainLoss = totalTrainLoss / trainSteps
        avgValLoss = totalValLoss / valSteps
        avgDiceTrain = dice_score_train / trainSteps
        avgDiceVal = dice_score_val / valSteps
        
        # Checking the validation loss if not changing for early stopping
#         if early_stopper.early_stop(avgValLoss):
#             break
        
        # Adapting the learining rate based on maximized Dice Score
        scheduler.step(avgValLoss)
        print(opt.state_dict()['param_groups'][0]['lr'])

        # update the training history
        CE_loss["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
        CE_loss["val_loss"].append(avgValLoss.cpu().detach().numpy())
        dsc_loss["Dice_train"].append(avgDiceTrain.cpu().detach().numpy())
        dsc_loss["Dice_val"].append(avgDiceVal.cpu().detach().numpy())

        # print the model training and validation information
        print("EPOCH: {}/{}".format(e + 1, Config.Num_epochs))
        print("Train loss: {:.4f}, Validation loss: {:.4f}".format(avgTrainLoss, avgValLoss))
        print("Training Dice Score: {:.4f}% , Validation Dice Score: {:.4f}%".format(avgDiceTrain * 100, avgDiceVal * 100))
        del dice_score_train, dice_score_val, avgDiceTrain, avgDiceVal, totalTrainLoss, totalValLoss, avgTrainLoss, avgValLoss

    # display the total time needed to perform the training
    endTime = time.time()
    print("Total time taken to train the model: {:.2f}s".format(endTime - startTime))
    
    # plot the training loss
    plt.style.use("ggplot")
    fig1 = plt.figure()
    ax1 = fig1.add_subplot(2, 1, 1)
    plt.plot(CE_loss["train_loss"], label="train_loss")
    plt.plot(CE_loss["val_loss"], label="validation_loss")
    plt.title("Training Loss vs. Validation Loss")
    plt.xlabel("Epoch #")
    plt.ylabel("Loss")
    plt.legend(loc="upper right")
    plt.show()
    
    plt.style.use("ggplot")
    fig1 = plt.figure()
    ax1 = fig1.add_subplot(2, 1, 2)
    plt.plot(dsc_loss["Dice_train"], label="train_dice_score")
    plt.plot(dsc_loss["Dice_val"], label="validation_dice_score")
    plt.title("Training Dice Score vs. Validation Dice Score")
    plt.xlabel("Epoch #")
    plt.ylabel("Dice Score")
    plt.legend(loc="upper left")
    plt.show()

    # serialize the model to disk
    torch.save(unet, Config.MODEL_PATH)

In [None]:
import matplotlib.pyplot as plt 

for img, mask in trainLoader:
#     img = np.array(img)
#     plt.hist(img.ravel() , bins = 50, density = True)
    print('The shape of the image is: {}'.format(img.shape))
    print('The type of the image is: {}'.format(img.dtype))
    print('The shape of the mask is: {}'.format(mask.shape))
    print('The shape of the mask is: {}'.format(mask.dtype))
    print(torch.unique(mask) , torch.unique(img))

In [None]:
import matplotlib.pyplot as plt

for img,mask in trainLoader:
    figure, ax = plt.subplots(nrows=1, ncols=2, figsize=(18, 18))
    img = np.transpose(img[0,:,:,:],(1,2,0))
    mask = np.array(mask[0,:,:])
    ax[0].imshow(img)
    ax[1].imshow(mask)
    ax[0].set_title("Image")
    ax[1].set_title("Ground Truth")
    ax[0].grid(False)
    ax[1].grid(False)
    figure.tight_layout()
    figure

In [None]:
import torch.nn as nn

model= torch.load('C:\\Users\\Jessica NT MCA\\Desktop\\MA_Abdelrahman\\Master Thesis Project\\Multi-class UNET\\Dataset\\output\\unet_out.pth').to(Config.DEVICE)
model.eval()

def check_accuracy(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(Config.DEVICE)
            y = y.to(Config.DEVICE)
            softmax = nn.Softmax(dim=1)
            preds = softmax(model(x.float()))
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

In [None]:
check_accuracy(trainLoader, model)

In [None]:
check_accuracy(valLoader, model)

In [None]:
import matplotlib.pyplot as plt
import torch.nn as nn
from UModel import UNet
import torch
import Config
import dataloader
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import numpy as np
import cv2

X_test = []
y_test = []
filenames = []

path = 'C:\\Users\\z004b1tz\\Desktop\\Master Thesis Project\\Multi-class UNET\\Dataset\\output\\test_paths.txt'

lines = open(path,"r").read().split("\n")
for line in lines:
    filename = os.path.basename(line)
    filenames.append(filename)
    X_test.append(line)
    y = 'C:\\Users\\z004b1tz\\Desktop\\Master Thesis Project\\Multi-class UNET\\Dataset\\Masks\\'+filename
    y_test.append(y)

test_transform = A.Compose([A.Resize(Config.Input_Height, Config.Input_Width),
                               A.Normalize(mean=(0.0), std=(1.0)),
                               ToTensorV2()])
testDS = dataloader.MyDataset(imagePaths=X_test, maskPaths=y_test, transform=test_transform)
testLoader = DataLoader(testDS, shuffle=False, batch_size=Config.Batch_size,
                             pin_memory=Config.PIN_MEMORY)



model= torch.load('C:\\Users\\z004b1tz\\Desktop\\Master Thesis Project\\Multi-class UNET\Dataset\\output\\unet_out.pth', map_location='cpu')
model.eval()

image_no = 0
for x,y in testLoader:
    x = x.to(Config.DEVICE)
    figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
    softmax = nn.Softmax(dim=1)
    preds = torch.argmax(softmax(model(x.float())), axis = 1).cpu().detach().numpy()
    img = np.transpose(np.array(x[0,:,:,:].to('cpu')),(1,2,0))
    preds = np.array(preds[0,:,:])
    mask = np.array(y[0,:,:])
    ax[0].imshow(img)
    ax[1].imshow(mask)
    ax[2].imshow(preds)
    ax[0].set_title("Image")
    ax[1].set_title("Ground Truth")
    ax[2].set_title("Prediction")
    ax[0].grid(False)
    ax[1].grid(False)
    ax[2].grid(False)
    cv2.imwrite('C:\\Users\\z004b1tz\\Desktop\\Master Thesis Project\\Multi-class UNET\\Dataset\\output\\predictions\\'+filenames[image_no], preds)
    figure.tight_layout()
    figure
    image_no +=1

In [None]:
## Training loop ##
model = UNet(1,3).to(Config.DEVICE)
loss_fn = FocalLoss(weight = torch.FloatTensor([0.02,0.38,0.9]).to(Config.DEVICE))
opt = Adam(model.parameters() , lr = Config.Init_LR)
scheduler = lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=30)
loop = tqdm(range(10))
train_loss = []

startTime = time.time()

# loop over the training set

for e in loop:
    for (i, (x, y)) in enumerate(trainLoader):
        model.train()
        # send the input to the device
        x = x.to(Config.DEVICE)
        y = y.to(Config.DEVICE)

        with torch.cuda.amp.autocast():
            pred = model(x.float())
    #         print(pred.shape)
            CEloss = loss_fn(pred, y)

        opt.zero_grad()
        CEloss.backward()
        opt.step()

        # Adapting the learining rate based on maximized Dice Score
        scheduler.step(CEloss)
#         print(opt.state_dict()['param_groups'][0]['lr'])

        loop.set_postfix(loss=CEloss.item())
        train_loss.append(CEloss.cpu().detach().numpy())
        del CEloss

torch.save(model.state_dict(), Config.MODEL_PATH)
endTime = time.time()
print("Total time taken to train the model: {:.2f}s".format(endTime - startTime))