In [12]:
import nibabel as nib
import numpy as np
from skimage.morphology import skeletonize
from scipy.ndimage import gaussian_filter
import plotly.graph_objects as go


In [13]:

def load_nifti(file_path):
    return nib.load(file_path).get_fdata()


In [14]:

def compute_skeleton(binary_volume):
    return skeletonize(binary_volume)


# Math of Compute Normals


In [15]:

def compute_normals(skeleton):
    grad_x = np.gradient(skeleton, axis=0) # x-axis
    grad_y = np.gradient(skeleton, axis=1) # y-axis
    grad_z = np.gradient(skeleton, axis=2) # z-axis
    
    normals = np.stack([grad_x, grad_y, grad_z], axis=-1) # shape: (x, y, z, 3)
    norm = np.linalg.norm(normals, axis=-1, keepdims=True) # euclidean norm of the normals
    normals = np.divide(normals, norm, out=np.zeros_like(normals), where=norm!=0) # normalize the normals
    
    return normals


In [16]:

def extract_skeleton_points(skeleton):
    return np.array(np.where(skeleton)).T


In [17]:

def visualize_results(volume, skeleton, normals, output_file):
    skeleton_points = extract_skeleton_points(skeleton)
    
    # Create a 3D scatter plot for the skeleton
    scatter = go.Scatter3d(
        x=skeleton_points[:, 0],
        y=skeleton_points[:, 1],
        z=skeleton_points[:, 2],
        mode='markers',
        marker=dict(size=2, color='blue'),
        name='Skeleton'
    )
    
    # Create line segments for the normals
    normal_x = []
    normal_y = []
    normal_z = []
    for point in skeleton_points:
        normal = normals[point[0], point[1], point[2]]
        end_point = point + normal * 5  # Adjust the length of normal vectors
        normal_x.extend([point[0], end_point[0], None])
        normal_y.extend([point[1], end_point[1], None])
        normal_z.extend([point[2], end_point[2], None])
    
    normal_lines = go.Scatter3d(
        x=normal_x,
        y=normal_y,
        z=normal_z,
        mode='lines',
        line=dict(color='red', width=1),
        name='Normals'
    )
    
    # Create the layout
    layout = go.Layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        title='Coronary Artery Skeleton with Normals',
        width=1200,
        height=800,
        legend=dict(
            x=0.7,
            y=0.9,
            bgcolor='rgba(255, 255, 255, 0.5)',
            bordercolor='rgba(0, 0, 0, 0.1)',
            borderwidth=1
        )
    )
    
    # Combine all traces
    fig = go.Figure(data=[scatter, normal_lines], layout=layout)
    
    # Save as HTML
    fig.write_html(output_file)
    
    # Show the plot
    fig.show()


In [19]:

def main(nifti_file_path, output_file):
    # Load the NIfTI file
    volume = load_nifti(nifti_file_path)
    
    # Compute the skeleton
    skeleton = compute_skeleton(volume > 0)  # Assuming the segmentation is binary
    
    # Compute normals
    normals = compute_normals(gaussian_filter(skeleton.astype(float), sigma=1))
    
    # Visualize the results and save as HTML
    visualize_results(volume, skeleton, normals, output_file)

if __name__ == "__main__":
    nifti_file_path = "NIFTY/Img_001.nii.gz"
    output_file = "normals_web.html"
    main(nifti_file_path, output_file)