# Uncertainty Heatmaps
This script handles the generation of uncertainty heatmaps.
Please note that the get_corresponding_ISMs() function might fail based on the naming convention given during the inversion process.

In [None]:
import nibabel as nib
import numpy as np
import os
from tqdm import tqdm
import logging
import tkinter as tk
from tkinter import filedialog
import traceback

#### pick_files():
Opens a series of file dialogs for the user to select directories and returns the selected paths.

The function prompts the user to select the following directories:
1. The folder containing the ground truth segmentation masks (GTSMs).
2. The folder containing the inverted segmentation masks (ISMs).
3. The output directory where processed files will be saved.
4. The directory for storing the logging file.

Returns:
    tuple: A tuple containing four paths (GTSM_dir_path, ISM_dir_path, output_path, logging_output_path).

In [None]:
def pick_files():
    root = tk.Tk()
    root.withdraw()  # Hide the main tkinter window
    GTSM_dir_path = filedialog.askdirectory(title="Select the folder with the GROUND TRUTH SEGMENTATION MASKS")
    print(f"Selected folder: {GTSM_dir_path}")
    ISM_dir_path = filedialog.askdirectory(title="Select the folder with the INVERTED SEGMENTATION MASKS")
    print(f"Selected file: {ISM_dir_path}")
    output_path = filedialog.askdirectory(title="Select the LOCATION for the OUTPUT FILES")
    print(f"Selected folder: {output_path}")
    logging_output_path = filedialog.askdirectory(title="Select the LOCATION for the LOGGING FILE")
    print(f"Selected folder: {logging_output_path}")
    return GTSM_dir_path, ISM_dir_path, output_path, logging_output_path

#### load_nifti_file():
Load a NIfTI file and return its data as a numpy array.

Parameters

    filepath : str
        The path to the NIfTI file to load.

Returns

    ndarray
        The data of the NIfTI file as a numpy array.

In [None]:
def load_nifti_file(filepath):
    nifti = nib.load(filepath)
    return nifti.get_fdata()

#### calculate_uncertainty():
Calculate the uncertainty map from a set of segmentations as described in the paper.

Parameters

    labels : ndarray
        Shape: (num_segmentations, height, width, depth)
        The set of segmentations to calculate uncertainty from.

Returns

    uncertainty_map : ndarray
        Shape: (height, width, depth)
        The uncertainty map, where each voxel value represents the uncertainty
        in the segmentation.

In [None]:
def calculate_uncertainty(labels):  
    labels = labels.astype(int)
    num_segmentations, height, width, depth = labels.shape
    labels_reshaped = labels.reshape(num_segmentations, -1)  # Shape: (num_segmentations, height*width*depth)

    # Use np.apply_along_axis to compute uncertainty
    def calculate_for_voxel(voxel_labels):        
        counts = np.bincount(voxel_labels)
        mode_label_count = np.max(counts)
        agreement_fraction = mode_label_count / len(voxel_labels)
        uncertainty = 1 - agreement_fraction
        return uncertainty

    # Vectorized computation across all voxels
    uncertainty_results = np.apply_along_axis(calculate_for_voxel, axis=0, arr=labels_reshaped)

    # Reshape the flat uncertainty map back to the original spatial dimensions
    uncertainty_map = uncertainty_results.reshape(height, width, depth)

    return uncertainty_map

#### get_corresponding_ISMs():
Given a path to a ground truth segmentation mask (GTSM) and a list of
paths to inverted segmentation masks (ISMs), return a list of the ISMs
that correspond to the GTSM.

Parameters

    GTSM_path : str
        The path to the GTSM.
    ISM_paths : list of str
        A list of paths to the ISMs.

Returns

    list of ndarray
        A list of the ISMs that correspond to the GTSM.

In [None]:
def get_corresponding_ISMs(GTSM_path, ISM_paths):
    correspinding_label_masks = []
    TOSM_file_name = os.path.splitext(os.path.basename(GTSM_path))[0].split('_SEGMENTATION')[0]
    TCSM_files = {
    os.path.splitext(os.path.basename(TCSM_path))[0].split('.nii.gz')[0]: TCSM_path 
    for TCSM_path in ISM_paths
    }

    for file_name, path in TCSM_files.items():
        if TOSM_file_name in file_name:
            correspinding_label_masks.append(load_nifti_file(path))
    return correspinding_label_masks

#### get_nifti_file_paths():
Retrieves the file paths of NIfTI files in the specified directory.

This function searches for files with '.nii' or '.nii.gz' extensions
in the given directory and returns their full paths.

Args:
    directory_path (str): The path to the directory containing NIfTI files.

Returns:
    list of str: A list of full file paths for each NIfTI file found in the directory.

In [None]:
def get_nifti_file_paths(directory_path):
    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.nii') or f.endswith('.nii.gz')]

### Running the methods

In [None]:
# Pick files and set up logging
GTSM_dir_path, ISM_dir_path, output_path, logging_output_path = pick_files()
GTSM_paths = get_nifti_file_paths(GTSM_dir_path)
ISM_paths = get_nifti_file_paths(ISM_dir_path)
logging_output_path = os.path.join(logging_output_path, f'heatmap_error.log')
logging.basicConfig(
    filename=logging_output_path,
    level=logging.ERROR,
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
# For each ground truth segmentation mask, calculate the uncertainty heatmap, set the correct affine, and save it to the output folder
for GTSM_path in tqdm(GTSM_paths, total=len(GTSM_paths), desc='Calculating Dice Heatmap'):
    try:
        # Load the ground truth segmentation mask  
        GTSM_file = nib.load(GTSM_path)
        # Find the corresponding inverted segmentation masks
        correspoding_ISMs = get_corresponding_ISMs(GTSM_path, ISM_paths)
        # Calculate the uncertainty heatmap
        heatmap = calculate_uncertainty(np.array(correspoding_ISMs))
        # Set the correct affine and save
        affine = GTSM_file.affine
        nib.save(nib.Nifti1Image(heatmap, affine), os.path.join(output_path, f"{os.path.basename(GTSM_path).split('.nii.gz')[0]}_HEATMAP.nii.gz"))
    except Exception as e:
        logging.error(f"Error processing {GTSM_path}: {e}, continuing...")
        logging.error(traceback.format_exc())
        continue