In [1]:
import warnings
warnings.simplefilter('ignore')

# Molecular Dynamics

In this tutorial we will cover how to use trained models to drive MD simulations.
For this purpose, apax offers two options: ASE and JaxMD. Keep in mind that JaxMD can be GPU/TPU accelerated and is therefore much faster.
Both will be covered below.

## Basic Model Training

First we need to train a model.
If you have the parameters from tutorial 01, you can point the paths to those models and skip the current section to the [ASE MD](##-The-ASE-calculator) or the [JaxMD](##-JaxMD) section.

In [2]:
!apax template train    # generating the config file in the cwd

In [3]:
from pathlib import Path
from apax.utils.datasets import download_benzene_DFT, mod_md_datasets
from apax.train.run import run
from apax.utils.helpers import mod_config
import yaml


# Download and modify the dataset
data_path = Path("project")
experiment = "benzene_md"

file_path = download_benzene_DFT(data_path)
file_path = mod_md_datasets(file_path)


# Modify the config file (can be done manually)
config_path = Path("config.yaml")

config_updates = {
    "n_epochs": 100,
    "data": {
        "experiment": experiment,
        "directory": str(data_path / "models"),
        "data_path": str(file_path),
        "energy_unit": "kcal/mol",
        "pos_unit": "Ang",
    }
}
config_dict = mod_config(config_path, config_updates)


# dump config for cli showcase
with open("config.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)


# Train model
run(config_dict)


Precomputing NL: 100%|███████████████████████████████████████| 1000/1000 [00:00<00:00, 12924.12it/s]
Precomputing NL: 100%|█████████████████████████████████████████| 100/100 [00:00<00:00, 11632.43it/s]
Epochs: 100%|██████████████████████████████████████| 100/100 [03:36<00:00,  2.17s/it, val_loss=0.31]


## The ASE calculator

If you require some ASE features during your simulation, we provide an alternative to the JaxMD interface.

Please refer to the [ASE documentation](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html)  to see how to use ASE calculators.

An ASE calculator of a trained model can be instantiated as follows.

In [4]:
from ase.io import read
from apax.md import ASECalculator
from ase.md.langevin import Langevin
from ase import units
import numpy as np
from ase.io.trajectory import Trajectory


# read starting structure and define modelpath
atoms = read(file_path, index=0)
model_dir = data_path / f"models/{experiment}"


# initiolize the apax ase calculator and assign it to the starting structure
calc = ASECalculator(model_dir=model_dir)
atoms.calc = calc


# perform MD simulation
dyn = Langevin(
    atoms=atoms,
    timestep=0.5 * units.fs,
    temperature_K=300,
    friction=0.01 / units.fs,
)

traj = Trajectory('example.traj', 'w', atoms)
dyn.attach(traj.write, interval=100)
dyn.run(1000)
traj.close()

## JaxMD

While the ASE interface is convenient and flexible, it is not meant for high performance applications.
For these purposes, apax comes with an interface to [JaxMD](https://jax-md.readthedocs.io/en/main/#).
JaxMD is a high performance molecular dynamics engine built on top of [Jax](https://jax.readthedocs.io/en/latest/index.html).
The CLI provides easy access to standard NVT and NPT simulations.
More complex simulation loops are relatively easy to build yourself in JaxMD (see their colab notebooks for examples). 
Trained apax models can of course be used as `energy_fn` in such custom simulations.
If you have a suggestion for adding some MD feature or thermostat to the core of `apax`, feel free to open up an issue on Github LINK.


### Configuration
We can once again use the template command to give ourselves a quickstart.


In [5]:
!apax template md


Open the config and specify the starting structure and simulation parameters.
If you specify the data set file itself, the first structure of the data set is going to be used as the initial structure.
Your `md_config_minimal.yaml` should look similar to this:

```yaml
ensemble:
    temperature: 300 # K
    
duration: 20_000 # fs
initial_structure: project/benzene_mod.xyz
```


In [6]:
from apax.utils.helpers import mod_config
import yaml


config_path = Path("md_config.yaml")

config_updates = {
    "initial_structure": str(file_path),    # if the model from example 01 is used change this
    "duration": 1000, #fs
    "ensemble": {
        "temperature": 300,
    }
}
config_dict = mod_config(config_path, config_updates)

with open("md_config.yaml", "w") as conf:
    yaml.dump(config_dict, conf, default_flow_style=False)


As with training configurations, we can use the `validate` command to ensure our input is valid before we submit the calculation.


In [7]:
!apax validate md md_config.yaml

[32mSuccess![0m
md_config.yaml is a valid MD config.


## Running the simulation

The simulation can be started by running

In [8]:
!apax md config.yaml md_config.yaml

INFO | 12:39:38 | reading structure
INFO | 12:39:39 | Unable to initialize backend 'cuda': 
INFO | 12:39:39 | Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO | 12:39:39 | Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO | 12:39:39 | initializing model
INFO | 12:39:39 | loading checkpoint from /home/linux3_i1/segreto/uni/dev/apax/examples/project/models/benzene_md/best
INFO | 12:39:39 | Initializing new trajectory file at md/md.h5
INFO | 12:39:39 | initializing simulation
INFO | 12:39:41 | running simulation for 1.0 ps
Simulation: 100%|███████████████████████████████████| 2000/2000 [00:10<00:00, 183.72it/s, T=196.3 K]
INFO | 12:39:52 | simulation finished after elapsed time: 10.93 s




where `config.yaml` is the configuration file that was used to train the model.

During the simulation, a progress bar tracks the instantaneous temperature at each outer step.

`prog bar`

### Observables

TODO

To remove all the created files and clean up yor working directory run

In [None]:
!rm -r project md config.yaml example.traj md_config.yaml