In [64]:
import numpy as np
from scipy.spatial.distance import pdist, cdist, squareform
from scipy.spatial import distance_matrix
from reconstrain.envs.motion_planning import topk, index_to_coo
import torch
import torch.nn.functional as F

rng = np.random.RandomState(0)

In [83]:
def knn_pdist(positions, K):
    dist = squareform(pdist(positions))
    idx = topk(dist, K)
    return index_to_coo(idx)

def knn_cdist(positions, K):
    dist = cdist(positions, positions)
    idx = topk(dist, K)
    return index_to_coo(idx)

def knn_distance(positions, K):
    dist = distance_matrix(positions, positions)
    idx = topk(dist, K)
    return index_to_coo(idx)

def knn_torch(positions, K):
    n_agents = positions.shape[0]
    with torch.no_grad():
        positions = torch.as_tensor(positions).cuda()
        dist = torch.cdist(positions, positions)
        adj = torch.zeros_like(dist)
        topk = torch.topk(dist, K+1, largest=False, dim=1)[1]
        adj[torch.arange(n_agents)[:, None], topk] = 1
        adj[topk, torch.arange(n_agents)[:, None]] = 1
    return adj

n_agents = 100
%timeit -n 100 knn_pdist(rng.uniform(size=(n_agents, 2)), 3)
%timeit -n 100 knn_cdist(rng.uniform(size=(n_agents, 5)), 3)
%timeit -n 100 knn_distance(rng.uniform(size=(n_agents, 5)), 3)
%timeit -n 100 knn_torch(rng.uniform(size=(n_agents, 5)), 3)

227 µs ± 63.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
202 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
469 µs ± 2.56 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
275 µs ± 7.38 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
