In [1]:
import jax
import numpy as np
from deepmd_jax import md
jax.config.update('jax_enable_x64', True)
path = '/root/water_128_val/set.002/'
coord = np.load('/root/water_128_val/set.002/coord.npy')[0].reshape(384, 3)
box = np.load('/root/water_128_val/set.002/box.npy')[0,0]
type_idx = np.genfromtxt('/root/water_128_val/type.raw')
sim = md.Simulation(model_path='/root/water.pkl',
                           box=box,
                           type_idx=type_idx,
                           mass=[15.9994, 1.0078],
                           routine='NPT_Nose_Hoover',
                           dt=0.01,
                           initial_position=coord,
                           temperature=300,
                           pressure=1,
                           debug=True,
                           use_neighbor_list_when_possible=False,
                           barostat_kwargs={'tau': 1000}
)
sim.report_interval = 1
force = np.load('/root/water_128_val/set.002/force.npy')[0].reshape(384, 3)
energy = np.load('/root/water_128_val/set.002/energy.npy')[0]
force_pred = sim.getForce()
energy_pred = sim.getEnergy()
init_velocity = sim._state.velocity
print("Energy/Force error: ", np.abs(energy-energy_pred), ((force - force_pred)**2).mean()**0.5)
# trajectory = sim.run(10)
# trajectory = sim.run(1000)
_ = sim.run(1)

2024-11-02 23:57:12.708243: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


# DeepMD-JAX: Starting on 1 device(s): [CudaDevice(id=0)]
# Model loaded from '/root/water.pkl'.
# Lattice vectors for neighbor images: Max 1 out of 1 condidates.
Energy/Force error:  0.09313945774192689 12.687433180656557
# Running 1 steps...
Step     Temperature  KE           PE           Pressure     Invariant    Time  
0        315.324      15.651       -59878.626   -782.501     -59862.972   4.234 
1        316.423      15.706       -59878.629   -759.715     -59862.921   12.672
# Finished 1 steps in 0h 0m 16s.
# Performance: 0.000 ns/day, 0.000 step/μs/atom


{'position': array([[[0.4996, 0.2623, 0.9079],
         [0.5577, 0.2642, 0.9342],
         [0.5027, 0.308 , 0.8675],
         ...,
         [0.0123, 0.0653, 0.7637],
         [0.0429, 0.0727, 0.7127],
         [0.9743, 0.0151, 0.7547]],
 
        [[0.4996, 0.2623, 0.9079],
         [0.5577, 0.2642, 0.9342],
         [0.5027, 0.308 , 0.8675],
         ...,
         [0.0123, 0.0653, 0.7637],
         [0.0429, 0.0727, 0.7127],
         [0.9742, 0.0151, 0.7546]]]),
 'velocity': array([[[-0.0011, -0.0033, -0.004 ],
         [ 0.0085,  0.0037,  0.025 ],
         [-0.0073,  0.0292, -0.0065],
         ...,
         [-0.001 , -0.0011,  0.0001],
         [ 0.0146,  0.0003, -0.0108],
         [-0.0254,  0.0035, -0.0128]],
 
        [[-0.001 , -0.0033, -0.0039],
         [ 0.0074,  0.0034,  0.0245],
         [-0.0076,  0.0296, -0.0078],
         ...,
         [-0.0011, -0.0012,  0.0003],
         [ 0.0157,  0.0013, -0.013 ],
         [-0.0235,  0.0053, -0.0126]]]),
 'box': array([[15.6446, 15.6446

In [23]:
import jax_md
import jax.numpy as jnp
box = 3 * jnp.ones(3)
displacement_fn = jax_md.space.periodic_general(box)[0]
x = jnp.ones(3)
displacement_fn(x*0.1, -x*0.1, box=jnp.ones(3)*4)

Array([0.79999995, 0.79999995, 0.79999995], dtype=float32)

In [4]:
(sim._state.velocity - init_velocity)

Array([[ 0.0001, -0.    ,  0.0001],
       [-0.0011, -0.0003, -0.0005],
       [-0.0002,  0.0004, -0.0013],
       ...,
       [-0.0001, -0.0002,  0.0001],
       [ 0.0012,  0.0011, -0.0021],
       [ 0.0019,  0.0019,  0.0001]], dtype=float64)

In [12]:
sim._state.velocity

Array([[-0.0008,  0.003 , -0.002 ],
       [ 0.0083, -0.0034, -0.0047],
       [ 0.0211,  0.0019, -0.0085],
       ...,
       [ 0.0013, -0.0001, -0.0004],
       [-0.009 ,  0.0187, -0.0014],
       [ 0.0084, -0.008 , -0.0063]], dtype=float32)

In [4]:
trajectory = sim.run(10)
sim.getPressure(), sim._state.box_momentum

# Running 10 steps...
Step     Temperature  KE           PE           Pressure     Invariant    Time  
0        274.130      13.607       -59878.629   -1352.789    -59865.020   4.651 
1        1199.632     59.545       -59883.211   13801.769    -59823.660   11.482
2        1184.543     58.796       -59881.645   25931.736    -59822.805   5.839 
3        2059.625     102.231      -59886.016   24758.908    -59783.637   0.004 
4        1729.469     85.844       -59883.703   13953.556    -59797.496   0.004 
5        2208.787     109.635      -59886.133   21642.129    -59775.801   0.004 
6        2334.092     115.855      -59885.883   34772.832    -59768.816   0.004 
7        2667.092     132.383      -59887.648   35368.426    -59753.324   0.004 
8        2351.886     116.738      -59886.184   22831.166    -59766.602   0.004 
9        2079.926     103.239      -59885.617   18484.316    -59778.617   0.005 
10       1907.387     94.675       -59884.820   26527.699    -59785.488   0.005 
# Fini

(Array(26527.7, dtype=float32), Array(805.9974, dtype=float32))

In [7]:
sim._state.box

Array([15.6446, 15.6446, 15.6446], dtype=float32)

In [15]:
sim._state.position

Array([[0.8153, 0.1035, 0.2032],
       [0.7251, 0.1329, 0.6143],
       [0.865 , 0.8189, 0.5708],
       ...,
       [0.1928, 0.0213, 0.9472],
       [0.6706, 0.1383, 0.1499],
       [0.2421, 0.2359, 0.8061]], dtype=float32)

In [9]:
sim._state.velocity

Array([[-0.0008,  0.003 , -0.002 ],
       [ 0.0083, -0.0034, -0.0047],
       [ 0.0211,  0.0019, -0.0085],
       ...,
       [ 0.0013, -0.0001, -0.0004],
       [-0.009 ,  0.0187, -0.0014],
       [ 0.0084, -0.008 , -0.0063]], dtype=float32)

In [6]:
sim._state

NoseHooverChain(position=Array([0., 0., 0.], dtype=float32), momentum=Array([0., 0., 0.], dtype=float32), mass=Array([74453.76,    64.63,    64.63], dtype=float32), tau=50.0, kinetic_energy=Array(13.6067, dtype=float32), degrees_of_freedom=1152)

In [9]:
sim._state.barostat

NoseHooverChain(position=Array([-0., -0., -0.], dtype=float32), momentum=Array([-0.0129, -0.0129, -0.0129], dtype=float32), mass=Array([6462.9995, 6462.9995, 6462.9995], dtype=float32), tau=Array(500., dtype=float32, weak_type=True), kinetic_energy=Array(0., dtype=float32), degrees_of_freedom=1)

In [5]:
sim.getPressure()

Array(-1352.9045, dtype=float32)