In [None]:
import os
import nibabel as nib

input_img_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/nnunet/nnUNet_raw/Dataset008_synthseg/imagesTs/synthseg_0000_0001.nii.gz"
segmentation_img_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/nnunet/nnUNet_results/Dataset008_synthseg/synthseg_0000.nii.gz"

import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from matplotlib.colors import ListedColormap

# Load the images
input_img = nib.load(input_img_path)
segmentation_img = nib.load(segmentation_img_path)

input_data = input_img.get_fdata()
seg_data = segmentation_img.get_fdata().astype(int)

# Extract shape (assume shape is something like (X, Y, Z))
X, Y, Z = input_data.shape

# Define the label mapping (from the given JSON)
labels = {
    "background": 0,
    "left_cerebral_white_matter/right_cerebral_white_matter": 1,
    "left_cerebral_cortex/right_cerebral_cortex": 2,
    "left_lateral_ventricle/right_lateral_ventricle": 3,
    "left_inferior_lateral_ventricle/right_inferior_lateral_ventricle": 4,
    "left_cerebellum_white_matter/right_cerebellum_white_matter": 5,
    "left_cerebellum_cortex/right_cerebellum_cortex": 6,
    "left_thalamus_proper/right_thalamus_proper": 7,
    "left_caudate/right_caudate": 8,
    "left_putamen/right_putamen": 9,
    "left_pallidum/right_pallidum": 10,
    "left_hippocampus/right_hippocampus": 11,
    "left_amygdala/right_amygdala": 12,
    "left_accumbens_area/right_accumbens_area": 13,
    "left_ventral_dc/right_ventral_dc": 14,
    "left_vessel/right_vessel": 15,
    "left_choroid_plexus/right_choroid_plexus": 16,
    "third_ventricle": 17,
    "fourth_ventricle": 18,
    "brain_stem": 19,
    "csf": 20,
    "left_undetermined": 21,
    "fifth_ventricle": 22,
    "wm_hypointensities": 23,
    "non_wm_hypointensities": 24,
    "optic_chiasm": 25,
    "air_internal": 26,
    "artery": 27,
    "eyes": 28,
    "other_tissues": 29,
    "rectus_muscles": 30,
    "mucosa": 31,
    "skin": 32,
    "spinal_cord": 33,
    "vein": 34,
    "bone_cortical": 35,
    "bone_cancellous": 36,
    "cortical_csf": 37,
    "optic_nerve": 38
}

# We have 0-based class indexing from the values. Let's create a colormap.
# background = 0 -> black
# For the other classes, we can use a range of distinguishable colors.
num_classes = max(labels.values()) + 1
colors = np.zeros((num_classes, 4))
colors[0] = [0, 0, 0, 1]  # black for background

# Assign random distinct colors to each class except background
# For simplicity, we can just choose a colormap and pick colors from it.
# Let's use a tab20 or rainbow for distinct colors.
cmap_base = plt.cm.tab20(np.linspace(0,1,num_classes))
colors[1:] = cmap_base[1:]  # assign from colormap

# Create a ListedColormap
seg_cmap = ListedColormap(colors)

# Prepare a nice legend for the classes
# We'll create a custom legend with colored patches.
import matplotlib.patches as mpatches

label_patches = []
for k, v in labels.items():
    # skip classes not actually present (optional)
    # But let's just show all as per given info
    color = seg_cmap(v)
    patch = mpatches.Patch(color=color, label=f"{k} ({v})")
    label_patches.append(patch)

def plot_slice(z):
    # Create figure and subplots
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"Slice {z}", fontsize=16)
    
    # Left: input image
    axes[0].imshow(input_data[:,:,z].T, cmap='gray', origin='lower')
    axes[0].set_title("Input MRI")
    axes[0].axis('off')
    
    # Right: segmentation
    axes[1].imshow(seg_data[:,:,z].T, cmap=seg_cmap, origin='lower', interpolation='nearest')
    axes[1].set_title("Segmentation")
    axes[1].axis('off')
    
    # Add the legend outside the plot
    # We can place it to the right side of the figure
    fig.subplots_adjust(right=0.8)
    legend_ax = fig.add_axes([0.82, 0.1, 0.15, 0.8]) 
    legend_ax.axis('off')
    legend_ax.legend(handles=label_patches, loc='upper left', fontsize='small', borderaxespad=0.)
    
    plt.show()

# Use ipywidget to interactively explore slices
interact(plot_slice, z=IntSlider(min=0, max=Z-1, step=1, value=Z//2))


gif_output_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/gif_outputs"
gif_file_name = "synaptive_t1_anatomical_segmentation.gif"

interactive(children=(IntSlider(value=77, description='z', max=154), Output()), _dom_classes=('widget-interact…

<function __main__.plot_slice(z)>

In [1]:
import os
import imageio
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

# Paths
input_img_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/nnunet/nnUNet_raw/Dataset008_synthseg/imagesTs/synthseg_0000_0001.nii.gz"
segmentation_img_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/nnunet/nnUNet_results/Dataset008_synthseg/synthseg_0000.nii.gz"

gif_output_path = "/home/fp427/rds/rds-cam-segm-7tts6phZ4tw/mission/gif_outputs"
gif_file_name = "synaptive_t1_anatomical_segmentation.gif"

# Load the images
input_img = nib.load(input_img_path)
segmentation_img = nib.load(segmentation_img_path)

input_data = input_img.get_fdata()
seg_data = segmentation_img.get_fdata().astype(int)

# Extract shape (assume shape is (X, Y, Z))
X, Y, Z = input_data.shape

# Define the label mapping (from the given JSON)
labels = {
    "background": 0,
    "left_cerebral_white_matter/right_cerebral_white_matter": 1,
    "left_cerebral_cortex/right_cerebral_cortex": 2,
    "left_lateral_ventricle/right_lateral_ventricle": 3,
    "left_inferior_lateral_ventricle/right_inferior_lateral_ventricle": 4,
    "left_cerebellum_white_matter/right_cerebellum_white_matter": 5,
    "left_cerebellum_cortex/right_cerebellum_cortex": 6,
    "left_thalamus_proper/right_thalamus_proper": 7,
    "left_caudate/right_caudate": 8,
    "left_putamen/right_putamen": 9,
    "left_pallidum/right_pallidum": 10,
    "left_hippocampus/right_hippocampus": 11,
    "left_amygdala/right_amygdala": 12,
    "left_accumbens_area/right_accumbens_area": 13,
    "left_ventral_dc/right_ventral_dc": 14,
    "left_vessel/right_vessel": 15,
    "left_choroid_plexus/right_choroid_plexus": 16,
    "third_ventricle": 17,
    "fourth_ventricle": 18,
    "brain_stem": 19,
    "csf": 20,
    "left_undetermined": 21,
    "fifth_ventricle": 22,
    "wm_hypointensities": 23,
    "non_wm_hypointensities": 24,
    "optic_chiasm": 25,
    "air_internal": 26,
    "artery": 27,
    "eyes": 28,
    "other_tissues": 29,
    "rectus_muscles": 30,
    "mucosa": 31,
    "skin": 32,
    "spinal_cord": 33,
    "vein": 34,
    "bone_cortical": 35,
    "bone_cancellous": 36,
    "cortical_csf": 37,
    "optic_nerve": 38
}

# Create a colormap for segmentation
num_classes = max(labels.values()) + 1
colors = np.zeros((num_classes, 4))
colors[0] = [0, 0, 0, 1]  # black for background

cmap_base = plt.cm.get_cmap('hsv', num_classes)
colors[1:] = cmap_base(range(1, num_classes))
seg_cmap = ListedColormap(colors)

# Create legend patches
label_patches = []
for k, v in labels.items():
    color = seg_cmap(v)
    patch = mpatches.Patch(color=color, label=f"{k} ({v})")
    label_patches.append(patch)

# Create output directories if not existing
os.makedirs(gif_output_path, exist_ok=True)
temp_dir = os.path.join(gif_output_path, "temp_frames")
os.makedirs(temp_dir, exist_ok=True)

frames = []
for z_idx in range(Z):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"Slice {z_idx}", fontsize=16)
    
    # Left: Input MRI
    axes[0].imshow(input_data[:,:,z_idx].T, cmap='gray', origin='lower')
    axes[0].set_title("Input MRI")
    axes[0].axis('off')
    
    # Right: Segmentation
    axes[1].imshow(seg_data[:,:,z_idx].T, cmap=seg_cmap, origin='lower', interpolation='nearest')
    axes[1].set_title("Segmentation")
    axes[1].axis('off')

    # Add legend
    fig.subplots_adjust(right=0.8)
    legend_ax = fig.add_axes([0.82, 0.1, 0.15, 0.8])
    legend_ax.axis('off')
    legend_ax.legend(handles=label_patches, loc='upper left', fontsize='small', borderaxespad=0.)
    
    # Save the frame
    frame_path = os.path.join(temp_dir, f"frame_{z_idx:03d}.png")
    plt.savefig(frame_path, dpi=72)
    plt.close(fig)
    frames.append(frame_path)

# Create a GIF from the saved frames
images = [imageio.imread(frame) for frame in frames]
imageio.mimsave(os.path.join(gif_output_path, gif_file_name), images, fps=5)  # Adjust fps as needed


  cmap_base = plt.cm.get_cmap('hsv', num_classes)
  images = [imageio.imread(frame) for frame in frames]
