In [1]:
# https://github.com/davezdeng8/tlfpad/blob/master/util_ec.py

import torch
import numpy as np
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device == "cuda":
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
else:
  start = None
  end = None

In [2]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    mask = group_idx != N
    cnt = mask.sum(dim=-1)
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx, cnt

In [3]:
def knn_point(k, pos1, pos2):
    '''
    Input:
        k: int32, number of k in k-nn search
        pos1: (batch_size, ndataset, c) float32 array, input points
        pos2: (batch_size, npoint, c) float32 array, query points
    Output:
        val: (batch_size, npoint, k) float32 array, L2 distances
        idx: (batch_size, npoint, k) int32 array, indices to input points
    '''
    B, N, C = pos1.shape
    M = pos2.shape[1]
    pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
    pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
    dist = torch.sum(-(pos1-pos2)**2,-1)
    val,idx = dist.topk(k=k,dim = -1)
    return torch.sqrt(-val), idx

In [4]:
def measure_ball(device, start, end, radius, max_neighbor, pc_points, query_points):
  with torch.no_grad():
    if device == "cuda":
      start.record()
    else:
      start = time.time()
    neighbors, counts = query_ball_point(radius, max_neighbor, pc_points, query_points)
    if device == "cuda":
      end.record()
      torch.cuda.synchronize()
      return start.elapsed_time(end)
    else:
      return 1000 * (time.time() - start)

def measure_knn(device, start, end, max_neighbor, pc_points, query_points):
  with torch.no_grad():
    if device == "cuda":
      start.record()
    else:
      start = time.time()
    dists, indices = knn_point(max_neighbor, pc_points, query_points)
    if device == "cuda":
      end.record()
      torch.cuda.synchronize()
      return start.elapsed_time(end)
    else:
      return 1000 * (time.time() - start)

In [13]:
max_neighbor = 10
radius = 0.5
num_pc = 100000
num_query = 1000
variance = 10

In [14]:
avg_ball = 0
avg_knn = 0
num_reps = 1

with torch.no_grad():
    for i in range(num_reps):
      pc_points = variance * torch.randn(1, num_pc, 3).to(device)
      query_points = variance * torch.randn(1, num_query, 3).to(device)
      avg_knn += measure_knn(device, start, end, max_neighbor, pc_points, query_points)/num_reps

#     for i in range(num_reps):
#       pc_points = variance * torch.randn(1, num_pc, 3).to(device)
#       query_points = variance * torch.randn(1, num_query, 3).to(device)
#       avg_ball += measure_ball(device, start, end, radius, max_neighbor, pc_points, query_points)/num_reps

print(device)
print("Ball:", avg_ball)
print("KNN:", avg_knn)

cuda
Ball: 0
KNN: 0.1761913299560547
