<a href="https://colab.research.google.com/github/JessemanGray/PHYLLOTAXIS/blob/main/Phyllo_benchmarks_kdtree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from scipy.spatial import KDTree
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time

# Generate points on a sphere using the Fibonacci Sphere algorithm
def generate_fibonacci_sphere(n_points, radius=1.0):
    indices = np.arange(n_points, dtype=float) + 0.5
    phi = np.arccos(1 - 2 * indices / n_points)  # Latitude
    golden_ratio = (1 + np.sqrt(5)) / 2  # Golden ratio
    theta = 2 * np.pi * indices / golden_ratio  # Longitude

    x = radius * np.cos(theta) * np.sin(phi)
    y = radius * np.sin(theta) * np.sin(phi)
    z = radius * np.cos(phi)

    points = np.vstack((x, y, z)).T  # Store in (N, 3) shape
    return points

# Generate grid points
def generate_grid_points(n_points, radius=1.0):
    num_side = int(np.cbrt(n_points))
    x_vals, y_vals, z_vals = np.meshgrid(
        np.linspace(-radius, radius, num_side),
        np.linspace(-radius, radius, num_side),
        np.linspace(-radius, radius, num_side),
    )
    x_vals, y_vals, z_vals = x_vals.flatten(), y_vals.flatten(), z_vals.flatten()
    points = np.vstack((x_vals, y_vals, z_vals)).T
    return points[:n_points]  # Ensure exact number of points

# Generate random points
def generate_random_points(n_points, radius=1.0):
    points = np.random.uniform(-radius, radius, size=(n_points, 3))
    return points

# Benchmark search
def benchmark_search(tree, query_points, k=1):
    search_times = []
    for query_point in query_points:
        start = time.perf_counter()
        _ = tree.query(query_point, k=k)
        end = time.perf_counter()
        search_times.append(end - start)
    return np.mean(search_times), np.std(search_times)

# Visualize with Plotly (no grid)
def visualize_dataset(points, query_point, nearest_idx, time_vals, title):
    # Colors: Gradient based on time_vals
    colors = time_vals  # Use time_vals for gradient color

    # Create scatter plot for all points
    scatter = go.Scatter3d(
        x=points[:, 0], y=points[:, 1], z=points[:, 2],
        mode='markers',
        marker=dict(
            size=4,
            color=colors,
            colorscale='Viridis',  # Gradient colormap
            opacity=0.8,
            colorbar=dict(title='Time Value')
        ),
        name='Points'
    )

    # Highlight query point
    query = go.Scatter3d(
        x=[query_point[0]], y=[query_point[1]], z=[query_point[2]],
        mode='markers',
        marker=dict(size=8, color='red'),
        name='Query Point'
    )

    # Highlight nearest neighbor
    neighbor = go.Scatter3d(
        x=[points[nearest_idx, 0]], y=[points[nearest_idx, 1]], z=[points[nearest_idx, 2]],
        mode='markers',
        marker=dict(size=8, color='cyan'),
        name='Nearest Neighbor'
    )

    # Create figure
    fig = go.Figure(data=[scatter, query, neighbor])

    # Update layout for black background and no grid
    fig.update_layout(
        scene=dict(
            bgcolor='black',  # Black background
            xaxis=dict(showbackground=False, showticklabels=False, title='', showgrid=False),
            yaxis=dict(showbackground=False, showticklabels=False, title='', showgrid=False),
            zaxis=dict(showbackground=False, showticklabels=False, title='', showgrid=False)
        ),
        paper_bgcolor='black',  # Black background outside the plot
        margin=dict(l=0, r=0, b=0, t=0),  # Remove margins
        showlegend=True,
        title=dict(text=title, font=dict(color='white'))  # Add title
    )

    return fig

# Main script
if __name__ == "__main__":
    # Parameters
    n_points = 1000
    radius = 1.0
    num_queries = 100

    # Generate datasets
    phyllo_points = generate_fibonacci_sphere(n_points, radius)
    grid_points = generate_grid_points(n_points, radius)
    random_points = generate_random_points(n_points, radius)

    # Create KD-Trees for each dataset
    phyllo_tree = KDTree(phyllo_points)
    grid_tree = KDTree(grid_points)
    random_tree = KDTree(random_points)

    # Generate query points
    query_indices = np.random.choice(n_points, size=num_queries, replace=False)
    query_points = phyllo_points[query_indices]  # Use phyllotactic points for queries

    # Benchmark search for each dataset
    datasets = {
        "Fibonacci Sphere": phyllo_tree,
        "Grid": grid_tree,
        "Random": random_tree,
    }

    results = {}
    for name, tree in datasets.items():
        mean_time, std_time = benchmark_search(tree, query_points)
        results[name] = (mean_time, std_time)
        print(f"{name} Search: {mean_time * 1e6:.2f} µs ± {std_time * 1e6:.2f} µs per query")

    # Perform a single query for visualization
    query_point = query_points[0]  # Use the first query point for visualization
    _, nearest_idx_phyllo = phyllo_tree.query(query_point, k=1)
    _, nearest_idx_grid = grid_tree.query(query_point, k=1)
    _, nearest_idx_random = random_tree.query(query_point, k=1)

    # Time values for gradient color
    time_vals = np.sin(np.linspace(0, 2 * np.pi, n_points))  # Cyclic time values

    # Create subplots for all datasets
    fig = make_subplots(
        rows=1, cols=3,
        specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]],
        subplot_titles=("Fibonacci Sphere", "Grid", "Random")
    )

    # Add Fibonacci Sphere plot
    fig_phyllo = visualize_dataset(phyllo_points, query_point, nearest_idx_phyllo, time_vals, "Fibonacci Sphere")
    for trace in fig_phyllo.data:
        fig.add_trace(trace, row=1, col=1)

    # Add Grid plot
    fig_grid = visualize_dataset(grid_points, query_point, nearest_idx_grid, time_vals, "Grid")
    for trace in fig_grid.data:
        fig.add_trace(trace, row=1, col=2)

    # Add Random plot
    fig_random = visualize_dataset(random_points, query_point, nearest_idx_random, time_vals, "Random")
    for trace in fig_random.data:
        fig.add_trace(trace, row=1, col=3)

    # Update layout for subplots
    fig.update_layout(
        scene1=dict(bgcolor='black'),
        scene2=dict(bgcolor='black'),
        scene3=dict(bgcolor='black'),
        paper_bgcolor='black',
        margin=dict(l=0, r=0, b=0, t=30),
        showlegend=False
    )

    # Show the plot
    fig.show()

Fibonacci Sphere Search: 43.68 µs ± 27.96 µs per query
Grid Search: 48.46 µs ± 73.46 µs per query
Random Search: 41.06 µs ± 6.75 µs per query
