# Nearest neighbor search (NNS)
NNS can find neighbor points in a point cloud that are close in distance to a query (any coordinate) on 3D space. NNS is very important because in a point cloud processing, each point often obtains information from neighbors. Examples are the following.
- When computing the handcrafted features, normals, etc. of a point according to neighbors, we use NNS to find neighbors.
- In deep learning models, NNS is used to input the features of neighbors of a point to the convolution module.

This section introduce the following NNS methods. 
- kNN (k Nearest Neighbor)
- Radius Nearest Neighbor


In [1]:
%load_ext autoreload
%autoreload 2

## kNN (k Nearest Neighbor)
kNN finds $k$ nearest neighbor points in a point cloud that are close in distance to a query.

This subsection use the following code:

In [2]:
# for kNN
import numpy as np
from tutlibs.nns import k_nearest_neighbors
from tutlibs.operator import gather

# for description
from tutlibs.io import Points as io
from tutlibs.utils import single_color
from tutlibs.visualization import JupyterVisualizer as jv
import inspect

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [8]:
# load a point cloud data.
coords, _, _ = io.read("../data/bunny_pc.ply")

# define queries.
queries = coords[::500]

# kNN
k = 10
idx, dist = k_nearest_neighbors(queries, coords, k) # (N, k)
nn_coords = gather(coords, idx) # (N, k, 3)
nn_coords = nn_coords.reshape(len(nn_coords)*k, 3) # (N*k, 3)

# visualize results.
obj_queries = jv.point(queries, single_color("#ff0000", len(queries)))
obj_nn_points = jv.point(nn_coords, single_color("#00ff00", len(nn_coords)))
obj_points = jv.point(coords, single_color("#0000ff", len(coords)))
jv.display([obj_points, obj_nn_points, obj_queries])

Output()

The above output shows $k$ neighbors (green) in a point cloud (blue). Red points are queries. This implementation can find $k$ neighbors for each query.  
**Note**: in the above visualizer, the green points are hidden under the red points, so the number of neighborhoods is the number of red points + the number of green points = $k$.

Next, we look at the contents of the `k_nearest_neighbors` function.

In [4]:
print(inspect.getsource(k_nearest_neighbors))

def k_nearest_neighbors(
    coords1: np.ndarray, coords2: np.ndarray, k: int
) -> Tuple[np.ndarray, np.ndarray]:
    """Compute k nearest neighbors between coords1 and coords2.

    Args:
        coords1: coordinates of centroid points (N, C)
        coords2: coordinates of all points (M, C)
        k: number of nearest neighbors

    Returns:
        idxs: indices of k nearest neighbors (N, k)
        square distances: square distance for kNN (N, k)
    """

    # compute distances between coords1 and coords2.
    point_pairwise_distances = square_distance(
        coords1, coords2
    )  # ((N, 3), (M, 3)) -> (N, M)

    # sort the distances between two points in order of closeness and get top-k indices.
    idxs = np.argsort(point_pairwise_distances, axis=-1)[:, :k]  # (N, M) -> (N, k)

    # get the distance between two points according to the top-k indices.
    square_dists = np.take_along_axis(
        point_pairwise_distances, idxs, axis=-1
    )  # ((N, M), (N, k)) -> (N, k)



In the above implementation, `k_nearest_neighbors` returns sample indices of a point cloud. `k_nearest_neighbors` process is as follows:

1. Computes relative distances between points and queries.
2. Sort relative distances to get top-$k$ neighbor indices.
3. Get $k$ square distances from $k$ neighbor indices.


## Radius Nearest Neighbor
Radius Nearest Neighbor finds neighbor points in a point cloud within a radius from a query.

This subsection use the following code:

In [10]:
# for kNN
import numpy as np
from tutlibs.nns import radius_nearest_neighbors
from tutlibs.operator import gather

# for description
from tutlibs.io import Points as io
from tutlibs.utils import single_color
from tutlibs.visualization import JupyterVisualizer as jv
import inspect

In [11]:
# load a point cloud data.
coords, _, _ = io.read("../data/bunny_pc.ply")

# define point cloud A and B.
queries = coords[::500]

# Radius NN
k = 64
r = 0.05
idx, dist, mask = radius_nearest_neighbors(queries, coords, r, k) # (N, k)
nn_coords = gather(coords, idx) # (N, k, 3)
nn_coords = nn_coords[mask] # (number of neighbors, 3)

# visualize results.
obj_points = jv.point(coords, single_color("#0000ff", len(coords)))
obj_nn_points = jv.point(nn_coords, single_color("#00ff00", len(nn_coords)))
obj_queries = jv.point(queries, single_color("#ff0000", len(queries)))
jv.display([obj_points, obj_nn_points, obj_queries])


Output()

The above output shows neighbors (green) within a radius in a point cloud (blue). Red points are queries. This implementation can find $k$ neighbors for each query. The number of neighbors is different for each query, so this function returns a mask to get neighbors within a radius.
**Note**: in the above visualizer, the green points are hidden under the red points.

Next, we look at the contents of the `radius_nearest_neighbors` function.


In [8]:
print(inspect.getsource(radius_nearest_neighbors))

def radius_nearest_neighbors(coords1: np.ndarray, coords2: np.ndarray, r: float, k: int = 32) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute radius nearest neighbors between coords1 and coords2.

    Args:
        coords1: coordinates of centroid points (N, C)
        coords2: coordinates of all points (M, C)
        r: radius
        k: number of nearest neighbors

    Returns:
        idxs: indices of k nearest neighbors (N, k)
        square distances: square distance for kNN (N, k)
        mask : radius mask (bool) for idxs an distance (N, k)
    """

    # compute kNN.
    idxs, square_dists = k_nearest_neighbors(coords1, coords2, k)

    # get radius nearest neighbors mask.
    radius_mask = square_dists < r**2

    return idxs, square_dists, radius_mask




上記実装をコメントで各段分けて、`処理内容(コメント): 補足`の形式で示す(より細かい説明もしくは補足は段落を下げて示す)。

$N$個の点を含む点群`coords1`と$M$個の点を含む点群`coords2`、距離しきい値`r`、近傍点の最大数$k$を引数とする(上記の説明では、`coords1`が点群A、`coords2`が点群Bに相当)。返り値が近傍点の配列、近傍点との距離を入れた配列、距離しきい値に収まっているか示すマスクである場合、以下の処理がなされる。

1. `compute kNN.`: kNNを用いて近傍点のインデックスと点間の距離配列を取得する。
2. `get radius nearest neighbors mask.`: 距離配列を用いて、`r`内に存在する点であるかを示すマスクを取得する。
