# Nearest neighbor search (NNS)
NNS can find neighbor points in a point cloud that are close in distance to a query (any coordinate) on 3D space. NNS is very important because in a point cloud processing, each point often obtains information from neighbors. Examples are the following.
- When computing the handcrafted features, normals, etc. of a point according to neighbors, we use NNS to find neighbors.
- In deep learning models, NNS is used to input the features of neighbors of a point to the convolution module.

This section introduce the following NNS methods. 
- kNN (k Nearest Neighbor)
- Radius Nearest Neighbor
- Radius and k Nearest Neighbor
- kNN with KDTree


In [1]:
%load_ext autoreload
%autoreload 2


## kNN (k Nearest Neighbor)
kNN finds $k$ nearest neighbor points in a point cloud that are close in distance to a query.

This subsection use the following code:

In [2]:
# for kNN
from tutlibs.nns import k_nearest_neighbors

# for description
from tutlibs.io import Points as io
from tutlibs.utils import single_color
from tutlibs.visualization import JupyterVisualizer as jv
import inspect


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [7]:
# load a point cloud data.
coords, _, _ = io.read("../data/bunny_pc.ply")

# define queries.
queries = coords[::500]

# kNN
k = 10
idx, dist = k_nearest_neighbors(queries, coords, k)  # (N, k)
# you can access to neighbors of each query (ex: coords[idx[query_index]]),
# but this example places neighbor coordinates in a row for visualization.
nn_coords = coords[idx].reshape(-1, 3)  # (all_neighbors, 3)

# visualize results.
obj_queries = jv.point(queries, single_color("#ff0000", len(queries)))
obj_nn_points = jv.point(nn_coords, single_color("#00ff00", len(nn_coords)))
obj_points = jv.point(coords, single_color("#0000ff", len(coords)))
jv.display([obj_points, obj_nn_points, obj_queries])


Output()

The above output shows $k$ neighbors (green) in a point cloud (blue). Red points are queries. This implementation can find $k$ neighbors for each query.  
**Note**: in the above visualizer, the green points are hidden under the red points, so the number of neighborhoods is the number of red points + the number of green points = $k$.

Next, we look at the contents of the `k_nearest_neighbors` function.

In [12]:
print(inspect.getsource(k_nearest_neighbors))


def k_nearest_neighbors(
    coords1: np.ndarray, coords2: np.ndarray, k: int
) -> Tuple[np.ndarray, np.ndarray]:
    """Compute k nearest neighbors between coords1 and coords2.

    Args:
        coords1: coordinates of centroid points (N, C)
        coords2: coordinates of all points (M, C)
        k: number of nearest neighbors

    Returns:
        idxs: indices of k nearest neighbors (N, k)
        square distances: square distance for kNN (N, k)
    """

    # compute distances between coords1 and coords2.
    point_pairwise_distances = square_distance(
        coords1, coords2
    )  # ((N, 3), (M, 3)) -> (N, M)

    # sort the distances between two points in order of closeness and get top-k indices.
    idxs = np.argsort(point_pairwise_distances, axis=-1)[:, :k]  # (N, M) -> (N, k)

    # get the distance between two points according to the top-k indices.
    square_dists = np.take_along_axis(
        point_pairwise_distances, idxs, axis=-1
    )  # ((N, M), (N, k)) -> (N, k)



In the above implementation, `k_nearest_neighbors` returns sample indices of a point cloud. `k_nearest_neighbors` process is as follows:

1. Computes relative distances between points and queries.
2. Sort relative distances to get top-$k$ neighbor indices.
3. Get $k$ square distances from $k$ neighbor indices.


## Radius Nearest Neighbor
Radius Nearest Neighbor finds neighbor points in a point cloud within a radius from a query.

This subsection use the following code:

In [8]:
# for radius nearest neighbor
import numpy as np
from tutlibs.nns import radius_nearest_neighbors

# for description
from tutlibs.io import Points as io
from tutlibs.utils import single_color
from tutlibs.visualization import JupyterVisualizer as jv
import inspect


In [9]:
# load a point cloud data.
coords, _, _ = io.read("../data/bunny_pc.ply")

# define point cloud A and B.
queries = coords[::500]

# Radius NN
r = 0.05
idx, dist = radius_nearest_neighbors(queries, coords, r)
# you can access to neighbors of each query (ex: coords[idx[query_index]]),
# but this example concatenates neighbor indices for visualization.
nn_coords = coords[np.concatenate(idx)]

# visualize results.
obj_points = jv.point(coords, single_color("#0000ff", len(coords)))
obj_nn_points = jv.point(nn_coords, single_color("#00ff00", len(nn_coords)))
obj_queries = jv.point(queries, single_color("#ff0000", len(queries)))
jv.display([obj_points, obj_nn_points, obj_queries])


Output()

The above output shows neighbors (green) within a radius in a point cloud (blue). Red points are queries. This implementation can find $k$ neighbors for each query. The number of neighbors is different for each centeroid point, so this function returns lists of neighbors within a radius.  
**Note**: in the above visualizer, the green points are hidden under the red points.

Next, we look at the contents of the `radius_nearest_neighbors` function.


In [10]:
print(inspect.getsource(radius_nearest_neighbors))


def radius_nearest_neighbors(
    coords1: np.ndarray, coords2: np.ndarray, r: float
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """Compute radius nearest neighbors between coords1 and coords2.
    The number of neighbors is different for each centeroid point, so return list data.

    Args:
        coords1: coordinates of centroid points (N, C)
        coords2: coordinates of all points (M, C)
        r: radius

    Returns:
        idxs: indices of neighbors within a radius
        square distances: square distance between pairwise points
    """

    # compute nearest neighbors.
    idxs, square_dists = k_nearest_neighbors(coords1, coords2, len(coords2))

    # get radius nearest neighbors masks.
    radius_masks = square_dists < r ** 2

    # get nearest neighbors according to masks
    radius_neighbor_indices = []
    radius_neighbor_square_dists = []
    for i, radius_mask in enumerate(radius_masks):
        radius_neighbor_indices.append(idxs[i, radius_mask])
        rad

In the above implementation, `radius_nearest_neighbors` returns sample indices of a point cloud. `radius_nearest_neighbors` process is as follows:
1. Compute sorted all neighbors in a point cloud.
2. Get masks indicating whether the point is within the radius.
3. Create radius neighbor list from masks


## Radius and k Nearest Neighbor
Radius Nearest Neighbor finds neighbor points in a point cloud within a radius from a query.

This subsection use the following code:

In [7]:
# for kNN
import numpy as np
from tutlibs.nns import radius_and_k_nearest_neighbors

# for description
from tutlibs.io import Points as io
from tutlibs.utils import single_color
from tutlibs.visualization import JupyterVisualizer as jv
import inspect


In [8]:
# load a point cloud data.
coords, _, _ = io.read("../data/bunny_pc.ply")

# define point cloud A and B.
queries = coords[::500]

# Radius NN
k = 64
r = 0.05
idx, dist = radius_and_k_nearest_neighbors(queries, coords, r, k)
nn_coords = coords[idx[idx != len(coords)]]

# visualize results.
obj_points = jv.point(coords, single_color("#0000ff", len(coords)))
obj_nn_points = jv.point(nn_coords, single_color("#00ff00", len(nn_coords)))
obj_queries = jv.point(queries, single_color("#ff0000", len(queries)))
jv.display([obj_points, obj_nn_points, obj_queries])


Output()

The above output shows neighbors (green) within a radius in a point cloud (blue). Red points are queries. This implementation can find $k$ neighbors for each query. If the number of neighbors is less than $k$, the number of points is assigned to indices.  
**Note**: in the above visualizer, the green points are hidden under the red points.

Next, we look at the contents of the `radius_and_k_nearest_neighbors` function.


In [9]:
print(inspect.getsource(radius_and_k_nearest_neighbors))


def radius_and_k_nearest_neighbors(
    coords1: np.ndarray, coords2: np.ndarray, r: float, k: int = 32
) -> Tuple[np.ndarray, np.ndarray]:
    """Compute radius and k nearest neighbors between coords1 and coords2.

    Args:
        coords1: coordinates of centroid points (N, C)
        coords2: coordinates of all points (M, C)
        r: radius
        k: number of nearest neighbors

    Returns:
        idxs: indices of neighbors (N, k)
        square distances: square distance between pairwise points (N, k)
    """

    # compute kNN.
    idxs, square_dists = k_nearest_neighbors(coords1, coords2, k)

    # get radius nearest neighbors mask.
    radius_masks = square_dists < r ** 2

    # radius mask
    idxs[radius_masks == False] = len(coords2)
    square_dists[radius_masks == False] = -1

    return idxs, square_dists



In the above implementation, `radius_and_k_nearest_neighbors` returns sample indices of a point cloud. `radius_and_k_nearest_neighbors` process is as follows:
1. Compute sorted $k$ neighbors in a point cloud.
2. Get masks indicating whether the point is within the radius.
3. Assigne ignore indices and distances according to masks


## kNN with KDTree
TODO: add implementation