In [None]:
import nibabel as nib
import numpy as np
from sklearn.decomposition import PCA
import pyvista as pv
from scipy.ndimage import binary_erosion

# -------------------------------------------------------------
# 1) Load segmentation mask
# -------------------------------------------------------------
def load_mask(path):
    img = nib.load(path)
    mask = img.get_fdata()
    affine = img.affine
    return mask, affine, img


# -------------------------------------------------------------
# 2) Compute skull mid-sagittal plane using PCA
# -------------------------------------------------------------
def compute_midline(mask, affine):
    coords = np.column_stack(np.nonzero(mask > 0))
    coords_mm = nib.affines.apply_affine(affine, coords)

    pca = PCA(n_components=3)
    pca.fit(coords_mm)

    centroid = coords_mm.mean(axis=0)
    normal = pca.components_[0]   # Main axis normal to MSP

    return centroid, normal


# -------------------------------------------------------------
# 3) Visualise skull + midline
# -------------------------------------------------------------
def visualize_midline(mask, affine, centroid, normal, downsample=True):
    skull = mask.copy()

    if downsample:
        skull = binary_erosion(skull, iterations=2)

    # Generate mesh
    grid = pv.wrap((skull > 0).astype(np.uint8))
    mesh = grid.extract_surface().smooth(n_iter=20)

    # Make plane
    plane = pv.Plane(center=centroid,
                     direction=normal,
                     i_size=200,
                     j_size=200)

    # Plot
    pl = pv.Plotter()
    pl.add_mesh(mesh, color="white", opacity=0.3)
    pl.add_mesh(plane, color="red", opacity=0.6)
    pl.show()


# -------------------------------------------------------------
# 4) Split mandible using midline plane
# -------------------------------------------------------------
def split_mandible(mandible_mask, affine, centroid, normal):
    """
    mandible_mask : binary mask of mandible
    affine        : affine from nibabel
    centroid      : MSP center
    normal        : MSP normal vector
    """

    # Get voxel coords
    coords = np.column_stack(np.nonzero(mandible_mask > 0))
    coords_mm = nib.affines.apply_affine(affine, coords)

    # Signed distances of each voxel to plane
    d = np.dot(coords_mm - centroid, normal)

    # Allocate new masks
    left_mask = np.zeros_like(mandible_mask)
    right_mask = np.zeros_like(mandible_mask)

    # Fill them
    left_indices = coords[d < 0]
    right_indices = coords[d >= 0]

    left_mask[left_indices[:,0], left_indices[:,1], left_indices[:,2]] = 1
    right_mask[right_indices[:,0], right_indices[:,1], right_indices[:,2]] = 1

    return left_mask, right_mask


# -------------------------------------------------------------
# 5) Save NIfTI mask
# -------------------------------------------------------------
def save_mask(mask, reference_img, out_path):
    nii = nib.Nifti1Image(mask.astype(np.uint8), reference_img.affine, reference_img.header)
    nib.save(nii, out_path)
    print("Saved:", out_path)


In [None]:
# -------------------------------------------------------------
# USER INPUTS
# -------------------------------------------------------------
skull_path = r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\Total_Segmentator\skull.nii.gz"
mandible_path = r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\Total_Segmentator\mandible.nii.gz"

# -------------------------------------------------------------
# LOAD MASKS
# -------------------------------------------------------------
skull_mask, skull_affine, skull_img = load_mask(skull_path)
mandible_mask, mandible_affine, mandible_img = load_mask(mandible_path)

# -------------------------------------------------------------
# COMPUTE MIDLINE FROM SKULL
# -------------------------------------------------------------
centroid, normal = compute_midline(skull_mask, skull_affine)

print("Midline centroid (mm):", centroid)
print("Midline normal:", normal)

# -------------------------------------------------------------
# VISUALISE MIDLINE + SKULL
# -------------------------------------------------------------
visualize_midline(skull_mask, skull_affine, centroid, normal)

# -------------------------------------------------------------
# SPLIT MANDIBLE USING MIDLINE
# -------------------------------------------------------------
left_mask, right_mask = split_mandible(mandible_mask, mandible_affine, centroid, normal)

# -------------------------------------------------------------
# SAVE RESULTS
# -------------------------------------------------------------
save_mask(left_mask, mandible_img, r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\asymmetry\mandible_left.nii.gz")
save_mask(right_mask, mandible_img, r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\asymmetry\mandible_right.nii.gz")
