# Example 03: PyTorch Neighbor List

This notebook demonstrates the use of `TorchNeighborList` for efficient neighbor finding in atomic structures.

**Features:**
- Isolated systems (molecules)
- Periodic boundary conditions (crystals)
- GPU acceleration (if available)
- Type-dependent cutoffs

In [None]:
import torch
import numpy as np
from aenet.torch_featurize import TorchNeighborList

## 1. Isolated System: Water Molecule

Find neighbors in a water molecule (non-periodic system).

In [None]:
# Water molecule positions (Angstroms)
positions = torch.tensor([
    [0.000, 0.000,  0.118],  # O
    [0.000, 0.755, -0.471],  # H
    [0.000, -0.755, -0.471]  # H
], dtype=torch.float64)

# Create neighbor list with 2.0 Angstrom cutoff
nbl = TorchNeighborList(cutoff=2.0, device='cpu')

# Find neighbors
result = nbl.get_neighbors(positions)

print(f"Number of atom pairs found: {result['edge_index'].shape[1]}")
print(f"\nEdge index (source, target pairs):")
print(result['edge_index'])
print(f"\nDistances (Angstroms):")
print(result['distances'])
print(f"\nNumber of neighbors per atom:")
print(result['num_neighbors'])

## 2. Periodic System: Simple Cubic Crystal

Find neighbors in a periodic crystal structure.

In [None]:
# Simple cubic structure (AuCu example)
positions = torch.tensor([
    [0.0, 0.0, 0.0],  # Cu corner
    [0.0, 2.0, 2.0],  # Cu face center
    [2.0, 0.0, 2.0],  # Au face center
    [2.0, 2.0, 0.0]   # Au face center
], dtype=torch.float64)

# Cubic unit cell (4x4x4 Angstroms)
cell = torch.tensor([
    [4.0, 0.0, 0.0],
    [0.0, 4.0, 0.0],
    [0.0, 0.0, 4.0]
], dtype=torch.float64)

# Periodic in all directions
pbc = torch.tensor([True, True, True])

# Create neighbor list
nbl_pbc = TorchNeighborList(cutoff=3.5, device='cpu')

# Find neighbors with PBC
result_pbc = nbl_pbc.get_neighbors(positions, cell=cell, pbc=pbc)

print(f"Number of atom pairs found: {result_pbc['edge_index'].shape[1]}")
print(f"\nNumber of neighbors per atom:")
print(result_pbc['num_neighbors'])
print(f"\nFirst 5 edge pairs:")
print(result_pbc['edge_index'][:, :5])
print(f"\nFirst 5 distances:")
print(result_pbc['distances'][:5])
print(f"\nFirst 5 cell offsets:")
print(result_pbc['offsets'][:5])

## 3. Per-Atom Neighbor Access

Access neighbors for individual atoms.

In [None]:
# Get neighbors of first atom (Cu corner)
atom_0_neighbors = nbl_pbc.get_neighbors_of_atom(
    0, positions, cell=cell, pbc=pbc
)

print("Neighbors of atom 0:")
print(f"  Neighbor indices: {atom_0_neighbors['indices']}")
print(f"  Distances: {atom_0_neighbors['distances']}")
print(f"  Cell offsets: {atom_0_neighbors['offsets']}")

## 4. GPU Acceleration (Optional)

If CUDA is available, neighbor finding can be accelerated on GPU.

In [None]:
# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available!")
    
    # Create neighbor list on GPU
    nbl_gpu = TorchNeighborList(cutoff=3.5, device='cuda')
    
    # Move data to GPU (or it will be done automatically)
    positions_gpu = positions.cuda()
    cell_gpu = cell.cuda()
    pbc_gpu = pbc.cuda()
    
    # Find neighbors on GPU
    result_gpu = nbl_gpu.get_neighbors(
        positions_gpu, cell=cell_gpu, pbc=pbc_gpu
    )
    
    print(f"Found {result_gpu['edge_index'].shape[1]} pairs on GPU")
    print(f"Result tensors are on device: {result_gpu['distances'].device}")
else:
    print("CUDA not available, skipping GPU example")

## 5. Type-Dependent Cutoffs

Use different cutoffs for different atom type pairs.

In [None]:
# Atom types (0 = Cu, 1 = Au)
atom_types = torch.tensor([0, 0, 1, 1], dtype=torch.long)

# Define type-specific cutoffs
# (0,0) = Cu-Cu, (0,1) = Cu-Au, (1,1) = Au-Au
cutoff_dict = {
    (0, 0): 2.8,  # Cu-Cu pairs: 2.8 Angstroms
    (0, 1): 3.0,  # Cu-Au pairs: 3.0 Angstroms
    (1, 1): 3.2,  # Au-Au pairs: 3.2 Angstroms
}

# Create neighbor list with type-dependent cutoffs
nbl_typed = TorchNeighborList(
    cutoff=3.5,  # Maximum cutoff
    atom_types=atom_types,
    cutoff_dict=cutoff_dict,
    device='cpu'
)

# Get neighbors for atom 0 (Cu) with type filtering
cu_neighbors = nbl_typed.get_neighbors_of_atom(
    0, positions, cell=cell, pbc=pbc
)

print("Cu atom neighbors (with type-specific cutoffs):")
print(f"  Neighbor indices: {cu_neighbors['indices']}")
print(f"  Neighbor types: {atom_types[cu_neighbors['indices']]}")
print(f"  Distances: {cu_neighbors['distances']}")

## Summary

This notebook demonstrated:
- Finding neighbors in isolated systems (molecules)
- Finding neighbors with periodic boundary conditions (crystals)
- Per-atom neighbor access
- GPU acceleration
- Type-dependent cutoffs

The `TorchNeighborList` class provides efficient neighbor finding for both CPU and GPU, supporting various boundary conditions and filtering options.