In [1]:
import numpy as np
import scipy
from python.ADMPForce import ADMPGenerator
from scipy.stats import special_ortho_group
from python.utils import convert_cart2harm
import mpidplugin

mScales = np.array([0.0, 0.0, 0.0, 1.0])
pScales = np.array([0.0, 0.0, 0.0, 1.0])
dScales = np.array([0.0, 0.0, 0.0, 1.0])
rc = 8 # in Angstrom
ethresh = 1e-4


pdb = 'tests/samples/waterdimer_aligned.pdb'
xml = 'tests/samples/mpidwater.xml'
generator = ADMPGenerator(pdb, xml, rc, ethresh, mScales, pScales, dScales, )
# get a random geometry for testing
scipy.random.seed(1000)
R1 = special_ortho_group.rvs(3)
R2 = special_ortho_group.rvs(3)

positions = generator.positions
positions[0:3] = positions[0:3].dot(R1)
positions[3:6] = positions[3:6].dot(R2)
positions[3:] += np.array([3.0, 0.0, 0.0])


force = generator.create_force()
force.update()
force.kappa = 0.328532611

multipoles_lc = np.concatenate((np.expand_dims(force.mpid_params['charges'], axis=1), force.mpid_params['dipoles'], force.mpid_params['quadrupoles']), axis=1)
Q_lh = convert_cart2harm(multipoles_lc, lmax=2)
axis_types = force.mpid_params['axis_types']
axis_indices = force.mpid_params['axis_indices']


In [2]:
force.compile_reci_space_energy_and_force()



In [36]:
ene, f = force.calc_reci_space_energy_and_force()
print(ene, '\n', f)

3.8504010162464084 
 [[  4.81163492  -2.88011513 -10.71335151]
 [ -2.12240909   1.67451737   6.10969664]
 [ -3.08651289   1.45696349   4.54239251]
 [  4.82683478   1.44850801  -0.18667916]
 [ -0.66875967  -1.83414292  -1.98458272]
 [ -3.76073239   0.1343097    2.23245487]]


In [5]:
%timeit ene, f = force.calc_reci_space_energy_and_force()

8.75 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [37]:
pme_reci = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))

In [None]:
ene0 = pme_reci(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
findiff = np.empty((6, 3))
delta = np.zeros((6,3))
for i in range(6):
    for j in range(3):
        delta[i][j] = 1
        findiff[i][j] = pme_reci(positions +delta, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
        delta[i][j] = 0
print(findiff)

[[3.85040102 3.85040102 3.85040102]
 [3.85040102 3.85040102 3.85040102]
 [3.85040102 3.85040102 3.85040102]
 [3.85040102 3.85040102 3.85040102]
 [3.85040102 3.85040102 3.85040102]
 [3.85040102 3.85040102 3.85040102]]


In [23]:
print("Testing behaviour of the same function, under different jax wrappers")
print("====================================================================")
ene1, f = value_and_grad(pme_reciprocal_energy)(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("value_and_grad:      ", ene1)
ene2 = pme_reciprocal_energy(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("only value:          ", ene2)
ene3, f = jit(value_and_grad(pme_reciprocal_energy), static_argnums=(4,5,6,7))(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("jit, value_and_grad: ", ene3)
ene4 = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("jit, only value:     ", ene4)

Testing behaviour of the same function, under different jax wrappers
value_and_grad:       3.8504010162466096
only value:           3.850401016246607
jit, value_and_grad:  3.8504010162464084
jit, only value:      3.8504010162402884
