In [None]:
import numpy as np
from typing import List, Tuple, Iterable
from ase.data import atomic_masses, covalent_radii, atomic_numbers
from ase.neighborlist import natural_cutoffs, NeighborList

In [3]:
def get_masses(symbols: Iterable[str]) -> np.ndarray:
    """Return masses (amu) for each element symbol."""
    out = []
    for s in symbols:
        Z = atomic_numbers[s]
        out.append(float(atomic_masses[Z]))
    return np.array(out, float)
    
def get_covalent_radii(symbols: Iterable[str]) -> np.ndarray:
    """Return covalent radii (Å) for each element symbol."""
    out = []
    for s in symbols:
        Z = atomic_numbers[s]
        out.append(float(covalent_radii[Z]))
    return np.array(out, float)

## Bond inference

infer covalent bonds from geometry using a distance threshold based on covalent radii.

$$
d_{ij} \le r_i^{(\mathrm{cut})} + r_j^{(\mathrm{cut})},\quad r_k^{(\mathrm{cut})} = \text{scale}\times r_{\mathrm{cov}}(Z_k).
$$

In [36]:
def guess_bonds(symbols: List[str], coords: np.ndarray, scale: float = 1.20, use_ase: bool = True) -> List[Tuple[int, int]]:
    """
    Infer bonds by distance threshold relative to covalent radii.
    Returns list of (i,j) with i<j.
    """
    n = len(symbols)
    if coords.shape != (n, 3):
        raise ValueError("coords must be shape (N,3)")

    from ase import Atoms
    atoms = Atoms(symbols=symbols, positions=coords)
    cutoffs = natural_cutoffs(atoms, mult = scale)
    nl = NeighborList(cutoffs, self_interaction=False, bothways=True)
    nl.update(atoms)
    edges = []
    for i in range(n):
        idxs, offsets = nl.get_neighbors(i)
        for j in idxs:
            if i < j:
                edges.append((i, j.item()))
    return edges

## Demo



In [39]:
from ase.build import molecule

target = 'H2O'
atoms = molecule(target); 
symbols = atoms.get_chemical_symbols()
coords = atoms.get_positions()       # (N,3) Å
masses = atoms.get_masses()          # (N,) amu

edges = guess_bonds(symbols, coords, scale=1.50)

print("Name: ",target)
print("N atoms:", len(symbols))
print("atoms:", symbols)
print("coords (Å):\n", coords)
print("masses (amu):", masses)
print("edges:", edges)


Name:  H2O
N atoms: 3
atoms: ['O', 'H', 'H']
coords (Å):
 [[ 0.        0.        0.119262]
 [ 0.        0.763239 -0.477047]
 [ 0.       -0.763239 -0.477047]]
masses (amu): [15.999  1.008  1.008]
edges: [(0, 1), (0, 2), (1, 2)]
