# SETTINGS

In [None]:
import numpy as np

############### EXECUTION SETTINGS ###############

PARENT_EXECUTION_DIR = "executions/unet/train-k_fold_val-test"  


################### DATA SETTINGS #################

DATA_DIR = "data/train_data"  
# Directory where the samples (data and labels folders) and the metadata.json file are located.

METADATA_DATASET = "metadata_38_samples.json"  
# Name of the metadata.json file located in DATA_DIR, which contains the specifications of the samples and their corresponding information used to create the dataset.

TEST_PERCENTAGE_TRAIN_TEST_SPLIT = 15
# Percentage of the dataset that will be used for testing during the train-test split.

CROSS_VALIDATION_K_FOLDS = 4
# Number of folds for k-fold cross-validation

NUM_WORKERS = 4
# Number of logical CPU Cores used for parallelizing data laoding

############ TRAINING/INFERENCE SETTINGS ##########

BATCH_SIZE_TRAINING = 2
BATCH_SIZE_INFERENCE = 1 # It doesn't matter the batch size. It always processes the patches sequentially inside the batch.

PATIENCE_EARLY_STOPPING = 50 #7
# Number of epochs to wait before early stopping if no improvement is observed.

MAX_TRAINING_EPOCHS = 3000  
# Maximum number of epochs for training.

LEARNING_RATE = 1e-4 #1e-4
WEIGHT_DECAY = 1e-4 #1e-2
WARMUP_EPOCHS_LR_SCHEDULER = np.ceil(0.10 * MAX_TRAINING_EPOCHS)

############### DATA PROCESSING SETTINGS ##########

PREPROCESSING_TECHNIQUE = "ppt"
PREPROCESSING_CHANNELS = 10

################### MODEL SETTINGS #################



# CHEKING DISK STORAGE

In [None]:
import shutil

# Get disk space details
total, used, free = shutil.disk_usage("/")

# Convert to human-readable format
def format_size(size):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if size < 1024:
            return f"{size:.2f} {unit}"
        size /= 1024

print(f"Total Space: {format_size(total)}")
print(f"Used Space: {format_size(used)}")
print(f"Free Space: {format_size(free)}")

In [None]:
import os

def get_directory_size(directory):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(directory):
        for file in filenames:
            file_path = os.path.join(dirpath, file)
            # Add file size, skipping broken symbolic links
            if os.path.exists(file_path):
                total_size += os.path.getsize(file_path)
    return total_size

def print_directory_size(directory):
    size_bytes = get_directory_size(directory)
    # Convert bytes to a human-readable format (KB, MB, GB)
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if size_bytes < 1024:
            print(f"Size of '{directory}': {size_bytes:.2f} {unit}")
            break
        size_bytes /= 1024

# Example usage
preprocessed_files_path = "data/train_data/preprocessed_files"  # Change this to your target directory
print_directory_size(preprocessed_files_path)

preprocessed_files_path = "executions"  # Change this to your target directory
print_directory_size(preprocessed_files_path)

# CREATING EXECUTION DIRECTORY

In [None]:
import os

executions_ids = [
    int(execution_dir.split('-')[0].split('=')[1])
    for model_dir in os.listdir("executions")
    if os.path.isdir(os.path.join("executions", model_dir))  and model_dir != ".ipynb_checkpoints"
    for execution_type_dir in os.listdir(os.path.join("executions", model_dir))
    if os.path.isdir(os.path.join("executions", model_dir, execution_type_dir)) and execution_type_dir != ".ipynb_checkpoints"
    for execution_dir in os.listdir(os.path.join("executions", model_dir, execution_type_dir))
    if os.path.isdir(os.path.join("executions", model_dir, execution_type_dir, execution_dir)) and execution_dir != ".ipynb_checkpoints"
]

print(f"Total executions previously done: {len(executions_ids)}")

parent_execution_dir_path = os.path.join(PARENT_EXECUTION_DIR)

# Count the subdirectories inside parent_execution_dir_path
executions_parent_dir_count = len([
    d for d in os.listdir(parent_execution_dir_path)
    if os.path.isdir(os.path.join(parent_execution_dir_path, d)) and d != ".ipynb_checkpoints"
])

print(f"Total executions previously done ({PARENT_EXECUTION_DIR}): {executions_parent_dir_count}")

execution_id = max(executions_ids) + 1  # Assign 0 if the list is empty, otherwise the count of execution_directories

print(f"Assigned execution_id: {execution_id}")

# Create a unique experiment configuration identifier
experiment_config_id = (
    f"id={execution_id}-"
    f"preprocessing=[{PREPROCESSING_TECHNIQUE}]-"
    f"preprocessing_channels=[{PREPROCESSING_CHANNELS}]-"
    f"batch_training=[{BATCH_SIZE_TRAINING}]"
)

# Define and create the execution directory
execution_dir = os.path.join(parent_execution_dir_path, experiment_config_id)
os.makedirs(execution_dir, exist_ok=True)

print(f"Created execution directory: {execution_dir}")


# GENERATING SETTINGS JSON

In [None]:
import json

# Define the settings as a dictionary
settings = {
    "EXECUTION_SETTINGS": {
        "PARENT_EXECUTION_DIR": PARENT_EXECUTION_DIR,
        # Directory where the results of this execution will be stored.
    },
    
    "DATA_SETTINGS": {
        "DATA_DIR": DATA_DIR,
        # Directory where the samples (data and labels folders) and the metadata.json file are located.
        "METADATA_DATASET": METADATA_DATASET,
        # Name of the metadata.json file located in DATA_DIR, which contains the specifications of the samples and their corresponding information used to create the dataset.
        "TEST_PERCENTAGE_TRAIN_TEST_SPLIT": TEST_PERCENTAGE_TRAIN_TEST_SPLIT,
        # Percentage of the dataset that will be used for testing during the train-test split.
        "CROSS_VALIDATION_K_FOLDS": CROSS_VALIDATION_K_FOLDS,
        # Number of folds for k-fold cross-
        "NUM_WORKERS": NUM_WORKERS, 
        # Number of logical CPU Cores used for parallelizing data laoding
    },
    
    "TRAINING_INFERENCE_SETTINGS": {
        "BATCH_SIZE_TRAINING": BATCH_SIZE_TRAINING,
        "BATCH_SIZE_INFERENCE": BATCH_SIZE_INFERENCE,
        "PATIENCE_EARLY_STOPPING": PATIENCE_EARLY_STOPPING,
        # Number of epochs to wait before early stopping if no improvement is observed.
        "MAX_TRAINING_EPOCHS": MAX_TRAINING_EPOCHS,
        # Maximum number of epochs for training.
        "LEARNING_RATE": LEARNING_RATE,
        "WEIGHT_DECAY": WEIGHT_DECAY,
        "WARMUP_EPOCHS_LR_SCHEDULER": WARMUP_EPOCHS_LR_SCHEDULER,  # Since MAX_TRAINING_EPOCHS = 10
    },
    
    "DATA_PROCESSING_SETTINGS": {
        "PREPROCESSING_TECHNIQUE" : PREPROCESSING_TECHNIQUE,
        "PREPROCESSING_CHANNELS" : PREPROCESSING_CHANNELS
    },
    
    "MODEL_SETTINGS": {

    }
}

settings_json_file_path = os.path.join(execution_dir, "settings.json")

# Save the dictionary as a pretty JSON file
with open(settings_json_file_path, "w") as json_file:
    json.dump(settings, json_file, indent=4)

# Print success message
print(f"Settings JSON file successfully generated at: {settings_json_file_path}")

# DATA SPLITTING

In [None]:
import sys
import importlib

# Import the whole module 'lr_schedulers'
import utils.data_splitter

# Reload the 'lr_schedulers' module
importlib.reload(utils.data_splitter)

# Now, you can import the class from the reloaded module
from utils.data_splitter import load_metadata, stratified_split, k_fold_stratified_split
import os

#Load metadata into a dict
metadata = load_metadata(os.path.join(DATA_DIR, METADATA_DATASET))
# Perform stratified split
train_files, test_files = stratified_split(metadata, test_size = TEST_PERCENTAGE_TRAIN_TEST_SPLIT/100)
k_fold_training_splits = k_fold_stratified_split(train_files, k=CROSS_VALIDATION_K_FOLDS)

# Import UNET_VGG11 LIGHTNING MODEL

In [None]:
import importlib
from models.unet import UNET_VGG11_lightning_model

# Reload the module (if necessary)
importlib.reload(UNET_VGG11_lightning_model)

# Import the class from the reloaded module
from models.unet.UNET_VGG11_lightning_model import UNET_VGG11_LightningModel

# Import UNET_VGG11 DATASET

In [None]:
import importlib
from models.unet import UNET_VGG11_dataset

# Reload the module (if necessary)
importlib.reload(UNET_VGG11_dataset)

# Import the class from the reloaded module
from models.unet.UNET_VGG11_dataset import UNET_VGG11_Dataset


# TRAINING

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pandas as pd
import time
import shutil
import gc

training_times_splits = []
validation_losses_splits = []
best_split_validation_loss = 1
best_split = None

# Measure start training time
start_execution_time = time.time()

for split_idx, (train_files, validation_files) in enumerate(k_fold_training_splits):

    # Clear the GPU cache between splits
    torch.cuda.empty_cache()
    
    ####################################################################################
    ############################# DATA PREPARATION #####################################
    ####################################################################################
    
    print()
    print(f"{'=' * 50}")
    print(f"{'=' * 20} SPLIT: {split_idx} {'=' * 20}")
    print(f"{'=' * 50}")
    print()

    print()
    print(f"{'=' * 15}> TRAINING DATA PREPARATION")
    print()

    print(f" SAMPLES: {train_files.keys()}")
    print()
    print()
    print()
    
    train_dataset = UNET_VGG11_Dataset(
        metadata_dict_with_files_selected=train_files,
        data_dir=DATA_DIR,
        preprocessing_technique = PREPROCESSING_TECHNIQUE,
        preprocessing_channels = PREPROCESSING_CHANNELS,
        
    )

    print()
    print()
    print(f"{'=' * 15}> VALIDATION DATA PREPARATION")
    print()

    print(f" SAMPLES: {validation_files.keys()}")
    print()
    print()
    print()
    
    val_dataset = UNET_VGG11_Dataset(
        metadata_dict_with_files_selected=validation_files,
        data_dir=DATA_DIR,
        preprocessing_technique = PREPROCESSING_TECHNIQUE,
        preprocessing_channels = PREPROCESSING_CHANNELS,
    )
       
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAINING, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE_INFERENCE, num_workers=NUM_WORKERS)

    print()

    ####################################################################################
    ############################# TRAINER SETTING ######################################
    ####################################################################################
     
    # Model checkpoint callback
    split_dir = os.path.join(execution_dir, f"split_{split_idx}")
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=split_dir,
        filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}",
        save_top_k=1,
        mode='min'
    )
    
    # Early stopping callback
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=PATIENCE_EARLY_STOPPING,
        mode='min'
    )
    
    # Define logger
    logger = CSVLogger(
        save_dir=split_dir,
        name="",
        version="")
    
    trainer = Trainer(
        max_epochs=MAX_TRAINING_EPOCHS,
        callbacks=[checkpoint_callback, early_stopping_callback],
        log_every_n_steps=1, # log every n batches
        logger=logger)
    

    ####################################################################################
    ############################### TRAINING ###########################################
    ####################################################################################

    print()
    print(f"{'=' * 10} TRAINING STARTS {'=' * 10}")
    print()
            
    # Measure start training time
    start_time = time.time()
    
    # Define and initilize model
    model = UNET_VGG11_LightningModel(
        lr_optimizer = LEARNING_RATE,
        weight_decay_optimizer = WEIGHT_DECAY,
        warmup_epochs_lr_scheduler = WARMUP_EPOCHS_LR_SCHEDULER
    )
    # Fit the model
    trainer.fit(model, train_loader, val_loader)
    
    # Measure end training time
    end_time = time.time()
    
    # Calculate and print the elapsed time
    elapsed_time = end_time - start_time

    training_times_splits.append(elapsed_time)
    
    print(f"Training completed in {elapsed_time // 3600:.0f}h {elapsed_time % 3600 // 60:.0f}m {elapsed_time % 60:.0f}s")

    print()
    print(f"{'=' * 10} TRAINING FINISHED {'=' * 10}")
    print()
    
    ####################################################################################
    ################## PLOTTING TRAINING & VALIDATION LOSSES ###########################
    ####################################################################################

    print()
    print(f"{'=' * 10} PLOTTING TRAINING & VALIDATION LOSSES EVOLUTION {'=' * 10}")
    print()
    
    # Path to the latest metrics file
    metrics_file = os.path.join(split_dir, 'metrics.csv')
    
    # Load the logged metrics
    metrics_df = pd.read_csv(metrics_file)
    
    # Ensure metrics_df is loaded
    # Filter rows where train_loss_epoch and val_loss_epoch are not NaN
    train_loss = metrics_df['train_loss_epoch'].dropna()
    val_loss = metrics_df['val_loss_epoch'].dropna()
    
    # Use the 'epoch' column as x-axis
    epochs_train = metrics_df.loc[metrics_df['train_loss_epoch'].notna(), 'epoch']
    epochs_val = metrics_df.loc[metrics_df['val_loss_epoch'].notna(), 'epoch']
    
    # Plot Training and Validation Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_train, train_loss, label="Training Loss")
    plt.plot(epochs_val, val_loss, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(experiment_config_id + "-train_patches=[" + str(len(train_dataset)) + "]" + "\n Training and Validation Loss Over Epochs")
    plt.legend()
    plt.grid(True)
    
    # Save the plot to the folder
    save_path = os.path.join(split_dir, f"loss-evol_{experiment_config_id}.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')  # Save as PNG with high resolution
    print(f"Plot saved at: {save_path}")
    
    plt.show()
    
    ####################################################################################
    ###################### PLOTTING BEST_CHECKPOINT RESULTS ############################
    ####################################################################################
    
    print()
    print(f"{'=' * 10} PLOTTING BEST_CHECKPOINT RESULTS {'=' * 10}")
    print()
    
    # Path to the best checkpoint
    best_checkpoint_path = checkpoint_callback.best_model_path
    
    # Extract the best epoch number from the checkpoint filename
    best_epoch = float(best_checkpoint_path.split("best-checkpoint-")[1].split("epoch=")[1].split("-")[0]) # best-checkpoint-epoch=02-val_loss=0.85
    
    # Get metrics for the best epoch
    metrics_for_best_epoch = metrics_df.loc[metrics_df['epoch'] == best_epoch]
    val_loss_epoch = metrics_for_best_epoch.loc[metrics_for_best_epoch['val_loss_epoch'].notna(), 'val_loss_epoch'].values[0]
    val_mean_iou_epoch = metrics_for_best_epoch.loc[metrics_for_best_epoch['val_mean_iou_epoch'].notna(), 'val_mean_iou_epoch'].values[0]
    val_dice_epoch = metrics_for_best_epoch.loc[metrics_for_best_epoch['val_dice_epoch'].notna(), 'val_dice_epoch'].values[0]
    val_fpr_epoch = metrics_for_best_epoch.loc[metrics_for_best_epoch['val_fpr_epoch'].notna(), 'val_fpr_epoch'].values[0]
    
    # Titles for the plots
    plot_title_template = (
        f"Epoch: {best_epoch:.1f}, "
        f"val_loss_epoch: {val_loss_epoch:.4f}, "
        f"val_mean_iou_epoch: {val_mean_iou_epoch:.4f}, "
        f"val_dice_epoch: {val_dice_epoch:.4f}, "
        f"val_fpr_epoch: {val_fpr_epoch:.4f}"

    )
    
    # Load the best checkpoint
    best_model = UNET_VGG11_LightningModel.load_from_checkpoint(
        checkpoint_path=best_checkpoint_path,
        lr_optimizer = LEARNING_RATE,
        weight_decay_optimizer = WEIGHT_DECAY,
        warmup_epochs_lr_scheduler = WARMUP_EPOCHS_LR_SCHEDULER
    )
    best_model.eval()  # Set model to evaluation mode
    best_model.freeze()  # Freeze the model
    
    # Initialize lists to collect all results
    all_ground_truth = []
    all_predictions = []
    all_sample_ids = []
    
    # Inference on validation set
    for batch_idx, batch in enumerate(val_loader):
        sample_ids, x, y = batch  # sample_ids shape: (bacth_size, 1), x shape: (batch_size, input_channels, height, width), y shape: (batch_size, output_channels, height, width)

        x = x.to(best_model.device)
        y = y.to(best_model.device)
        
        # Forward pass through the model
        y_hat = best_model(x)  # y_hat shape: (batch_size, output_channels, height, width)

        ##############  POST-PROCESSING   ###############
        
        # Apply softmax to probabilities
        y_hat_probabilities = F.softmax(y_hat, dim=1)
    
        ############## (FINAL OUTPUT SEGMENTATION) POST-PROCESSING   ###############

        for i in range(0, x.shape[0]):
            # Classify each pixel accoring to probabilties assigning the class with higher probabilty
            y_hat_categorized = torch.argmax(y_hat_probabilities[i], dim=0)
            # Convert y one_encoded into categorized tensor
            y_categorized = torch.argmax(y[i], dim=0)

            # Append to the global list
            all_sample_ids.append(sample_ids[i])
            all_ground_truth.append(y_categorized)
            all_predictions.append(y_hat_categorized)
        
    
    # Find unique classes dynamically
    all_classes = set()
    for gt, pred in zip(all_ground_truth, all_predictions):
        all_classes.update(torch.unique(gt).tolist())  # Add classes from ground truth
        all_classes.update(torch.unique(pred).tolist())  # Add classes from predictions
    
    # Sort the classes to ensure order
    all_classes = sorted(all_classes)
    
    # Define class labels dynamically (for simplicity, use numeric labels for now)
    class_labels = {cls: f"Class {cls}" for cls in all_classes}
    num_classes = len(class_labels)
    
    # Create a discrete colormap with exactly `num_classes` colors
    colormap = plt.cm.get_cmap("viridis", num_classes)
    
    # Create legend patches using discrete colors from the colormap
    legend_patches = [
        mpatches.Patch(color=colormap(i), label=f"{i}: {label}")
        for i, label in class_labels.items()
    ]
    
    # Plot all results together
    num_samples = len(all_ground_truth)
    
    # Create the figure
    fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))
    
    # Ensure axes is always 2D
    if num_samples == 1:
        axes = axes[None, :]  # Ensure axes is 2D when there's only one sample
    
    # Plot ground truth and predictions
    for sample_idx, (sample_id, ground_truth, prediction) in enumerate(zip(all_sample_ids, all_ground_truth, all_predictions)):
        # Ground truth
        im_gt = axes[sample_idx, 0].imshow(ground_truth.cpu().numpy(), cmap="viridis", interpolation="none", vmin=min(all_classes), vmax=max(all_classes))
        axes[sample_idx, 0].set_title(f"Sample {sample_id} - Ground Truth")
        axes[sample_idx, 0].axis("off")
    
        # Prediction
        im_pred = axes[sample_idx, 1].imshow(prediction.cpu().numpy(), cmap="viridis", interpolation="none", vmin=min(all_classes), vmax=max(all_classes))
        axes[sample_idx, 1].set_title(f"Sample {sample_id} - Prediction")
        axes[sample_idx, 1].axis("off")
    
    # Add a single legend for the entire figure
    fig.legend(
        handles=legend_patches,
        loc="upper center",
        ncol=num_classes,
        bbox_to_anchor=(0.5, 1.02),
        fontsize=12
    )
    
    # Add a title for the entire figure
    fig.suptitle("RESULTS BEST CHECKPOINT\n" + plot_title_template, fontsize=16, y=1.10)
    plt.tight_layout()
    
    # Save the combined plot
    combined_plot_path = os.path.join(split_dir, f"results_{best_checkpoint_path.split('/')[-1]}.png")
    plt.savefig(combined_plot_path, dpi=300, bbox_inches="tight")
    plt.show()
    
    plt.close(fig)

    del y
    del y_hat
    del x
    del model
    del best_model
    del trainer
    torch.cuda.empty_cache()
    gc.collect()    

    ####################################################################################
    ############################ SPLIT PERFORMANCE COMPARISON ###########################
    ####################################################################################

    validation_losses_splits.append(val_loss_epoch)
    if val_loss_epoch < best_split_validation_loss:
        best_split = split_idx
        best_split_validation_loss = val_loss_epoch
        
        # Generate the directory name to keep
        # keep_dir_name = f"split_{best_split}"        
        # # Iterate through the items in the base directory
        # for item in os.listdir(execution_dir):
        #     item_path = os.path.join(execution_dir, item)         
        #     # Check if the item is a directory and not the one we want to keep
        #     if os.path.isdir(item_path) and item != keep_dir_name:
        #         try:
        #             # Remove the directory and its contents
        #             shutil.rmtree(item_path)
        #         except Exception as e:
        #             print(f"Failed to delete {item_path}. Reason: {e}")        
        # print("New best split updated in execution folder")

# Measure end training time
end_execution_time = time.time()
    
# Calculate and print the elapsed time
elapsed_execution_time = end_execution_time - start_execution_time

####################################################################################
############################ ANALYSING SPLIT PERFORMANCE ############################
####################################################################################

# Number of splits
splits = np.arange(len(training_times_splits))

# Create the figure and axis for the combined plot
fig, ax1 = plt.subplots(figsize=(10, 6))

# Bar plot for Training Times
ax1.bar(splits, training_times_splits, color='lightgrey', label='Training Time (s)')
ax1.set_xlabel("Split Index")
ax1.set_ylabel("Training Time (seconds)", color='grey')
ax1.tick_params(axis='y', labelcolor='grey')

# Line plot for Validation Losses
ax2 = ax1.twinx()
ax2.plot(splits, validation_losses_splits, marker='o', color='blue', label='Validation Loss')
ax2.scatter(best_split, best_split_validation_loss, color='red', s=100, zorder=3, label=f'Best Split (Split {best_split})')
ax2.set_ylabel("Validation Loss", color='blue')
ax2.tick_params(axis='y', labelcolor='blue')

# Compute mean and standard deviation
avg_validation_loss = np.mean(validation_losses_splits)
std_dev_validation_loss = np.std(validation_losses_splits, ddof=1)  # Use ddof=1 for sample standard deviation

# Compute Coefficient of Variation (CV)
cv_validation_loss = (std_dev_validation_loss / avg_validation_loss) * 100

plot_title = (
    experiment_config_id + "-training_samples=[" + str(len(train_dataset)) + "]\n"
    "Training Times and Validation Losses Across Splits\n"
    f"Complete Execution Time (Including Plotting): {elapsed_execution_time // 3600:.0f}h {elapsed_execution_time % 3600 // 60:.0f}m {elapsed_execution_time % 60:.0f}s\n"
    f"Best Split: {best_split} - Validation Loss: {best_split_validation_loss:.4f}\n"
    f"Averaged Validation Loss: {avg_validation_loss:.4f} - Coefficient of Variation (CV): {cv_validation_loss:.2f}%"
)

# Title and Grid
fig.suptitle(plot_title,)
fig.tight_layout()
fig.legend(loc="upper left", bbox_to_anchor=(0.017, 1.00))
ax1.grid(True)

# Save the combined plot
plot_path = os.path.join(execution_dir, f"splits_comparison.png")
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
plt.show()
    
plt.close(fig)



# TESTING

In [None]:
import glob

####################################################################################
################################### DATA PREPARATION ###############################
####################################################################################

print()
print(f"{'=' * 10} TESTING DATA PREPARATION {'=' * 10}")
print()

print(f" SAMPLES: {test_files.keys()}")
print()

test_dataset = UNET_VGG11_Dataset(
    metadata_dict_with_files_selected=test_files,
    data_dir=DATA_DIR,
    preprocessing_technique = PREPROCESSING_TECHNIQUE,
    preprocessing_channels = PREPROCESSING_CHANNELS,
)
   
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_INFERENCE, num_workers=NUM_WORKERS)

####################################################################################
############################# TRAINER SETTING ######################################
####################################################################################

# Define logger
logger = CSVLogger(
    save_dir=execution_dir,
    name="",
    version="")

trainer = Trainer(
    logger=logger)

####################################################################################
############################## TESTING #############################################
####################################################################################

print()
print(f"{'=' * 10} TESTING STARTS {'=' * 10}")
print()
        
# Measure start training time
start_time = time.time()

# Specify the directory path
best_split_dir = os.path.join(execution_dir, f"split_{best_split}")

# Find all .ckpt files in the directory
ckpt_files = [filename for filename in os.listdir(best_split_dir) if filename.endswith('.ckpt')]

# Handle different cases
if not ckpt_files:
    raise FileNotFoundError(f"No .ckpt file found in the directory: {best_split_dir}")
elif len(ckpt_files) > 1:
    raise RuntimeError(f"Multiple .ckpt files found in the directory: {best_split_dir} -> {ckpt_files}")
else:
    model_checkpoint_path = os.path.join(best_split_dir, ckpt_files[0])
    print(f"Found checkpoint: {model_checkpoint_path}")

# Load the best model
best_model = UNET_VGG11_LightningModel.load_from_checkpoint(
    checkpoint_path=model_checkpoint_path,
    lr_optimizer = LEARNING_RATE,
    weight_decay_optimizer = WEIGHT_DECAY,
    warmup_epochs_lr_scheduler = WARMUP_EPOCHS_LR_SCHEDULER
)

test_results = trainer.test(best_model, dataloaders=test_loader)[0]

# Measure end training time
end_time = time.time()

# Calculate and print the elapsed time
testing_elapsed_time = end_time - start_time

print(f"Testing completed in {testing_elapsed_time // 3600:.0f}h {testing_elapsed_time % 3600 // 60:.0f}m {testing_elapsed_time % 60:.0f}s")

print()
print(f"{'=' * 10} TESTING FINISHED {'=' * 10}")
print()


####################################################################################
############################## PLOTTING TEST RESULTS ###############################
####################################################################################

test_ground_truths = best_model.test_ground_truths
test_predictions = best_model.test_predictions
test_sample_ids = best_model.test_sample_ids
test_loss = test_results["test_loss_epoch"]
test_mean_iou = test_results["test_mean_iou_epoch"]
test_dice_coeff = test_results["test_dice_epoch"]
test_fpr = test_results["test_fpr_epoch"]

# Find unique classes dynamically
all_classes = set()
for gt, pred in zip(test_ground_truths, test_predictions):
    all_classes.update(torch.unique(gt).tolist())  # Add classes from ground truth
    all_classes.update(torch.unique(pred).tolist())  # Add classes from predictions

# Sort the classes to ensure order
all_classes = sorted(all_classes)

# Define class labels dynamically (for simplicity, use numeric labels for now)
class_labels = {cls: f"Class {cls}" for cls in all_classes}
num_classes = len(class_labels)

# Create a discrete colormap with exactly `num_classes` colors
colormap = plt.cm.get_cmap("viridis", num_classes)

# Create legend patches using discrete colors from the colormap
legend_patches = [
    mpatches.Patch(color=colormap(i), label=f"{i}: {label}")
    for i, label in class_labels.items()
]

# Plot all results together
num_samples = len(test_ground_truths)


# Create the figure
fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))

# Ensure axes is always 2D
if num_samples == 1:
    axes = axes[None, :]  # Ensure axes is 2D when there's only one sample

# Plot ground truth and predictions
for sample_idx, (sample_id, ground_truth, prediction) in enumerate(zip(test_sample_ids, test_ground_truths, test_predictions)):
    # Ground truth
    im_gt = axes[sample_idx, 0].imshow(ground_truth.cpu().numpy(), cmap="viridis", interpolation="none", vmin=min(all_classes), vmax=max(all_classes))
    axes[sample_idx, 0].set_title(f"Sample {sample_id} - Ground Truth")
    axes[sample_idx, 0].axis("off")

    # Prediction
    im_pred = axes[sample_idx, 1].imshow(prediction.cpu().numpy(), cmap="viridis", interpolation="none", vmin=min(all_classes), vmax=max(all_classes))
    axes[sample_idx, 1].set_title(f"Sample {sample_id} - Prediction")
    axes[sample_idx, 1].axis("off")

# Add a single legend for the entire figure
fig.legend(
    handles=legend_patches,
    loc="upper center",
    ncol=num_classes,
    bbox_to_anchor=(0.5, 0.99),
    fontsize=12
)

# Titles for the plots
plot_title_template = (
    f"Testing Execution Time (without Plotting): {testing_elapsed_time // 3600:.0f}h {testing_elapsed_time % 3600 // 60:.0f}m {testing_elapsed_time % 60:.0f}s\n"
    f"test_loss: {test_loss:.4f}, "
    f"test_mean_iou: {test_mean_iou:.4f}, "
    f"test_dice_coeff: {test_dice_coeff:.4f}, "
    f"test_fpr: {test_fpr:.4f} "

)

# Add a title for the entire figure
fig.suptitle(f"TESTING RESULTS (Model from best split = {best_split})\n" + plot_title_template, fontsize=16, y=1.03)
plt.tight_layout()

# Save the combined plot
combined_plot_path = os.path.join(execution_dir, f"test_results_model_best_split[{best_split}].png")
plt.savefig(combined_plot_path, dpi=300, bbox_inches="tight")
plt.show()

plt.close(fig)
