In [2]:
import numpy as np
import torch

In [3]:
torch.cuda.is_available()

True

In [4]:
points = np.random.random((1024**2, 3)).astype(np.float32)
query = np.random.random(3).astype(np.float32)

In [5]:
def numpy_implementation(points: np.ndarray, query_point: np.ndarray) -> np.ndarray:
    """ Return nearest neighbour point """
    return points[np.linalg.norm(points - query_point, axis=1).argmin()]

In [6]:
numpy_implementation(points, query)

array([0.50451374, 0.44584525, 0.8137127 ], dtype=float32)

In [7]:
%timeit numpy_implementation(points, query)

90.2 ms ± 8.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
points_torch = torch.tensor(points).cuda()
query_torch = torch.tensor(query).cuda()

In [9]:
%timeit points_torch = torch.tensor(points).cuda()

8 ms ± 365 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
def torch_implementation(points: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
    return points[torch.min(torch.norm(points - query, dim=1), dim=0)[1]].cpu()

In [11]:
torch_implementation(points_torch, query_torch)

tensor([0.5045, 0.4458, 0.8137])

In [13]:
%timeit torch_implementation(points_torch, query_torch)

924 µs ± 264 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
