In [1]:
import numpy as np
import scipy
from python.ADMPForce import ADMPGenerator
from scipy.stats import special_ortho_group
from python.utils import convert_cart2harm
import mpidplugin

mScales = np.array([0.0, 0.0, 0.0, 1.0])
pScales = np.array([0.0, 0.0, 0.0, 1.0])
dScales = np.array([0.0, 0.0, 0.0, 1.0])
rc = 8 # in Angstrom
ethresh = 1e-4


pdb = 'tests/samples/waterdimer_aligned.pdb'
xml = 'tests/samples/mpidwater.xml'
generator = ADMPGenerator(pdb, xml, rc, ethresh, mScales, pScales, dScales, )
# get a random geometry for testing
scipy.random.seed(1000)
R1 = special_ortho_group.rvs(3)
R2 = special_ortho_group.rvs(3)

positions = generator.positions
positions[0:3] = positions[0:3].dot(R1)
positions[3:6] = positions[3:6].dot(R2)
positions[3:] += np.array([3.0, 0.0, 0.0])


force = generator.create_force()
force.update()
force.kappa = 0.328532611

multipoles_lc = np.concatenate((np.expand_dims(force.mpid_params['charges'], axis=1), force.mpid_params['dipoles'], force.mpid_params['quadrupoles']), axis=1)
Q_lh = convert_cart2harm(multipoles_lc, lmax=2)
axis_types = force.mpid_params['axis_types']
axis_indices = force.mpid_params['axis_indices']


In [2]:
force.compile_reci_space_energy_and_force()



In [3]:
ene, f = force.calc_reci_space_energy_and_force()

In [4]:
print(ene)

3.8504010162464084


In [5]:
%timeit ene, f = force.calc_reci_space_energy_and_force()

8.75 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [1]:
from python.pme import gen_pme_reciprocal

from jax import value_and_grad, jit

pme_reciprocal_energy = gen_pme_reciprocal(axis_types, axis_indices)
print("Testing behaviour of the same function, under different jax wrappers")
print("====================================================================")
ene0, f = value_and_grad(pme_reciprocal_energy)(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("value_and_grad:      ", ene0)
ene0 = pme_reciprocal_energy(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("only value:          ", ene0)
ene0, f = jit(value_and_grad(pme_reciprocal_energy), static_argnums=(4,5,6,7))(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("jit, value_and_grad: ", ene0)
ene0 = jit(pme_reciprocal_energy, static_argnums=(4,5,6,7))(positions, force.box,  Q_lh, force.kappa, force.lmax, force.K1, force.K2, force.K3)
print("jit, only value:     ", ene0)

NameError: name 'axis_types' is not defined