In [3]:
import torch

def k_nearest_neighbors(query_pos, all_pos, constraint_id=None, k_query=10):
    """
    Find the k-nearest neighbors for each query point in a subset of atoms.
    - query_pos: (n_sample, 3) tensor of query positions
    - all_pos: (n_constrained, 3) tensor of all atom positions
    - constraint_id: (n_constrained) tensor of indices of constrained atoms
    - k_query: number of neighbors to find
    Returns:
    - indices of k-nearest neighbors: (n_sample, k_query) tensor
    - distances to k-nearest neighbors: (n_sample, k_query) tensor
    """
    # Select the atoms of interest using the constraint_id
    if constraint_id is None:
        constraint_id = torch.arange(all_pos.size(0), device=all_pos.device)
    constrained_pos = all_pos[constraint_id]

    # Calculate squared distances using broadcasting
    # (n_sample, 1, 3) - (1, n_constrained, 3) -> (n_sample, n_constrained, 3)
    diff = query_pos.unsqueeze(1) - constrained_pos.unsqueeze(0)
    dist_squared = (diff ** 2).sum(dim=2)  # Sum over the coordinate dimension

    # Get the k smallest distances and their indices for each query point
    # We use k+1 here because topk includes the zero distance (self-neighbor) when query_pos is part of all_pos
    distances, indices = torch.topk(dist_squared, k_query, largest=False, sorted=True)

    # Return the indices within the constrained list and the square root of distances
    # We need to map back the indices from the constrained subset to the original all_pos index
    actual_indices = constraint_id[indices]
    return actual_indices, torch.sqrt(distances)

def neighbors_within_distance(query_pos, all_pos, constraint_id=None, distance_threshold=5.0):
    """
    Find the neighbors within a distance threshold for each query point in a subset of atoms.
    - query_pos: (n_sample, 3) tensor of query positions
    - all_pos: (n_constrained, 3) tensor of all atom positions
    - constraint_id: (n_constrained) tensor of indices of constrained atoms
    - distance_threshold: maximum distance to consider a neighbor
    Returns:
    - indices of neighbors within distance: list of (n_neighbors) tensors
    - distances to neighbors within distance: list of (n_neighbors) tensors
    """
    # Select the atoms of interest using the constraint_id
    if constraint_id is None:
        constraint_id = torch.arange(all_pos.size(0), device=all_pos.device)
        
    constrained_pos = all_pos[constraint_id]

    # Calculate squared distances using broadcasting
    # (n_sample, 1, 3) - (1, n_constrained, 3) -> (n_sample, n_constrained, 3)
    diff = query_pos.unsqueeze(1) - constrained_pos.unsqueeze(0)
    dist_squared = (diff ** 2).sum(dim=2)  # Sum over the coordinate dimension

    # Apply the distance threshold
    # Convert distance_threshold to squared distance to use with our squared distances
    threshold_squared = distance_threshold ** 2
    within_threshold = dist_squared <= threshold_squared

    # Gather indices and distances for those within the threshold
    indices = []
    distances = []
    for i in range(query_pos.size(0)):
        mask = within_threshold[i]
        indices.append(constraint_id[mask])
        distances.append(torch.sqrt(dist_squared[i][mask]))

    return indices, distances


In [5]:
# Example usage
n_sample = 5
all_atoms = 20
distance_threshold = 5.0  # Distance threshold for neighbors

# Random positions for sample and all atoms
query_pos = torch.randn(n_sample, 3)
all_pos = torch.randn(all_atoms, 3)

# A tensor containing indices for constrained atoms
constraint_id = torch.tensor([13, 14, 15, 16], dtype=torch.long)

# Call the function
neighbors_idx, neighbors_dist = neighbors_within_distance(query_pos, all_pos, constraint_id, distance_threshold)

print("Indices of neighbors within distance:\n", neighbors_idx)
print("Distances to neighbors within distance:\n", neighbors_dist)


# Example usage
k_query = 3  
k_neighbors_idx, k_neighbors_dist = k_nearest_neighbors(query_pos, all_pos, constraint_id, k_query)

print("Indices of k-nearest neighbors:\n", k_neighbors_idx)
print("Distances to k-nearest neighbors:\n", k_neighbors_dist)

Indices of neighbors within distance:
 [tensor([13, 14, 15, 16]), tensor([13, 14, 15, 16]), tensor([13, 14, 15, 16]), tensor([13, 14, 15, 16]), tensor([13, 14, 15, 16])]
Distances to neighbors within distance:
 [tensor([2.5375, 2.8132, 3.5797, 2.2820]), tensor([2.1110, 1.5143, 2.6021, 1.5827]), tensor([3.5923, 2.9311, 1.0664, 2.2926]), tensor([2.2816, 1.2792, 2.4509, 1.6545]), tensor([1.5104, 2.4684, 1.8955, 0.4210])]
Indices of k-nearest neighbors:
 tensor([[16, 13, 14],
        [14, 16, 13],
        [15, 16, 14],
        [14, 16, 13],
        [16, 13, 15]])
Distances to k-nearest neighbors:
 tensor([[2.2820, 2.5375, 2.8132],
        [1.5143, 1.5827, 2.1110],
        [1.0664, 2.2926, 2.9311],
        [1.2792, 1.6545, 2.2816],
        [0.4210, 1.5104, 1.8955]])
