# Neighborlist generation

Generating the neighbor list is not a trivial exercise. It can be slow and difficult to classify exactly what is and isn't a neighbor.

In [None]:
import time
import numpy as np
from collections import defaultdict

import ase.db
from ase.neighborlist import NeighborList
from ase.data import covalent_radii
from ase.io import write

In [None]:
# Connect the ase-db.
db = ase.db.connect('../../data/solar.db')
atoms = list(db.select())
# random.shuffle(atoms)

# Compile a list of atoms and target values.
alist = []
targets = []
for row in atoms:
    try:
        targets.append(row.Energy)
        alist.append(row.toatoms())
    except AttributeError:
        continue

# Analyze the size of molecules in the db.
print('pulled {} molecules from db'.format(len(alist)))
size = []
for a in alist:
    size.append(len(a))

print('min: {0}, mean: {1:.0f}, max: {2} molecule size'.format(
    min(size), sum(size)/len(size), max(size)))

In [None]:
atoms = alist[0]

st = time.time()

cutoffs = [covalent_radii[a.number] for a in atoms]
nl = NeighborList(
    cutoffs, skin=0.3, sorted=False, self_interaction=False, bothways=True)

nl.build(atoms)

neighborlist = {}
for i, _ in enumerate(atoms):
    neighborlist[i] = sorted(list(map(int, nl.get_neighbors(i)[0])))

print('compiled neighborlist in {}'.format(
    time.time() - st))

In [None]:
st = time.time()

dx, neighbor_number = None, 1

# Set up buffer dict.
if dx is None:
    dx = dict.fromkeys(set(atoms.get_atomic_numbers()), 0)
    for i in dx:
        dx[i] = covalent_radii[i] / 2.

conn = {}
for a1 in atoms:
    c = []
    for a2 in atoms:
        if a1.index != a2.index:
            d = np.linalg.norm(np.asarray(a1.position) -
                               np.asarray(a2.position))
            r1 = covalent_radii[a1.number]
            r2 = covalent_radii[a2.number]
            dxi = (dx[a1.number] + dx[a2.number]) / 2.
            if neighbor_number == 1:
                d_max1 = 0.
            else:
                d_max1 = ((neighbor_number - 1) * (r2 + r1)) + dxi
            d_max2 = (neighbor_number * (r2 + r1)) + dxi
            if d > d_max1 and d < d_max2:
                c.append(a2.index)
            conn[a1.index] = sorted(list(map(int, c)))

print('compiled neighborlist in {}'.format(
    time.time() - st))

In [None]:
assert neighborlist == conn