In [1]:
import numpy as np

In [2]:
# import numpy as np

class Node:
    def __init__(self, point, left=None, right=None):
        self.point = point
        self.left = left
        self.right = right

def build_kdtree(points, depth=0):
    if len(points) == 0:
        return None

    # Select axis based on depth so that axis cycles through all valid dimensions
    k = len(points[0])  # assumes all points have the same dimension
    axis = depth % k

    # Sort point array and choose median as pivot element
    sorted_points = sorted(points, key=lambda point: point[axis])
    median = len(sorted_points) // 2

    # Create node and construct subtrees
    return Node(
        point=sorted_points[median],
        left=build_kdtree(sorted_points[:median], depth + 1),
        right=build_kdtree(sorted_points[median + 1:], depth + 1)
    )

# Sample usage
points = np.random.rand(10, 3)  # 10 random 3D points
kdtree = build_kdtree(points)


In [5]:
print(kdtree.point)
print(points)

[0.29981935 0.08311264 0.9599225 ]
[[0.16466299 0.37163582 0.31588025]
 [0.01113178 0.16983367 0.98883336]
 [0.4564964  0.24962949 0.6884301 ]
 [0.29981935 0.08311264 0.9599225 ]
 [0.93313405 0.90909413 0.4889762 ]
 [0.36033492 0.31862883 0.94381073]
 [0.25703502 0.32463779 0.60165438]
 [0.08792472 0.37545386 0.87325603]
 [0.22211796 0.61697032 0.90516935]
 [0.50358875 0.84208334 0.35603661]]


In [6]:
print(kdtree.left)

<__main__.Node object at 0x0000017809122960>


In [8]:
# Nearest neighbor search using NumPy
def nearest_neighbor(query, points):
    # Calculate squared Euclidean distances between the query point and all points
    distances = np.sum((points - query) ** 2, axis=1)
    
    # Find the index of the nearest point (minimum distance)
    nearest_idx = np.argmin(distances)
    
    # Return the nearest point and the square root of the distance
    return points[nearest_idx], np.sqrt(distances[nearest_idx])

# Example usage
query_point = np.array([0.5, 0.5, 0.5])
points = np.random.rand(10000, 3)  # 10000 random 3D points
nearest_point, distance = nearest_neighbor(query_point, points)
print(f"Nearest point: {nearest_point}, Distance: {distance}")


Nearest point: [0.51239851 0.51565852 0.47950725], Distance: 0.028615817333971873


In [25]:
import numpy as np

class Node:
    def __init__(self, point, left=None, right=None):
        self.point = point
        self.left = left
        self.right = right

def build_kdtree(points, lr=0, depth=0):
    if len(points) == 0:
        print('no point leftRight {}'.format(lr))
        return None

    k = len(points[0])  # assumes all points have the same dimension
    axis = depth % k
    # print('axis',axis)
    # Sort point array and choose median as pivot element
    sorted_points = sorted(points, key=lambda point: point[axis])
    print('sort {} leftRight {}'.format(axis, lr),len(sorted_points))
    print('sort {} '.format(axis),sorted_points)
    # print('leftRight', lr)
    # print('sort 1 ',sorted_points[1])
    # print('sort 2 ',sorted_points[2])
    median = len(sorted_points) // 2

    # Create node and construct subtrees
    return Node(
        point=sorted_points[median],
        left=build_kdtree(sorted_points[:median], 0, depth + 1),
        right=build_kdtree(sorted_points[median + 1:], 1, depth + 1)
    )

def squared_distance(point1, point2):
    """Returns the squared Euclidean distance between two points."""
    return np.sum((np.array(point1) - np.array(point2)) ** 2)

def nearest_neighbor_kdtree(root, query_point, depth=0, best=None):
    if root is None:
        return best

    k = len(query_point)
    axis = depth % k

    # Update best point so far
    if best is None or squared_distance(query_point, root.point) < squared_distance(query_point, best):
        best = root.point

    # Determine whether to go left or right in the tree
    if query_point[axis] < root.point[axis]:
        best = nearest_neighbor_kdtree(root.left, query_point, depth + 1, best)
        # If there could be a closer point in the other branch, search the other side as well
        if (query_point[axis] - root.point[axis]) ** 2 < squared_distance(query_point, best):
            best = nearest_neighbor_kdtree(root.right, query_point, depth + 1, best)
    else:
        best = nearest_neighbor_kdtree(root.right, query_point, depth + 1, best)
        if (query_point[axis] - root.point[axis]) ** 2 < squared_distance(query_point, best):
            best = nearest_neighbor_kdtree(root.left, query_point, depth + 1, best)

    return best

# Example usage
points = np.random.rand(10, 3)  # 10000 random 3D points
print(points)
kdtree = build_kdtree(points)

query_point = np.array([0.5, 0.5, 0.5])
nearest_point = nearest_neighbor_kdtree(kdtree, query_point)
distance = np.sqrt(squared_distance(query_point, nearest_point))

print(f"Nearest point: {nearest_point}, Distance: {distance}")


[[0.67943921 0.0558481  0.16369618]
 [0.2397218  0.8017082  0.36606032]
 [0.03546277 0.55561624 0.24857199]
 [0.93695451 0.04133317 0.08582425]
 [0.45887926 0.21934512 0.52912916]
 [0.9428053  0.73739106 0.95408341]
 [0.76498404 0.3744493  0.94370574]
 [0.48937524 0.46298671 0.06216865]
 [0.2516358  0.51022037 0.5249465 ]
 [0.36792518 0.98410422 0.57811382]]
sort 0 leftRight 0 10
sort 0  [array([0.03546277, 0.55561624, 0.24857199]), array([0.2397218 , 0.8017082 , 0.36606032]), array([0.2516358 , 0.51022037, 0.5249465 ]), array([0.36792518, 0.98410422, 0.57811382]), array([0.45887926, 0.21934512, 0.52912916]), array([0.48937524, 0.46298671, 0.06216865]), array([0.67943921, 0.0558481 , 0.16369618]), array([0.76498404, 0.3744493 , 0.94370574]), array([0.93695451, 0.04133317, 0.08582425]), array([0.9428053 , 0.73739106, 0.95408341])]
sort 1 leftRight 0 5
sort 1  [array([0.45887926, 0.21934512, 0.52912916]), array([0.2516358 , 0.51022037, 0.5249465 ]), array([0.03546277, 0.55561624, 0.24857

In [None]:
import cupy as cp

# Define the KD-tree nearest neighbor search kernel
kd_tree_search_kernel = cp.RawKernel(r'''
extern "C" __global__
void kd_tree_search(const float* __restrict__ queries, 
                    const float* __restrict__ tree_points,
                    const int* __restrict__ tree_axes,
                    const int* __restrict__ tree_left,
                    const int* __restrict__ tree_right,
                    float* __restrict__ results,
                    int num_queries, int dims, int num_nodes) {
    int q_idx = blockDim.x * blockIdx.x + threadIdx.x;

    if (q_idx < num_queries) {
        // Load query point
        const float* query = &queries[q_idx * dims];

        // Variables for nearest neighbor search
        float min_dist = 1e10;  // Large initial value for comparison
        int best_idx = -1;

        // Stack for tree traversal
        int stack[64];
        int stack_top = -1;

        // Start from the root node (node 0)
        stack[++stack_top] = 0;

        while (stack_top >= 0) {
            int node_idx = stack[stack_top--];

            if (node_idx < 0 || node_idx >= num_nodes) {
                continue;
            }

            // Load the node's data
            const float* node_point = &tree_points[node_idx * dims];
            int axis = tree_axes[node_idx];
            int left_idx = tree_left[node_idx];
            int right_idx = tree_right[node_idx];

            // Compute squared distance to the current node
            float dist = 0.0;
            for (int d = 0; d < dims; d++) {
                float diff = query[d] - node_point[d];
                dist += diff * diff;
            }

            // Check if this node is closer than the current best
            if (dist < min_dist) {
                min_dist = dist;
                best_idx = node_idx;
            }

            // Decide which branch to explore
            float diff = query[axis] - node_point[axis];
            int near_idx = (diff < 0) ? left_idx : right_idx;
            int far_idx = (diff < 0) ? right_idx : left_idx;

            stack[++stack_top] = near_idx;

            // Check if we need to explore the far side
            if (diff * diff < min_dist) {
                stack[++stack_top] = far_idx;
            }
        }

        // Store the result (index of the nearest neighbor)
        results[q_idx] = best_idx;
    }
}
''', 'kd_tree_search')

# Nearest neighbor search using KD-tree on GPU
def kd_tree_nearest_neighbor_gpu(queries, tree_points, tree_axes, tree_left, tree_right):
    num_queries, dims = queries.shape
    num_nodes = tree_points.shape[0]

    # Copy data to GPU
    queries_gpu = cp.asarray(queries, dtype=cp.float32)
    tree_points_gpu = cp.asarray(tree_points, dtype=cp.float32)
    tree_axes_gpu = cp.asarray(tree_axes, dtype=cp.int32)
    tree_left_gpu = cp.asarray(tree_left, dtype=cp.int32)
    tree_right_gpu = cp.asarray(tree_right, dtype=cp.int32)
    results_gpu = cp.empty(num_queries, dtype=cp.int32)

    # Launch the CUDA kernel
    threads_per_block = 256
    blocks_per_grid = (num_queries + threads_per_block - 1) // threads_per_block
    kd_tree_search_kernel((blocks_per_grid,), (threads_per_block,),
                          (queries_gpu, tree_points_gpu, tree_axes_gpu, tree_left_gpu, tree_right_gpu, results_gpu,
                           num_queries, dims, num_nodes))

    # Copy back the results
    nearest_neighbor_indices = results_gpu.get()

    return nearest_neighbor_indices

# Example usage
queries = np.random.rand(5, 3)  # 5 random 3D query points
nearest_indices = kd_tree_nearest_neighbor_gpu(queries, tree_points, tree_axes, tree_left, tree_right)
nearest_points = tree_points[nearest_indices]
print(nearest_points)
