# Example 3: PyTorch Neighbor List

This notebook demonstrates the use of `TorchNeighborList` for efficient neighbor finding in atomic structures.

**Features:**
- Isolated systems (molecules)
- Periodic boundary conditions (crystals)
- GPU acceleration (if available)
- Type-dependent cutoffs

In [2]:
import torch
import numpy as np
import aenet.io.structure
from aenet.torch_featurize import TorchNeighborList

## 1. Isolated System: Water Molecule

The niehgbor list is integrated with `AtomicStructure` and can be used via a convenience method.

Find neighbors in a water molecule (non-periodic system).

In [7]:
struc = aenet.io.structure.read('water.xyz')
neighbors = struc.get_neighbors(i=0, cutoff=2.0, return_self=False)

print(neighbors)


 Composition        : H2
 Number of atoms    : 2
 Number of species  : 1
 Geometric center   : 0.00000000 0.00000000 -0.47116000 (Ang)
 Diameter           : 1.511 (Ang)

 Cartesian coordinates

 H         0.00000000       0.75545000      -0.47116000
 H         0.00000000      -0.75545000      -0.47116000



Alternatively, it can be used directly on coordinates and unit cells.  Note that the results will be PyTorch tensor objects.

In [9]:
from aenet.torch_featurize.neighborlist import TorchNeighborList
import numpy as np

# Create neighbor list
nbl = TorchNeighborList(cutoff=4.0, device='cpu')

# Find neighbors (accepts numpy arrays)
positions = np.array([[0.0, 0.0, 0.0],
                      [1.5, 0.0, 0.0],
                      [3.0, 0.0, 0.0]])

# Get neighbors of atom 0
result = nbl.get_neighbors_of_atom(0, positions)

neighbor_indices = result['indices']    # Which atoms are neighbors
distances = result['distances']         # Distances to neighbors
offsets = result['offsets']            # Cell offsets (None for isolated)
print(neighbor_indices)

tensor([1, 2])


This also works for periodic structures.

In [14]:
import numpy as np

a = 4.05
cell = np.array([
    [0.0, 0.5*a, 0.5*a],
    [0.5*a, 0.0, 0.5*a],
    [0.5*a, 0.5*a, 0.0]
])

positions = np.array([[0.0, 0.0, 0.0]])

nbl = TorchNeighborList(cutoff=4.0)
result = nbl.get_neighbors_of_atom(0, positions, cell=cell)

# Offsets show which periodic images each neighbor belongs to
print(result['offsets'])

tensor([[-1,  0,  0],
        [-1,  0,  1],
        [-1,  1,  0],
        [ 0, -1,  0],
        [ 0, -1,  1],
        [ 0,  0, -1],
        [ 0,  0,  1],
        [ 0,  1, -1],
        [ 0,  1,  0],
        [ 1, -1,  0],
        [ 1,  0, -1],
        [ 1,  0,  0]])


## 2. Low-level usage

On the lowest level, the neighbor list returns edges and offsets.

In [20]:
# FCC structure
positions = torch.tensor([
    [0.0, 0.0, 0.0],
    [0.0, 0.5, 0.5],
    [0.5, 0.0, 0.5],
    [0.5, 0.5, 0.0]
], dtype=torch.float64)

# Cubic unit cell (4x4x4 Angstroms)
cell = torch.tensor([
    [4.0, 0.0, 0.0],
    [0.0, 4.0, 0.0],
    [0.0, 0.0, 4.0]
], dtype=torch.float64)

# Periodic in all directions
pbc = torch.tensor([True, True, True])

# Create neighbor list
nbl_pbc = TorchNeighborList(cutoff=2.85, device='cpu')

# Find neighbors with PBC
result_pbc = nbl_pbc.get_neighbors(positions, cell=cell, pbc=pbc)

print(f"Number of atom pairs found: {result_pbc['edge_index'].shape[1]}")
print(f"\nNumber of neighbors per atom:")
print(result_pbc['num_neighbors'])
print(f"\nFirst 5 edge pairs:")
print(result_pbc['edge_index'][:, :5])
print(f"\nFirst 5 distances:")
print(result_pbc['distances'][:5])
print(f"\nFirst 5 cell offsets:")
print(result_pbc['offsets'][:5])

Number of atom pairs found: 48

Number of neighbors per atom:
tensor([12, 12, 12, 12])

First 5 edge pairs:
tensor([[0, 0, 0, 1, 1],
        [3, 2, 2, 2, 3]])

First 5 distances:
tensor([2.8284, 2.8284, 2.8284, 2.8284, 2.8284], dtype=torch.float64)

First 5 cell offsets:
tensor([[-1, -1,  0],
        [-1,  0, -1],
        [-1,  0,  0],
        [-1,  0,  0],
        [-1,  0,  0]])


## 3. Per-Atom Neighbor Access

Access neighbors for individual atoms.

In [21]:
# Get neighbors of first atom (Cu corner)
atom_0_neighbors = nbl_pbc.get_neighbors_of_atom(
    0, positions, cell=cell, pbc=pbc
)

print("Neighbors of atom 0:")
print(f"  Neighbor indices: {atom_0_neighbors['indices']}")
print(f"  Distances: {atom_0_neighbors['distances']}")
print(f"  Cell offsets: {atom_0_neighbors['offsets']}")

Neighbors of atom 0:
  Neighbor indices: tensor([3, 2, 2, 3, 1, 1, 3, 1, 2, 1, 2, 3])
  Distances: tensor([2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284,
        2.8284, 2.8284, 2.8284], dtype=torch.float64)
  Cell offsets: tensor([[-1, -1,  0],
        [-1,  0, -1],
        [-1,  0,  0],
        [-1,  0,  0],
        [ 0, -1, -1],
        [ 0, -1,  0],
        [ 0, -1,  0],
        [ 0,  0, -1],
        [ 0,  0, -1],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0]])


## 4. GPU Acceleration (Optional)

If CUDA is available, neighbor finding can be accelerated on GPU.

In [22]:
# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available!")
    
    # Create neighbor list on GPU
    nbl_gpu = TorchNeighborList(cutoff=3.5, device='cuda')
    
    # Move data to GPU (or it will be done automatically)
    positions_gpu = positions.cuda()
    cell_gpu = cell.cuda()
    pbc_gpu = pbc.cuda()
    
    # Find neighbors on GPU
    result_gpu = nbl_gpu.get_neighbors(
        positions_gpu, cell=cell_gpu, pbc=pbc_gpu
    )
    
    print(f"Found {result_gpu['edge_index'].shape[1]} pairs on GPU")
    print(f"Result tensors are on device: {result_gpu['distances'].device}")
else:
    print("CUDA not available, skipping GPU example")

CUDA not available, skipping GPU example


## 5. Type-Dependent Cutoffs

Use different cutoffs for different atom type pairs.

In [24]:
# Atom types (0 = Cu, 1 = Au)
atom_types = torch.tensor([0, 0, 1, 1], dtype=torch.long)

# Define type-specific cutoffs
# (0,0) = Cu-Cu, (0,1) = Cu-Au, (1,1) = Au-Au
cutoff_dict = {
    (0, 0): 2.85,  # Cu-Cu pairs: 2.8 Angstroms
    (0, 1): 4.05,  # Cu-Au pairs: 3.0 Angstroms
    (1, 1): 2.85,  # Au-Au pairs: 3.2 Angstroms
}

# Create neighbor list with type-dependent cutoffs
nbl_typed = TorchNeighborList(
    cutoff=5.0,  # Maximum cutoff
    atom_types=atom_types,
    cutoff_dict=cutoff_dict,
    device='cpu'
)

# Get neighbors for atom 0 (Cu) with type filtering
cu_neighbors = nbl_typed.get_neighbors_of_atom(
    0, positions, cell=cell, pbc=pbc
)

print("Cu atom neighbors (with type-specific cutoffs):")
print(f"  Neighbor indices: {cu_neighbors['indices']}")
print(f"  Neighbor types: {atom_types[cu_neighbors['indices']]}")
print(f"  Distances: {cu_neighbors['distances']}")

Cu atom neighbors (with type-specific cutoffs):
  Neighbor indices: tensor([3, 2, 2, 3, 1, 1, 3, 1, 2, 1, 2, 3])
  Neighbor types: tensor([1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1])
  Distances: tensor([2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284, 2.8284,
        2.8284, 2.8284, 2.8284], dtype=torch.float64)
