In [None]:
import nibabel as nib
import numpy as np
from skimage import measure
import plotly.graph_objects as go

def load_masks(path_L, path_R):
    mandible_L_nii = nib.load(path_L)
    mandible_R_nii = nib.load(path_R)
    mandible_L = mandible_L_nii.get_fdata().astype(bool)
    mandible_R = mandible_R_nii.get_fdata().astype(bool)
    return mandible_L, mandible_R

def fit_midplane(mandible_L, mandible_R):
    coords_L = np.argwhere(mandible_L)
    coords_R = np.argwhere(mandible_R)
    all_points = np.vstack([coords_L, coords_R])
    centroid = all_points.mean(axis=0)
    _, _, Vt = np.linalg.svd(all_points - centroid)
    normal = Vt[-1]
    normal /= np.linalg.norm(normal)
    return centroid, normal

def create_mandible_mesh(mandible_L, mandible_R):
    combined = mandible_L + mandible_R
    verts, faces, _, _ = measure.marching_cubes(combined, level=0.5)
    return verts, faces

def create_plane_mesh(point, normal, size=100):

    if np.allclose(normal, [0,0,1]):
        v = np.array([1,0,0])
    else:
        v = np.cross(normal, [0,0,1])
        v /= np.linalg.norm(v)
    u = np.cross(normal, v)
  
    corners = np.array([[-size, -size],
                        [ size, -size],
                        [ size,  size],
                        [-size,  size]])
    plane_pts = point + corners[:,0,None]*u + corners[:,1,None]*v
    plane_faces = np.array([[0,1,2],[0,2,3]])
    return plane_pts, plane_faces

def classify_points(points, plane_point, plane_normal):
    """
    Returns +1 for right, -1 for left
    points: Nx3 array
    """
    side = np.sign(np.dot(points - plane_point, plane_normal))
    return side

def plot_mandible_with_plane(verts, plane_pts, plane_faces):
    mesh = go.Mesh3d(
        x=verts[:,0], y=verts[:,1], z=verts[:,2],
        color='orange', opacity=0.5
    )
    plane = go.Mesh3d(
        x=plane_pts[:,0], y=plane_pts[:,1], z=plane_pts[:,2],
        color='blue', opacity=0.3
    )
    fig = go.Figure(data=[mesh, plane])
    fig.update_layout(scene=dict(
        xaxis_title='X', yaxis_title='Y', zaxis_title='Z'
    ))
    fig.show()

path_L = r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\asymmetry\Mandible_L.nii.gz"
path_R = r"Z:\FacialDeformation_MPhys\rhabdo_data_proton\DICOMS\abby\UIDQQ0x7axQ0Q1\asymmetry\Mandible_R.nii.gz"

mandible_L, mandible_R = load_masks(path_L, path_R)
plane_point, plane_normal = fit_midplane(mandible_L, mandible_R)
verts, faces = create_mandible_mesh(mandible_L, mandible_R)
plane_pts, plane_faces = create_plane_mesh(plane_point, plane_normal)

print("Plane point:", plane_point)
print("Plane normal:", plane_normal)

plot_mandible_with_plane(verts, plane_pts, plane_faces)

test_point = np.array([50, 60, 30])
side = classify_points(test_point, plane_point, plane_normal)
print("Side (+1=right, -1=left):", side)
