In [None]:
#%% Set data path
data_path = r'./database'

In [None]:
#importing all needed packages
import os
import glob
import nibabel as nib
import numpy as np
import monai
import torch
import torch.optim as optim
import torch.nn as nn
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from monai.transforms import Compose, ScaleIntensityd, EnsureChannelFirstd, RandZoomd, RandFlipd, RandRotated, Resized, RandAdjustContrastd, RandGaussianSmoothd, RandGaussianNoised, Rand2DElasticd, RandShiftIntensityd
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.utils import one_hot
from monai.networks.nets import UNet
from sklearn.model_selection import KFold
from tqdm import tqdm
import wandb
import cv2
import shutil

In [None]:
#%% Adapted Script for Heart MRI Segmentation using a 2D UNet with Cross-Validation and Learning Rate Scheduling

#%% Utility function for loading NIfTI images
def load_nii(img_path):
    nimg = nib.load(img_path)
    return nimg.get_fdata(), nimg.affine, nimg.header

##%% Build a dictionary for the test dataset
def build_test_dict(data_path):
    image_dir = os.path.join(data_path, "testing", "image")
    mask_dir = os.path.join(data_path, "testing", "segmentation")

    image_paths = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
    mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*_gt.nii.gz")))

    mask_dict = {os.path.basename(m).replace("_gt", ""): m for m in mask_paths}
    dataset_dicts = []
    for img_path in image_paths:
        filename = os.path.basename(img_path)
        patient_number = filename.split('_')[0]
        frame = os.path.splitext(os.path.splitext(filename)[0])[0].split('_')[1]
        mask_path = mask_dict.get(filename, None)
        if mask_path and os.path.exists(mask_path):
            dataset_dicts.append({"patient": patient_number, "img": img_path, "mask": mask_path, "img_meta_dict": {"filename": filename, "patient_id": patient_number, "frame": frame}})
    return dataset_dicts

class LoadHeartData(monai.transforms.Transform):
    def __call__(self, sample):
        img_vol, _, _ = load_nii(sample['img'])
        mask_vol, _, _ = load_nii(sample['mask'])
        images = np.moveaxis(img_vol, -1, 0)
        masks = np.moveaxis(mask_vol, -1, 0)
        patient_id = sample['img_meta_dict'].get('patient_id', 'unknown')
        filename = sample['img_meta_dict'].get('filename', 'unknown')
        frame = sample['img_meta_dict'].get('frame', 'unknown') 
        slice_list = []
        for i in range(images.shape[0]):
            slice_list.append({
                'img': images[i].astype(np.float32),
                'mask': masks[i].astype(np.uint8),
                'img_meta_dict': {
                    'affine': np.eye(2),
                    'patient_id': patient_id,
                    'filename': filename,
                    'slice_index': i,
                    'frame': frame
                },
                'mask_meta_dict': {'affine': np.eye(2)}
            })
        return slice_list

#%% Set data path
main_path = r'./database'
#main_path = r"/content/drive/MyDrive/ACDC/database"

#%% Build dataset dictionary (training + validation combined for CV)
dataset_dicts = build_test_dict(main_path)

#%% Define transforms
transforms = Compose([
    LoadHeartData(),
    EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
    ScaleIntensityd(keys=['img'], minv=0, maxv=1),
    Resized(keys=['img', 'mask'], spatial_size=(256, 256), mode=['bilinear', 'nearest']),
    RandZoomd(keys=['img', 'mask'], min_zoom=0.9, max_zoom=1.1, mode=['bilinear', 'nearest'], prob=0.5),
    RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=1),
    RandRotated(keys=['img', 'mask'], range_x=0.1, range_y=0.1, mode=['bilinear', 'nearest'], prob=0.5),
    Rand2DElasticd(keys=['img', 'mask'], spacing=(5, 5), magnitude_range=(0, 0.1), prob=0.5, mode=['bilinear', 'nearest']),
    RandAdjustContrastd(keys=['img'], gamma=(0.8, 1.2), prob=0.3),
    RandGaussianSmoothd(keys=['img'], sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), prob=0.3),
    RandGaussianNoised(keys=['img'], prob=0.3, mean=0.0, std=0.05),
    RandShiftIntensityd(keys=['img'], prob=0.5, offsets=(10,20))
])

In [None]:
#%% Flatten dataset to handle all slices
def flatten_dataset(dataset_list, transform):
    flat_list = []
    for data in dataset_list:
        flat_list.extend(transform(data))
    return flat_list

#full_dataset = flatten_dataset(dataset_dicts * 3, transforms)
#full_dataset = np.array(full_dataset)  # Convert to numpy for indexing in cross-validation

#%% Cross-Validation Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 200
k_folds = 5
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

fold_results = []
dice_metric = DiceMetric(include_background=True, reduction="mean")

In [None]:
model_dir = "." 

#%% Build and flatten the test dataset
test_dicts = build_test_dict(r'./database')
test_flat = flatten_dataset(test_dicts, transforms)
#%% Create test DataLoader
test_dataset = monai.data.Dataset(data=test_flat)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

#%% Function to load ensemble models
def load_ensemble_models(model_paths, device):
    models = []
    for path in model_paths:
        model = UNet(
            spatial_dims=2,
            in_channels=1,
            out_channels=4,
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
        model.load_state_dict(torch.load(path, map_location=device))
        model.eval()
        models.append(model)
    return models

#FOR UNCERTAINTY #https://medium.com/biased-algorithms/uncertainty-estimation-in-machine-learning-with-monte-carlo-dropout-72377f5ee276

ensemble_model_path = "ensembleHeartUNet.pt"

model_paths = [os.path.join(model_dir, f"bestHeartUNet_Fold{i+1}.pt") for i in range(5)]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EnsembleModel with MC Dropout for uncertainty calculation
class EnsembleModelWithUncertainty(torch.nn.Module):
    def __init__(self, models, num_samples_mc=10):
        super(EnsembleModelWithUncertainty, self).__init__()
        self.models = torch.nn.ModuleList(models)
        self.num_samples_mc = num_samples_mc  # Number of Monte Carlo samples

    def forward(self, x):
        outputs_list = []
        for model in self.models:
            model.train()  # Dropout is actief
            model_outputs = []
            for _ in range(self.num_samples_mc):
                model_outputs.append(torch.softmax(model(x), dim=1))  # Voorspellingen met actieve dropout
            outputs_list.append(torch.stack(model_outputs, dim=0))  # [num_samples_mc, B, C, H, W]

        # Combineer alle voorspellingen van het ensemble
        all_outputs = torch.cat(outputs_list, dim=0)  # [num_samples_mc * len(models), B, C, H, W]

        # Bereken het gemiddelde en de standaarddeviatie
        mean_output = torch.mean(all_outputs, dim=0)  # Gemiddelde voorspelling (mean_pred)
        std_output = torch.std(all_outputs, dim=0)    # Standaarddeviatie (uncertainty)

        return mean_output, std_output  # Retourneer zowel het gemiddelde als de standaarddeviatie

# Evaluation phase:
def compute_ensemble_dice_and_uncertainty(model_paths, device, test_dataloader, num_samples_mc=10):
    # Load ensemble models
    ensemble_models = load_ensemble_models(model_paths, device)
    ensemble_model = EnsembleModelWithUncertainty(ensemble_models, num_samples_mc)
    
    # Initialize Dice metric and empty list for uncertainty maps
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric.reset()
    uncertainty_maps = []
    total_uncertainty = 0.0
    num_batches = 0

    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            images = batch["img"].to(device)
            # Perform inference with the ensemble model and get predictions
            mean_pred, std_pred = ensemble_model(images)
            # Calculate Dice score for the current batch
            preds = torch.argmax(mean_pred, dim=1, keepdim=True)  # Select the class with the highest probability
            preds_onehot = one_hot(preds, num_classes=4)
            masks = batch["mask"].to(device).unsqueeze(1)
            masks_onehot = one_hot(masks, num_classes=4)
            dice_metric(y_pred=preds_onehot, y=masks_onehot)

            # Calculate and store uncertainty (e.g., standard deviation)
            batch_uncertainty = std_pred.mean().item()  # Gemiddelde onzekerheid per batch
            total_uncertainty += batch_uncertainty
            num_batches += 1
            uncertainty_maps.append(std_pred.cpu().numpy())  # Add uncertainty map
            
    # Aggregate the Dice score for the whole dataset
    dice_score = dice_metric.aggregate().item()
    dice_metric.reset()
    
    # Compute overall uncertainty score
    avg_uncertainty = total_uncertainty / num_batches if num_batches > 0 else 0.0
    # Return the Dice score, the list of uncertainty maps, and overall uncertainty score
    return dice_score, uncertainty_maps, avg_uncertainty

dice_score, uncertainty_maps, avg_uncertainty = compute_ensemble_dice_and_uncertainty(model_paths, device, test_dataloader)

#print scores
print(f"Ensemble Dice Score on Test Dataset: {dice_score}")
print(f"Overall Uncertainty Score: {avg_uncertainty}")



In [None]:
#NEW CLASS ENSEMBLEMODEL TO OUTPUT UNCERTAINTY
class EnsembleModel(torch.nn.Module):
    def __init__(self, models, num_samples_mc=10):
        super(EnsembleModel, self).__init__()
        self.models = torch.nn.ModuleList(models)
        self.num_samples_mc = num_samples_mc  #number of monte carlo samples

    def forward(self, x):
        outputs_list = []
        for model in self.models:
            model.train()  # put it in train mode to ensure dropout remains active
            model_outputs = []
            for _ in range(self.num_samples_mc):
                model_outputs.append(torch.softmax(model(x), dim=1))  
            outputs_list.append(torch.stack(model_outputs, dim=0))  # [num_samples_mc, B, C, H, W]

        # combine the results of the ensemble
        all_outputs = torch.cat(outputs_list, dim=0)  # [num_samples_mc * len(models), B, C, H, W]

        # calculate mean and standard deviation
        mean_output = torch.mean(all_outputs, dim=0)  # mean prediction
        std_output = torch.std(all_outputs, dim=0)   # uncertainty (standard deviation)

        return mean_output, std_output  #return outputs

In [None]:
#only patient 149
from collections import defaultdict

# Create directory for saving plots
output_dir = "Report_uncertainty_error"
os.makedirs(output_dir, exist_ok=True)

# GROUP samples per patient frame
grouped_samples = defaultdict(list)

# get samples per patient and frame
for idx in range(len(test_dataset)):
    sample = test_dataset[idx]
    pid = sample["img_meta_dict"]["patient_id"]
    frame = sample["img_meta_dict"].get("frame", "N/A")
    key = (pid, frame)
    grouped_samples[key].append(sample)

#print("Alle patient-frame combinaties in grouped_samples:")
#for (patient_id, frame) in grouped_samples.keys():
 #   print(f"Patient: {patient_id}, Frame: {frame}")

    
with torch.no_grad():
    for (patient_id, frame), samples in grouped_samples.items():
        if not (str(patient_id) == "patient149" and str(frame) == "frame12"):
            continue

        fig, axes = plt.subplots(len(samples), 5, figsize=(24, 3 * len(samples)))
        fig.subplots_adjust(top=0.92, wspace=0.01, hspace=0.4)
        #fig.tight_layout(rect=[0, 0, 1, 0.95])
        fig.suptitle(f"Patient: {patient_id}, Frame: {frame}", fontsize=16, fontweight='bold')

        if len(samples) == 1:
            axes = [axes]  # Ensure it's iterable

        for row, sample in enumerate(samples):
            slice_index = sample["img_meta_dict"].get("slice_index", 0)
            img = sample["img"].to(device).unsqueeze(0)
            mask = sample["mask"].to(device).unsqueeze(0)

            ensemble_models = load_ensemble_models(model_paths, device)
            ensemble_model = EnsembleModel(ensemble_models, num_samples_mc=10)

            mean_pred, uncertainty = ensemble_model(img)
            output_prob = torch.softmax(mean_pred, dim=1)
            pred = torch.argmax(output_prob, dim=1, keepdim=True)

            img_np = img.squeeze().cpu().numpy()
            mask_np = mask.squeeze().cpu().numpy()
            pred_np = pred.squeeze().cpu().numpy()
            uncertainty_np = uncertainty.squeeze().cpu().numpy().mean(axis=0)

            mask_rv = (mask_np == 1).astype(np.uint8)
            mask_myo = (mask_np == 2).astype(np.uint8)
            mask_lv = (mask_np == 3).astype(np.uint8)
            pred_rv = (pred_np == 1).astype(np.uint8)
            pred_myo = (pred_np == 2).astype(np.uint8)
            pred_lv = (pred_np == 3).astype(np.uint8)

            error_lv = np.abs(pred_lv - mask_lv)
            error_rv = np.abs(pred_rv - mask_rv)
            error_myo = np.abs(pred_myo - mask_myo)
            error_background = ((mask_np == 0) & (pred_np != 0)).astype(np.uint8)
            combined_error = (pred_np != mask_np).astype(np.uint8)

            row_axes = axes[row]

            row_axes[0].imshow(img_np, cmap="gray")
            row_axes[0].set_title(f"Slice {slice_index}")

            row_axes[1].imshow(mask_np, cmap="gray")
            row_axes[1].set_title("GT Mask")

            row_axes[2].imshow(pred_np, cmap="gray")
            row_axes[2].set_title("Prediction")

            row_axes[3].imshow(uncertainty_np, cmap="hot")
            row_axes[3].set_title("Uncertainty")

            row_axes[4].imshow(combined_error, cmap="hot")
            row_axes[4].set_title("Total Error")

            images = [img_np, mask_np, pred_np, uncertainty_np, combined_error]
            titles = ["Slice", "GT Mask", "Prediction", "Uncertainty", "Total Error"]

            for col, (ax, img, title) in enumerate(zip(axes[row], images, titles)):
                ax.imshow(img, cmap="hot" if title in ["Uncertainty", "Total Error"] else "gray")
                ax.set_title(f"{title} {slice_index}" if title == "Slice" else title)
                ax.axis("off")

        save_path = os.path.join(output_dir, f"Patient_{patient_id}_Frame_{frame}_ALL_SLICES.png")
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()


In [None]:
# kan weg als andere het doen VISUALIZATION
# Create directory for saving plots
output_dir = "Plots_uncertainty_error"
os.makedirs(output_dir, exist_ok=True)

# Resize image because we do not want to apply transforms on the visualized image
# (leads to weird intensities), but we do want it to be the same size
def resize_image(img, size=(256, 256)):
    # If img is a tensor, convert it to a NumPy array
    if isinstance(img, torch.Tensor):
        img = img.cpu().numpy()
    img_resized = cv2.resize(img, size, interpolation=cv2.INTER_LINEAR)
    return img_resized

indices = list(range(len(test_dataset)))
#print(f"Selected test sample indices: {indices}")

dice_metric = DiceMetric(include_background=True, reduction="mean")

# Inference and debugging with uncertainty maps
with torch.no_grad():
    for i, idx in enumerate(indices):
        # Visualize the results
        sample = test_dataset[idx]
        
        # Load original NIfTI slice to ensure accurate visualization
        patient_id = sample["img_meta_dict"]["patient_id"]
        frame = sample["img_meta_dict"].get("frame", "N/A")
        slice_index = sample["img_meta_dict"].get("slice_index", 0)
        matching_sample = next((s for s in test_dicts if s["patient"] == patient_id and s["img_meta_dict"]["frame"] == frame), None)

        if matching_sample is not None:
            nifti_img = nib.load(matching_sample["img"])
            volume_data = nifti_img.get_fdata()
            if volume_data.ndim == 3 and slice_index < volume_data.shape[-1]:
                raw_slice = volume_data[..., slice_index]
                img_data_resized = resize_image(raw_slice)
            else:
                img_data_resized = np.zeros((256, 256))  # fallback if invalid
        else:
            img_data_resized = np.zeros((256, 256))  # fallback if not found

        # Get the filename
        filename = sample.get("filename", sample.get("image_path", f"Slice {idx}"))
        patient_id = sample["img_meta_dict"]["patient_id"]
        frame = sample["img_meta_dict"].get("frame", "N/A")
        slice_index = sample["img_meta_dict"].get("slice_index", "N/A")
        print(f"Patient: {patient_id}, Frame: {frame}, Slice: {slice_index}")
        
        img = sample["img"].to(device).unsqueeze(0)  # Add batch dimension
        mask = sample["mask"].to(device).unsqueeze(0)  # Keep 1-channel mask

        # Load the model
        ensemble_models = load_ensemble_models(model_paths, device)
        ensemble_model = EnsembleModel(ensemble_models, num_samples_mc=10)

        # Apply model on one image
        mean_pred, uncertainty = ensemble_model(img)

        # Apply softmax to the mean prediction for class probabilities
        output_prob = torch.softmax(mean_pred, dim=1)

        # Apply argmax to get the predicted class
        pred = torch.argmax(output_prob, dim=1, keepdim=True)

        # Convert tensors to NumPy arrays for plotting
        img_np = img.squeeze().cpu().numpy()
        
        mask_np = mask.squeeze().cpu().numpy()
        mask_rv = (mask_np == 1).astype(np.uint8)
        mask_myo = (mask_np == 2).astype(np.uint8)
        mask_lv = (mask_np == 3).astype(np.uint8)
        
        pred_np = pred.squeeze().cpu().numpy()
        pred_rv = (pred_np == 1).astype(np.uint8)
        pred_myo = (pred_np == 2).astype(np.uint8)
        pred_lv = (pred_np == 3).astype(np.uint8)
    
        # Error maps: where prediction and ground truth differ
        error_lv = np.abs(pred_lv - mask_lv)
        error_rv = np.abs(pred_rv - mask_rv)
        error_myo = np.abs(pred_myo - mask_myo)
        
        # Pixels that are background in the ground truth but predicted as something else
        error_background = ((mask_np == 0) & (pred_np != 0)).astype(np.uint8)
        combined_error = (pred_np != mask_np).astype(np.uint8)

        uncertainty_np = uncertainty.squeeze().cpu().numpy()  # Uncertainty map (standard deviation)
        uncertainty_np = uncertainty_np.mean(axis=0)  # Make 2D map

        # Compute Dice score
        mask_onehot = one_hot(mask, num_classes=4)
        pred_onehot = one_hot(pred, num_classes=4)
        dice_metric(y_pred=pred_onehot, y=mask_onehot)
        dice_score = dice_metric.aggregate().item()
        dice_metric.reset()
        
        fig, axes = plt.subplots(1, 9, figsize=(22,4))
        fig.suptitle(f"Patient: {patient_id}, Frame: {frame}, Slice: {slice_index}", fontsize=14, fontweight='bold')

        axes[0].imshow(img_np, cmap="gray")
        axes[0].set_title(f"\n{filename}")

        axes[1].imshow(mask_np, cmap="gray")
        axes[1].set_title("Ground Truth Mask")

        axes[2].imshow(pred_np, cmap="gray")
        axes[2].set_title("Total Segmentation")

        axes[3].imshow(uncertainty_np, cmap="hot")
        axes[3].set_title("Uncertainty")
    
        axes[4].imshow(error_lv, cmap="hot")
        axes[4].set_title("LV Error")

        axes[5].imshow(error_rv, cmap="hot")
        axes[5].set_title("RV Error")

        axes[6].imshow(error_myo, cmap="hot")
        axes[6].set_title("MYO Error")
        
        axes[7].imshow(error_background, cmap="hot")
        axes[7].set_title("Background Errors")

        axes[8].imshow(combined_error, cmap="hot")
        axes[8].set_title("Combined Error Map")
        
        # Hide axes for all subplots
        for ax in axes:
            ax.axis("off")
        plt.tight_layout()
                # Save plot
        #save_path = os.path.join(output_dir, f"Patient_{patient_id}_Frame_{frame}_Slice_{slice_index}.png")
        #plt.savefig(save_path, bbox_inches='tight')  #
        plt.show()  
        plt.close()  


In [None]:
#WITH ALL ERROR PLOTS 
from collections import defaultdict

# Create directory for saving plots
output_dir = "Patients_uncertainty_error"
os.makedirs(output_dir, exist_ok=True)

# Group samples per patient frame
grouped_samples = defaultdict(list)

# get samplees per patient & frame
for idx in range(len(test_dataset)):
    sample = test_dataset[idx]
    pid = sample["img_meta_dict"]["patient_id"]
    frame = sample["img_meta_dict"].get("frame", "N/A")
    key = (pid, frame)
    grouped_samples[key].append(sample)

with torch.no_grad():
    for (patient_id, frame), samples in grouped_samples.items():
        fig, axes = plt.subplots(len(samples), 9, figsize=(24, 3 * len(samples)))
        fig.tight_layout(rect=[0, 0, 1, 0.95]) 
        fig.suptitle(f"Patient: {patient_id}, Frame: {frame}", fontsize=16, fontweight='bold')

        if len(samples) == 1:
            axes = [axes]  # Ensure it's iterable in single-slice case

        for row, sample in enumerate(samples):
            slice_index = sample["img_meta_dict"].get("slice_index", 0)
            img = sample["img"].to(device).unsqueeze(0)
            mask = sample["mask"].to(device).unsqueeze(0)

            ensemble_models = load_ensemble_models(model_paths, device)
            ensemble_model = EnsembleModel(ensemble_models, num_samples_mc=10)

            mean_pred, uncertainty = ensemble_model(img)
            output_prob = torch.softmax(mean_pred, dim=1)
            pred = torch.argmax(output_prob, dim=1, keepdim=True)

            img_np = img.squeeze().cpu().numpy()
            mask_np = mask.squeeze().cpu().numpy()
            pred_np = pred.squeeze().cpu().numpy()
            uncertainty_np = uncertainty.squeeze().cpu().numpy().mean(axis=0)

            mask_rv = (mask_np == 1).astype(np.uint8)
            mask_myo = (mask_np == 2).astype(np.uint8)
            mask_lv = (mask_np == 3).astype(np.uint8)
            pred_rv = (pred_np == 1).astype(np.uint8)
            pred_myo = (pred_np == 2).astype(np.uint8)
            pred_lv = (pred_np == 3).astype(np.uint8)

            error_lv = np.abs(pred_lv - mask_lv)
            error_rv = np.abs(pred_rv - mask_rv)
            error_myo = np.abs(pred_myo - mask_myo)
            error_background = ((mask_np == 0) & (pred_np != 0)).astype(np.uint8)
            combined_error = (pred_np != mask_np).astype(np.uint8)

            row_axes = axes[row]

            row_axes[0].imshow(img_np, cmap="gray")
            row_axes[0].set_title(f"Slice {slice_index}")

            row_axes[1].imshow(mask_np, cmap="gray")
            row_axes[1].set_title("GT Mask")

            row_axes[2].imshow(pred_np, cmap="gray")
            row_axes[2].set_title("Prediction")

            row_axes[3].imshow(uncertainty_np, cmap="hot")
            row_axes[3].set_title("Uncertainty")

            row_axes[4].imshow(error_lv, cmap="hot")
            row_axes[4].set_title("LV Error")

            row_axes[5].imshow(error_rv, cmap="hot")
            row_axes[5].set_title("RV Error")

            row_axes[6].imshow(error_myo, cmap="hot")
            row_axes[6].set_title("MYO Error")

            row_axes[7].imshow(error_background, cmap="hot")
            row_axes[7].set_title("Background Error")

            row_axes[8].imshow(combined_error, cmap="hot")
            row_axes[8].set_title("Combined Error")

            for ax in row_axes:
                ax.axis("off")
    
        save_path = os.path.join(output_dir, f"Patient_{patient_id}_Frame_{frame}_ALL_SLICES.png")
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()

In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
import numpy as np
import os

# Create directory for saving plots
output_dir = "Report_uncertainty_error"
os.makedirs(output_dir, exist_ok=True)

# GROUP samples per patient frame
grouped_samples = defaultdict(list)

# Get samples per patient and frame
for idx in range(len(test_dataset)):
    sample = test_dataset[idx]
    pid = sample["img_meta_dict"]["patient_id"]
    frame = sample["img_meta_dict"].get("frame", "N/A")
    key = (pid, frame)
    grouped_samples[key].append(sample)

#print("Alle patient-frame combinaties in grouped_samples:")
#for (patient_id, frame) in grouped_samples.keys():
 #   print(f"Patient: {patient_id}, Frame: {frame}")

with torch.no_grad():
    for (patient_id, frame), samples in grouped_samples.items():
        if not (str(patient_id) == "patient149" and str(frame) == "frame12"):
            continue

        # Set a smaller figure size and reduce spacing
        fig, axes = plt.subplots(len(samples), 5, figsize=(20, 3 * len(samples)))
        plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
        # Tighten layout further: reduce both hspace and wspace
        #fig.subplots_adjust(hspace=2, wspace=0.05)  # Decrease space between rows and columns

        fig.suptitle(f"Patient: {patient_id}, Frame: {frame}", fontsize=16, fontweight='bold')

        if len(samples) == 1:
            axes = [axes]  # Ensure it's iterable

        for row, sample in enumerate(samples):
            slice_index = sample["img_meta_dict"].get("slice_index", 0)
            img = sample["img"].to(device).unsqueeze(0)
            mask = sample["mask"].to(device).unsqueeze(0)

            ensemble_models = load_ensemble_models(model_paths, device)
            ensemble_model = EnsembleModel(ensemble_models, num_samples_mc=10)

            mean_pred, uncertainty = ensemble_model(img)
            output_prob = torch.softmax(mean_pred, dim=1)
            pred = torch.argmax(output_prob, dim=1, keepdim=True)

            img_np = img.squeeze().cpu().numpy()
            mask_np = mask.squeeze().cpu().numpy()
            pred_np = pred.squeeze().cpu().numpy()
            uncertainty_np = uncertainty.squeeze().cpu().numpy().mean(axis=0)

            mask_rv = (mask_np == 1).astype(np.uint8)
            mask_myo = (mask_np == 2).astype(np.uint8)
            mask_lv = (mask_np == 3).astype(np.uint8)
            pred_rv = (pred_np == 1).astype(np.uint8)
            pred_myo = (pred_np == 2).astype(np.uint8)
            pred_lv = (pred_np == 3).astype(np.uint8)

            error_lv = np.abs(pred_lv - mask_lv)
            error_rv = np.abs(pred_rv - mask_rv)
            error_myo = np.abs(pred_myo - mask_myo)
            error_background = ((mask_np == 0) & (pred_np != 0)).astype(np.uint8)
            combined_error = (pred_np != mask_np).astype(np.uint8)

            row_axes = axes[row]

            row_axes[0].imshow(img_np, cmap="gray")
            row_axes[0].set_title(f"Slice {slice_index}")

            row_axes[1].imshow(mask_np, cmap="gray")
            row_axes[1].set_title("GT Mask")

            row_axes[2].imshow(pred_np, cmap="gray")
            row_axes[2].set_title("Prediction")

            row_axes[3].imshow(uncertainty_np, cmap="hot")
            row_axes[3].set_title("Uncertainty")

            row_axes[4].imshow(combined_error, cmap="hot")
            row_axes[4].set_title("Total Error")

            for ax in row_axes:
                ax.axis("off")

        save_path = os.path.join(output_dir, f"Patient_{patient_id}_Frame_{frame}_ALL_SLICES.png")
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()


In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
import numpy as np
import os

# Create directory for saving plots
output_dir = "Report_uncertainty_error"
os.makedirs(output_dir, exist_ok=True)

# GROUP samples per patient frame
grouped_samples = defaultdict(list)

# Get samples per patient and frame
for idx in range(len(test_dataset)):
    sample = test_dataset[idx]
    pid = sample["img_meta_dict"]["patient_id"]
    frame = sample["img_meta_dict"].get("frame", "N/A")
    key = (pid, frame)
    grouped_samples[key].append(sample)

with torch.no_grad():
    for (patient_id, frame), samples in grouped_samples.items():
        if not (str(patient_id) == "patient107" and str(frame) == "frame01"):
            continue

        # Select only desired slices
        selected_samples = [s for s in samples if s["img_meta_dict"].get("slice_index", 0) in [0,1,3,5,6,7]]

        fig, axes = plt.subplots(len(selected_samples), 5, figsize=(20, 3 * len(selected_samples)))
        plt.subplots_adjust(hspace=0.2, wspace=0.05, top=0.95)
        fig.suptitle(f"Patient: {patient_id}, Frame: {frame}", fontsize=16, fontweight='bold')

        if len(selected_samples) == 1:
            axes = [axes]  # Ensure iterable if only one row

        for row, sample in enumerate(selected_samples):
            slice_index = sample["img_meta_dict"].get("slice_index", 0)
            img = sample["img"].to(device).unsqueeze(0)
            mask = sample["mask"].to(device).unsqueeze(0)

            ensemble_models = load_ensemble_models(model_paths, device)
            ensemble_model = EnsembleModel(ensemble_models, num_samples_mc=10)

            mean_pred, uncertainty = ensemble_model(img)
            output_prob = torch.softmax(mean_pred, dim=1)
            pred = torch.argmax(output_prob, dim=1, keepdim=True)

            img_np = img.squeeze().cpu().numpy()
            mask_np = mask.squeeze().cpu().numpy()
            pred_np = pred.squeeze().cpu().numpy()
            uncertainty_np = uncertainty.squeeze().cpu().numpy().mean(axis=0)

            mask_rv = (mask_np == 1).astype(np.uint8)
            mask_myo = (mask_np == 2).astype(np.uint8)
            mask_lv = (mask_np == 3).astype(np.uint8)
            pred_rv = (pred_np == 1).astype(np.uint8)
            pred_myo = (pred_np == 2).astype(np.uint8)
            pred_lv = (pred_np == 3).astype(np.uint8)

            combined_error = (pred_np != mask_np).astype(np.uint8)

            row_axes = axes[row]

            row_axes[0].imshow(img_np, cmap="gray")
            row_axes[0].set_title(f"Slice {slice_index}")

            row_axes[1].imshow(mask_np, cmap="gray")
            row_axes[1].set_title("GT Mask")

            row_axes[2].imshow(pred_np, cmap="gray")
            row_axes[2].set_title("Prediction")

            row_axes[3].imshow(uncertainty_np, cmap="hot")
            row_axes[3].set_title("Uncertainty")

            row_axes[4].imshow(combined_error, cmap="hot")
            row_axes[4].set_title("Total Error")

            for ax in row_axes:
                ax.axis("off")

        save_path = os.path.join(output_dir, f"Patient_{patient_id}_Frame_{frame}_SLICES_13567.png")
        plt.savefig(save_path,  dpi=200, bbox_inches='tight')
        plt.close()


In [None]:
#make zip from all patients in test

output_dir = "Patients_uncertainty_error"
# Maak ZIP van de output directory
zip_filename = "Patients_uncertainty_error.zip"
shutil.make_archive("Patients_uncertainty_error", 'zip', output_dir)

print(f"✅ ZIP-bestand gemaakt: {zip_filename}")