In [1]:
%load_ext autoreload
%load_ext line_profiler
%autoreload 2

In [2]:
from cell_list import get_cell_list, get_neighbors_list, get_neighbor_ids
from verlet_list import get_neighbor_ids as get_verlet_neighbor_ids
from verlet_list import get_neighborhood, _pairwise_dist
from jax import random
import jax.numpy as jnp

## Setting up a simple box

This will create a $3 \times 3 \times 3$ centered at 0 and puts particles on the nodes of a grid with edge size of 1. First particle is placed at the center of the box $(0, 0, 0)$ and the rest of the particles (26 more) are placed at $+/- 1$ of the center.

In [3]:
box_edge = 3.0
box_size = jnp.array([box_edge, box_edge, box_edge])
#positions = random.uniform(random.PRNGKey(0), (int(1e2), 3))*box_edge
positions = jnp.array([[ 0.0,  0.0,  0.0],
                       [ 0.0,  1.0,  0.0],
                       [ 0.0, -1.0,  0.0],
                       [ 1.0,  0.0,  0.0],
                       [ 1.0,  1.0,  0.0], 
                       [ 1.0, -1.0,  0.0],
                       [-1.0,  0.0,  0.0],
                       [-1.0,  1.0,  0.0],
                       [-1.0, -1.0,  0.0],
                       [ 0.0,  0.0,  1.0],
                       [ 0.0,  1.0,  1.0],
                       [ 0.0, -1.0,  1.0],
                       [ 1.0,  0.0,  1.0],
                       [ 1.0,  1.0,  1.0], 
                       [ 1.0, -1.0,  1.0],
                       [-1.0,  0.0,  1.0],
                       [-1.0,  1.0,  1.0],
                       [-1.0, -1.0,  1.0],
                       [ 0.0,  0.0, -1.0],
                       [ 0.0,  1.0, -1.0],
                       [ 0.0, -1.0, -1.0],
                       [ 1.0,  0.0, -1.0],
                       [ 1.0,  1.0, -1.0], 
                       [ 1.0, -1.0, -1.0],
                       [-1.0,  0.0, -1.0],
                       [-1.0,  1.0, -1.0],
                       [-1.0, -1.0, -1.0]])
                       
cutoff_c = 1.0
cutoff_v = 1.0
buffer = 30

Particles are shifted so that all $x, y, z$ values are positive ($(0, 0, 0)$ at the corner instead of the center).

In [4]:
positions += box_edge/2

# Cell list method
This will create a cell list which breaks the box into 27 boxes of edge size 1 (`cutoff_c`). Because of how we placed the particles, each box will contain exactly one particle.

In [5]:
cell_list = get_cell_list(positions, box_size, cutoff_c)
idxs = jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])

## Method 1

Using `get_neighbors_list` we can get the neighbors for a list of particles. Under the hood, this functions calls `vmap` on another function `get_neighbor_ids` which is implemented for a single particle (+ some postprocessing).

In [6]:
nbors = get_neighbors_list(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idxs=idxs, buffer_size_cell=buffer, mask_self=False)
print(nbors)

[[26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]
 [26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10  9  8  7  6  5  4  3
   2  1  0]]


## Method 2

Using `get_neighbor_ids` directly with an external for loop and the postprocessing outside of the neighbor list.

In [7]:
nbors_2 = []
for i in idxs:
    n_i = get_neighbor_ids(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idx=i, buffer_size_cell=buffer)
    n_i = n_i[n_i != i]
    n_i = n_i[n_i != -1]
    nbors_2.append(n_i)
print(nbors_2)

[Array([24,  6, 15, 25,  7, 16, 26,  8, 17, 18,  0,  9, 19, 10, 20,  2, 11,
       21,  3, 12, 22,  4, 13, 23,  5, 14], dtype=int32), Array([25,  7, 16, 26,  8, 17, 24,  6, 15, 19,  1, 10, 20, 11, 18,  0,  9,
       22,  4, 13, 23,  5, 14, 21,  3, 12], dtype=int32), Array([20,  2, 11, 18,  0,  9, 19,  1, 10, 23,  5, 14, 21, 12, 22,  4, 13,
       26,  8, 17, 24,  6, 15, 25,  7, 16], dtype=int32), Array([18,  0,  9, 19,  1, 10, 20,  2, 11, 21,  3, 12, 22, 13, 23,  5, 14,
       24,  6, 15, 25,  7, 16, 26,  8, 17], dtype=int32), Array([19,  1, 10, 20,  2, 11, 18,  0,  9, 22,  4, 13, 23, 14, 21,  3, 12,
       25,  7, 16, 26,  8, 17, 24,  6, 15], dtype=int32), Array([23,  5, 14, 21,  3, 12, 22,  4, 13, 26,  8, 17, 24, 15, 25,  7, 16,
       20,  2, 11, 18,  0,  9, 19,  1, 10], dtype=int32), Array([21,  3, 12, 22,  4, 13, 23,  5, 14, 24,  6, 15, 25, 16, 26,  8, 17,
       18,  0,  9, 19,  1, 10, 20,  2, 11], dtype=int32), Array([22,  4, 13, 23,  5, 14, 21,  3, 12, 25,  7, 16, 26, 17, 24,  

Both methods return all 27 particles as neighbors for each particle. This is the expected result because of the way the positions are set up.

# Verlet list

Verlet list can be used together with a cell list to get exact cutoffs. Cell list will capture all particles in the 27 neighboring cell, regardless of exact cutoff selected. This hybrid method is more efficient than calculating pairwise distance for all particles and doesn't require a skin radius.

## Method 1
Here we can get the particle ids of the neighbors of particle $0$ and get its exact neighborhood (`cutoff_v`) with a verlet list using `get_neighborhood`. This functions returns a list of `bool`s that shows whether the particles in the cell list neighbor list are within the verlet cutoff as well.

In [8]:
atom0_n_ids = nbors[0][nbors[0] != -1]
neighbors = get_neighborhood(positions[atom0_n_ids, :], positions[0, :], cutoff_v, box_size)
print(atom0_n_ids[jnp.where(neighbors)])

[18  9  6  3  2  1  0]


## Method 2
We can also get the neighborhood for all particles in a position matrix. The result is a $N \times N$ matrix (or a list of size $N$ if `sparse = True`)

In [9]:
n_ids_verlet = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=False)
print(n_ids_verlet)

[[ True  True  True  True False False  True False False  True False False
  False False False False False False  True False False False False False
  False False False]
 [ True  True False False  True False False  True False False  True False
  False False False False False False False  True False False False False
  False False False]
 [ True  True  True False False  True False False  True False False  True
  False False False False False False False False  True False False False
  False False False]
 [ True False False  True  True  True False False False False False False
   True False False False False False False False False  True False False
  False False False]
 [False  True False  True  True False False False False False False False
  False  True False False False False False False False False  True False
  False False False]
 [False False  True  True  True  True False False False False False False
  False False  True False False False False False False False False  True
  False

In [10]:
n_ids_verlet_s = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=True)
print(n_ids_verlet_s)

[Array([ 0,  1,  2,  3,  6,  9, 18], dtype=int32), Array([ 0,  1,  4,  7, 10, 19], dtype=int32), Array([ 0,  1,  2,  5,  8, 11, 20], dtype=int32), Array([ 0,  3,  4,  5, 12, 21], dtype=int32), Array([ 1,  3,  4, 13, 22], dtype=int32), Array([ 2,  3,  4,  5, 14, 23], dtype=int32), Array([ 0,  3,  6,  7,  8, 15, 24], dtype=int32), Array([ 1,  4,  6,  7, 16, 25], dtype=int32), Array([ 2,  5,  6,  7,  8, 17, 26], dtype=int32), Array([ 0,  9, 10, 11, 12, 15], dtype=int32), Array([ 1,  9, 10, 13, 16], dtype=int32), Array([ 2,  9, 10, 11, 14, 17], dtype=int32), Array([ 3,  9, 12, 13, 14], dtype=int32), Array([ 4, 10, 12, 13], dtype=int32), Array([ 5, 11, 12, 13, 14], dtype=int32), Array([ 6,  9, 12, 15, 16, 17], dtype=int32), Array([ 7, 10, 13, 15, 16], dtype=int32), Array([ 8, 11, 14, 15, 16, 17], dtype=int32), Array([ 0,  9, 18, 19, 20, 21, 24], dtype=int32), Array([ 1, 10, 18, 19, 22, 25], dtype=int32), Array([ 2, 11, 18, 19, 20, 23, 26], dtype=int32), Array([ 3, 12, 18, 21, 22, 23], dtype