In [1]:
%load_ext autoreload
%load_ext line_profiler
%autoreload 2

In [2]:
from cell_list import get_cell_list, get_neighbors_list, get_neighbor_ids
from verlet_list import get_neighbor_ids as get_verlet_neighbor_ids
from verlet_list import get_neighborhood, _pairwise_dist
from jax import random
import jax.numpy as jnp
from jax_md.partition import neighbor_list
from jax_md.space import periodic

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Setting up a simple box

This will create a $5 \times 5 \times 5$ box and puts 100 random particles in it.

In [3]:
box_edge = 5.0
box_size = jnp.array([box_edge, box_edge, box_edge])
positions = random.uniform(random.PRNGKey(0), (int(1e2), 3))*box_edge
cutoff_c = 1.0 # cutoff for cell list
cutoff_v = 1.0 # cutoff for verlet list
buffer = 30

# Cell list method
This will create a cell list which breaks the box into 125 boxes of edge size 1 (`cutoff_c`).

In [4]:
cell_list = get_cell_list(positions, box_size, cutoff_c)

## Method 1

Using `get_neighbors_list` we can get the neighbors for a list of particles. Under the hood, this functions calls `vmap` on another function `get_neighbor_ids` which is implemented for a single particle (+ some postprocessing).

In [5]:
idxs = jnp.asarray([i for i in range(10)]) # only the first 10 particles
nbors = get_neighbors_list(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idxs=idxs, buffer_size_cell=buffer, mask_self=False)
print(nbors) # padded with -1

[[99 95 93 85 78 77 75 71 70 69 68 63 62 56 38 32 31 30 12  2  0 -1 -1 -1
  -1 -1 -1 -1]
 [99 98 94 88 82 80 79 73 67 66 58 57 54 44 43 42 41 34 29 24 22 20 11 10
   8  7  4  1]
 [96 94 93 89 76 71 65 55 47 39 37 25 21 14  9  2  0 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1]
 [92 81 80 79 75 72 60 56 54 48 42 40 35 33 30 28 25 24 23 18 17 12 11  3
  -1 -1 -1 -1]
 [99 98 94 88 82 80 79 73 67 66 58 57 54 44 43 42 41 34 29 24 22 20 11 10
   8  7  4  1]
 [96 91 86 83 69 68 61 59 55 52 50 45 38 32 27 26 19 16 15 13 12 10  7  5
  -1 -1 -1 -1]
 [94 89 88 87 69 62 45 41 39 38 33 28 22  8  6 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1]
 [99 95 88 85 82 78 70 65 58 57 51 50 45 41 34 31 22 20 10  8  7  5  4  1
  -1 -1 -1 -1]
 [94 90 89 88 87 82 69 67 58 51 45 41 39 38 34 28 22 20 10  8  7  6  4  1
  -1 -1 -1 -1]
 [99 92 86 84 78 74 71 70 64 63 60 52 37 36 35 15 14  9  2 -1 -1 -1 -1 -1
  -1 -1 -1 -1]]


## Method 2

Using `get_neighbor_ids` directly with an external for loop and the postprocessing outside of the neighbor list.

In [6]:
nbors_2 = []
for i in idxs:
    n_i = get_neighbor_ids(box_size=box_size, cutoff=cutoff_c, cell_idx=cell_list, idx=i, buffer_size_cell=buffer)
    n_i = n_i[n_i != i]
    n_i = n_i[n_i != -1]
    nbors_2.append(n_i)
print(jnp.concatenate(nbors_2, axis=0)) # no padding

[99 38 69 85 62 70 78 31 95 32 68 63 71  2 77 93 30 56 75 12 43 11 24 42
 67 44 73 94 29 66 54 79 80 22 88  4 34 58 82 57  8 41 20 98 99  7 10  9
 14 37 47 65  0 76 71 21 93 55 96 25 89 39 94 30 56 75 12 60 35 92 40 33
 28 17 18 11 24 42 25 48 81 72 54 79 80 23 43 11 24 42 67 44 73 94 29 66
 54 79 80 22 88  1 34 58 82 57  8 41 20 98 99  7 10 38 69  7 10 45 15 16
 83 32 68 50 52 86 27 91 59 55 96 13 12 19 26 61 89 39 94 33 28 22 88 87
  8 41 38 69 62 45 22 88  1  4 34 58 82 57  8 41 20 51 99 10 45 85 65 70
 78 31  5 50 95 89 67 39 94 28 90 22 88  1  4 34 58 82  6 87 41 20 51 38
 69  7 10 45 64 84 36 15 99 74 52 86 14 37 70 78 60 35 92 63 71  2]


# Verlet list

Verlet list can be used together with a cell list to get exact cutoffs. Cell list will capture all particles within a `cutoff_c`$\times$`cutoff_c`$\times$`cutoff_c` box, regardless of exact cutoff selected. This hybrid method is more efficient than calculating pairwise distance for all particles and doesn't require a skin radius.

## Method 1
Here we can get the particle ids of the neighbors of particle $0$ and get its exact neighborhood (`cutoff_v`) with a verlet list using `get_neighborhood`. This functions returns a list of `bool`s that shows whether the particles in the cell list neighbor list are within the verlet cutoff as well.

In [7]:
atom0_n_ids = nbors[0][nbors[0] != -1] # get neighbor ids of atom 0 and remove padding
neighbors = get_neighborhood(positions[atom0_n_ids, :], positions[0, :], cutoff_v, box_size)
print(atom0_n_ids[jnp.where(neighbors)]) # note that 0 is in the list too.

[93 70 63 31  0]


## Method 2
We can also get the neighborhood for all particles in a position matrix. The result is a $N \times N$ matrix (or a list of size $N$ if `sparse = True`)

In [8]:
n_ids_verlet = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=False)
print(n_ids_verlet)

[[ True False False ... False False False]
 [False  True False ... False False False]
 [False False  True ... False False False]
 ...
 [False False False ...  True False False]
 [False False False ... False  True False]
 [False False False ... False False  True]]


In [9]:
n_ids_verlet_s = get_verlet_neighbor_ids(positions, cutoff_v, box_size, sparse=True, mask_self=True)
for particle in n_ids_verlet_s:
    print(particle)

[31 63 70 93]
[ 4 10 58 66 82 88]
[76]
[11 17 35 60 72 92]
[ 1 34 43 54 58 66 88]
[45 50 91]
[ 8 38 41 87]
[10]
[ 6 89]
[14 37]
[ 1  7 41]
[ 3 17 24 35 42 72 81 92]
[18]
[55 96]
[ 9 37]
[16 36]
[15]
[ 3 11 33 48 60 72 81]
[12]
[27 52 86]
[41]
[96]
[23 58 66 88]
[22 88]
[11 34 42 43 44 57 67 73 92]
[89]
[61 90 97]
[19 47 52 83 91]
[87]
[33 40 75]
[56 60 74 75 95]
[ 0 63 70 77]
[45 68]
[17 29 56 60 75]
[ 4 24 42 43 44 57 58 66 67 82]
[ 3 11 60 92]
[15]
[ 9 14]
[ 6 45 69]
[89]
[29]
[ 6 10 20]
[11 24 34 44 57 73 92]
[ 4 24 34 44 58 67]
[24 34 42 43 57 67 73 92]
[ 5 32 38 69]
[76]
[27 83 98]
[17 72 79 81 84]
[51 53]
[ 5 59]
[49]
[19 27 86]
[49]
[ 4 57 66 79 80]
[13 96]
[30 33 60 74 75 95]
[24 34 42 44 54 73]
[ 1  4 22 34 43 66 67 82 88]
[50 64 74]
[ 3 17 30 33 35 56]
[26 90 97]
[]
[ 0 31 70 77 78]
[59 74 84 95]
[98]
[ 1  4 22 34 54 58 88]
[24 34 43 44 58 82]
[32 86]
[38 45]
[ 0 31 63 78 99]
[92]
[ 3 11 17 48 81]
[24 42 44 57]
[30 56 59 64 95]
[29 30 33 56 77]
[ 2 46]
[31 63 75]
[63 70]
[48 

## Comparison with `jax_md` neighbor list

Define a periodic box and create the neighbor list

In [10]:
disp_fn, shift_fn = periodic(box_size)
nb_fn = neighbor_list(disp_fn, box_size, cutoff_v)
nbs = nb_fn.allocate(positions)

Check that all the returned neighbors are the same from `verlet_list` and `jax_md`.

In [11]:
for idx, nbl in enumerate(n_ids_verlet_s):
    for n in nbl:
        assert n in nbs.idx[idx]
    for n in nbs.idx[idx]:
        if n != 100:
            assert n in nbl
    print(f"{idx}: OK")
print("All OK")

0: OK
1: OK
2: OK
3: OK
4: OK
5: OK
6: OK
7: OK
8: OK
9: OK
10: OK
11: OK
12: OK
13: OK
14: OK
15: OK
16: OK
17: OK
18: OK
19: OK
20: OK
21: OK
22: OK
23: OK
24: OK
25: OK
26: OK
27: OK
28: OK
29: OK
30: OK
31: OK
32: OK
33: OK
34: OK
35: OK
36: OK
37: OK
38: OK
39: OK
40: OK
41: OK
42: OK
43: OK
44: OK
45: OK
46: OK
47: OK
48: OK
49: OK
50: OK
51: OK
52: OK
53: OK
54: OK
55: OK
56: OK
57: OK
58: OK
59: OK
60: OK
61: OK
62: OK
63: OK
64: OK
65: OK
66: OK
67: OK
68: OK
69: OK
70: OK
71: OK
72: OK
73: OK
74: OK
75: OK
76: OK
77: OK
78: OK
79: OK
80: OK
81: OK
82: OK
83: OK
84: OK
85: OK
86: OK
87: OK
88: OK
89: OK
90: OK
91: OK
92: OK
93: OK
94: OK
95: OK
96: OK
97: OK
98: OK
99: OK
All OK
