In [5]:
import dash
from dash import dcc, html, Input, Output
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import numpy as np
import nibabel as nib
from scipy.ndimage import binary_dilation, generate_binary_structure
from sklearn.decomposition import PCA
from skimage.measure import marching_cubes
import os
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.spatial import cKDTree
from collections import defaultdict

# Load the NIfTI file
folder_number = 'C0001'
file_path = os.path.join(r"D:\Utrecht Studying\Team challenge\Aneurysm_TC_data", folder_number, "corrected_segmentation.nii.gz")
nii_file = nib.load(file_path)
seg_data = nii_file.get_fdata()

x_max, y_max, z_max = seg_data.shape

def normalize_coords(x, y, z):
    return x / x_max, y / y_max, z / z_max


# Get the coordinates of voxels with values 1 and 2
x1, y1, z1 = np.where(seg_data == 1)  # Cerebral vessels
x2, y2, z2 = np.where(seg_data == 2)  # Aneurysms

# Downsampling function
def downsample_coordinates(x, y, z, factor=2):
    indices = np.arange(0, len(x), factor)
    return x[indices], y[indices], z[indices]

# Downsample coordinates
x1, y1, z1 = downsample_coordinates(x1, y1, z1, factor=1)
x2, y2, z2 = downsample_coordinates(x2, y2, z2, factor=1)

# Perform additional dilation on the connection region to make the connection surface thicker
def find_thick_connection_region(seg_data, iterations=1):
    # Create masks for vessels and aneurysms
    vessels_mask = (seg_data == 1)
    aneurysms_mask = (seg_data == 2)
    
    # Basic dilation to get the initial connection region
    structure = generate_binary_structure(3, 1)  # 3D connectivity structure
    dilated_vessels = binary_dilation(vessels_mask, structure=structure)
    dilated_aneurysms = binary_dilation(aneurysms_mask, structure=structure)
    connection_region = np.logical_and(dilated_vessels, dilated_aneurysms)
    

    x, y, z = np.where(connection_region)
    return x, y, z

# Use the further dilated connection region
x, y, z = find_thick_connection_region(seg_data, iterations=1)

# Calculate the normal vector of the connected region
def calculate_normal_vector(x, y, z):
    # Combine the coordinates into a set of points
    points = np.vstack((x, y, z)).T
    
    # Use PCA to fit a plane
    pca = PCA(n_components=3)
    pca.fit(points)
    
    # The normal vector is the eigenvector corresponding to the smallest eigenvalue of PCA
    normal_vector = pca.components_[2]
    return normal_vector

# Calculate the normal vector and center point
if len(x) > 0:
    normal_vector = calculate_normal_vector(x, y, z)
    center_x, center_y, center_z = np.mean(x), np.mean(y), np.mean(z)
else:
    normal_vector = np.array([0, 0, 1])  # Default normal vector
    center_x, center_y, center_z = 0, 0, 0  # Default center point

# Generate mesh for cerebral vessels and aneurysms
def generate_mesh(seg_data, label):
    mask = (seg_data == label)
    if np.any(mask):
        verts, faces, _, _ = marching_cubes(mask, level=0.5)
        return verts, faces
    return None, None

# Generate meshes for vessels (label 1) and aneurysms (label 2)
verts1, faces1 = generate_mesh(seg_data, 1)  # Cerebral vessels
verts2, faces2 = generate_mesh(seg_data, 2)  # Aneurysms

# Generate mesh for the (thickened) connected region
def generate_connection_mesh(x, y, z):
    if len(x) > 0:
        # Create a binary mask for the connected region
        connection_mask = np.zeros_like(seg_data, dtype=bool)
        connection_mask[x, y, z] = True
        
        # Generate mesh using marching cubes
        verts, faces, _, _ = marching_cubes(connection_mask, level=0.5)
        return verts, faces
    return None, None

# Generate mesh for the connected region
verts_conn, faces_conn = generate_connection_mesh(x, y, z)

# Calculate the two basis vectors of the orthogonal plane
def calculate_orthogonal_plane(normal_vector):
    # Choose a vector that is not parallel to the normal vector
    if normal_vector[0] != 0 or normal_vector[1] != 0:
        base_vector = np.array([0, 0, 1])
    else:
        base_vector = np.array([1, 0, 0])
    
    # Calculate the first orthogonal vector
    u = np.cross(normal_vector, base_vector)
    u /= np.linalg.norm(u)
    
    # Calculate the second orthogonal vector
    v = np.cross(normal_vector, u)
    v /= np.linalg.norm(v)
    
    return u, v

# Calculate the two basis vectors of the orthogonal plane
u, v = calculate_orthogonal_plane(normal_vector)

# Generate points on the plane
plane_size = 50 
plane_points = np.array([
    [center_x + u[0] * i + v[0] * j, 
     center_y + u[1] * i + v[1] * j, 
     center_z + u[2] * i + v[2] * j]
    for i in range(-plane_size, plane_size + 1, 10)
    for j in range(-plane_size, plane_size + 1, 10)
])

# Generate plane indices for Mesh3d
def generate_plane_indices(plane_size):
    indices = []
    for i in range(plane_size * 2):
        for j in range(plane_size * 2):
            indices.append([i, i + 1, j])
            indices.append([i + 1, j, j + 1])
    return np.array(indices)

plane_indices = generate_plane_indices(plane_size)

# Verify that the plane is orthogonal to the normal vector
def verify_orthogonality(plane_points, normal_vector):
    # Calculate the vector between two points on the plane
    vector_on_plane = plane_points[1] - plane_points[0]
    
    # The dot product should be zero if the plane is orthogonal to the normal vector
    dot_product = np.dot(vector_on_plane, normal_vector)
    return dot_product

# Verify orthogonality
dot_product = verify_orthogonality(plane_points, normal_vector)
print(f"Dot product between plane vector and normal vector: {dot_product}")

# Initialize Dash application
app = dash.Dash(__name__)

# Layout
app.layout = html.Div([

    # 3D Visualization
    html.Div([
        html.H3("3D and 2D Projection of Cerebral Vessels and Aneurysms"),
        # Add padding/margin to create space between the plot and angle values
        html.Div(id='camera-angles', style={'padding': '10px', 'font-size': '20px', 'textAlign': 'center', 'margin-bottom': '20px', 'color': 'black'}),
        dcc.Graph(
            id='3d-plot',
            figure=go.Figure(),
            config={'scrollZoom': True, 'displayModeBar': True},
            style={'height': '100vh'}
        )
    ]),

    # 2D Projection with Overlap
    html.Div([
        html.H3("2D Projection View (With Overlap)"),
        dcc.Graph(
            id='2d-plot-overlap',
            figure=go.Figure(),
            config={'displayModeBar': False},
            style={'height': '100vh'}
        )
    ]),

    # 2D Projection without Overlap
    html.Div([
        html.H3("2D Projection View (No Overlap)"),
        dcc.Graph(
            id='2d-plot-no-overlap',
            figure=go.Figure(),
            config={'displayModeBar': False},
            style={'height': '100vh'}
        )
    ]),

    # Angle Adjustment
    html.Div([
        html.Label("Adjust Rotation Angle (0-360°):"),
        dcc.Slider(
            id='angle-slider',
            min=0,
            max=360,
            step=1,
            value=0,
            marks={i: f"{i}°" for i in range(0, 361, 40)},
            tooltip={"placement": "bottom", "always_visible": True}
        )
    ]),

    # Toggle Structures
    html.Div([
        html.Label("Toggle Structures:"),
        dcc.Checklist(
            id='structure-toggle',
            options=[
                {'label': 'Cerebral Vessels', 'value': 'vessels'},
                {'label': 'Aneurysms', 'value': 'aneurysms'},
                {'label': 'Connection Region', 'value': 'connection'},
                {'label': 'Orthogonal Plane', 'value': 'plane'}
            ],
            value=['vessels', 'aneurysms', 'connection', 'plane'],
            inline=True
        )
    ]),
])


# Calculate camera view
def calculate_camera_view(normal_vector, angle, zoom_factor=1.5):
    angle_rad = np.deg2rad(angle)
    normal_vector = normal_vector / np.linalg.norm(normal_vector)
    u, v = calculate_orthogonal_plane(normal_vector) # Make sure u, v are orthogonal basis of the plane

    # Calculate the position of the eye (strictly on the plane)
    radius = 512 # Distance from the camera to the center
    plane_center=np.array([center_x, center_y, center_z])
    eye_pos = plane_center + radius * (np.cos(angle_rad) * u + np.sin(angle_rad) * v)

    # Project to the plane to eliminate numerical errors 
    projection_distance = np.dot(eye_pos - np.array([center_x, center_y, center_z]), normal_vector) 
    print(f"Eye position distance from plane: {projection_distance}") 
    eye_pos = eye_pos - projection_distance * (normal_vector) 

    if projection_distance> 1e-6: 
        eye_pos = plane_center + (eye_pos - plane_center) - np.dot(eye_pos - plane_center, normal_vector) * normal_vector 

    projection_distance = np.dot(eye_pos - np.array([center_x, center_y, center_z]), normal_vector) 
    print(f"Eye position distance from plane(corrected): {projection_distance}") 
    # Set camera parameters 
    camera_eye = dict(x=float(eye_pos[0]), y=float(eye_pos[1]), z=float(eye_pos[2])) 
    camera_center = dict(x=float(center_x), y=float(center_y), z=float(center_z)) 
    camera_up = dict(x=float(normal_vector[0]), y=float(normal_vector[1]), z=float(normal_vector[2])) 

    return camera_eye, camera_center, camera_up

# Project 3D points to 2D
def project_to_2d(x, y, z, camera_eye, camera_center, camera_up):
    camera_eye = np.array([camera_eye['x'], camera_eye['y'], camera_eye['z']])
    camera_center = np.array([camera_center['x'], camera_center['y'], camera_center['z']])
    camera_up = np.array([camera_up['x'], camera_up['y'], camera_up['z']])

    # Calculate the view direction (from eye to center)
    view_dir = camera_eye - camera_center
    view_dir = view_dir / np.linalg.norm(view_dir)

    # Calculate the right vector (orthogonal to camera_up and view_dir)
    right = np.cross(camera_up, view_dir)
    right = right / np.linalg.norm(right)

    # Recalculate the up vector to ensure orthogonality
    up = np.cross(view_dir, right)
    up = up / np.linalg.norm(up)

    # Convert to camera coordinate system
    points = np.column_stack([x, y, z])
    translated_points = points - camera_eye

    # Calculate the projected coordinates
    projected_x = np.dot(translated_points, right)
    projected_y = np.dot(translated_points, up)
    depth = np.dot(translated_points, view_dir) # Keep depth information

    # Calculate the projected coordinates of the center point (should be close to (0,0))
    center_proj_x = np.dot(camera_center - camera_eye, right)
    center_proj_y = np.dot(camera_center - camera_eye, up)

    # Translate all points so that the center point is (0,0)
    projected_x -= center_proj_x
    projected_y -= center_proj_y

    # Handle NaN values
    projected_x = np.nan_to_num(projected_x, nan=0)
    projected_y = np.nan_to_num(projected_y, nan=0)
    depth = np.nan_to_num(depth, nan=0)

    return projected_x, projected_y, depth # Return depth information

def calculate_azimuth_elevation(camera_eye, camera_center):
    # Calculate vector from camera to center
    direction = np.array([camera_center['x'] - camera_eye['x'],
                            camera_center['y'] - camera_eye['y'],
                            camera_center['z'] - camera_eye['z']])
    direction = direction / np.linalg.norm(direction)
    
    azimuth = np.arctan2(direction[1], direction[0]) * 180 / np.pi
    elevation = np.arctan2(direction[2], np.sqrt(direction[0]**2 + direction[1]**2)) * 180 / np.pi
    return azimuth, elevation

# Update the camera angles based on user interaction
@app.callback(
    dash.dependencies.Output('camera-angles', 'children'),
    [dash.dependencies.Input('3d-plot', 'relayoutData')]
)
def update_camera_angles(relayoutData):
    # Ensure relayoutData is not None and contains camera data
    if relayoutData is not None and 'scene.camera' in relayoutData:
        camera = relayoutData['scene.camera']
        if camera:
            eye = camera['eye']
            # Calculate azimuth and elevation
            azimuth = np.arctan2(eye['y'], eye['x']) * 180 / np.pi  # In degrees
            elevation = np.arctan2(eye['z'], np.sqrt(eye['x']**2 + eye['y']**2)) * 180 / np.pi
            return f"Azimuth: {azimuth:.2f}°, Elevation: {elevation:.2f}°", azimuth, elevation
    # Return a default value if relayoutData is None or doesn't contain camera information
    return "Azimuth: 0°, Elevation: 0°", 0, 0

@app.callback(
    [Output('3d-plot', 'figure'), 
     Output('2d-plot-overlap', 'figure'), 
     Output('2d-plot-no-overlap', 'figure')],
    [Input('angle-slider', 'value'), 
     Input('structure-toggle', 'value')]
)
def update_plots(angle, visible_structures):
    fig_3d = go.Figure()
    fig_2d_overlap = go.Figure()
    fig_2d_no_overlap = go.Figure()

    # Add cerebral vessels
    if 'vessels' in visible_structures and verts1 is not None and faces1 is not None:
        fig_3d.add_trace(go.Mesh3d(
            x=verts1[:, 0], y=verts1[:, 1], z=verts1[:, 2],
            i=faces1[:, 0], j=faces1[:, 1], k=faces1[:, 2],
            color='red',
            opacity=0.5,
            name="Cerebral Vessels (Label 1)"
        ))

    # Add aneurysms
    if 'aneurysms' in visible_structures and verts2 is not None and faces2 is not None:
        fig_3d.add_trace(go.Mesh3d(
            x=verts2[:, 0], y=verts2[:, 1], z=verts2[:, 2],
            i=faces2[:, 0], j=faces2[:, 1], k=faces2[:, 2],
            color='blue',
            opacity=0.5,
            name="Aneurysms (Label 2)"
        ))

    # Add connection region
    if 'connection' in visible_structures and verts_conn is not None and faces_conn is not None:
        fig_3d.add_trace(go.Mesh3d(
            x=verts_conn[:, 0], y=verts_conn[:, 1], z=verts_conn[:, 2],
            i=faces_conn[:, 0], j=faces_conn[:, 1], k=faces_conn[:, 2],
            color='green',
            opacity=0.5,
            name="Connection Region"
        ))

    # Add orthogonal plane
    if 'plane' in visible_structures:
        fig_3d.add_trace(go.Mesh3d(
            x=plane_points[:, 0], y=plane_points[:, 1], z=plane_points[:, 2],
            i=plane_indices[:, 0], j=plane_indices[:, 1], k=plane_indices[:, 2],
            color='yellow',
            opacity=0.5,
            name="Orthogonal Plane"
        ))

    # Add normal vector
    if len(x) > 0:
        start_point = np.array([center_x, center_y, center_z])
        end_point = start_point + normal_vector * 50
        fig_3d.add_trace(go.Scatter3d(
            x=[start_point[0], end_point[0]],
            y=[start_point[1], end_point[1]],
            z=[start_point[2], end_point[2]],
            mode='lines',
            line=dict(color='purple', width=3),
            name="Normal Vector"
        ))

    # Calculate camera view - now returns camera_center as well
    camera_eye, camera_center, camera_up = calculate_camera_view(normal_vector, angle, zoom_factor=2)
    azimuth, elevation = calculate_azimuth_elevation(camera_eye, camera_center)

    # Project points to 2D - now using camera_center
    projected_x1, projected_y1, depth1 = project_to_2d(x1, y1, z1, camera_eye, camera_center, camera_up)
    projected_x2, projected_y2, depth2 = project_to_2d(x2, y2, z2, camera_eye, camera_center, camera_up)

    if verts_conn is not None:
        projected_x_conn, projected_y_conn, depth_conn = project_to_2d(verts_conn[:, 0], verts_conn[:, 1], verts_conn[:, 2], camera_eye, camera_center, camera_up)
    else:
        projected_x_conn, projected_y_conn, depth_conn = np.array([]), np.array([]), np.array([])

    # Combine all points with their labels and depths
    all_points = np.concatenate([
        np.column_stack((projected_x1, projected_y1, depth1, np.zeros_like(depth1))),  # Label 0: Cerebral vessels
        np.column_stack((projected_x2, projected_y2, depth2, np.ones_like(depth2))),   # Label 1: Aneurysms
        np.column_stack((projected_x_conn, projected_y_conn, depth_conn, 2 * np.ones_like(depth_conn)))  # Label 2: Connection region
    ])
    sorted_indices = np.argsort(all_points[:, 2])  # Sort by depth (descending order)
    sorted_points = all_points[sorted_indices]
    
    # Use KDTree for faster overlap detection
    tree = cKDTree(sorted_points[:, :2])
    overlapping_indices_vessels_aneurysms = set()  # For vessels + aneurysms overlap (pink)
    overlapping_indices_vessels_connection = set()  # For vessels + connection overlap (gold)

    for i, point in enumerate(sorted_points[:, :2]):
        neighbors = tree.query_ball_point(point, r=0.3)  # Adjust radius as needed
        if len(neighbors) > 1:
            # Get labels of neighboring points
            labels = sorted_points[neighbors, 3]
            unique_labels = np.unique(labels)
            
            # Check for vessels + aneurysms overlap
            if 0 in unique_labels and 1 in unique_labels:
                overlapping_indices_vessels_aneurysms.update(neighbors)
            
            # Check for vessels + connection overlap
            if 0 in unique_labels and 2 in unique_labels:
                overlapping_indices_vessels_connection.update(neighbors)

    # Assign colors
    color_map = {0: 'red', 1: 'blue', 2: 'green'}
    colors = np.array([color_map[label] for label in sorted_points[:, 3]])

    # Mark overlapping points
    if len(overlapping_indices_vessels_aneurysms) > 0:
        colors[list(overlapping_indices_vessels_aneurysms)] = 'pink'  # Vessels + aneurysms overlap
    if len(overlapping_indices_vessels_connection) > 0:
        colors[list(overlapping_indices_vessels_connection)] = 'gold'  # Vessels + connection overlap

    # Add 2D projection with overlap after sorting
    fig_2d_overlap.add_trace(go.Scatter(
        x=sorted_points[:, 0],
        y=sorted_points[:, 1],
        mode='markers',
        marker=dict(size=6, color=colors),
        name='2D Projection (With Overlap)'
    ))

    # Add legend traces for the overlap plot
    fig_2d_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='red'),
        name="Cerebral Vessels"
    ))
    fig_2d_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='blue'),
        name="Aneurysms"
    ))
    fig_2d_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='green'),
        name="Connection Region"
    ))
    fig_2d_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='pink'),
        name="Overlap (Vessels + Aneurysms)"
    ))
    fig_2d_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='gold'),
        name="Overlap (Vessels + Connection)"
    ))
    
    # Set aspect ratio to equal for the overlap plot
    fig_2d_overlap.update_layout(
        xaxis=dict(scaleanchor="y"),  # This ensures that the x and y axes have the same scale
        yaxis=dict(scaleanchor="x"),
    )

    # Assign colors without overlap (no pink color)
    colors_sorted = np.array([color_map[label] for label in sorted_points[:, 3]])

    # Add 2D projection without overlap
    fig_2d_no_overlap.add_trace(go.Scatter(
        x=sorted_points[:, 0],
        y=sorted_points[:, 1],
        mode='markers',
        marker=dict(size=6, color=colors_sorted),
        name='2D Projection (No Overlap, Sorted by Depth)'
    ))

    # Add legend traces for the no-overlap plot
    fig_2d_no_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='red'),
        name="Cerebral Vessels"
    ))
    fig_2d_no_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='blue'),
        name="Aneurysms"
    ))
    fig_2d_no_overlap.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=2, color='green'),
        name="Connection Region"
    ))

    # Set aspect ratio to equal for the no-overlap plot
    fig_2d_no_overlap.update_layout(
        xaxis=dict(scaleanchor="y"),  # This ensures that the x and y axes have the same scale
        yaxis=dict(scaleanchor="x"),
    )

    # Update 3D layout with proper camera settings
    fig_3d.update_layout(
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
            camera=dict(

                eye=dict( 
                x=camera_eye['x']/x_max, 
                y=camera_eye['y']/y_max, 
                z=camera_eye['z']/z_max 
                ),
                center=dict( 
                x=camera_center['x']/x_max, 
                y=camera_center['y']/y_max, 
                z=camera_center['z']/x_max 
                ),
                up=dict(
                    x=camera_up['x'],
                    y=camera_up['y'],
                    z=camera_up['z']
                )
            )
        ),
        title="3D Projection"
    )

    return fig_3d, fig_2d_overlap, fig_2d_no_overlap


# Run the app 
if __name__ == '__main__':
    app.run_server(debug=True, port=8052)

Dot product between plane vector and normal vector: -5.773159728050814e-15


Eye position distance from plane: 2.842170943040401e-14
Eye position distance from plane(corrected): 0.0
