In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborListFreud



In [4]:
path = '/Users/arminsh/Documents/GADES/examples/AlanineDipeptide/'
prmtop = app.AmberPrmtopFile(f"{path}/alanine-dipeptide.prmtop")
inpcrd = app.AmberInpcrdFile(f"{path}/alanine-dipeptide.inpcrd")

In [5]:
ff = Hamiltonian(f"{path}/protein.ff14SB.xml")

potentials = ff.createPotential(prmtop.topology, nonbondedMethod=app.NoCutoff, useDispersionCorrection=False)
params = ff.getParameters()
positions = jnp.array(inpcrd.getPositions(asNumpy=True).value_in_unit(unit.nanometer))

box = jnp.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]])

nbList = NeighborListFreud(box, 3, potentials.meta["cov_map"]) # maybe remove the box?
nbList.allocate(positions)
pairs = nbList.pairs
efunc_all = potentials.getPotentialFunc()
dmff_e = efunc_all(positions, box, pairs, ff.paramset)

print("DMFF Pot E: ", dmff_e)

DMFF Pot E:  -55.794260834909345


In [6]:
params = ff.paramset # easy for implemenentation - must not change

efunc_pos = lambda pos: efunc_all(pos.reshape(-1, 3), box, pairs, ff.paramset) # for easy implementation 
grad_vec = jax.grad(efunc_pos)

force = grad_vec(positions.flatten())
print("Forces: \n", force)
force.shape

Forces: 
 [-1.71865366e+02 -3.18532912e+01  6.93394280e-01 -2.69338193e+02
 -3.51439108e+02 -2.29458474e+00 -5.28602801e+00  4.40961711e+01
  4.26381335e+01 -6.79627810e+00  4.33137449e+01 -4.09523736e+01
  4.52549770e+02  2.17764116e+02 -2.35854039e+02  6.63423153e+02
  4.01344132e+02 -3.88141697e+02  9.49967236e+01  7.69630908e+02
  4.01534867e+02 -4.90775901e+01 -4.57633388e+01  7.38115682e+00
 -4.28528076e+02 -3.13662037e+02 -1.89956133e+02  1.74440184e+01
  8.66267050e+00 -4.23307622e+01  1.04952670e+01  1.14678495e+02
  3.50646113e+02 -1.77497554e+01  2.76645948e+01  6.28447219e+01
 -8.25922834e+01  2.01080786e+02  1.30534383e+02 -4.08666513e+02
 -2.61975927e+02  3.33848622e+02  5.80858470e+01 -1.93137316e+01
 -2.68883324e+02  2.16593345e+02  4.39660829e+01  1.80234687e+01
  6.44641858e+01 -3.42606470e+02 -1.88895749e+02  1.40152566e+02
  1.03662316e+02  6.88386192e+00 -2.76119653e+02 -2.02536969e+02
  1.60757282e-01 -3.92006329e+01 -2.84702293e+02  9.29432738e-01
  1.83045552e+0

(66,)

In [7]:
hessian_vec = jax.hessian(efunc_pos)

hess = hessian_vec(positions.flatten())
print(hess.shape)

(66, 66)


In [10]:
# Compute the eigenvalues
w, v = jnp.linalg.eigh(hess)
w_idx = w.argsort()
w = w[w_idx]
v = v[:, w_idx]

print(w[0])

-1205.1526670440148
