In [1]:
import numpy as np
from scipy import interpolate
import plotly.graph_objects as go
from scipy.integrate import cumtrapz

def generate_sphere_point_cloud(center, radius, point_density):
    sphere_volume = 4/3 * np.pi * radius**3
    num_points = int(sphere_volume * point_density)
    phi = np.random.uniform(0, 2*np.pi, num_points)
    costheta = np.random.uniform(-1, 1, num_points)
    theta = np.arccos(costheta)
    u = np.random.uniform(0, 1, num_points)
    r = radius * np.cbrt(u)
    x = r * np.sin(theta) * np.cos(phi)
    y = r * np.sin(theta) * np.sin(phi)
    z = r * np.cos(theta)
    return np.column_stack((x, y, z)) + center

def generate_random_control_points(num_points, scale):
    return np.random.randn(num_points, 3) * scale

def generate_curvilinear_structure_with_spheres(max_points=10000):
    num_control_points = np.random.randint(5, 10)
    curve_scale = np.random.uniform(1, 3)
    slice_radius = np.random.uniform(0.3, 0.7)
    
    control_points = generate_random_control_points(num_control_points, curve_scale)
    
    degree = 3
    t = np.linspace(0, 1, 500)
    tck, _ = interpolate.splprep(control_points.T, k=degree, s=0)
    skeleton = np.array(interpolate.splev(t, tck)).T

    curve_length = np.sum(np.linalg.norm(np.diff(skeleton, axis=0), axis=1))
    curve_volume = np.pi * slice_radius**2 * curve_length
    point_density = max_points / (2 * curve_volume)

    main_points = []
    skeleton_indices = []
    for i in range(len(skeleton) - 1):
        segment_length = np.linalg.norm(skeleton[i+1] - skeleton[i])
        num_points = int(segment_length * np.pi * slice_radius**2 * point_density)
        
        tangent = skeleton[i+1] - skeleton[i]
        tangent = tangent / np.linalg.norm(tangent)

        normal = np.random.randn(3)
        normal = normal - normal.dot(tangent) * tangent
        normal = normal / np.linalg.norm(normal)
        binormal = np.cross(tangent, normal)

        for _ in range(num_points):
            r = slice_radius * np.sqrt(np.random.uniform(0, 1))
            theta = np.random.uniform(0, 2*np.pi)
            point = skeleton[i] + segment_length * np.random.uniform(0, 1) * tangent + r * (np.cos(theta) * normal + np.sin(theta) * binormal)
            main_points.append(point)
            skeleton_indices.append(i)

    main_points = np.array(main_points)
    skeleton_indices = np.array(skeleton_indices)

    num_spheres = np.random.randint(3, 4)
    additional_spheres = []
    sphere_skeleton_indices = []
    for _ in range(num_spheres):
        P_index = np.random.randint(0, len(skeleton))
        P = skeleton[P_index]

        if P_index == len(skeleton) - 1:
            tangent = skeleton[P_index] - skeleton[P_index-1]
        else:
            tangent = skeleton[P_index+1] - skeleton[P_index]
        tangent = tangent / np.linalg.norm(tangent)
        normal = np.random.randn(3)
        normal = normal - normal.dot(tangent) * tangent
        normal = normal / np.linalg.norm(normal)
        binormal = np.cross(tangent, normal)

        sphere_radius = slice_radius * np.random.uniform(1.0, 2.0)
        
        sphere = generate_sphere_point_cloud(P, sphere_radius, point_density)
        
        # Calculate the segment range based on sphere radius and skeleton properties
        avg_segment_length = curve_length / (len(skeleton) - 1)
        segment_range = int(np.ceil(2 * sphere_radius / avg_segment_length))
        
        # Ensure the segment range is at least 2 and not larger than the skeleton
        segment_range = max(2, min(segment_range, len(skeleton) - 1))
        
        start_index = max(0, P_index - segment_range // 2)
        end_index = min(len(skeleton), P_index + segment_range // 2)
        
        sphere_indices = []
        for point in sphere:
            distances = np.linalg.norm(skeleton[start_index:end_index] - point, axis=1)
            closest_skeleton_index = start_index + np.argmin(distances)
            sphere_indices.append(closest_skeleton_index)
        
        # Move sphere along normal direction to make it tangent
        sphere_shifted = sphere + normal * (slice_radius + sphere_radius)
        
        additional_spheres.append(sphere_shifted)
        sphere_skeleton_indices.append(np.array(sphere_indices))

    combined_point_cloud = np.vstack([main_points] + additional_spheres)
    combined_indices = np.concatenate([np.zeros(len(main_points))] + [np.full(len(sphere), i+1) for i, sphere in enumerate(additional_spheres)])
    combined_skeleton_indices = np.concatenate([skeleton_indices] + sphere_skeleton_indices)

    return skeleton, combined_point_cloud, combined_indices, combined_skeleton_indices

# Generate the curvilinear structure with additional spheres
skeleton, combined_point_cloud, combined_indices, combined_skeleton_indices = generate_curvilinear_structure_with_spheres(max_points=100000)

print(f"Skeleton shape: {skeleton.shape}")
print(f"Combined point cloud shape: {combined_point_cloud.shape}")
print(f"Combined indices shape: {combined_indices.shape}")
print(f"Combined skeleton indices shape: {combined_skeleton_indices.shape}")
print(f"Number of additional spheres: {len(np.unique(combined_indices)) - 1}")
for i in range(1, len(np.unique(combined_indices))):
    print(f"Number of points in additional sphere {i}: {np.sum(combined_indices == i)}")
print(f"Total number of points: {len(combined_point_cloud)}")

Skeleton shape: (500, 3)
Combined point cloud shape: (64262, 3)
Combined indices shape: (64262,)
Combined skeleton indices shape: (64262,)
Number of additional spheres: 3
Number of points in additional sphere 1: 4808
Number of points in additional sphere 2: 1666
Number of points in additional sphere 3: 8023
Total number of points: 64262


In [None]:
# visualize the original data with skeleton

my_data = []
pc = combined_point_cloud
lb = combined_indices
skel_id = combined_skeleton_indices

# indx_in = np.argwhere( (skel_id > 190) & (skel_id <195)).squeeze()

indx_in = np.where(lb == 0)[0]
pc_in = pc[indx_in]

indx_an = np.where(lb != 0)[0]
pc_an = pc[indx_an]
cl = skeleton
cl_t = skeleton[190:195]

x, y, z = pc_in[:, 0], pc_in[:, 1], pc_in[:, 2]
data=go.Scatter3d(x=x,y=y,z=z,mode='markers', marker=dict(size=1.5, opacity=0.5,color = '#FD836E'))

my_data.append(data)

x, y, z = pc_an[:, 0], pc_an[:, 1], pc_an[:, 2]
data=go.Scatter3d(x=x,y=y,z=z,mode='markers', marker=dict(size=1.5, opacity=1,color = 'red'))
my_data.append(data)


fig = go.Figure(data = my_data)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), template="plotly_white")
fig.update_scenes(
    xaxis=dict(showgrid=False, showticklabels=False, title=dict(text='')),
    yaxis=dict(showgrid=False, showticklabels=False, title=dict(text='')),
    zaxis=dict(showgrid=False, showticklabels=False, title=dict(text='')),
    camera=dict(eye=dict(x=-1, y=-1.25, z=1)) )
fig.update_shapes(visible=False)    

In [1]:
import numpy as np
from scipy import interpolate
import h5py

import numpy as np
import plotly.graph_objects as go
# from NeAR.visualization import plot_vox
import nibabel as nib
from scipy.interpolate import interp1d,UnivariateSpline
import os

def smooth_3d_array(points,num=None,**kwargs):
    x,y,z = points[:,0],points[:,1],points[:,2]
    points = np.zeros((num,3))
    if num is None:
        num = len(x)
    w = np.arange(0,len(x),1)
    sx = UnivariateSpline(w,x,**kwargs)
    sy = UnivariateSpline(w,y,**kwargs)
    sz = UnivariateSpline(w,z,**kwargs)
    wnew = np.linspace(0,len(x),num)
    points[:,0] = sx(wnew)
    points[:,1] = sy(wnew)
    points[:,2] = sz(wnew)
    return points

def calculate_tnb_frame(curve, epsilon=1e-8):
    curve = np.asarray(curve)
    
    # Calculate T (tangent)
    T = np.gradient(curve, axis=0)
    T_norms = np.linalg.norm(T, axis=1)
    T = T / T_norms[:, np.newaxis]
    
    # Identify straight segments
    is_straight = T_norms < epsilon
    
    # Calculate N (normal) for non-straight parts
    dT = np.gradient(T, axis=0)
    N = dT - np.sum(dT * T, axis=1)[:, np.newaxis] * T
    N_norms = np.linalg.norm(N, axis=1)
    
    # Handle points where the normal is undefined or in straight segments
    undefined_N = (N_norms < epsilon) | is_straight
    
    if np.all(undefined_N):
        print("the entire curve is straight")
        # If the entire curve is straight, choose an arbitrary normal
        N = np.zeros_like(T)
        N[:, 0] = T[:, 1]
        N[:, 1] = -T[:, 0]
        N = N / np.linalg.norm(N, axis=1)[:, np.newaxis]
    elif np.any(undefined_N):
        print("handling straight parts")
        # Only proceed with interpolation if there are any straight parts
        # Find segments of curved and straight parts
        segment_changes = np.where(np.diff(undefined_N))[0] + 1
        segments = np.split(np.arange(len(curve)), segment_changes)
        
        for segment in segments:
            if undefined_N[segment[0]]:
                # This is a straight segment
                left_curved = np.where(~undefined_N[:segment[0]])[0]
                right_curved = np.where(~undefined_N[segment[-1]+1:])[0] + segment[-1] + 1
                
                if len(left_curved) > 0 and len(right_curved) > 0:
                    # Interpolate between left and right curved parts
                    left_N = N[left_curved[-1]]
                    right_N = N[right_curved[0]]
                    t = np.linspace(0, 1, len(segment))
                    N[segment] = (1-t[:, np.newaxis]) * left_N + t[:, np.newaxis] * right_N
                elif len(left_curved) > 0:
                    # Use normal from left curved part
                    N[segment] = N[left_curved[-1]]
                elif len(right_curved) > 0:
                    # Use normal from right curved part
                    N[segment] = N[right_curved[0]]
                else:
                    # No curved parts found, use arbitrary normal
                    N[segment] = np.array([T[segment[0]][1], -T[segment[0]][0], 0])
                
                # Ensure N is perpendicular to T
                N[segment] = N[segment] - np.sum(N[segment] * T[segment], axis=1)[:, np.newaxis] * T[segment]
                N[segment] = N[segment] / np.linalg.norm(N[segment], axis=1)[:, np.newaxis]
    else:
        print("no straight parts")
    
    # If there are no straight parts, N is already calculated correctly for all points
    
    # Calculate B (binormal) ensuring orthogonality
    B = np.cross(T, N)
    
    # Ensure perfect orthogonality through Gram-Schmidt
    N = N - np.sum(N * T, axis=1)[:, np.newaxis] * T
    N = N / np.linalg.norm(N, axis=1)[:, np.newaxis]
    
    B = B - np.sum(B * T, axis=1)[:, np.newaxis] * T
    B = B - np.sum(B * N, axis=1)[:, np.newaxis] * N
    B = B / np.linalg.norm(B, axis=1)[:, np.newaxis]
    
    return T, N, B

def straighten_using_frenet(helix, points, skel_idx=None):
    """
    Straighten the structure based on the helix (skeleton) using the Frenet frame.
    
    Args:
    - helix (numpy array): Points forming the helix (skeleton).
    - points (numpy array): Points surrounding the helix.
    
    Returns:
    - straightened_helix (numpy array): Straightened version of the helix.
    - straightened_points (numpy array): Transformed surrounding points.
    """
    # Compute the Frenet frame for the helix
    T, N, B = calculate_tnb_frame(helix)
    
    # Parameterize the helix based on cumulative distance (arclength)
    deltas = np.diff(helix, axis=0)
    distances = np.linalg.norm(deltas, axis=1)
    cumulative_distances = np.insert(np.cumsum(distances), 0, 0)
    
    # Map helix to a straight line along Z-axis
    straightened_helix = np.column_stack((np.zeros_like(cumulative_distances),
                                          np.zeros_like(cumulative_distances),
                                          cumulative_distances))
    
    straightened_points = []
    if skel_idx is not None:
        for point_idx in range(points.shape[0]):
            point = points[point_idx]
            closest_idx = skel_idx[point_idx]
            
            # Compute the vector from the closest helix point to the current point
            vector = point - helix[closest_idx]
            # print(vector.shape)
            
            # Compute the azimuthal and polar angles using the Frenet frame
            theta = np.arctan2(np.dot(vector, N[closest_idx]), np.dot(vector, B[closest_idx]))
            phi = np.arccos(np.dot(vector, T[closest_idx]) / np.linalg.norm(vector))
            
            # Compute the radius (distance to the helix)
            r = np.linalg.norm(vector)
            
            # Map the point using the computed spherical coordinates
            # Mapping 1
            x = r * np.cos(theta)
            y = r * np.sin(theta)
            z = cumulative_distances[closest_idx]

            # # # Mapping 2
            # x = r * np.sin(phi) * np.cos(theta)
            # y = r * np.sin(phi) * np.sin(theta)
            # z = cumulative_distances[closest_idx] + r * np.cos(phi)
            
            straightened_point = [x, y, z]
            straightened_points.append(straightened_point)

    else:
        for point in points:
            # Find closest point on the helix
            deltas = helix - point
            distances_to_helix = np.linalg.norm(deltas, axis=1)
            closest_idx = np.argmin(distances_to_helix)
            
            # Compute the vector from the closest helix point to the current point
            vector = point - helix[closest_idx]
            
            # Compute the azimuthal and polar angles using the Frenet frame
            theta = np.arctan2(np.dot(vector, N[closest_idx]), np.dot(vector, B[closest_idx]))
            phi = np.arccos(np.dot(vector, T[closest_idx]) / np.linalg.norm(vector))
            
            # Compute the radius (distance to the helix)
            r = distances_to_helix[closest_idx]
            
            # Map the point using the computed spherical coordinates
            # Mapping 1
            # x = r * np.cos(theta)
            # y = r * np.sin(theta)
            # z = cumulative_distances[closest_idx]

            # # Mapping 2
            x = r * np.sin(phi) * np.cos(theta)
            y = r * np.sin(phi) * np.sin(theta)
            z = cumulative_distances[closest_idx] + r * np.cos(phi)
            
            straightened_point = [x, y, z]
            straightened_points.append(straightened_point)
    
    return straightened_helix, np.array(straightened_points)

def visualize_frenet_frame(points, T, N, B):
    """
    Visualize the curve and its Frenet frame using Plotly with arrows at the end of the extended TNB vectors.
    
    Args:
    - points (list of numpy arrays): List of 3D points representing the curve.
    - T, N, B (lists of numpy arrays): Lists of tangent, normal, and binormal vectors.
    """
    
    # Extract coordinates for easy plotting
    X, Y, Z = zip(*points)
    
    # Create the plot
    fig = go.Figure()

    # Plot the curve
    fig.add_trace(go.Scatter3d(x=X, y=Y, z=Z, mode='lines', name='Curve', line=dict(color='black')))

    # Define extension length and arrow length
    extension_length = 5
    arrow_length = 10

    # Plot the Frenet frame with arrows at the end
    for i in range(0, len(points), 5):  # we sample every 10th point for clarity
        # Calculate ending points for extended vectors
        extended_end_T = points[i] + T[i] * extension_length
        extended_end_N = points[i] + N[i] * extension_length
        extended_end_B = points[i] + B[i] * extension_length
        
        fig.add_trace(go.Scatter3d(x=[X[i], extended_end_T[0]], y=[Y[i], extended_end_T[1]], z=[Z[i], extended_end_T[2]], 
                                   mode='lines', line=dict(color='red'), showlegend=(i==0), name='Extended Tangent' if i == 0 else ""))
        fig.add_trace(go.Scatter3d(x=[X[i], extended_end_N[0]], y=[Y[i], extended_end_N[1]], z=[Z[i], extended_end_N[2]], 
                                   mode='lines', line=dict(color='green'), showlegend=(i==0), name='Extended Normal' if i == 0 else ""))
        fig.add_trace(go.Scatter3d(x=[X[i], extended_end_B[0]], y=[Y[i], extended_end_B[1]], z=[Z[i], extended_end_B[2]], 
                                   mode='lines', line=dict(color='blue'), showlegend=(i==0), name='Extended Binormal' if i == 0 else ""))
        
        # Plot the cones at the extended positions
        fig.add_trace(go.Cone(x=[extended_end_T[0]], y=[extended_end_T[1]], z=[extended_end_T[2]], 
                              u=[T[i][0]*arrow_length], v=[T[i][1]*arrow_length], w=[T[i][2]*arrow_length], 
                              sizemode='scaled', sizeref=0.1, anchor='cm', showscale=False, colorscale=[[0, 'red'], [1, 'red']], 
                              name='Tangent' if i == 0 else "", showlegend=(i==0)))
        fig.add_trace(go.Cone(x=[extended_end_N[0]], y=[extended_end_N[1]], z=[extended_end_N[2]], 
                              u=[N[i][0]*arrow_length], v=[N[i][1]*arrow_length], w=[N[i][2]*arrow_length], 
                              sizemode='scaled', sizeref=0.1, anchor='cm', showscale=False, colorscale=[[0, 'green'], [1, 'green']], 
                              name='Normal' if i == 0 else "", showlegend=(i==0)))
        fig.add_trace(go.Cone(x=[extended_end_B[0]], y=[extended_end_B[1]], z=[extended_end_B[2]], 
                              u=[B[i][0]*arrow_length], v=[B[i][1]*arrow_length], w=[B[i][2]*arrow_length], 
                              sizemode='scaled', sizeref=0.1, anchor='cm', showscale=False, colorscale=[[0, 'blue'], [1, 'blue']], 
                              name='Binormal' if i == 0 else "", showlegend=(i==0)))

    # Configure the layout
    fig.update_layout(scene=dict(aspectmode='auto'))

    # Show the plot
    fig.show()

# First, include the entire generate_curvilinear_structure_with_spheres function and its helper functions here
# (Copy the entire code from the previous response up to but not including the line "# Generate the curvilinear structure with additional spheres")

# Then, add the following code:

def sample_points(points, labels, skeleton_indices, n_points=4096):
    if len(points) >= n_points:
        # Randomly sample points
        idx = np.random.choice(len(points), n_points, replace=False)
    else:
        # Randomly sample with replacement to reach n_points
        idx = np.random.choice(len(points), n_points, replace=True)
    return points[idx], labels[idx], skeleton_indices[idx]

def generate_dataset(num_samples, n_points=4096):
    dataset = []
    for _ in range(num_samples):
        skeleton, point_cloud, indices, skeleton_indices = generate_curvilinear_structure_with_spheres()
        
        # Convert indices to binary labels: 0 for main structure, 1 for spheres
        labels = (indices > 0).astype(int)
        # Sample or pad to n_points
        point_cloud, labels, skeleton_indices = sample_points(point_cloud, labels, skeleton_indices, n_points)

        # Straighten the point cloud using the Frenet frame
        cl_t, seg_t = straighten_using_frenet(skeleton, point_cloud,skeleton_indices)
        # cl_t, seg_t = straighten_using_frenet(skeleton, point_cloud)
        pc_trans = np.array(seg_t)
        skel_trans = np.array(cl_t)

        dataset.append((point_cloud, labels, skeleton_indices,skeleton,pc_trans,skel_trans))
    return dataset

def save_h5(filename, points, labels, skeleton_indices,skeleton,points_trans,skel_trans):
    with h5py.File(filename, 'w') as f:
        f.create_dataset('pc', data=points)
        f.create_dataset('lb', data=labels)
        f.create_dataset('idx_idx', data=skeleton_indices)
        f.create_dataset('skel', data=skeleton)
        f.create_dataset('pc_trans', data=points_trans)
        f.create_dataset('skel_trans', data=skel_trans)