# Using E3_layers with TorchMD
This notebook is modified based on the official torchMD examples: https://github.com/torchmd/torchmd.

## System setup

We use the `moleculekit` library for reading the input topologies and starting coordinates

In [None]:
from moleculekit.molecule import Molecule
import os

testdir = "../torchmd/test-data/prod_alanine_dipeptide_amber/"
mol = Molecule(os.path.join(testdir, "structure.prmtop"))  # Reading the system topology
mol.read(os.path.join(testdir, "input.coor"))  # Reading the initial simulation coordinates
mol.read(os.path.join(testdir, "input.xsc"))  # Reading the box dimensions

Next we will load a forcefield file and use the above topology to extract the relevant parameters which will be used for the simulation

In [None]:
from torchmd.forcefields.forcefield import ForceField
from torchmd.parameters import Parameters
import torch

precision = torch.float
device = torch.device("cuda:0")

ff = ForceField.create(mol, os.path.join(testdir, "structure.prmtop"))
parameters = Parameters(ff, mol, precision=precision, device=device)

Now we can create a `System` object which will contain the state of the system during the simulation, including:
1. The current atom coordinates
1. The current box size
1. The current atom velocities
1. The current atom forces

In [None]:
from torchmd.integrator import maxwell_boltzmann
from torchmd.systems import System

system = System(mol.numAtoms, nreplicas=1, precision=precision, device=device)
system.set_positions(mol.coords)
system.set_box(mol.box)
system.set_velocities(maxwell_boltzmann(parameters.masses, T=300, replicas=1))

Lastly we will create a `Force` object which will be used to evaluate the potential on a given `System` state

In [None]:
from e3_layers.utils import build
from e3_layers import configs
from e3_layers.data import Batch, computeEdgeIndex

class MyClass():
    def __init__(self, config, atom_types, parameters, r_max=None):
        self.par = parameters # information such as masses, used by the integrator
        self.atom_types = atom_types
        self.model = build(config).to(device)
        self.n_nodes = torch.ones((1, 1), dtype=torch.long)* atom_types.shape[0]
        if r_max is None:
            self.r_max = config.r_max
        else:
            self.r_max = r_max
        
    def compute(self, pos, box, forces):
        data = {'pos': pos[0], 'species': self.atom_types, '_n_nodes': self.n_nodes}
        attrs = {'pos': ('node', '1x1o'), 'species': ('node','1x0e')}
        batch = Batch(attrs, **data).to(device)
        batch = computeEdgeIndex(batch, r_max=self.r_max)
        batch = self.model(batch)
        forces[0, :] = batch['forces'].detach()
        return [batch['energy'].item()]
    
config = configs.config_energy_force().model_config
config.n_dim = 32 # to prevent OOM
atom_types = parameters.mapped_atom_types 
# Usually there should be some conversion, if unmapped atom types are used during training
forces = MyClass(config, atom_types, parameters)

In [None]:
state_dict = torch.load(model_path, map_location=device)
model_state_dict = {}
for key, value in state_dict.items():
    if key[:7] == 'module.': # remove DDP wrappers
        key = key[7:]
    model_state_dict[key] = value
forces.model.load_state_dict(model_state_dict)

## Dynamics

For performing the dynamics we will create an `Integrator` object for integrating the time steps of the simulation as well as a `Wrapper` object for wrapping the system coordinates within the periodic cell

In [None]:
from torchmd.integrator import Integrator
from torchmd.wrapper import Wrapper

langevin_temperature = 300  # K
langevin_gamma = 0.1
timestep = 1  # fs

integrator = Integrator(system, forces, timestep, device, gamma=langevin_gamma, T=langevin_temperature)
wrapper = Wrapper(mol.numAtoms, mol.bonds if len(mol.bonds) else None, device)

In [None]:
from torchmd.minimizers import minimize_bfgs

minimize_bfgs(system, forces, steps=500)  # Minimize the system

Create a CSV file logger for the simulation which keeps track of the energies and temperature.

In [None]:
from torchmd.utils import LogWriter

logger = LogWriter(path="logs/", keys=('iter','ns','epot','ekin','etot','T'), name='monitor.csv')

Now we can finally perform the full dynamics

In [None]:
from tqdm import tqdm 
import numpy as np

FS2NS = 1E-6 # Femtosecond to nanosecond conversion

steps = 1000
output_period = 10
save_period = 100
traj = []

trajectoryout = "mytrajectory.npy"

iterator = tqdm(range(1, int(steps / output_period) + 1))
Epot = forces.compute(system.pos, system.box, system.forces)
for i in iterator:
    Ekin, Epot, T = integrator.step(niter=output_period)
    wrapper.wrap(system.pos, system.box)
    currpos = system.pos.detach().cpu().numpy().copy()
    traj.append(currpos)
    
    if (i*output_period) % save_period  == 0:
        np.save(trajectoryout, np.stack(traj, axis=2))

    logger.write_row({'iter':i*output_period,'ns':FS2NS*i*output_period*timestep,'epot':Epot,'ekin':Ekin,'etot':Epot+Ekin,'T':T})