In [None]:
import torch
from torch_fps import farthest_point_sampling, farthest_point_sampling_with_knn

# Create example inputs
points = torch.randn(4, 1000, 5)     # [B, N, D] - batch of point clouds
mask = torch.ones(4, 1000, dtype=torch.bool)  # [B, N] - valid point mask
K = 512  # Number of samples per batch (must be <= number of valid points)

# Perform farthest point sampling
idx = farthest_point_sampling(points, mask, K)  # [B, K] - selected point indices

# Use indices to gather sampled points
sampled_points = points.gather(1, idx.unsqueeze(-1).expand(-1, -1, 5))  # [B, K, D]

# Fused FPS + kNN: get centroids and their k nearest neighbors in one pass
centroid_idx, neighbor_idx = farthest_point_sampling_with_knn(
    points, mask, K=512, k_neighbors=32
)  # centroid_idx: [B, K], neighbor_idx: [B, K, k_neighbors]

In [None]:
B, N = 2, 257
for D in [3, 16, 32, 64, 128, 256, 768, 1024, 2048, 4096]:
    pts = torch.randn(B, N, D, device="cuda", dtype=torch.float32).contiguous()
    msk = torch.ones(B, N, device="cuda", dtype=torch.uint8).contiguous()
    torch.cuda.synchronize()
    out = farthest_point_sampling(pts, msk, 64)
    torch.cuda.synchronize()
    print("ok D=", D)

3 torch.Size([2, 128]) torch.int64 cuda:0
8 torch.Size([2, 128]) torch.int64 cuda:0
