In [None]:
import os
import ase
import math
import time
import numpy as np
import torch
from tqdm.autonotebook import tqdm
from torch_geometric.loader import DataLoader
import numpy as np

from nequip.data import AtomicDataDict
from nequip.scripts.deploy import load_deployed_model
from nequip.utils._global_options import _set_global_options
from nequip.utils import load_file, instantiate, Config
from nequip.train import Trainer

In [None]:
class Integrator:

    KBOLTZMANN = 1.38064852e-23 # J/K | 8.617333262e-5 eV/K
    AVOGADRO = 6.022140857e23   # 1/mol
    JPERKJ = 1000               # J/kJ
    NM_TO_A = 10.

    def __init__(
            self,
            dt: float,   # ps
            temp: float, # K
    ):
        self.dt = dt
        self.beta = self.JPERKJ/self.KBOLTZMANN/self.AVOGADRO/temp # mol/kJ
    
    def timestep(
            self,
            coords: torch.Tensor,   # n_atoms, 3
            momenta: torch.Tensor, # n_atoms, 3
            forces: torch.Tensor,   # n_atoms, 3
            masses: torch.Tensor,   # n_atoms
        ):
        raise NotImplementedError

class Langevin(Integrator):
    
    def __init__(
            self,
            dt: float,   # ps
            temp: float, # K
            friction: float = 10., # 1/ps
    ):
        super().__init__(dt, temp)
        self.friction = friction
        self.vscale = math.exp(-self.friction * self.dt)
        self.noisescale = math.sqrt(1 - self.vscale * self.vscale)

    def timestep(
            self,
            coords: torch.Tensor,   # n_atoms, 3
            momenta: torch.Tensor, # n_atoms, 3
            forces: torch.Tensor,   # n_atoms, 3
            masses: torch.Tensor,   # n_atoms
        ):
        coords_new = coords + momenta * self.dt / masses
        momenta_prime = momenta + self.dt * forces
        
        noise = torch.sqrt(1. / self.beta / masses)
        noise = noise * torch.randn(size=coords.size(), device=coords.device)
        momenta_new = self.vscale * momenta_prime + self.noisescale * noise * masses
        return coords_new, momenta_new


def get_edge_index(positions: torch.Tensor, r_max: float):
    dist_matrix = torch.norm(positions[:, None, ...] - positions[None, ...], dim=-1).fill_diagonal_(torch.inf)
    return torch.argwhere(dist_matrix <= r_max).T.long()

In [None]:
R_MAX_KEY = "r_max"
MOMENTA_KEY = "momenta"
MASSES_KEY = "masses"

atomic_type_to_mass = {
    1: 1.00797,
    6: 12.011,
}

ENERGY_UNITS_TO_KJ_mol = 1.
n_steps = 1000

model_path = "deployed/mace-toluene-50k-exp1-deployed.pth"
device = "cuda:1"

In [None]:
model, metadata = load_deployed_model(
    model_path,
    device=device,
    set_global_options=True,  # don't warn that setting
)
model = model.to(device)
print("loaded deployed model.")
model_r_max = float(metadata[R_MAX_KEY])

In [None]:
train_dir = "results/md17/mace-toluene-50k-exp1"
model_name = "best_model.pth"
global_config = os.path.join(train_dir, "config.yaml")
global_config = Config.from_file(str(global_config), defaults={})
_set_global_options(global_config)

del global_config

# load a training session model
model, model_config = Trainer.load_model_from_training_session(
    traindir=train_dir, model_name=model_name
)
model = model.to(device)
model = model.eval()

In [None]:
xyz_data = ase.io.read("data/toluene.xyz")

positions = torch.tensor(xyz_data.positions, device=device, dtype=torch.get_default_dtype())
masses = torch.tensor([[atomic_type_to_mass[at]] for at in xyz_data.get_atomic_numbers()], device=device, dtype=torch.get_default_dtype())
momenta = torch.rand_like(positions, device=device, dtype=torch.get_default_dtype()) * masses
atomic_number_to_atom_type = {
    1: 0,
    6: 1,
}
atom_type = torch.tensor([atomic_number_to_atom_type[an] for an in xyz_data.get_atomic_numbers()], device=device)
batch = torch.zeros(len(positions), device=device, dtype=torch.long)
edge_index = get_edge_index(positions=positions, r_max=model_config['r_max']).to(device)

data = {
    AtomicDataDict.POSITIONS_KEY: positions,
    MOMENTA_KEY: momenta,
    MASSES_KEY: masses,
    AtomicDataDict.ATOM_TYPE_KEY: atom_type,
    AtomicDataDict.BATCH_KEY: batch,
    AtomicDataDict.EDGE_INDEX_KEY: edge_index,
}

In [None]:
integrator = Langevin(dt=0.01, temp=300.00)

with tqdm(total=n_steps) as pbar:
    try:
        for i in range(1, n_steps + 1):
            # predict + extract data
            t = time.time()
            out = model(data)
            # logging.debug(f"Model inference time: {time.time() - t}")
            t = time.time()
            coords = out.get(AtomicDataDict.POSITIONS_KEY)
            momenta = out.get(MOMENTA_KEY)
            forces = out.get(AtomicDataDict.FORCE_KEY) * ENERGY_UNITS_TO_KJ_mol
            masses = out.get(MASSES_KEY)

            coords_new, momenta_new = integrator.timestep(coords, momenta, forces, masses)

            data[AtomicDataDict.POSITIONS_KEY] = coords_new
            data[MOMENTA_KEY] = momenta_new
            # logging.debug(f"Integrator time: {time.time() - t}")

            # if recompute_graph_counter == 0:
            #     t = time.time()

            #     # Make atoms outside of the box re-appear on the other side
            #     cell = out[AtomicDataDict.CELL_KEY]
            #     for j, col in enumerate(data[AtomicDataDict.POSITIONS_KEY].T):
            #         col = col.view(args.nwalkers, -1)
            #         c = cell[..., j, j]
            #         col[col < 0] += c
            #         col[col > c] -= c
                
            #     # Recompute graph
            #     dataset.data["R"] = data[AtomicDataDict.POSITIONS_KEY].view(args.nwalkers, -1, 3).detach().cpu()
            #     dataset.data["p"] = data[AtomicDataDict.MOMENTA_KEY].view(args.nwalkers, -1, 3).detach().cpu()
            #     dataloader = DataLoader(dataset=dataset, **dl_kwargs)
            #     data = iter(dataloader).next()[AtomicDataDict.DATA_KEY].to(args.device)
            #     data = AtomicData.to_AtomicDataDict(data)

            #     logging.debug(f"Graph re-computation time: {time.time() - t}")

            # recompute_graph_counter += 1
            # if args.recgraph is not None and recompute_graph_counter >= args.recgraph:
            #     recompute_graph_counter = 0

            # if not i % log_every_n_steps:
            #     write_md_config(out, curr_step=i, dt=args.timestep, n_atoms=dataset.data["R"].shape[1])
            #     traj_coords[(i-1)//log_every_n_steps % trajlen, ...] = r.cpu().numpy().reshape(args.nwalkers, -1, 3)
            #     traj_momenta[(i-1)//log_every_n_steps % trajlen, ...] = p.cpu().numpy().reshape(args.nwalkers, -1, 3)

            # # append current structure to xyz file
            # if not i % (log_every_n_steps * trajlen):
            #     save_traj(run_dir, top, traj_coords, traj_momenta, traj_file_index, reordering_filter)
            #     traj_coords = np.zeros((min(trajlen, n_steps//log_every_n_steps), args.nwalkers, dataset.data["R"].shape[1], dataset.data["R"].shape[2]), dtype=np.float32)
            #     traj_momenta = np.zeros((min(trajlen, n_steps//log_every_n_steps), args.nwalkers, dataset.data["R"].shape[1], dataset.data["R"].shape[2]), dtype=np.float32)
            #     traj_file_index += 1
            
            pbar.update(1)
    except KeyboardInterrupt:
        print("Simulation manually stopped")