In [1]:
import numpy as np
from numba import cuda, float32, int32

In [2]:
# Kernel to compute the squared distances and reduce to find the closest point
@cuda.jit
def find_nearest_point_kernel(points, query, closest_point, min_distance):
    i = cuda.grid(1)

    # Define a large number for "infinity" for comparison
    inf = 1e20

    # Each thread calculates the squared Euclidean distance
    if i < points.shape[0]:
        dx = points[i, 0] - query[0]
        dy = points[i, 1] - query[1]
        dz = points[i, 2] - query[2]
        dist = dx**2 + dy**2 + dz**2
        
        # Atomic operation to update the closest point and minimum distance
        if dist < min_distance[0]:
            min_distance[0] = dist
            closest_point[0] = i  # Save the index of the closest point

def find_nearest_point_gpu(points_device, query_device, closest_point, min_distance):
    # Call the kernel to compute the closest point and minimum distance
    find_nearest_point_kernel[blocks, threads_per_block](points_device, query_device, closest_point, min_distance)

    # Retrieve the index of the closest point from the device
    return closest_point.copy_to_host()[0]

# Example usage
N = 1024**2  # Number of points
points_host = np.random.rand(N, 3).astype(np.float32)  # 1 million 3D points
points_device = cuda.to_device(points_host)  # Pre-allocate points on the device

query_host = np.array([0.1, 0.5, 0.9], dtype=np.float32)  # Query point

# Allocate device memory for the closest point and minimum distance
closest_point = cuda.device_array(1, dtype=np.int32)  # Use np.int32 for integer array
min_distance = cuda.device_array(1, dtype=np.float32)  # Use np.float32 for float array

query_device = cuda.to_device(np.array(query_host, dtype=np.float32))

# Define grid and block dimensions
threads_per_block = 352
blocks = (N + threads_per_block - 1) // threads_per_block

# Find the closest point on GPU
nearest_point = find_nearest_point_gpu(points_device, query_device, closest_point, min_distance)
print("Nearest point:", points_host[nearest_point])

Nearest point: [0.22603026 0.97334635 0.8298668 ]


In [3]:
%timeit points_host[find_nearest_point_gpu(points_device, query_device, closest_point, min_distance)]

693 µs ± 34.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
