Set-up nearest neighbour problem and profile on cpu and check that we can get an improvement with pytorch.

## Imports

In [1]:
import numpy as np
import torch

In [2]:
from common import save_numpy_array

In [3]:
torch.cuda.is_available()

True

## Setup problem

Start a million points with the task to return the closest point to a query point.

In [4]:
points = np.random.random((1024**2, 3)).astype(np.float32)
query = np.random.random(3).astype(np.float32)

## Get base measurement for numpy

Make a one line numpy implementation and profile it.

In [5]:
def numpy_implementation(points: np.ndarray, query_point: np.ndarray) -> np.ndarray:
    """ Return nearest neighbour point """
    return points[np.linalg.norm(points - query_point, axis=1).argmin()]

In [6]:
numpy_implementation(points, query)

array([0.72436523, 0.10307071, 0.2642327 ], dtype=float32)

In [7]:
%timeit numpy_implementation(points, query)

79.5 ms ± 4.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Get base measurement for pytorch

We are often limited to cases where we can preallocate on the cpu. In this instance, the gpu allocation isn't too expensive. If we are allocating often, we can use page locked memory (which is a cuda concept but it is in Pytorch).

In [8]:
%timeit points_torch = torch.tensor(points).cuda()

3.72 ms ± 642 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
points_torch = torch.tensor(points).cuda()
query_torch = torch.tensor(query).cuda()

Pytorch also has a one line implementation which can be used to check if the operation is worth doing on the gpu.

In [10]:
def torch_implementation(points: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
    """ Return nearest neighbour point from torch array """
    return points[torch.min(torch.norm(points - query, dim=1), dim=0)[1]].cpu()

We also have to copy back from gpu to host here but the memory is relatively small.

In [11]:
torch_implementation(points_torch, query_torch)

tensor([0.7244, 0.1031, 0.2642])

Using pytorch makes a significant improvement.

In [12]:
%timeit torch_implementation(points_torch, query_torch)

532 µs ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Save the points for later use

In [13]:
save_numpy_array("nearest_neighbour_points.npy", points)

In [14]:
save_numpy_array("nearest_neighbour_query.npy", query)