In [None]:
import os
import torch
import pathlib
import ase.io
import numpy as np
from tqdm.auto import tqdm

from nequip.train import Trainer
from nequip.utils import Config
from nequip.data import AtomicData, Collater, dataset_from_config
from nequip.data import AtomicDataDict

In [None]:
args_train_dir = "results/nmr_prod_all/bmrb_mace_prod_all1/"
args_dataset_config = os.path.join("configs", "nmr", "bmrb_prod_all_test.yaml")

model_name = "best_model.pth"
args_model = os.path.join(args_train_dir, model_name)
device = "cuda:2"

test_frame_index = 206
minimization_max_steps = 1000
minimization_threshold_error = 1.
dtau = 5e-5

In [None]:
# load a training session model
model, model_config = Trainer.load_model_from_training_session(
    traindir=args_train_dir, model_name=model_name
)
model = model.to(device)
model.eval()

test_config = Config.from_file(str(args_dataset_config), defaults={})
model_config.update(test_config)

dataset, _ = dataset_from_config(model_config, prefix="test_dataset")
pdb_code = dataset.datasets[test_frame_index].file_name.split('/')[-1].split('.')[0]
c = Collater.for_dataset(dataset, exclude_keys=[])

test_idcs = torch.arange(len(dataset.datasets))

this_batch_test_indexes = test_idcs[test_frame_index : test_frame_index + 1]
datas = [dataset[int(idex)] for idex in this_batch_test_indexes]

batch = c.collate(datas)
batch = batch.to(device)
input_ = AtomicData.to_AtomicDataDict(batch)

if AtomicDataDict.PER_ATOM_ENERGY_KEY in input_:
    not_nan_edge_filter = torch.isin(input_[AtomicDataDict.EDGE_INDEX_KEY][0], torch.argwhere(~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY].flatten())).flatten())
    input_[AtomicDataDict.EDGE_INDEX_KEY] = input_[AtomicDataDict.EDGE_INDEX_KEY][:, not_nan_edge_filter]
    input_[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = input_[AtomicDataDict.EDGE_CELL_SHIFT_KEY][not_nan_edge_filter]
    input_[AtomicDataDict.ORIG_BATCH_KEY] = input_[AtomicDataDict.BATCH_KEY].clone()
    input_[AtomicDataDict.BATCH_KEY] = input_[AtomicDataDict.BATCH_KEY][~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY]).flatten()]

In [None]:
results = {
    "pos": [],
    "cs": [],
    "loss": [],
}

with tqdm(total=minimization_max_steps) as pbar:
    for i in range(minimization_max_steps):
        input_[AtomicDataDict.POSITIONS_KEY].requires_grad_(True)
        out_ = model(input_)

        pred_cs = out_[AtomicDataDict.PER_ATOM_ENERGY_KEY]
        target_cs = input_[AtomicDataDict.PER_ATOM_ENERGY_KEY]
        not_nan_node_filter = torch.argwhere(~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY].flatten())).flatten()

        loss = torch.pow((pred_cs[not_nan_node_filter] - target_cs[not_nan_node_filter]), 2)

        forces = -torch.autograd.grad(
            [loss.sum()],
            [out_[AtomicDataDict.POSITIONS_KEY]],
            create_graph=True,
        )[0]

        out_[AtomicDataDict.POSITIONS_KEY].requires_grad_(False)
        input_[AtomicDataDict.POSITIONS_KEY] += forces.detach() * dtau

        results['pos'].append(input_[AtomicDataDict.POSITIONS_KEY].detach().cpu().numpy())
        results['cs'].append(pred_cs.detach().cpu().numpy().flatten())
        results['loss'].append(loss.detach().sum().cpu().numpy())

        pbar.update(1)
        if loss.sum().item() < minimization_threshold_error:
            pbar.update(minimization_max_steps - i - 1)
            break

In [None]:
def get_test_xyz_filename(args_train_dir, pdb_code):
    for item in pathlib.Path(args_train_dir).rglob("*evaluation.xyz"):
        if pdb_code in item.name:
            return str(item)
    raise Exception

test_xyz_filename = get_test_xyz_filename(args_train_dir, pdb_code)
test_minimized_xyz_filename = test_xyz_filename.split('.')[0] + "_minimized.xyz"

test_xyz = ase.io.read(test_xyz_filename, index=":", format="extxyz")[0].copy()
test_xyz.arrays['positions'] = results['pos'][-1]
test_xyz.arrays['energies'] = results['cs'][-1]

ase.io.write(
    test_minimized_xyz_filename,
    test_xyz,
    format="extxyz",
    append=False,
)
print(f"minimization file {test_minimized_xyz_filename} saved!")