# TorchMD Simulation

This notebook runs a short torchmd simulation of a water box.

In [1]:
import torch
import numpy as np

Set seeds to ensure reproducibility.

In [2]:
torch.manual_seed(1)
np.random.seed(1)

## Build Test System

The files defining the water box are part of the `torchmd` test suite. The setup is defined in the module `waterbox.py`.
The box contains 96 water molecules as well as one ion pair (sodium and chlorine).

In [3]:
from waterbox import WaterBox
waterbox = WaterBox(nreplicas=1, device="cuda:0")

## Energy Minimization

In [4]:
from torchmd.minimizers import minimize_bfgs
minimize_bfgs(waterbox.system, waterbox.forces, steps=500)  # Minimize the system

Iter  Epot            fmax    
   0   -691.583104    57.640355
   1   -249.736811    212.411990
   2   -731.377728    19.837357
   3   -740.159184    20.139690
   4   -786.285572    43.255250
   5   -792.541717    66.349296
   6   -815.886303    22.912048
   7   -830.756328    21.255851
   8   -842.972508    26.154271
   9   -852.190711    18.151096
  10   -857.584259    92.193282
  11   -871.894553    28.030490
  12   -882.560371    23.991500
  13   -893.327520    46.935104
  14   -897.904242    31.250359
  15   -903.884071    12.794050
  16   -910.240080    20.790360
  17   -917.728850    23.770905
  18   -927.771507    25.981902
  19   -933.377868    28.852858
  20   -941.155280    16.347461
  21   -945.924789    15.937664
  22   -951.955511    17.664942
  23   -957.618717    48.166647
  24   -968.034143    20.388571
  25   -975.991177    19.592388
  26   -982.013637    14.397089
  27   -984.929809    30.084176
  28   -990.443739    14.777269
  29   -996.866784    15.398847
  30   -

 280   -1226.104091    5.546917
 281   -1226.280714    3.453054
 282   -1226.395066    4.422568
 283   -1226.475216    8.147781
 284   -1226.641286    3.621295
 285   -1226.783197    2.174641
 286   -1226.895291    2.626520
 287   -1226.997881    3.811947
 288   -1227.077061    2.015832
 289   -1227.185835    1.812612
 290   -1227.275335    3.052026
 291   -1227.385819    2.327054
 292   -1227.431608    3.056663
 293   -1227.519760    1.719529
 294   -1227.551042    1.539336
 295   -1227.642719    1.609306
 296   -1227.688440    4.066972
 297   -1227.759590    1.332864
 298   -1227.807063    0.998892
 299   -1227.850369    2.098982
 300   -1227.908440    2.989852
 301   -1227.977466    1.747174
 302   -1228.040340    1.200483
 303   -1228.090601    2.115769
 304   -1228.130384    1.343683
 305   -1228.159857    1.075215
 306   -1228.231449    1.595242
 307   -1228.264136    2.795175
 308   -1228.303523    1.174278
 309   -1228.342720    0.671764


## Prepare Simulation

In [5]:
from torchmd.integrator import Integrator
from torchmd.wrapper import Wrapper

langevin_temperature = 300.0  # K
langevin_gamma = 1 
timestep = 1  # fs

integrator = Integrator(
    waterbox.system, 
    waterbox.forces, 
    timestep, 
    waterbox.device, 
    gamma=langevin_gamma, 
    T=langevin_temperature
)
wrapper = Wrapper(
    waterbox.mol.numAtoms, 
    waterbox.mol.bonds if len(waterbox.mol.bonds) else None, 
    waterbox.device
)

In [6]:
from torchmd.utils import LogWriter

logger = LogWriter(
    path="logs/",
    keys=('iter','ns','epot','ekin','etot','T'),
    name='monitor.csv'
)

Writing logs to  logs/monitor.csv


## Run Simulation

In [7]:
from tqdm import tqdm 
import numpy as np

FS2NS = 1E-6 # Femtosecond to nanosecond conversion

steps = 10010
output_period = 10
save_period = 100
traj = []

trajectoryout = "xyz_vel.npy"

system = waterbox.system
forces = waterbox.forces

iterator = tqdm(range(1, int(steps / output_period) + 1))
Epot = forces.compute(system.pos, waterbox.system.box, system.forces)
for i in iterator:
    Ekin, Epot, T = integrator.step(niter=output_period)
    wrapper.wrap(system.pos, system.box)
    currpos = system.pos.detach().cpu().numpy().copy()
    currvel = system.vel.detach().cpu().numpy().copy()
    traj.append(np.stack([currpos[0], currvel[0]], axis=0))
    
    if (i*output_period) % save_period  == 0:
        np.save(trajectoryout, np.stack(traj, axis=0))

    logger.write_row({
        'iter':i*output_period,
        'ns':FS2NS*i*output_period*timestep,
        'epot':Epot,
        'ekin':Ekin,
        'etot':Epot+Ekin,
        'T':T
    })

100%|██████████| 1001/1001 [00:56<00:00, 17.67it/s]
