# Query the index of a location within the 3D grid

### This functionality has been used widely throughout the program 
**Old version**
- Input: a location formed by x,y,z coordinates
- Output: the index of the location within the 3D grid
**Method**:
1. Calculate the distance between the given location and all the locations in the grid. 
2. Find the location with the smallest distance to the given location.
3. Return the index of the location with the smallest distance to the given location.

**New version**
- Input: a location formed by x,y,z coordinates
- Output: the index of the location within the 3D grid
**Method**:
1. Construct the KDTree of the 3D grid.
2. Query the index of the given location within the KDTree.

**Reasons for the change**
- The old version is slow because it has to calculate the distance between the given location and all the locations in the grid.
- The new version is faster because it uses KDTree to query the index of the given location.




In [5]:
import numpy as np
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist
from pykdtree.kdtree import KDTree as KDTree_pykdtree
import timeit
import matplotlib.pyplot as plt
import time

In [2]:
n = 100
xv = np.linspace(0, 1, n)
yv = np.linspace(0, 1, n)
zv = np.linspace(0, 1, n)
xx, yy, zz = np.meshgrid(xv, yv, zv)
xx = xx.flatten()
yy = yy.flatten()
zz = zz.flatten()
grid = np.vstack((xx, yy, zz)).T

query_points = np.random.rand(1000, 3)

tree = KDTree(grid)
tree_pykdtree = KDTree_pykdtree(grid)
def get_ind_from_locations(loc: np.ndarray) -> np.ndarray:
    """
    Get the indices of the closest grid point to the given locations.
    """
    # Get the distances and indices of the closest grid point to the given locations
    dist = cdist(loc, grid)
    ind = np.argmin(dist, axis=1)
    return ind


In [9]:
# Method 1: KDTree from scipy
%timeit tree.query(query_points, k=1)

dist1, ind1 = tree.query(query_points, k=1)

1.79 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
# Method 2: KDTree from pykdtree
%timeit tree_pykdtree.query(query_points, k=1)
dist2, ind2 = tree_pykdtree.query(query_points, k=1)

503 µs ± 30 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
# Method 3: get distance and indices from cdist
t1 = time.time()
ind3 = get_ind_from_locations(query_points)
t2 = time.time()
print(f"Time: {t2-t1}")

Time: 4.143543004989624


In [18]:
# Method 4: KDTree from pykdtree with multi-threading
%timeit tree_pykdtree.query(query_points, k=1, num_threads=4)
dist4, ind4 = tree_pykdtree.query(query_points, k=1, num_threads=4)

TypeError: query() got an unexpected keyword argument 'num_threads'

In [16]:
np.all(ind1 == ind2)

np.all(ind1 == ind3)

np.all(ind2 == ind3)

np.all(ind1 == ind4)

True