In [1]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import imageio

# Updated paths with base removed
original_image_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/tof/tof_input/tof.nii.gz"
proxy_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/tof/tof_tta_prediction/proxies/tof.nii.gz"
tta_prediction_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/tof/tof_tta_prediction/tof.nii.gz"

dict_image_paths = {
    "original": original_image_path,
    "proxy": proxy_path,
    "tta": tta_prediction_path
}

# Load the images into a dictionary
dict_images = {
    key: nib.load(path)
    for key, path in dict_image_paths.items()
}

# Function to visualize slices interactively
def interactive_slice_viewer(dict_images, dict_image_paths):
    """
    Visualize images interactively with ipywidgets slice viewer.
    
    Parameters:
        dict_images (dict): Dictionary of loaded NIfTI images with keys as names and values as nibabel objects.
        dict_image_paths (dict): Dictionary of image paths with keys as names and values as paths.
    """

    # Extract the data from nibabel images
    data_dict = {key: image.get_fdata() for key, image in dict_images.items()}
    
    # Get the shape of the first image (assumes all images have the same dimensions)
    slice_shape = next(iter(data_dict.values())).shape
    
    def plot_slice(slice_index):
        """
        Plot the selected slice with images and legend.
        
        Parameters:
            slice_index (int): Index of the slice to visualize.
        """
        fig, axs = plt.subplots(1, len(data_dict), figsize=(15, 5))
        fig.suptitle(f"Slice Index: {slice_index}", fontsize=16)
        
        for ax, (key, data) in zip(axs, data_dict.items()):
            ax.imshow(data[:, :, slice_index], cmap='gray' if key == "original" else 'hot', interpolation='nearest')
            ax.set_title(f"{key.capitalize()} Image")
            ax.axis('off')
        
        # Add a legend with filenames
        legend_entries = [f"{key}: {path}" for key, path in dict_image_paths.items()]
        fig.legend(legend_entries, loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=2, fontsize=10)
        plt.tight_layout()
        plt.show()
    
    # Create the interactive viewer
    interact(plot_slice, slice_index=(0, slice_shape[2] - 1))

# Function to save a GIF of slice animations
def save_slices_as_gif(dict_images, output_path, fps=25):
    """
    Save an animated GIF of the slices for all images.

    Parameters:
        dict_images (dict): Dictionary of loaded NIfTI images.
        output_path (str): Path to save the GIF.
        fps (int): Frames per second for the GIF.
    """
    data_dict = {key: image.get_fdata() for key, image in dict_images.items()}
    slice_shape = next(iter(data_dict.values())).shape

    # Create a figure for the slices
    images = []
    for slice_index in range(slice_shape[2]):
        fig, axs = plt.subplots(1, len(data_dict), figsize=(15, 5))
        fig.suptitle(f"Slice Index: {slice_index}", fontsize=16)

        for ax, (key, data) in zip(axs, data_dict.items()):
            ax.imshow(data[:, :, slice_index], cmap='gray' if key == "original" else 'hot', interpolation='nearest')
            ax.set_title(f"{key.capitalize()} Image")
            ax.axis('off')

        # Save the figure as an image for the GIF
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        images.append(image)
        plt.close(fig)

    # Save all collected images as a GIF
    imageio.mimsave(output_path, images, fps=fps)

# Visualize the images interactively
# interactive_slice_viewer(dict_images, dict_image_paths)

# Save slices as a GIF
# save_slices_as_gif(dict_images, "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/tof/tof.gif")
