# Import Python Modules, Dataset Creator, Scoring Metrics and Model

In [None]:
import os
import sys
import torch
from torch.utils.data import DataLoader
import albumentations as A
import pandas as pd
import time
import math
# Custom modules
import LOADER
import SIA_METRICS
import FCBSWINV2_TRANSFORMER

# Define 'Results Folder' name

In [None]:
save_string = "RESULTS_FOLDER_NAME"

# Create ID, Image and Masks Data from folder of images and CSV file IDS

In [None]:
train_IDs, train_X, train_Y = LOADER.load_data_to_model(384, "PATH TO 384x384 RESIZED IMAGES & MASKS FOLDER", "PATH TO TRAIN DATA SPLIT csv FILE")
valid_IDs, valid_X, valid_Y = LOADER.load_data_to_model(384, "PATH TO 384x384 RESIZED IMAGES & MASKS FOLDER", "PATH TO VALID DATA SPLIT csv FILE")
test_IDs,  test_X,  test_Y =  LOADER.load_data_to_model(384, "PATH TO 384x384 RESIZED IMAGES & MASKS FOLDER", "PATH TO TEST  DATA SPLIT csv FILE")

# Define Data Augmentations

In [None]:
geometric = A.Compose([
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.Transpose(p=0.5),
    A.Affine(scale=(0.5,1.5), translate_percent=(-0.125,0.125), rotate=(-180,180), shear=(-22.5,22), always_apply=True)
])

color = A.Compose([
    A.ColorJitter(brightness=(0.6,1.6), contrast=0.2, saturation=0.1, hue=0.01, always_apply=True)
])

# Create Train, Validation and Test Datasets

In [None]:
train_dataset = LOADER.Polyp_Dataset(train_IDs, train_X, train_Y, geo_transform=geometric, color_transform=color)
valid_dataset = LOADER.Polyp_Dataset(valid_IDs, valid_X, valid_Y, geo_transform=None, color_transform=None)
test_dataset =  LOADER.Polyp_Dataset(test_IDs,  test_X,  test_Y, geo_transform=None, color_transform=None)

# Create Train, Validation and Test Dataloaders

In [None]:
train_batch_size = 2
valid_batch_size = 1
test_batch_size = 1

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=6)
valid_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=False, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=6)

# Create Model

In [None]:
model = FCBSWINV2_TRANSFORMER.FCBSwinV2_Transformer(size=384, checkpoint_path="PATH TO PRE-TRAINED SWINV2 MODEL WEIGHTS)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define PyTorch Optimizer, Scheduler and Loss Metrics

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.6, patience=10)
# Custom scoring functions for image-wise averaging
criterion = SIA_METRICS.DiceBCELoss()
dice = SIA_METRICS.DiceLoss()
IoU  = SIA_METRICS.IoULoss()

# Train the model and evaluate on validation data

In [None]:
epochs = 200
start_epoch = 0

max_trn_batch = 1000000
max_val_batch = 1000000

# Dataset Sizes
train_size = len(train_dataset)
valid_size = len(valid_dataset)

# Initialise multiclass scores and losses lists
train_ce_losses_lst = []
valid_ce_losses_lst = []

train_dice_losses_lst = []
valid_dice_losses_lst = []

train_dice_scores_lst = []
valid_dice_scores_lst = []

train_thresh_dice_losses_lst = []
valid_thresh_dice_losses_lst = []

train_thresh_dice_scores_lst = []
valid_thresh_dice_scores_lst = []

train_iou_losses_lst = []
valid_iou_losses_lst = []

train_iou_scores_lst = []
valid_iou_scores_lst = []

train_thresh_iou_losses_lst = []
valid_thresh_iou_losses_lst = []

train_thresh_iou_scores_lst = []
valid_thresh_iou_scores_lst = []

train_precision_score_lst = []
valid_precision_score_lst = []

train_recall_score_lst = []
valid_recall_score_lst = []

# Track minimum validation loss and maximum validation Dice and IOU scores
min_val_loss = 100
max_val_dice = 0
max_val_tice = 0

# Initialise learning rate list
lrs = []

for i in range(epochs):
    model.train()
    # Initialise epoch start time
    tic = time.time()

    # Track images being processed
    train_img_pro = 0
    valid_img_pro = 0

    # Initialise scores and losses to 0
    total_train_ce_loss = 0
    total_valid_ce_loss = 0

    total_train_dice_loss = 0
    total_valid_dice_loss = 0

    total_train_dice_score= 0
    total_valid_dice_score = 0

    total_train_thresh_dice_loss = 0
    total_valid_thresh_dice_loss = 0

    total_train_thresh_dice_score= 0
    total_valid_thresh_dice_score = 0

    total_train_iou_loss = 0
    total_valid_iou_loss = 0

    total_train_iou_score = 0
    total_valid_iou_score = 0

    total_train_thresh_iou_loss = 0
    total_valid_thresh_iou_loss = 0

    total_train_thresh_iou_score = 0
    total_valid_thresh_iou_score = 0

    total_train_precision_score = 0
    total_valid_precision_score = 0

    total_train_recall_score = 0
    total_valid_recall_score = 0

    # Append current learning rate to list
    lrs.append(optimizer.param_groups[0]["lr"])
    print("LR for epoch ", i," = ", optimizer.param_groups[0]["lr"])

    # Run the training batches
    for b, (name, img_train, msk_train) in enumerate(train_loader):
        img_train, msk_train = img_train.to(device), msk_train.float().to(device)

        # Load image tensor from train_loader
        #print("img_train: ", img_train.size())

        # Load ground truth mask tensor from train_loader
        #print("msk_train: ", msk_train.size())

        # Apply the model
        pred = model(img_train)
        #print("pred:      ", pred.size())
        # Track Number of Images Processed
        train_img_pro += pred.size()[0]
        #print("Batch: ", b, "img_pro", train_img_pro)

        # Apply sigmoid activation (needed fo dice and IoU loss calculations)
        output = torch.sigmoid(pred)

        # Calculate loss
        # Scoring loss metrics for binary case
        train_dice_loss = dice(output, msk_train)
        train_dice_score = 1 - train_dice_loss

        loss = criterion(output, msk_train)

        train_iou_loss  = IoU(output, msk_train)
        train_iou_score = 1 - train_iou_loss

        train_thresh_dice_loss = SIA_METRICS.Threshold_DiceLoss(output, msk_train, thresh=0.5, smooth=1e-6)
        train_thresh_dice_score = 1 - train_thresh_dice_loss

        train_thresh_iou_loss = SIA_METRICS.Threshold_IoULoss(output, msk_train, thresh=0.5, smooth=1e-6)
        train_thresh_iou_score = 1 - train_thresh_iou_loss

        train_precision_score = SIA_METRICS.custom_precision_score(output, msk_train, thresh=0.5, smooth=1e-6)
        train_recall_score =  SIA_METRICS.custom_recall_score(output, msk_train, thresh=0.5, smooth=1e-6)

        # Limit the number of batches
        if b == max_trn_batch:
            break
        b+=1

        # Track the total multi-class losses to get epoch averages
        total_train_ce_loss += loss.item()

        total_train_dice_loss += train_dice_loss.item()
        total_train_dice_score += train_dice_score.item()

        total_train_iou_loss += train_iou_loss.item()
        total_train_iou_score += train_iou_score.item()

        total_train_thresh_dice_loss +=  train_thresh_dice_loss.item()
        total_train_thresh_dice_score += train_thresh_dice_score.item()

        total_train_thresh_iou_loss +=  train_thresh_iou_loss.item()
        total_train_thresh_iou_score += train_thresh_iou_score.item()

        total_train_precision_score += train_precision_score.item()
        total_train_recall_score += train_recall_score.item()

        # Update model parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print interim results
        if b%(100*train_batch_size) == 0:
            print('*****************************************************************************************************************')
            print(f'EPOCH: {(i+start_epoch):2} BATCH: {b:4} [{train_img_pro:6}/{train_size}] COMB LOSS: {total_train_ce_loss/math.ceil((train_img_pro/train_batch_size)):7.5f} DICE LOSS: {total_train_dice_loss/math.ceil((train_img_pro/train_batch_size)):7.5f} IOU LOSS: {total_train_iou_loss/math.ceil((train_img_pro/train_batch_size)):7.5f}')
            print(f'                                                      TICE LOSS: {total_train_thresh_dice_loss/math.ceil((train_img_pro/train_batch_size)):7.5f} TOU LOSS: {total_train_thresh_iou_loss/math.ceil((train_img_pro/train_batch_size)):7.5f}')
            print(f'                                                      DICE SCRE: {total_train_dice_score/math.ceil((train_img_pro/train_batch_size)):7.5f} IOU SCRE: {total_train_iou_score/math.ceil((train_img_pro/train_batch_size)):7.5f}')
            print(f'                                                      TICE SCRE: {total_train_thresh_dice_score/math.ceil((train_img_pro/train_batch_size)):7.5f} TOU SCRE: {total_train_thresh_iou_score/math.ceil((train_img_pro/train_batch_size)):7.5f}')
            print(f'                                                      PREC SCRE: {total_train_precision_score/math.ceil((train_img_pro/train_batch_size)):7.5f} REC SCRE: {total_train_recall_score/math.ceil((train_img_pro/train_batch_size)):7.5f}')
            print(f'                                                      TIME:  {((time.time()-tic)/60):5.2f} ')

    # Append average training loss of epoch to list
    train_ce_losses_lst.append(total_train_ce_loss/math.ceil((train_img_pro/train_batch_size)))
    train_dice_losses_lst.append(total_train_dice_loss/math.ceil((train_img_pro/train_batch_size)))
    train_thresh_dice_losses_lst.append(total_train_thresh_dice_loss/math.ceil((train_img_pro/train_batch_size)))
    train_thresh_iou_losses_lst.append(total_train_thresh_iou_loss/math.ceil((train_img_pro/train_batch_size)))
    train_iou_losses_lst.append(total_train_iou_loss/math.ceil((train_img_pro/train_batch_size)))
    train_precision_score_lst.append(total_train_precision_score/math.ceil((train_img_pro/train_batch_size)))
    train_recall_score_lst.append(total_train_recall_score/math.ceil((train_img_pro/train_batch_size)))

    train_dice_scores_lst.append(total_train_dice_score/math.ceil((train_img_pro/train_batch_size)))
    train_thresh_dice_scores_lst.append(total_train_thresh_dice_score/math.ceil((train_img_pro/train_batch_size)))
    train_iou_scores_lst.append(total_train_iou_score/math.ceil((train_img_pro/train_batch_size)))
    train_thresh_iou_scores_lst.append(total_train_iou_score/math.ceil((train_img_pro/train_batch_size)))

    # Print epoch training results
    print('#################################################################################################################')
    print(f'EPOCH: {(i+start_epoch):2} TRN COMB LOSS: {total_train_ce_loss/math.ceil((train_img_pro/train_batch_size)):7.5f}')
    print(f'              DICE LOSS: {total_train_dice_loss/math.ceil((train_img_pro/train_batch_size)):7.5f} IOU LOSS: {total_train_iou_loss/math.ceil((train_img_pro/train_batch_size)):7.5f}')
    print(f'              TICE LOSS: {total_train_thresh_dice_loss/math.ceil((train_img_pro/train_batch_size)):7.5f} TOU LOSS: {total_train_thresh_iou_loss/math.ceil((train_img_pro/train_batch_size)):7.5f}')
    print(f'              DICE SCRE: {total_train_dice_score/math.ceil((train_img_pro/train_batch_size)):7.5f} IOU SCRE: {total_train_iou_score/math.ceil((train_img_pro/train_batch_size)):7.5f}')
    print(f'              TICE SCRE: {total_train_thresh_dice_score/math.ceil((train_img_pro/train_batch_size)):7.5f} TOU SCRE: {total_train_thresh_iou_score/math.ceil((train_img_pro/train_batch_size)):7.5f}')
    print(f'              PREC SCRE: {total_train_precision_score/math.ceil((train_img_pro/train_batch_size)):7.5f} REC SCRE: {total_train_recall_score/math.ceil((train_img_pro/train_batch_size)):7.5f}')
    print(f'              TIME:  {((time.time()-tic)/60):5.2f} ')
    print('#################################################################################################################')

    with torch.no_grad():
        model.eval()
        for b, (name, img_valid, msk_valid) in enumerate(valid_loader):
            img_valid, msk_valid = img_valid.to(device), msk_valid.float().to(device)

            # Load image tensor from valid_loader
            #print("img_valid: ", img_valid.size())

            # Apply the model
            pred = model(img_valid)
            # Track Number of Images Processed
            valid_img_pro += pred.size()[0]
            #print("Batch: ", b, "img_pro", valid_img_pro)

            # Apply sigmoid activation (needed fo dice and IoU loss calculations)
            output = torch.sigmoid(pred)

            # Calculate loss
            #loss = criterion(pred, msk_valid)

            # Scoring loss metrics for binary case
            valid_dice_loss = dice(output, msk_valid)
            valid_dice_score = 1 - valid_dice_loss

            loss = criterion(output, msk_valid)

            valid_iou_loss  = IoU(output, msk_valid)
            valid_iou_score = 1 - valid_iou_loss

            # Calculate thresholded values
            valid_thresh_dice_loss = SIA_METRICS.Threshold_DiceLoss(output, msk_valid, thresh=0.5, smooth=1)
            valid_thresh_dice_score = 1 - valid_thresh_dice_loss

            valid_thresh_iou_loss = SIA_METRICS.Threshold_IoULoss(output, msk_valid, thresh=0.5, smooth=1)
            valid_thresh_iou_score = 1 - valid_thresh_iou_loss

            valid_precision_score = SIA_METRICS.custom_precision_score(output, msk_valid, thresh=0.5, smooth=1)
            valid_recall_score =  SIA_METRICS.custom_recall_score(output, msk_valid, thresh=0.5, smooth=1)

            # Limit the number of batches
            if b == max_trn_batch:
                break
            b+=1

            # Track the total multi-class losses to get epoch averages
            total_valid_ce_loss += loss.item()

            total_valid_dice_loss += valid_dice_loss.item()
            total_valid_dice_score += valid_dice_score.item()

            total_valid_iou_loss += valid_iou_loss.item()
            total_valid_iou_score += valid_iou_score.item()

            total_valid_thresh_dice_loss +=  valid_thresh_dice_loss.item()
            total_valid_thresh_dice_score += valid_thresh_dice_score.item()

            total_valid_thresh_iou_loss +=  valid_thresh_iou_loss.item()
            total_valid_thresh_iou_score += valid_thresh_iou_score.item()

            total_valid_precision_score += valid_precision_score.item()
            total_valid_recall_score += valid_recall_score.item()

            # Print interim results
            if b%(100*valid_batch_size) == 0:
                print('*****************************************************************************************************************')
                print(f'EPOCH: {(i+start_epoch):2} BATCH: {b:4} [{valid_img_pro:6}/{valid_size}] COMB LOSS: {total_valid_ce_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f} DICE LOSS: {total_valid_dice_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f} IOU LOSS: {total_valid_iou_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
                print(f'                                                      TICE LOSS: {total_valid_thresh_dice_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f} TOU LOSS: {total_valid_thresh_iou_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
                print(f'                                                      DICE SCRE: {total_valid_dice_score/math.ceil(valid_img_pro/valid_batch_size):7.5f} IOU SCRE: {total_valid_iou_score/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
                print(f'                                                      TICE SCRE: {total_valid_thresh_dice_score/math.ceil(valid_img_pro/valid_batch_size):7.5f} TOU SCRE: {total_valid_thresh_iou_score/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
                print(f'                                                      PREC SCRE: {total_valid_precision_score/math.ceil(valid_img_pro/valid_batch_size):7.5f} REC SCRE: {total_valid_recall_score/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
                print(f'                                                      TIME:  {((time.time()-tic)/60):5.2f} ')

        # Append average validing loss of epoch to list
        valid_ce_losses_lst.append(total_valid_ce_loss/math.ceil(valid_img_pro/valid_batch_size))
        valid_dice_losses_lst.append(total_valid_dice_loss/math.ceil(valid_img_pro/valid_batch_size))
        valid_thresh_dice_losses_lst.append(total_valid_thresh_dice_loss/math.ceil(valid_img_pro/valid_batch_size))
        valid_thresh_iou_losses_lst.append(total_valid_thresh_iou_loss/math.ceil(valid_img_pro/valid_batch_size))
        valid_iou_losses_lst.append(total_valid_iou_loss/math.ceil(valid_img_pro/valid_batch_size))
        valid_precision_score_lst.append(total_valid_precision_score/math.ceil(valid_img_pro/valid_batch_size))
        valid_recall_score_lst.append(total_valid_recall_score/math.ceil(valid_img_pro/valid_batch_size))

        valid_dice_scores_lst.append(total_valid_dice_score/math.ceil(valid_img_pro/valid_batch_size))
        valid_thresh_dice_scores_lst.append(total_valid_thresh_dice_score/math.ceil(valid_img_pro/valid_batch_size))
        valid_iou_scores_lst.append(total_valid_iou_score/math.ceil(valid_img_pro/valid_batch_size))
        valid_thresh_iou_scores_lst.append(total_valid_thresh_iou_score/math.ceil(valid_img_pro/valid_batch_size))

        # Print epoch validation results
        print('#################################################################################################################')
        print(f'EPOCH: {(i+start_epoch):2} VAL COMB LOSS: {total_valid_ce_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
        print(f'              DICE LOSS: {total_valid_dice_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f} IOU LOSS: {total_valid_iou_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
        print(f'              TICE LOSS: {total_valid_thresh_dice_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f} TOU LOSS: {total_valid_thresh_iou_loss/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
        print(f'              DICE SCRE: {total_valid_dice_score/math.ceil(valid_img_pro/valid_batch_size):7.5f} IOU SCRE: {total_valid_iou_score/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
        print(f'              TICE SCRE: {total_valid_thresh_dice_score/math.ceil(valid_img_pro/valid_batch_size):7.5f} TOU SCRE: {total_valid_thresh_iou_score/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
        print(f'              PREC SCRE: {total_valid_precision_score/math.ceil(valid_img_pro/valid_batch_size):7.5f} REC SCRE: {total_valid_recall_score/math.ceil(valid_img_pro/valid_batch_size):7.5f}')
        print(f'              TIME:  {((time.time()-tic)/60):5.2f} ')
        print('#################################################################################################################')
        scheduler.step(min_val_loss)
        if i>5:
            if (total_valid_ce_loss/math.ceil(valid_img_pro/valid_batch_size)) <= min_val_loss:
                print("SAVING MIN VAL COMB LOSS MODEL")
                min_val_loss = (total_valid_ce_loss/math.ceil(valid_img_pro/valid_batch_size))
                torch.save(model.state_dict(), save_string + "_min_val_comb_loss.pt")
            if (total_valid_dice_score/math.ceil(valid_img_pro/valid_batch_size)) >= max_val_dice:
                print("SAVING MAX VAL DICE SCORE MODEL")
                max_val_dice = (total_valid_dice_score/math.ceil(valid_img_pro/valid_batch_size))
                torch.save(model.state_dict(), save_string + "_max_val_dice_score.pt")
            if (total_valid_thresh_dice_score/math.ceil(valid_img_pro/valid_batch_size)) >= max_val_tice:
                print("SAVING MAX VAL TICE SCORE MODEL")
                max_val_tice = (total_valid_thresh_dice_score/math.ceil(valid_img_pro/valid_batch_size))
                torch.save(model.state_dict(), save_string + "_max_val_tice_score.pt")

# Save Training/Validation Data to csv file

In [None]:
df_train_val_results = pd.DataFrame(data={'Train_Losses': train_ce_losses_lst, 'Valid_Losses': valid_ce_losses_lst,
                                          'Train_DICE_Losses': train_dice_losses_lst, 'Valid_DICE_Losses': valid_dice_losses_lst,
                                          'Train_DICE_Threshold_Losses': train_thresh_dice_losses_lst, 'Valid_DICE_Threshold_Losses': valid_thresh_dice_losses_lst,
                                          'Train_DICE_Scores': train_dice_scores_lst, 'Valid_DICE_Scores': valid_dice_scores_lst,
                                          'Train_DICE_Threshold_Scores': train_thresh_dice_scores_lst, 'Valid_DICE_Threshold_Scores': valid_thresh_dice_scores_lst,
                                          'Train_IOU_Losses': train_iou_losses_lst, 'Valid_IOU_Losses': valid_iou_losses_lst,
                                          'Train_IOU_Threshold_Losses': train_thresh_iou_losses_lst, 'Valid_IOU_Threshold_Losses': valid_thresh_iou_losses_lst,
                                          'Train_IOU_Scores': train_iou_scores_lst, 'Valid_IOU_Scores': valid_iou_scores_lst,
                                          'Train_IOU_Threshold_Scores': train_thresh_iou_scores_lst, 'Valid_IOU_Threshold_Scores': valid_thresh_iou_scores_lst,
                                          'Train_PREC_Scores': train_precision_score_lst, 'Valid_PREC_Scores': valid_precision_score_lst,
                                          'Train_RECALL_Scores': train_recall_score_lst ,'Valid_RECALL_Scores': valid_recall_score_lst })

df_train_val_results.to_csv(save_string+"_TRAIN-VAL.csv", sep=',',index=False)

# Load model instance with highest validation thresholded dice score

In [None]:
model_path = save_string + "_max_val_tice_score.pt"
weights = torch.load(model_path)
model.load_state_dict(weights, strict=True)

# Generate Test Predictions

In [None]:
test_img_pro = 0
total_test_ce_loss = 0
total_test_dice_loss = 0
total_test_dice_score = 0
total_test_thresh_dice_loss = 0
total_test_thresh_dice_score = 0
total_test_iou_loss = 0
total_test_iou_score = 0
total_test_thresh_iou_loss = 0
total_test_thresh_iou_score = 0
total_test_precision_score = 0
total_test_recall_score = 0
max_test_batch = 1000000
test_size = len(test_dataset)
test_ce_losses_lst = []
test_dice_losses_lst = []
test_dice_scores_lst = []
test_thresh_dice_losses_lst = []
test_thresh_dice_scores_lst = []
test_iou_losses_lst = []
test_iou_scores_lst = []
test_thresh_iou_losses_lst = []
test_thresh_iou_scores_lst = []
test_precision_score_lst = []
test_recall_score_lst = []
min_test_loss = 100
max_test_dice = 0
max_test_tice = 0

results_df = pd.DataFrame(columns=[
    'Image Name', 'Dice Loss', 'Thresh Dice Loss',
    'Thresh Dice Score', 'IoU Loss', 'Thresh IoU Loss',
    'Thresh IoU Score', 'Precision Score', 'Recall Score'
])

# Ensure no_grad and evaluation mode are set to turn off batch norm and dropout
with torch.no_grad():
    model.eval()
    for b, (name, img_test, msk_test) in enumerate(test_loader):
        img_test, msk_test = img_test.to(device), msk_test.float().to(device)
        # Apply the model
        pred = model(img_test)
        # Track Number of Images Processed
        test_img_pro += pred.size()[0]

        # Apply sigmoid activation (needed for dice and IoU loss calculations)
        output = torch.sigmoid(pred)
        binary_preds = (output > 0.5).int()
        
        # Save prediction maps for each image within the test set
        for idx, img_name in enumerate(name):
            img_path = f"PATH TO MASK PREDICTION FOLDER/{img_name}"
            pred_data = binary_preds[idx].cpu().numpy().squeeze()
            pred_data = (pred_data * 255).astype('uint8')
            binary_img = Image.fromarray(pred_data)
            binary_img.save(img_path)

        # Calculate loss
        loss = criterion(output, msk_test)
        test_dice_loss = dice(output, msk_test)
        test_dice_score = 1 - test_dice_loss
        test_iou_loss = IoU(output, msk_test)
        test_iou_score = 1 - test_iou_loss

        # Calculate thresholded values
        test_thresh_dice_loss = SIA_METRICS.Threshold_DiceLoss(output, msk_test, thresh=0.5, smooth=1e-6)
        test_thresh_dice_score = 1 - test_thresh_dice_loss
        test_thresh_iou_loss = SIA_METRICS.Threshold_IoULoss(output, msk_test, thresh=0.5, smooth=1e-6)
        test_thresh_iou_score = 1 - test_thresh_iou_loss
        test_precision_score = SIA_METRICS.custom_precision_score(output, msk_test, thresh=0.5, smooth=1e-6)
        test_recall_score = SIA_METRICS.custom_recall_score(output, msk_test, thresh=0.5, smooth=1e-6)

        for img_name in name:
            temp_df = pd.DataFrame({
                'Image Name': [img_name],
                'Dice Loss': [test_dice_loss.item()],
                'Thresh Dice Loss': [test_thresh_dice_loss.item()],
                'Thresh Dice Score': [test_thresh_dice_score.item()],
                'IoU Loss': [test_iou_loss.item()],
                'Thresh IoU Loss': [test_thresh_iou_loss.item()],
                'Thresh IoU Score': [test_thresh_iou_score.item()],
                'Precision Score': [test_precision_score.item()],
                'Recall Score': [test_recall_score.item()]
            })
            results_df = pd.concat([results_df, temp_df], ignore_index=True)

        # Track the total multi-class losses to get epoch averages
        total_test_ce_loss += loss.item()
        total_test_dice_loss += test_dice_loss.item()
        total_test_dice_score += test_dice_score.item()
        total_test_iou_loss += test_iou_loss.item()
        total_test_iou_score += test_iou_score.item()
        total_test_thresh_dice_loss +=  test_thresh_dice_loss.item()
        total_test_thresh_dice_score += test_thresh_dice_score.item()
        total_test_thresh_iou_loss +=  test_thresh_iou_loss.item()
        total_test_thresh_iou_score += test_thresh_iou_score.item()
        total_test_precision_score += test_precision_score.item()
        total_test_recall_score += test_recall_score.item()

        # Print interim results
        if b % (1 * batch_size) == 0:
            print('*****************************************************************************************************************')
            print(f' BATCH: {b:4} [{test_img_pro:6}/{test_size}] COMB LOSS: {total_test_ce_loss/math.ceil(test_img_pro/batch_size):7.5f} DICE LOSS: {total_test_dice_loss/math.ceil(test_img_pro/batch_size):7.5f} IOU LOSS: {total_test_iou_loss/math.ceil(test_img_pro/batch_size):7.5f}')
            print(f'                                                      TICE LOSS: {total_test_thresh_dice_loss/math.ceil(test_img_pro/batch_size):7.5f} TOU LOSS: {total_test_thresh_iou_loss/math.ceil(test_img_pro/batch_size):7.5f}')
            print(f'                                                      DICE SCRE: {total_test_dice_score/math.ceil(test_img_pro/batch_size):7.5f} IOU SCRE: {total_test_iou_score/math.ceil(test_img_pro/batch_size):7.5f}')
            print(f'                                                      TICE SCRE: {total_test_thresh_dice_score/math.ceil(test_img_pro/batch_size):7.5f} TOU SCRE: {total_test_thresh_iou_score/math.ceil(test_img_pro/batch_size):7.5f}')
            print(f'                                                      PREC SCRE: {total_test_precision_score/math.ceil(test_img_pro/batch_size):7.5f} REC SCRE: {total_test_recall_score/math.ceil(test_img_pro/batch_size):7.5f}')
            print('*****************************************************************************************************************')

    # Append average test loss of epoch to list
    test_ce_losses_lst.append(total_test_ce_loss/math.ceil(test_img_pro/batch_size))
    test_dice_losses_lst.append(total_test_dice_loss/math.ceil(test_img_pro/batch_size))
    test_thresh_dice_losses_lst.append(total_test_thresh_dice_loss/math.ceil(test_img_pro/batch_size))
    test_thresh_iou_losses_lst.append(total_test_thresh_iou_loss/math.ceil(test_img_pro/batch_size))
    test_iou_losses_lst.append(total_test_iou_loss/math.ceil(test_img_pro/batch_size))
    test_precision_score_lst.append(total_test_precision_score/math.ceil(test_img_pro/batch_size))
    test_recall_score_lst.append(total_test_recall_score/math.ceil(test_img_pro/batch_size))
    test_dice_scores_lst.append(total_test_dice_score/math.ceil(test_img_pro/batch_size))
    test_thresh_dice_scores_lst.append(total_test_thresh_dice_score/math.ceil(test_img_pro/batch_size))
    test_iou_scores_lst.append(total_test_iou_score/math.ceil(test_img_pro/batch_size))
    test_thresh_iou_scores_lst.append(total_test_thresh_iou_score/math.ceil(test_img_pro/batch_size))

# Print Results on test set
print('#################################################################################################################')
print(f'               TEST COMB LOSS: {total_test_ce_loss/math.ceil(test_img_pro/batch_size):7.5f}')
print(f'               DICE LOSS: {total_test_dice_loss/math.ceil(test_img_pro/batch_size):7.5f} IOU LOSS: {total_test_iou_loss/math.ceil(test_img_pro/batch_size):7.5f}')
print(f'               TICE LOSS: {total_test_thresh_dice_loss/math.ceil(test_img_pro/batch_size):7.5f} TOU LOSS: {total_test_thresh_iou_loss/math.ceil(test_img_pro/batch_size):7.5f}')
print(f'               DICE SCRE: {total_test_dice_score/math.ceil(test_img_pro/batch_size):7.5f} IOU SCRE: {total_test_iou_score/math.ceil(test_img_pro/batch_size):7.5f}')
print(f'               TICE SCRE: {total_test_thresh_dice_score/math.ceil(test_img_pro/batch_size):7.5f} TOU SCRE: {total_test_thresh_iou_score/math.ceil(test_img_pro/batch_size):7.5f}')
print(f'               PREC SCRE: {total_test_precision_score/math.ceil(test_img_pro/batch_size):7.5f} REC SCRE: {total_test_recall_score/math.ceil(test_img_pro/batch_size):7.5f}')
print('#################################################################################################################')

# Save individual image predictions
results_df.to_csv(save_string + '_TEST_RESULTS.csv', index=False)

# Save single image averaged scores for the test set
average_scores = {'Dice Loss': total_test_dice_loss / test_size,
                  'Thresh Dice Loss': total_test_thresh_dice_loss / test_size,
                  'Thresh Dice Score': total_test_thresh_dice_score / test_size,
                  'IoU Loss': total_test_iou_loss / test_size,
                  'Thresh IoU Loss': total_test_thresh_iou_loss / test_size,
                  'Thresh IoU Score': total_test_thresh_iou_score / test_size,
                  'Precision Score': total_test_precision_score / test_size,
                  'Recall Score': total_test_recall_score / test_size}
avg_df = pd.DataFrame([average_scores])
avg_df.to_csv(save_string+'_AVERAGE_TEST_RESULTS.csv', index=False)