In [None]:
import numpy as np
import nibabel as nib
from sklearn.decomposition import PCA
import open3d as o3d
from scipy.ndimage import binary_erosion

# ------------------ Helper funcs ------------------

def load_mask(path):
    img = nib.load(path)
    mask = img.get_fdata().astype(bool)
    return mask, img.affine

def mask_to_pointcloud(mask, affine, step=2):
    from skimage import measure
    verts, faces, _, _ = measure.marching_cubes(mask.astype(np.uint8), level=0.5, step_size=step)
    verts = verts[:, [2,1,0]]
    homog = np.hstack([verts, np.ones((len(verts),1))])
    world = (affine @ homog.T).T[:, :3]
    return world, faces

def reflect_points(points, normal, d):
    # plane: normal · x + d = 0
    n = normal / np.linalg.norm(normal)
    signed = points.dot(n) + d
    return points - 2 * np.outer(signed, n)

# ------------------ Midline estimation: PCA + ICP ------------------

def estimate_midline_pca_icp(skull_mask, affine, icp_iter=50, icp_threshold=5.0):
    """
    Estimate mid-sagittal plane of skull via:
      1) PCA to get initial plane
      2) Split skull into left and right
      3) Reflect one side across the plane
      4) Use ICP to align reflected to original to refine plane

    Returns:
      centroid: point on plane (3,)
      normal: unit normal vector (3,)
    """

    # 1) Convert skull mask to point cloud
    pts, _ = mask_to_pointcloud(skull_mask, affine, step=2)

    # 2) PCA for initial plane
    pca = PCA(n_components=3)
    pca.fit(pts)
    centroid = pca.mean_
    # Determine which PCA axis is most likely the left-right axis:  
    # The smallest explained variance is often the left-right direction (depends on skull shape)
    variances = pca.explained_variance_
    # find index of smallest variance:
    idx = np.argmin(variances)
    normal_init = pca.components_[idx]
    normal_init = normal_init / np.linalg.norm(normal_init)

    # Compute initial offset d for plane equation n · x + d = 0
    d_init = -np.dot(normal_init, centroid)

    # 3) Split points into sides based on this plane
    signed = pts.dot(normal_init) + d_init
    left_pts = pts[signed > 0]
    right_pts = pts[signed <= 0]

    # 4) Reflect right_pts across plane
    right_reflected = reflect_points(right_pts, normal_init, d_init)

    # 5) Use Open3D to do ICP: align reflected right to left
    pcd_left = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(left_pts))
    pcd_ref = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(right_reflected))

    # initial transform = identity
    trans_init = np.eye(4)
    reg = o3d.pipelines.registration.registration_icp(
        pcd_ref, pcd_left, icp_threshold, trans_init,
        o3d.pipelines.registration.TransformationEstimationPointToPoint(),
        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=icp_iter)
    )

    # Apply transform to reflected points
    right_ref_aligned = np.asarray(pcd_ref.transform(reg.transformation).points)

    # 6) Recompute optimal plane from aligned reflected vs original
    # Compute new centroid as midpoint between corresponding centroids
    centroid_ref = right_ref_aligned.mean(axis=0)
    centroid_left = left_pts.mean(axis=0)
    centroid_new = (centroid_ref + centroid_left) / 2.0

    # Compute a refined normal by PCA on both aligned sets concatenated
    pts_concat = np.vstack([left_pts, right_ref_aligned])
    pca2 = PCA(n_components=3)
    pca2.fit(pts_concat)
    # again pick the component with smallest variance (assume that's L-R)
    var2 = pca2.explained_variance_
    idx2 = np.argmin(var2)
    normal_refined = pca2.components_[idx2]
    normal_refined = normal_refined / np.linalg.norm(normal_refined)

    # Recompute d
    d_refined = -np.dot(normal_refined, centroid_new)

    return centroid_new, normal_refined, reg

# ------------------ Generate midline points for visualisation ------------------

def make_midline_line(centroid, normal, length=300, n_points=200):
    """
    Create a line (points) along the direction perpendicular to normal,
    i.e. in the plane, to visualize mid-sagittal.

    length = total length of line in mm (in both directions)
    n_points = number of points on line
    """
    # we need a direction vector in the plane (perpendicular to normal)
    # pick arbitrary vector not parallel to normal
    v = np.array([1.0, 0.0, 0.0])
    if np.allclose(np.abs(np.dot(v, normal)), 1.0):
        v = np.array([0.0, 1.0, 0.0])

    # make v perpendicular to normal
    dir1 = v - np.dot(v, normal) * normal
    dir1 = dir1 / np.linalg.norm(dir1)

    # Also define a second in-plane direction (orthonormal)
    dir2 = np.cross(normal, dir1)
    dir2 = dir2 / np.linalg.norm(dir2)

    # parameter t for line
    ts = np.linspace(-length/2, length/2, n_points)

    # generate two sets of points: along dir1
    line_pts = np.array([centroid + t * dir1 for t in ts])
    return line_pts

# ------------------ Example usage ------------------

# Load skull mask
skull_mask, skull_affine = load_mask(r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\asymmetry\skull.nii.gz")

# Estimate refined midline
centroid, normal, icp_reg = estimate_midline_pca_icp(skull_mask, skull_affine)

print("Refined midline centroid:", centroid)
print("Refined normal:", normal)
print("ICP transformation:", icp_reg.transformation)

# Make a visualization line
midline_pts = make_midline_line(centroid, normal, length=300, n_points=200)
np.savetxt(r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\asymmetry\midline_pts.csv", midline_pts, delimiter=",", header="x,y,z", comments='')
