In [6]:
%matplotlib inline

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

base_path = r"Z:\Angie\SMILE_facialdeformation\StJude_cohort\abby\1720903\TP1\Total_Segmentator"

ct_path = r"Z:\Angie\SMILE_facialdeformation\StJude_cohort\abby\1720903\TP1\1720903_TP1.nii"

seg_files = {
    "head":               "head.nii.gz",
    "mandible":           "mandible.nii.gz",
    "skull":              "skull.nii.gz",
    "sinus_maxillary":    "sinus_maxillary.nii.gz",
    "sinus_frontal":      "sinus_frontal.nii.gz",
    "teeth_lower":        "teeth_lower.nii.gz",
    "teeth_upper":        "teeth_upper.nii.gz",
}

ct_img = nib.load(ct_path)
X = ct_img.get_fdata()
X = (X - np.min(X)) / (np.max(X) - np.min(X))

segments = {
    name: nib.load(os.path.join(base_path, fname)).get_fdata()
    for name, fname in seg_files.items()
}

spacing = ct_img.header.get_zooms()

overlay_colors = {
    "head":            (0.2, 0.4, 1.0, 0.25),     #blue
    "mandible":        (1.0, 1.0, 0.0, 0.35),    # yellow
    "skull":           (1.0, 0.5, 0.0, 0.30),    # orange
    "sinus_maxillary": (0.0, 1.0, 1.0, 0.35),    # cyan
    "sinus_frontal":   (0.0, 0.7, 0.0, 0.35),    # green
    "teeth_lower":     (1.0, 0.0, 0.0, 0.35),    # red
    "teeth_upper":     (0.4, 0.0, 0.8, 0.35),    # purple
}

def show_slice(slice_index, plane, alpha=0.35):
    plt.figure(figsize=(6, 6))

    if plane == 'Axial':
        img_slice = np.rot90(X[:, :, slice_index])
        aspect = spacing[1] / spacing[0]
        seg_slices = {n: np.rot90(S[:, :, slice_index]) for n, S in segments.items()}

    elif plane == 'Coronal':
        img_slice = np.flipud(np.rot90(X[:, slice_index, :]))
        aspect = spacing[2] / spacing[0]
        seg_slices = {n: np.flipud(np.rot90(S[:, slice_index, :])) for n, S in segments.items()}

    elif plane == 'Sagittal':
        img_slice = np.flipud(np.rot90(X[slice_index, :, :]))
        aspect = spacing[2] / spacing[1]
        seg_slices = {n: np.flipud(np.rot90(S[slice_index, :, :])) for n, S in segments.items()}

    plt.imshow(img_slice, cmap='gray', origin='lower', aspect=aspect)

    for name, seg_slice in seg_slices.items():
        mask = seg_slice > 0.5
        if np.any(mask):
            rgba = list(overlay_colors[name])
            rgba[3] = alpha
            overlay = np.zeros((*mask.shape, 4))
            overlay[mask] = rgba
            plt.imshow(overlay, origin='lower', aspect=aspect)

    plt.title(f"{plane} slice {slice_index}")
    plt.axis("off")

    legend_handles = [
        plt.Line2D([0], [0], color=overlay_colors[n][:3], lw=4, label=n)
        for n in overlay_colors
    ]
    leg = plt.legend(handles=legend_handles, loc='upper right', fontsize=8,
                     frameon=False, title="Segments", labelcolor="white")
    leg.get_title().set_color('white')

    clear_output(wait=True)
    display(plt.gcf())
    plt.close()

plane_dropdown = widgets.Dropdown(
    options=['Axial', 'Coronal', 'Sagittal'],
    value='Axial',
    description='Plane:'
)

slice_slider = widgets.IntSlider(
    value=X.shape[2] // 2,
    min=0,
    max=X.shape[2] - 1,
    step=1,
    description='Slice:',
    continuous_update=True,
    layout=widgets.Layout(width='80%')
)

alpha_slider = widgets.FloatSlider(
    value=0.35,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Opacity:',
    continuous_update=True,
    layout=widgets.Layout(width='60%')
)

def update_slider_range(*args):
    plane = plane_dropdown.value
    if plane == 'Axial':
        slice_slider.max = X.shape[2] - 1
        slice_slider.value = X.shape[2] // 2
    elif plane == 'Coronal':
        slice_slider.max = X.shape[1] - 1
        slice_slider.value = X.shape[1] // 2
    elif plane == 'Sagittal':
        slice_slider.max = X.shape[0] - 1
        slice_slider.value = X.shape[0] // 2

plane_dropdown.observe(update_slider_range, names='value')
update_slider_range()

widgets.interact(show_slice,
                 slice_index=slice_slider,
                 plane=plane_dropdown,
                 alpha=alpha_slider)


interactive(children=(IntSlider(value=256, description='Slice:', layout=Layout(width='80%'), max=511), Dropdowâ€¦

<function __main__.show_slice(slice_index, plane, alpha=0.35)>