In [2]:
# ONLY HAS TO BE RUN ONCE TO EXPORT DATASET FROM ZIP TO FOLDER
import zipfile

zip_path = "Resources.zip" 
extract_to = "Dataset"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [1]:
# ONLY HAS TO BE RUN ONCE TO ENABLE 5 FOLD CROSS VALIDATION
import os
import shutil

# Making a dictionary to store the disease class of each patient
dataset_path = "Dataset/training"  # Path to train + validation folder
patient_class_dict = {}

for patient_folder in os.listdir(dataset_path): # For file in the training folder
    if patient_folder.startswith("."):  # Skip hidden folders (.ipynb_checkpoints)
        continue

    patient_path = os.path.join(dataset_path, patient_folder) # Patient ID, such as patient001

    if os.path.isdir(patient_path):  # Process only valid patient folders (skip MANDATATORY_CITATION)
        info_file = os.path.join(patient_path, "Info.cfg")
        with open(info_file, "r") as patient_file: # Open file
                lines = patient_file.readlines()
                patient_class = lines[2].strip()
                patient_class_dict[patient_folder] = patient_class # Add to dictionary
        
# Splitting dataset based on dictionary values (20 patients in each dataset)

group_DCM = "Dataset/group_DCM"
group_HCM = "Dataset/group_HCM"
group_MINF = "Dataset/group_MINF"
group_NOR = "Dataset/group_NOR"
group_RV = "Dataset/group_RV"

# Create the directories for each class
for group in [group_DCM, group_HCM, group_MINF, group_NOR, group_RV]:
    if not os.path.exists(group):
        os.makedirs(group)

# Loop through all patients
for patient_folder, disease in patient_class_dict.items():
    patient_path = os.path.join(dataset_path, patient_folder)

    # Check if the folder exists and it's a directory
    if os.path.isdir(patient_path):
        # Determine the target group based on disease
        if "DCM" in disease:
            target_group = group_DCM
        elif "HCM" in disease:
            target_group = group_HCM
        elif "MINF" in disease:
            target_group = group_MINF
        elif "NOR" in disease:
            target_group = group_NOR
        elif "RV" in disease:
            target_group = group_RV
        else:
            print('unknown class error')
            continue

        # Create the patient's folder inside the target group directory
        target_patient_folder = os.path.join(target_group, patient_folder)
        if not os.path.exists(target_patient_folder):
            os.makedirs(target_patient_folder)

        # Copy respective files to new folder
        for file_name in os.listdir(patient_path):
            file_path = os.path.join(patient_path, file_name)
            if os.path.isfile(file_path):  # Check if it's a file
                # Move the file to the respective patient folder in the group folder
                shutil.copytree(file_path, os.path.join(target_patient_folder, file_name)) # YOU CAN ONLY RUN THIS ONCE, AFTER THAT THE TRAINING SET IS EMPTY



NotADirectoryError: [Errno 20] Not a directory: 'Dataset/training/patient001/Info.cfg'

In [1]:
## ALL FUNCTION NEEDED TO TRAIN AND VALIDATE MODEL (Has to be run before any code segment)
import os
import shutil
data_path_train = "Dataset/new_training"
data_path_valid = "Dataset/new_validation"
data_path_test = "Dataset/testing"

import glob
import nibabel as nib
import numpy as np
import monai
import torch.nn.functional as F
from medpy.metric.binary import hd, dc

import time
import torch

from monai.data import CacheDataset, DataLoader
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

from monai.transforms import (
    Compose,
    Lambda,
    LoadImaged,         # if using file paths
    AddChanneld,        # ensures channel-first format
    ScaleIntensityd,    # normalizes intensity to [minv, maxv]
    Spacingd,           # resamples to a uniform voxel spacing
    ResizeWithPadOrCropd,  # resizes images/masks to a fixed spatial size
    EnsureTyped,        # converts arrays to PyTorch tensors
    RandZoomd,          # random zoom augmentation
    RandFlipd,          # random flip augmentation
    RandRotated,        # random rotation augmentation
    RandShiftIntensityd # random intensity shift for brightness variation
)

# Start up wandb and start logging
import wandb



print("Done Importing!")

def get_ed_es_frames(config_path):
    """Extract ED and ES frame numbers from the info.cfg file."""

    ed_frame, es_frame = None, None
    with open(config_path, 'r') as f:
        for line in f:
            if line.startswith('ED:'):
                ed_frame = int(line.split(':')[1].strip())
            elif line.startswith('ES:'):
                es_frame = int(line.split(':')[1].strip())
    return ed_frame, es_frame


def build_dict_acdc(data_path, mode='train'):
    """
    This function returns a list of dictionaries, each containing the paths to the 2D slices 
    of the 3D MRI images and their corresponding masks.
    """
    if mode not in ["train", "val", "test"]:
        raise ValueError(f"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.")
    
    dicts = []
    
    # Loop over all patient directories
    patient_dirs = [d for d in glob.glob(os.path.join(data_path, '*')) if os.path.isdir(d)]
    
    for patient_dir in patient_dirs:
        patient_id = os.path.basename(patient_dir)
        config_path = os.path.join(patient_dir, "Info.cfg")
        
        if not os.path.exists(config_path):
            continue
        
        ed_frame, es_frame = get_ed_es_frames(config_path)
        
        # Identify the ED and ES image and mask paths
        ed_img_path = os.path.join(patient_dir, f"{patient_id}_frame{ed_frame:02d}.nii.gz")
        ed_mask_path = os.path.join(patient_dir, f"{patient_id}_frame{ed_frame:02d}_gt.nii.gz")
        es_img_path = os.path.join(patient_dir, f"{patient_id}_frame{es_frame:02d}.nii.gz")
        es_mask_path = os.path.join(patient_dir, f"{patient_id}_frame{es_frame:02d}_gt.nii.gz")
        
        for img_path, mask_path in [(ed_img_path, ed_mask_path), (es_img_path, es_mask_path)]:
            if not os.path.exists(img_path) or not os.path.exists(mask_path):
                continue
            
            # Load the 3D image and mask using nibabel
            img_volume = nib.load(img_path).get_fdata()
            mask_volume = nib.load(mask_path).get_fdata()
            #print("Unique values in loaded ground truth mask:", np.unique(mask_volume))
            
            # Ensure we have the same number of slices for image and mask
            num_slices = img_volume.shape[2]
            
            # Extract 2D slices
            for slice_idx in range(num_slices):
                img_slice = img_volume[:, :, slice_idx]
                mask_slice = mask_volume[:, :, slice_idx]
                
                dicts.append({'img': img_slice, 'mask': mask_slice})
    
    return dicts

class LoadHeartData(monai.transforms.Transform):
    """
    This custom Monai transform loads 2D slices of MRI data and their corresponding mask for heart segmentation.
    """
    def __init__(self, keys=None):
        pass
    
    def __call__(self, sample):
        img_slice = sample['img']
        mask_slice = sample['mask'] 
        
        # Ensure the image and mask are in compatible formats
        img_slice = np.array(img_slice, dtype=np.float32)
        mask_slice = np.array(mask_slice, dtype=np.uint8) 
        
        # Return the slice and mask with metadata. NOT SURE ABOUT THE METATDATA
        return {'img': img_slice, 'mask': mask_slice, 'img_meta_dict': {'affine': np.eye(2)}, 
                'mask_meta_dict': {'affine': np.eye(2)}}

HEADER = ["Name", "Dice LV", "Volume LV", "Err LV(ml)",
          "Dice RV", "Volume RV", "Err RV(ml)",
          "Dice MYO", "Volume MYO", "Err MYO(ml)"]

#
# Functions to process files, directories and metrics aka loss function
#
def metrics(img_gt, img_pred, voxel_size):
    """
    Function to compute the metrics between two segmentation maps given as input.

    Parameters
    ----------
    img_gt: np.array
    Array of the ground truth segmentation map.

    img_pred: np.array
    Array of the predicted segmentation map.

    voxel_size: list, tuple or np.array
    The size of a voxel of the images used to compute the volumes.

    Return
    ------
    A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml),
    Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)]
    """

    if img_gt.ndim != img_pred.ndim:
        raise ValueError("The arrays 'img_gt' and 'img_pred' should have the "
                         "same dimension, {} against {}".format(img_gt.ndim,
                                                                img_pred.ndim))
    #print("Unique values in ground truth:", np.unique(img_gt))
    #print("Unique values in prediction:", np.unique(img_pred))
    
    res = []
    # Loop on each classes of the input images
    for c in [3, 1, 2]:
        # Copy the gt image to not alterate the input
        gt_c_i = np.copy(img_gt)
        gt_c_i[gt_c_i != c] = 0

        # Copy the pred image to not alterate the input
        pred_c_i = np.copy(img_pred)
        pred_c_i[pred_c_i != c] = 0

        # Clip the value to compute the volumes
        gt_c_i = np.clip(gt_c_i, 0, 1)
        pred_c_i = np.clip(pred_c_i, 0, 1)

        # Compute the Dice
        dice = dc(gt_c_i, pred_c_i)

        # Compute volume
        volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.
        volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.

        res += [dice, volpred, volpred-volgt]

    return res

def compute_metrics_on_files(path_gt, path_pred):
    """
    Function to give the metrics for two files

    Parameters
    ----------

    path_gt: string
    Path of the ground truth image.

    path_pred: string
    Path of the predicted image.
    """
    gt, _, header = load_nii(path_gt)
    pred, _, _ = load_nii(path_pred)
    zooms = header.get_zooms()

    name = os.path.basename(path_gt)
    name = name.split('.')[0]
    res = metrics(gt, pred, zooms)
    res = ["{:.3f}".format(r) for r in res]

    formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}"
    print(formatting.format(*HEADER))
    print(formatting.format(name, *res))

    
# Note that this is a clear, but slow way to do this, we might be better off with a quick hardcode since the patients are ordered anyways.
# So technically, the first 20 of train is just group DCM and we dont need to make any folders or copy any files.
# this could also be sped up by moving per patient and not per file per patient (the last loop)
# DOOR ERIC        
# Recombine into a training and a validation set (set 1 to validation and 4 to training)
def recombining_data(recombine_index):
    
    new_train_path = os.path.join("Dataset", 'new_training')
    new_val_path = os.path.join("Dataset", 'new_validation')
    
    # If folder does not exist yet
    if not os.path.exists(new_train_path):
        os.makedirs(new_train_path)  # Creates the new training folder
    if not os.path.exists(new_val_path):
        os.makedirs(new_val_path)  # Creates the new validation folder
    
    
    # Empty the new_validation folder
    if os.path.exists(new_val_path):
        shutil.rmtree(new_val_path)
        os.makedirs(new_val_path)

    # Empty the new_training folder
    if os.path.exists(new_train_path):
        shutil.rmtree(new_train_path)
        os.makedirs(new_train_path)
    
   
    val_id = [1,2,3,4]
    offset = (recombine_index - 1) * 4
    val_id = [element + offset for element in val_id] # Add the offset to each element of val_id
    #print(val_id)
    
    train_id = list(range(1,21))
    for element in val_id:
        train_id.remove(element) # remove the validation patients
    #print(train_id)
    
    # now there is a list of numbers for who should be in val, and who should be in train
    
    # Define the classes (group folders) you want to loop through
    class_folders = ['group_DCM', 'group_HCM', 'group_MINF', 'group_NOR', 'group_RV']
    
    # Loop through each class folder
    for class_folder in class_folders:
        class_folder_path = os.path.join("Dataset", class_folder)
        
        patients_in_class = [folder for folder in os.listdir(class_folder_path)] # list of all file names in class
        
        for val_target in val_id: # copy all validation patients
            val_patient_target = patients_in_class[val_target - 1] # get name of validation patient
            #print(f"Processing validation patient: {val_patient_target}")
            
            # copy from source to destination
            source_folder = os.path.join("Dataset", class_folder,val_patient_target)
            destination_folder = os.path.join(new_val_path, val_patient_target)
            shutil.copytree(source_folder, destination_folder)
            
        for train_target in train_id: # copy all training patients
            train_patient_target = patients_in_class[train_target - 1] # get name of training patient
            #print(f"Processing training patient: {train_patient_target}")
            
            # copy from source to destination
            source_folder = os.path.join("Dataset", class_folder,train_patient_target)
            destination_folder = os.path.join(new_train_path, train_patient_target)
            shutil.copytree(source_folder, destination_folder)


2025-04-10 11:24:16.916115: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-10 11:24:20.239389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744277060.356549     234 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744277060.415490     234 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-10 11:24:20.602083: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Done Importing!


In [2]:
# ENTIRE TRAINING AND MODEL SAVING LOOP, sending all training, validation and test data to wandb


experiment_name = "PreProc_BasicUNet_Cross_validation" # CHANGE THIS PER RUN!!!!!!!
# Define folder based on the WandB run name and create the folder.
folder_save_path = experiment_name +  "_models"
os.makedirs(folder_save_path, exist_ok=True)

print("Beginning the loop")
# Recombine dataset
recombine_index = [1, 2, 3, 4, 5]

# Beginning Model train loop
for idx in recombine_index:
    print("Beginning loop for index:" + str(idx))
    # Recombine new validation and training set (in folders)
    print("Recombining Data")
    recombining_data(idx)
    # Start a new wandb run to track this script.
    # --------------------------------------------------------------------------------------------------------------------
    # Initialize a new WandB run with configuration based on the experiment name.
    run = wandb.init(
        entity="DLMI_Project",
        project="DLMI_Project",
        config={
            "learning_rate": 1e-4,
            "architecture": experiment_name,  # Using experiment name as the architecture identifier
            "dataset": "ACDC",
            "epochs": 10,
        },
        name=f"{experiment_name}_run_{idx}"
    )

    # Combine the folder path with the model filename.
    model_save_path = os.path.join(folder_save_path, f"{experiment_name}_cross_variant_{idx}.pth")
    
    
    # load training and test set

    print("PyTorch version:", torch.__version__)
    print("CUDA version (PyTorch):", torch.version.cuda)

    # Define a common preprocessing pipeline
    common_transform = Compose([
        LoadHeartData(),  
        AddChanneld(keys=["img", "mask"]), # Add channel dimension
        ScaleIntensityd(keys=["img"], minv=0, maxv=1),  # Normalize intensity
        Spacingd(keys=["img", "mask"], pixdim=(1.25, 1.25), mode=("bilinear", "nearest")), # Resample voxel spacing in x and y
        # Resized(keys=["img", "mask"], spatial_size=(256, 256), mode=("area", "nearest")),
        ResizeWithPadOrCropd(keys=["img", "mask"], spatial_size=[256, 256]), # Ensures all images have the same dimensions (without getting stretched out). 
        EnsureTyped(keys=["img", "mask"])
    ])

    # Train Transform
    train_transforms = Compose([
        *common_transform.transforms,  # Apply all common steps first
        RandZoomd(keys=["img", "mask"], prob=0.1, min_zoom=0.9, max_zoom=1.1, keep_size=True), # Random zoom, not too much so that you don't remove important parts
        RandFlipd(keys=["img", "mask"], prob=0.1, spatial_axis=0),  # Random flip. Spatial axis=0 for up-down flipping. Left-right flipping is not good because the model has to distinguish the left and right ventricle
        RandRotated(keys=["img", "mask"], range_x=np.pi/12, prob=0.1, mode=("bilinear", "nearest")),
        # RandShiftIntensityd(keys=["img"], offsets=0.05, prob=0.5) # Not too much, so that left-right orientation is no problem
        # monai.transforms.RandSpatialCropd(keys=['img', 'mask'], roi_size=[256,256], random_size=False),  # Random crop
        # monai.transforms.RandShiftIntensityd(keys=['img'], offsets=0.05, prob=0.5),  
    ])

    test_transforms = common_transform
    valid_transforms = common_transform

    # validation_test_transforms = monai.transforms.Compose([
    #     LoadHeartData(),  # Load the heart data (must be first!)
    #     monai.transforms.AddChanneld(keys=['img', 'mask']),  # Add channel dimension for img and multilabel mask
    #     monai.transforms.ScaleIntensityd(keys=['img'], minv=0, maxv=1),  # Normalize intensity to the range [0, 1]
    #     monai.transforms.Spacingd(keys=["img", "mask"], pixdim=(1.25, 1.25), mode=("bilinear", "nearest")),  # Resample voxel spacing in x and y
    #     monai.transforms.ResizeWithPadOrCropd(keys=['img', 'mask'], spatial_size=[256, 256]),  # Ensure consistent size [256, 256]
    # ])


    # original training data
    train_data = build_dict_acdc(data_path_train, mode='train')

    # Create CacheDatasets for training and testing
    train_dataset = CacheDataset(
        data=train_data,
        transform=train_transforms
    )

    test_dataset = CacheDataset(
        data=build_dict_acdc(data_path_test, mode='test'),
        transform=test_transforms
    )

    # compose the 4 training datasets into one variable
    valid_dataset = CacheDataset(
        data=build_dict_acdc(data_path_valid, mode='val'),
        transform=test_transforms
    )

    # Construct CacheDataset from the list of dictionaries and apply the transform
    #train_dataset = monai.data.CacheDataset(data=train_dict_list, transform=LoadHeartData())
    #test_dataset = monai.data.CacheDataset(data=test_dict_list, transform=LoadHeartData())

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)
    valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)
    # validation_loader



    # DEFINE THE ARCHITECTURE
    # ---------------------------------------------------------------------------------------------------------------
    # Define the device to use
    print("CUDA Available:", torch.cuda.is_available())
    print("CUDA Device Count:", torch.cuda.device_count())
    print("CUDA Current Device:", torch.cuda.current_device() if torch.cuda.is_available() else "No GPU")
    print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
    print("PyTorch version:", torch.__version__)
    print("CUDA version (PyTorch):", torch.version.cuda)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    #device = torch.device("cpu")
    torch.cuda.empty_cache()
    # Initialize the U-Net model.
    # Here, dimensions=2 for 2D slices; in_channels=1 and out_channels=1 for binary segmentation.
    model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    wandb.watch(model, log="all")
    # Define the loss function and optimizer.
    # DiceLoss with sigmoid=True is used for binary segmentation.
    loss_function = DiceLoss(softmax=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

    # (Optional) DiceMetric for evaluation during training
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    class_labels = {
        0: "Background",
        1: "Right Ventricular Endocardium",
        2: "Left Ventricular Epicardium",
        3: "Left Ventricular Endocardium"
    }
    print("Model loaded")


    print(torch.__version__)

    #torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True

    #unique_values = torch.unique(labels)
    #print("Unique values in batch labels:", unique_values)

    # START THE TRAINING
    # --------------------------------------------------------------------------------------------------------------------------
    # Training loop
    num_epochs = wandb.config.epochs
    for epoch in range(num_epochs):
        print("-" * 10, f"Epoch {epoch + 1}/{num_epochs}", "-" * 10)
        model.train()
        epoch_loss = 0
        step = 0
        start_time = time.time()

        for batch_data in train_loader:
            step += 1
            inputs = batch_data["img"].to(device)
            # Convert labels from shape (B, 1, H, W) to (B, H, W)
            labels = batch_data["mask"].squeeze(1).to(device)
            unique_values = torch.unique(labels)
            #print("Unique label values:", unique_values)

            optimizer.zero_grad()
            outputs = model(inputs)  # shape: (B, 4, H, W)
            outputs = outputs.contiguous()
            # Convert labels to one-hot encoding: shape becomes (B, H, W, 4)
            one_hot_labels = F.one_hot(labels.long(), num_classes=4)
            # Permute to get shape (B, 4, H, W)
            one_hot_labels = one_hot_labels.permute(0, 3, 1, 2).float()

            loss = loss_function(outputs, one_hot_labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            #print(f"{step}/{len(train_loader)}: loss = {loss.item():.4f}")
            wandb.log({"train_step_loss": loss.item(), "epoch": epoch + 1})

        epoch_loss /= step
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}, time: {epoch_time:.2f} sec")

        # Evaluate on the test set at the end of each epoch
        voxel_size = (1.25, 1.25)  # For 2D slices; adjust as needed

        model.eval()
        all_metrics = []
        with torch.no_grad():
            for val_data in valid_loader: # switch to validation loader
                val_inputs = val_data["img"].to(device)
                val_labels = val_data["mask"].squeeze(1).to(device)  # shape: (B, H, W)
                val_outputs = model(val_inputs)  # shape: (B, 4, H, W)

                # For evaluation, use the integer label maps directly.
                pred_labels = torch.argmax(torch.softmax(val_outputs, dim=1), dim=1)  # (B, H, W)
                gt_labels = val_labels  # already in (B, H, W) after squeeze

                # Convert to numpy arrays
                pred_labels_np = pred_labels.cpu().numpy()
                gt_labels_np = gt_labels.cpu().numpy()

                for gt, pred in zip(gt_labels_np, pred_labels_np):
                    sample_metrics = metrics(gt, pred, voxel_size)
                    all_metrics.append(sample_metrics)

            avg_metrics = np.mean(all_metrics, axis=0)
            print("Validation metrics:", avg_metrics)
            wandb.log({
                "epoch": epoch + 1,
                "epoch_loss": epoch_loss,
                "Dice_LV": avg_metrics[0],
                "Volume_LV": avg_metrics[1],
                "Err_LV": avg_metrics[2],
                "Dice_RV": avg_metrics[3],
                "Volume_RV": avg_metrics[4],
                "Err_RV": avg_metrics[5],
                "Dice_MY0": avg_metrics[6],
                "Volume_MY0": avg_metrics[7],
                "Err_MY0": avg_metrics[8],
                "epoch_time_sec": epoch_time
            })
    

    # Save the trained model at the end

    # Save the model.
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at {model_save_path}")
    
    # Get test results
    model.eval()
    all_metrics = []
    with torch.no_grad():
        for val_data in test_loader: # switch to validation loader
            val_inputs = val_data["img"].to(device)
            val_labels = val_data["mask"].squeeze(1).to(device)  # shape: (B, H, W)
            val_outputs = model(val_inputs)  # shape: (B, 4, H, W)

            # For evaluation, use the integer label maps directly.
            pred_labels = torch.argmax(torch.softmax(val_outputs, dim=1), dim=1)  # (B, H, W)
            gt_labels = val_labels  # already in (B, H, W) after squeeze

            # Convert to numpy arrays
            pred_labels_np = pred_labels.cpu().numpy()
            gt_labels_np = gt_labels.cpu().numpy()

            for gt, pred in zip(gt_labels_np, pred_labels_np):
                sample_metrics = metrics(gt, pred, voxel_size)
                all_metrics.append(sample_metrics)

        avg_metrics = np.mean(all_metrics, axis=0)
        print("Validation metrics:", avg_metrics)
        wandb.log({
            "Dice_LV_test": avg_metrics[0],
            "Volume_LV_test": avg_metrics[1],
            "Err_LV_test": avg_metrics[2],
            "Dice_RV_test": avg_metrics[3],
            "Volume_RV_test": avg_metrics[4],
            "Err_RV_test": avg_metrics[5],
            "Dice_MY0_test": avg_metrics[6],
            "Volume_MY0_test": avg_metrics[7],
            "Err_MY0_test": avg_metrics[8]
        })
        
        print("Test metrics:", avg_metrics)
    
    run.finish()


Done Importing!
Beginning the loop
Beginning loop for index:1
Recombining Data


[34m[1mwandb[0m: Currently logged in as: [33mdejanhonderd100[0m ([33mDLMI_Project[0m). Use [1m`wandb login --relogin`[0m to force relogin


PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1


Loading dataset: 100%|██████████| 1506/1506 [00:15<00:00, 94.53it/s] 
Loading dataset: 100%|██████████| 1076/1076 [00:11<00:00, 90.24it/s]
Loading dataset: 100%|██████████| 396/396 [00:03<00:00, 106.54it/s]


CUDA Available: True
CUDA Device Count: 2
CUDA Current Device: 0
CUDA Device Name: Tesla T4
PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1
Using device: cuda
Model loaded
2.5.1+cu121
---------- Epoch 1/10 ----------


  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  t = cls([], dtype=storage.dtype, device=storage.device)


Epoch 1 average loss: 0.8320, time: 25.61 sec
Validation metrics: [ 0.12406754  5.70952888  5.12298769  0.22892039  3.87763573  3.32714646
  0.02943132 22.41736111 21.80479403]
---------- Epoch 2/10 ----------
Epoch 2 average loss: 0.7532, time: 13.78 sec
Validation metrics: [ 0.06869822 11.93392914 11.34738794  0.36016891  1.89311869  1.34262942
  0.39665709  1.75598169  1.14341461]
---------- Epoch 3/10 ----------
Epoch 3 average loss: 0.6579, time: 14.67 sec
Validation metrics: [0.08060192 9.5943971  9.0078559  0.49549249 0.97700836 0.4265191
 0.57933956 1.11699416 0.50442708]
---------- Epoch 4/10 ----------
Epoch 4 average loss: 0.5428, time: 13.46 sec
Validation metrics: [ 0.66214257  0.53847064 -0.04807055  0.51872242  0.88786695  0.33737768
  0.63221529  0.76912879  0.15656171]
---------- Epoch 5/10 ----------
Epoch 5 average loss: 0.3845, time: 13.65 sec
Validation metrics: [0.73582685 0.59107876 0.00453756 0.55744761 0.76214489 0.21165562
 0.65240133 0.63923217 0.02666509]
--

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Dice_LV,▂▁▁▇▇▇████
Dice_LV_test,▁
Dice_MY0,▁▅▇▇██████
Dice_MY0_test,▁
Dice_RV,▁▄▆▇▇▇████
Dice_RV_test,▁
Err_LV,▄█▇▁▁▁▁▁▁▁
Err_LV_test,▁
Err_MY0,█▁▁▁▁▁▁▁▁▁
Err_MY0_test,▁

0,1
Dice_LV,0.79791
Dice_LV_test,0.76375
Dice_MY0,0.69314
Dice_MY0_test,0.65449
Dice_RV,0.58333
Dice_RV_test,0.58243
Err_LV,-0.02173
Err_LV_test,-0.02089
Err_MY0,-0.00355
Err_MY0_test,0.05734


Beginning loop for index:2
Recombining Data


PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1


Loading dataset: 100%|██████████| 1514/1514 [00:16<00:00, 93.96it/s] 
Loading dataset: 100%|██████████| 1076/1076 [00:11<00:00, 97.07it/s]
Loading dataset: 100%|██████████| 388/388 [00:03<00:00, 105.07it/s]


CUDA Available: True
CUDA Device Count: 2
CUDA Current Device: 0
CUDA Device Name: Tesla T4
PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1
Using device: cuda
Model loaded
2.5.1+cu121
---------- Epoch 1/10 ----------
Epoch 1 average loss: 0.8516, time: 14.38 sec
Validation metrics: [ 0.02717469 19.97388048 19.2990174   0.21607139  4.96375644  4.26749356
  0.0623662   8.19134182  7.47769008]
---------- Epoch 2/10 ----------
Epoch 2 average loss: 0.7924, time: 14.09 sec
Validation metrics: [ 0.0204755  20.35251289 19.67764981  0.31729939  2.43342059  1.7371577
  0.25664161  4.0071279   3.29347616]
---------- Epoch 3/10 ----------
Epoch 3 average loss: 0.7355, time: 14.21 sec
Validation metrics: [ 0.02265566 15.93504349 15.26018041  0.4393896   1.57000242  0.87373953
  0.36957038  2.51162613  1.79797439]
---------- Epoch 4/10 ----------
Epoch 4 average loss: 0.6747, time: 13.97 sec
Validation metrics: [ 0.38969708  0.65990657 -0.01495651  0.48158847  1.2123067   0.51604381
  0.4

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Dice_LV,▁▁▁▆▇█████
Dice_LV_test,▁
Dice_MY0,▁▃▄▅▆▇████
Dice_MY0_test,▁
Dice_RV,▁▃▅▆▇▇▇███
Dice_RV_test,▁
Err_LV,██▆▁▁▁▁▁▁▁
Err_LV_test,▁
Err_MY0,█▄▃▂▂▁▁▁▁▁
Err_MY0_test,▁

0,1
Dice_LV,0.53604
Dice_LV_test,0.51407
Dice_MY0,0.67954
Dice_MY0_test,0.66447
Dice_RV,0.56514
Dice_RV_test,0.60456
Err_LV,-0.35557
Err_LV_test,-0.30944
Err_MY0,0.02275
Err_MY0_test,0.06002


Beginning loop for index:3
Recombining Data


PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1


Loading dataset: 100%|██████████| 1538/1538 [00:17<00:00, 87.02it/s]
Loading dataset: 100%|██████████| 1076/1076 [00:11<00:00, 95.61it/s]
Loading dataset: 100%|██████████| 364/364 [00:03<00:00, 100.99it/s]


CUDA Available: True
CUDA Device Count: 2
CUDA Current Device: 0
CUDA Device Name: Tesla T4
PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1
Using device: cuda
Model loaded
2.5.1+cu121
---------- Epoch 1/10 ----------
Epoch 1 average loss: 0.8152, time: 15.31 sec
Validation metrics: [0.15166227 6.2986693  5.64081817 0.02106528 8.10744763 7.58049451
 0.29508019 2.14590917 1.50621995]
---------- Epoch 2/10 ----------
Epoch 2 average loss: 0.7316, time: 15.81 sec
Validation metrics: [0.63067492 0.84399038 0.18613925 0.08656145 6.8160285  6.28907538
 0.51824934 1.297858   0.65816878]
---------- Epoch 3/10 ----------
Epoch 3 average loss: 0.6221, time: 15.64 sec
Validation metrics: [0.7158785  0.7091432  0.05129207 0.33767949 1.43497167 0.90801854
 0.60624328 0.87189217 0.23220295]
---------- Epoch 4/10 ----------
Epoch 4 average loss: 0.4756, time: 15.27 sec
Validation metrics: [ 0.74596707  0.60605254 -0.05179859  0.50696621  0.79362552  0.26667239
  0.63439662  0.7332203   0.093

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Dice_LV,▁▆▇███████
Dice_LV_test,▁
Dice_MY0,▁▅▇▇▇▇████
Dice_MY0_test,▁
Dice_RV,▁▂▅▇▇█████
Dice_RV_test,▁
Err_LV,█▁▁▁▁▁▁▁▁▁
Err_LV_test,▁
Err_MY0,█▄▂▁▁▁▁▁▁▁
Err_MY0_test,▁

0,1
Dice_LV,0.78612
Dice_LV_test,0.7834
Dice_MY0,0.68727
Dice_MY0_test,0.66391
Dice_RV,0.57441
Dice_RV_test,0.59994
Err_LV,0.03737
Err_LV_test,0.01525
Err_MY0,-0.00548
Err_MY0_test,-0.01779


Beginning loop for index:4
Recombining Data


PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1


Loading dataset: 100%|██████████| 1498/1498 [00:16<00:00, 88.98it/s] 
Loading dataset: 100%|██████████| 1076/1076 [00:11<00:00, 95.82it/s]
Loading dataset: 100%|██████████| 404/404 [00:04<00:00, 90.89it/s]


CUDA Available: True
CUDA Device Count: 2
CUDA Current Device: 0
CUDA Device Name: Tesla T4
PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1
Using device: cuda
Model loaded
2.5.1+cu121
---------- Epoch 1/10 ----------
Epoch 1 average loss: 0.8364, time: 14.62 sec
Validation metrics: [ 0.1880674   1.55946782  0.87210319  0.03545438 14.29804301 13.6268603
  0.26375692  3.55461015  2.90492729]
---------- Epoch 2/10 ----------
Epoch 2 average loss: 0.7651, time: 15.53 sec
Validation metrics: [ 0.4221296   0.45799428 -0.22937036  0.02948592 15.74422184 15.07303914
  0.40378781  2.14863861  1.49895575]
---------- Epoch 3/10 ----------
Epoch 3 average loss: 0.6999, time: 15.23 sec
Validation metrics: [ 0.48521923  0.37824103 -0.30912361  0.03539615 15.37170869 14.70052599
  0.48001632  1.50236309  0.85268023]
---------- Epoch 4/10 ----------
Epoch 4 average loss: 0.6246, time: 15.23 sec
Validation metrics: [ 0.49489173  0.33582534 -0.35153929  0.03604384 15.69794632 15.02676361
  0.5

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Dice_LV,▁▆▇▇▇█████
Dice_LV_test,▁
Dice_MY0,▁▄▆▇▇█▇███
Dice_MY0_test,▁
Dice_RV,▁▁▁▁▁▁▁▇██
Dice_RV_test,▁
Err_LV,█▂▁▁▁▁▁▁▁▁
Err_LV_test,▁
Err_MY0,█▄▂▁▂▁▁▁▁▁
Err_MY0_test,▁

0,1
Dice_LV,0.5229
Dice_LV_test,0.50242
Dice_MY0,0.57816
Dice_MY0_test,0.55274
Dice_RV,0.39953
Dice_RV_test,0.39034
Err_LV,-0.35911
Err_LV_test,-0.31144
Err_MY0,0.41236
Err_MY0_test,0.42053


Beginning loop for index:5
Recombining Data


PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1


Loading dataset: 100%|██████████| 1552/1552 [00:17<00:00, 90.56it/s] 
Loading dataset: 100%|██████████| 1076/1076 [00:11<00:00, 96.59it/s]
Loading dataset: 100%|██████████| 350/350 [00:03<00:00, 104.94it/s]


CUDA Available: True
CUDA Device Count: 2
CUDA Current Device: 0
CUDA Device Name: Tesla T4
PyTorch version: 2.5.1+cu121
CUDA version (PyTorch): 12.1
Using device: cuda
Model loaded
2.5.1+cu121
---------- Epoch 1/10 ----------
Epoch 1 average loss: 0.8404, time: 15.65 sec
Validation metrics: [ 0.40698449  2.33869196  1.51647321  0.03188176 18.55836161 17.74334821
  0.0255951  21.89632143 21.00029464]
---------- Epoch 2/10 ----------
Epoch 2 average loss: 0.7808, time: 14.87 sec
Validation metrics: [ 0.61117535  1.1174375   0.29521875  0.0360621  18.65180804 17.83679464
  0.1643364   3.09306696  2.19704018]
---------- Epoch 3/10 ----------
Epoch 3 average loss: 0.7134, time: 12.55 sec
Validation metrics: [ 0.69163608  0.76673214 -0.05548661  0.03103127 16.36521429 15.55020089
  0.45345889  1.23683482  0.34080804]
---------- Epoch 4/10 ----------
Epoch 4 average loss: 0.6183, time: 15.05 sec
Validation metrics: [ 0.71273157  0.83241071  0.01019196  0.29354598  0.49544196 -0.31957143
  0.

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Dice_LV,▁▅▆▇▇█████
Dice_LV_test,▁
Dice_MY0,▁▂▆▆▇█████
Dice_MY0_test,▁
Dice_RV,▁▁▁▆▇█████
Dice_RV_test,▁
Err_LV,█▃▂▂▂▁▂▁▁▁
Err_LV_test,▁
Err_MY0,█▂▁▁▁▁▁▁▁▁
Err_MY0_test,▁

0,1
Dice_LV,0.7779
Dice_LV_test,0.75042
Dice_MY0,0.67387
Dice_MY0_test,0.64507
Dice_RV,0.39638
Dice_RV_test,0.39778
Err_LV,-0.19039
Err_LV_test,-0.06242
Err_MY0,-0.16192
Err_MY0_test,0.00914
