In [151]:
%reload_ext line_profiler
import sys
sys.path.insert(0, "../jax_geometry/")

In [152]:
from fast_neighborlist_1 import FastPrimitiveNeighborList
from ase.build import bulk


atoms = bulk('Cu', 'fcc', 3.5) * (2,2,2)


nl = FastPrimitiveNeighborList(cutoffs=[2.5]*len(atoms), skin=0.3,
                               sorted=False, self_interaction=False,
                               bothways=True, use_scaled_positions=False)
nl.build(atoms.pbc, atoms.cell, atoms.positions)

In [153]:
import jax
import numpy as np
from jax import numpy as jnp
import itertools

from ase.geometry import minkowski_reduce, wrap_positions
from ase.data import covalent_radii as COV_R





# @jax.jit
def _prod(x, y, dtype=int):
    x = jnp.asarray(x, dtype=dtype)
    y = jnp.asarray(y, dtype=dtype)
    return jnp.append(x,y)


r_op = minkowski_reduce(atoms.cell, atoms.pbc)[1]
pos, rcell = atoms.positions, r_op @ atoms.cell
pos_mksk = wrap_positions(pos, rcell, atoms.pbc, eps=0)
offsets = atoms.cell.scaled_positions(pos_mksk - pos).round().astype(int)
N = nl._cache.get('N')
n123 = np.asarray(list(itertools.product(
            range(-N[0], N[0] + 1),
            range(-N[1], N[1] + 1),
            range(-N[2], N[2] + 1))), dtype=int)
# n123 = np.asarray(nl._cache.get('n123'), dtype=int)


ijS = jax.vmap(jax.vmap(_prod, in_axes=(0,None)), in_axes=(None, 0))(
    jnp.column_stack(jnp.triu_indices(len(pos))), n123).reshape(-1, 5)
is_self_interaction = jnp.logical_and(
    ijS[:, 0] == ijS[:, 1],
    jax.vmap(jnp.allclose, in_axes=(0,None))(
        ijS[:, 2:], 0))
ijS = ijS[~ is_self_interaction]

def is_neighbor_(ijS, rcell, pos, cutoffs):
    (i, j), S = ijS[:2], ijS[2:5]
    delta = pos[i] + S @ rcell - pos[j]
    cutoff = cutoffs[i] + cutoffs[j]
    cond = jnp.abs(delta) <= cutoff
    return jnp.all(cond)


def is_neighbor_tight(ijS, rcell, pos, cutoffs):
    (i, j), S = ijS[:2], ijS[2:5]
    delta = pos[i] + S @ rcell - pos[j]
    cutoff = cutoffs[i] + cutoffs[j]
    cond = jnp.linalg.norm(delta) <= cutoff
    return jnp.all(cond)

is_nb = jax.vmap(is_neighbor_tight, in_axes=(0,None,None,None))(
    ijS,
    jnp.asarray(rcell),
    jnp.asarray(pos_mksk),
    jnp.asarray(nl.cutoffs))


ijS[is_nb].shape

(360, 5)

In [154]:
r_op

array([[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1]])

In [155]:
jnp.repeat(jnp.asarray([3.5]), 5)

Array([3.5, 3.5, 3.5, 3.5, 3.5], dtype=float32)

In [156]:
from ase.neighborlist import primitive_neighbor_list

def f(atoms, cutoffs):
    i, j, S = primitive_neighbor_list('ijS', atoms.pbc, atoms.cell,
                                      atoms.positions, cutoffs)
    return np.column_stack([i, j, S])

a1 = f(atoms, nl.cutoffs)
a2 = ijS[is_nb]
a2 = a2.at[:, 2:].set(a2[:, 2:]@r_op)

print(a1[a1[:, 0] == 0].shape)

(78, 5)


In [157]:
pos[0] - pos[1] - atoms.cell[-1], pos[1] - pos[0] + atoms.cell[-1]

(array([-5.25, -5.25,  0.  ]), array([5.25, 5.25, 0.  ]))

In [158]:
print(a2.shape, a1.shape)

(360, 5) (624, 5)


In [159]:
from ase.neighborlist import PrimitiveNeighborList
nl1 = PrimitiveNeighborList(cutoffs=[1.5]*len(atoms), skin=0.3,
                               sorted=False, self_interaction=False,
                               bothways=True, use_scaled_positions=False)
nl1.build(atoms.pbc, atoms.cell, atoms.positions)
nl1.get_neighbors(1)

(array([2, 3, 4, 5, 6, 0, 0, 2, 3, 4, 5, 6, 6, 6, 6, 6, 7, 7]),
 array([[ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  0,  0],
        [ 0, -1,  1],
        [ 0, -1,  0],
        [-1,  0,  1],
        [-1,  0,  0],
        [ 0, -1,  1],
        [ 0, -1,  0],
        [-1,  0,  1],
        [-1,  0,  0],
        [-1, -1,  1],
        [ 0, -1,  0],
        [-1,  0,  0]]))