In [1]:
import numpy as np
import torch

In [2]:
torch.cuda.is_available()

True

In [3]:
points = np.random.random((1024**2, 3)).astype(np.float32)
query = np.random.random(3).astype(np.float32)

In [4]:
def numpy_implementation(points: np.ndarray, query_point: np.ndarray) -> np.ndarray:
    """ Return nearest neighbour point """
    return points[np.sum(np.square(points - query_point), axis=1).argmin()]

In [5]:
numpy_implementation(points, query)

array([0.43225968, 0.02949279, 0.6148025 ], dtype=float32)

In [17]:
%timeit numpy_implementation(points, query)

125 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
points_torch = torch.tensor(points).cuda()
query_torch = torch.tensor(query).cuda()

In [8]:
%timeit points_torch = torch.tensor(points).cuda()

9.27 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
def torch_implementation(points: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
    return points[torch.min(((points - query) ** 2).sum(dim=1), dim=0)[1]].cpu()

In [10]:
torch_implementation(points_torch, query_torch)

tensor([0.4323, 0.0295, 0.6148])

In [11]:
%timeit torch_implementation(points_torch, query_torch)

1.02 ms ± 412 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
