In [1]:
import os
import torch
import random
import shutil
import time
import numpy as np
from tqdm import tqdm 
from matplotlib import pyplot as plt

import ase
from ase.io import read, write
from ase import Atoms
from ase.optimize.lbfgs import LBFGS

import schnetpack as spk
from schnetpack import properties
from schnetpack.interfaces.ase_interface import AtomsConverter, AseInterface
from schnetpack.interfaces.ensemble_calculator import EnsembleCalculator
from ase.optimize import QuasiNewton

import schnetpack
from schnetpack.interfaces.batchwise_optimization import ASEBatchwiseLBFGS, BatchwiseCalculator

In [2]:
from typing import Optional, List, Union, Dict


In [8]:
class AseInterfaceCustom(AseInterface):
    """
    Interface for ASE calculations (optimization and molecular dynamics)
    """

    def __init__(
        self,
        molecule_path: str,
        working_dir: str,
        model_file: str,
        neighbor_list: schnetpack.transform.Transform,
        energy_key: str = "energy",
        force_key: str = "forces",
        stress_key: Optional[str] = None,
        energy_unit: Union[str, float] = "kcal/mol",
        position_unit: Union[str, float] = "Angstrom",
        device: Union[str, torch.device] = "cpu",
        dtype: torch.dtype = torch.float32,
        converter: AtomsConverter = AtomsConverter,
        optimizer_class: type = QuasiNewton,
        fixed_atoms: Optional[List[int]] = None,
        transforms: Union[
            schnetpack.transform.Transform, List[schnetpack.transform.Transform]
        ] = None,
        additional_inputs: Dict[str, torch.Tensor] = None,
    ):
        
        # Setup directory
        self.working_dir = working_dir
        if not os.path.exists(self.working_dir):
            os.makedirs(self.working_dir)

        # Load the molecule
        self.molecule = read(molecule_path)

        # Apply position constraints
        if fixed_atoms:
            c = FixAtoms(fixed_atoms)
            self.molecule.set_constraint(constraint=c)

        # Set up optimizer
        self.optimizer_class = optimizer_class
        
        converter = AtomsConverter(
            neighbor_list=neighbor_list,
            device=device,
            dtype=dtype,
            transforms=transforms,
            additional_inputs=additional_inputs,
        )

        # Set up calculator
        calculator = EnsembleCalculator(
            model_file=model_file,
            neighbor_list=neighbor_list,
            device=device,
            energy_key=energy_key,
            force_key=force_key,
            stress_key=stress_key,
            energy_unit=energy_unit,
            position_unit=position_unit,
            dtype=dtype,
            #transforms=transforms,
            #additional_inputs=additional_inputs,
            ensemble_average_strategy=None,
        )

        self.molecule.calc = calculator

        self.dynamics = None

In [9]:
random.seed(42)

In [10]:
model_path_0 = "/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol.model"
model_path_1 = "/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol2.model"

# set device
device = torch.device("cuda")

# define neighbor list
cutoff = 5.0
nbh_list=spk.transform.MatScipyNeighborList(cutoff=cutoff)

## build atoms converter
#atoms_converter = AtomsConverter(
#    neighbor_list=nbh_list,
#    device=device,
#)

## build calculator
#calculator = BatchwiseCalculator(
#    model=[model_path_0, model_path_1],
#    atoms_converter=atoms_converter,
#    device=device,
#    energy_unit="kcal/mol",
#    position_unit="Ang",
#    dtype=torch.float32
#)

In [11]:
input_structure_file = "../../tests/testdata/md_ethanol.xyz"

if not os.path.exists('howto_batchwise_relaxations_outputs'):
    os.makedirs('howto_batchwise_relaxations_outputs')

# load initial structure
mol = read(input_structure_file)
pos = mol.get_positions()
# distort the structures and store them
for n in range(pos.shape[0]):
    pos[n] = pos[n] * random.uniform(0.95,1.05)
at = Atoms(positions=pos, numbers=mol.get_atomic_numbers())
write("./howto_batchwise_relaxations_outputs/init_ethanol.xyz", at, format="xyz")    


In [12]:
relax_dir = "howto_batchwise_relaxations_outputs/relax"
if os.path.exists(relax_dir):
    shutil.rmtree(relax_dir)
os.makedirs(relax_dir)
    
ase_interface = AseInterfaceCustom(
    molecule_path="./howto_batchwise_relaxations_outputs/init_ethanol.xyz",
    working_dir=relax_dir,
    model_file=[model_path_0, model_path_1],
    neighbor_list=nbh_list,
    device=device,
    dtype=torch.float32,
    energy_unit="kcal/mol",
    position_unit="Ang",
    #optimizer_class=LBFGS,
)
ase_interface.optimize(fmax=0.0005, steps=1000)

['/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol.model', '/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol2.model']
/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol.model
/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol2.model
                Step[ FC]     Time          Energy          fmax
BFGSLineSearch:    0[  0] 14:59:51   -97103.684891      222.2201
BFGSLineSearch:    1[  2] 14:59:51   -97134.118121      116.9887
BFGSLineSearch:    2[  4] 14:59:51   -97144.681615       40.8369
BFGSLineSearch:    3[  6] 14:59:51   -97146.064842       32.4911
BFGSLineSearch:    4[  8] 14:59:51   -97147.739568       34.5070
BFGSLineSearch:    5[ 10] 14:59:51   -97148.216091       35.6081
BFGSLineSearch:    6[ 12] 14:59:51   -97148.383482       30.5677
BFGSLineSearch:    7[ 14] 14:59:51   -97148.532490       32.4041
BFGSLineSearch:    8[ 16] 14:59:51   -97148.622036       33.3574
BFGSLineSearch:    9[ 18] 14:59:51   -97148.684193       31.0313
BFGSLineSearch: 

BFGSLineSearch:  121[878] 15:00:05   -97149.625854        0.1041
BFGSLineSearch:  122[886] 15:00:05   -97149.625853        0.1041
BFGSLineSearch:  123[890] 15:00:05   -97149.625853        0.1041
BFGSLineSearch:  124[894] 15:00:05   -97149.625852        0.1041
BFGSLineSearch:  125[905] 15:00:05   -97149.625853        0.1041
BFGSLineSearch:  126[917] 15:00:05   -97149.625853        0.1041
BFGSLineSearch:  127[925] 15:00:05   -97149.625855        0.1041
BFGSLineSearch:  128[929] 15:00:05   -97149.625854        0.1041
BFGSLineSearch:  129[933] 15:00:05   -97149.625853        0.1041
BFGSLineSearch:  130[941] 15:00:06   -97149.625852        0.1041
BFGSLineSearch:  131[952] 15:00:06   -97149.625854        0.1041
BFGSLineSearch:  132[959] 15:00:06   -97149.625854        0.1041
BFGSLineSearch:  133[968] 15:00:06   -97149.625855        0.1042
BFGSLineSearch:  134[976] 15:00:06   -97149.625853        0.1042
BFGSLineSearch:  135[985] 15:00:06   -97149.625854        0.1042
BFGSLineSearch:  136[989]

BFGSLineSearch:  246[1883] 15:00:20   -97149.625899        0.1205
BFGSLineSearch:  247[1898] 15:00:20   -97149.625898        0.1205
BFGSLineSearch:  248[1909] 15:00:20   -97149.625899        0.1205
BFGSLineSearch:  249[1921] 15:00:20   -97149.625901        0.1205
BFGSLineSearch:  250[1925] 15:00:20   -97149.625900        0.1205
BFGSLineSearch:  251[1929] 15:00:20   -97149.625899        0.1205
BFGSLineSearch:  252[1937] 15:00:21   -97149.625898        0.1205
BFGSLineSearch:  253[1949] 15:00:21   -97149.625900        0.1205
BFGSLineSearch:  254[1956] 15:00:21   -97149.625900        0.1205
BFGSLineSearch:  255[1960] 15:00:21   -97149.625898        0.1205
BFGSLineSearch:  256[1969] 15:00:22   -97149.625899        0.1205
BFGSLineSearch:  257[1977] 15:00:22   -97149.625900        0.1205
BFGSLineSearch:  258[1986] 15:00:22   -97149.625900        0.1205
BFGSLineSearch:  259[1990] 15:00:22   -97149.625899        0.1205
BFGSLineSearch:  260[2001] 15:00:22   -97149.625898        0.1205
BFGSLineSe

BFGSLineSearch:  371[2993] 15:00:36   -97149.625898        0.1195
BFGSLineSearch:  372[3001] 15:00:36   -97149.625898        0.1195
BFGSLineSearch:  373[3012] 15:00:36   -97149.625899        0.1195
BFGSLineSearch:  374[3018] 15:00:36   -97149.625898        0.1195
BFGSLineSearch:  375[3024] 15:00:36   -97149.625899        0.1195
BFGSLineSearch:  376[3032] 15:00:36   -97149.625898        0.1195
BFGSLineSearch:  377[3040] 15:00:36   -97149.625898        0.1195
BFGSLineSearch:  378[3048] 15:00:36   -97149.625898        0.1195
BFGSLineSearch:  379[3052] 15:00:36   -97149.625897        0.1195
BFGSLineSearch:  380[3070] 15:00:37   -97149.625899        0.1195
BFGSLineSearch:  381[3076] 15:00:37   -97149.625898        0.1195
BFGSLineSearch:  382[3084] 15:00:37   -97149.625898        0.1195
BFGSLineSearch:  383[3095] 15:00:37   -97149.625898        0.1195
BFGSLineSearch:  384[3104] 15:00:37   -97149.625897        0.1195
BFGSLineSearch:  385[3112] 15:00:37   -97149.625899        0.1195
BFGSLineSe

BFGSLineSearch:  496[3948] 15:00:48   -97149.625892        0.1195
BFGSLineSearch:  497[3952] 15:00:48   -97149.625890        0.1195
BFGSLineSearch:  498[3963] 15:00:49   -97149.625889        0.1195
BFGSLineSearch:  499[3974] 15:00:49   -97149.625891        0.1195
BFGSLineSearch:  500[3980] 15:00:49   -97149.625891        0.1195
BFGSLineSearch:  501[3988] 15:00:49   -97149.625892        0.1195
BFGSLineSearch:  502[3992] 15:00:49   -97149.625891        0.1195
BFGSLineSearch:  503[4003] 15:00:49   -97149.625892        0.1195
BFGSLineSearch:  504[4012] 15:00:49   -97149.625890        0.1195
BFGSLineSearch:  505[4026] 15:00:49   -97149.625891        0.1195
BFGSLineSearch:  506[4037] 15:00:50   -97149.625890        0.1195
BFGSLineSearch:  507[4046] 15:00:50   -97149.625891        0.1195
BFGSLineSearch:  508[4055] 15:00:50   -97149.625892        0.1195
BFGSLineSearch:  509[4064] 15:00:50   -97149.625891        0.1195
BFGSLineSearch:  510[4075] 15:00:50   -97149.625891        0.1195
BFGSLineSe

BFGSLineSearch:  621[5035] 15:01:03   -97149.625891        0.1190
BFGSLineSearch:  622[5044] 15:01:03   -97149.625891        0.1191
BFGSLineSearch:  623[5052] 15:01:03   -97149.625891        0.1191
BFGSLineSearch:  624[5061] 15:01:03   -97149.625891        0.1191
BFGSLineSearch:  625[5069] 15:01:03   -97149.625894        0.1191
BFGSLineSearch:  626[5073] 15:01:03   -97149.625890        0.1191
BFGSLineSearch:  627[5081] 15:01:03   -97149.625891        0.1191
BFGSLineSearch:  628[5097] 15:01:04   -97149.625894        0.1170
BFGSLineSearch:  629[5104] 15:01:04   -97149.625894        0.1170
BFGSLineSearch:  630[5111] 15:01:04   -97149.625893        0.1168
BFGSLineSearch:  631[5120] 15:01:04   -97149.625893        0.1168
BFGSLineSearch:  632[5127] 15:01:04   -97149.625893        0.1168
BFGSLineSearch:  633[5137] 15:01:04   -97149.625894        0.1168
BFGSLineSearch:  634[5141] 15:01:04   -97149.625893        0.1168
BFGSLineSearch:  635[5156] 15:01:04   -97149.625892        0.1168
BFGSLineSe

BFGSLineSearch:  746[6074] 15:01:16   -97149.625886        0.1168
BFGSLineSearch:  747[6078] 15:01:16   -97149.625883        0.1168
BFGSLineSearch:  748[6093] 15:01:17   -97149.625884        0.1167
BFGSLineSearch:  749[6105] 15:01:17   -97149.625887        0.1166
BFGSLineSearch:  750[6112] 15:01:17   -97149.625888        0.1166
BFGSLineSearch:  751[6117] 15:01:17   -97149.625886        0.1166
BFGSLineSearch:  752[6123] 15:01:17   -97149.625885        0.1166
BFGSLineSearch:  753[6133] 15:01:17   -97149.625885        0.1166
BFGSLineSearch:  754[6144] 15:01:17   -97149.625885        0.1166
BFGSLineSearch:  755[6153] 15:01:18   -97149.625886        0.1166
BFGSLineSearch:  756[6157] 15:01:18   -97149.625885        0.1166
BFGSLineSearch:  757[6166] 15:01:18   -97149.625886        0.1166
BFGSLineSearch:  758[6175] 15:01:18   -97149.625885        0.1166
BFGSLineSearch:  759[6183] 15:01:18   -97149.625884        0.1166
BFGSLineSearch:  760[6200] 15:01:18   -97149.625885        0.1166
BFGSLineSe

BFGSLineSearch:  871[7247] 15:01:31   -97149.625884        0.1161
BFGSLineSearch:  872[7254] 15:01:31   -97149.625885        0.1161
BFGSLineSearch:  873[7258] 15:01:31   -97149.625885        0.1161
BFGSLineSearch:  874[7269] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  875[7277] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  876[7283] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  877[7291] 15:01:32   -97149.625884        0.1161
BFGSLineSearch:  878[7300] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  879[7306] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  880[7310] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  881[7324] 15:01:32   -97149.625884        0.1161
BFGSLineSearch:  882[7344] 15:01:32   -97149.625885        0.1161
BFGSLineSearch:  883[7352] 15:01:33   -97149.625885        0.1161
BFGSLineSearch:  884[7362] 15:01:33   -97149.625885        0.1161
BFGSLineSearch:  885[7366] 15:01:33   -97149.625885        0.1161
BFGSLineSe

BFGSLineSearch:  996[8339] 15:01:45   -97149.625877        0.1157
BFGSLineSearch:  997[8348] 15:01:46   -97149.625877        0.1157
BFGSLineSearch:  998[8360] 15:01:46   -97149.625876        0.1157
BFGSLineSearch:  999[8369] 15:01:46   -97149.625874        0.1157
BFGSLineSearch:  1000[8383] 15:01:46   -97149.625876        0.1157


In [None]:
if not os.path.exists('howto_batchwise_relaxations_outputs'):
    os.makedirs('howto_batchwise_relaxations_outputs')
    
input_structure_file = "../../tests/testdata/md_ethanol.xyz"

# load initial structure
mol = read(input_structure_file)
pos = mol.get_positions()
# distort the structures and store them
for at_idx in range(batch_size):
    for n in range(pos.shape[0]):
        pos[n] = pos[n] * random.uniform(0.95,1.15)
    at = Atoms(positions=pos, numbers=mol.get_atomic_numbers())
    write("./howto_batchwise_relaxations_outputs/init_ethanol_{}.xyz".format(at_idx), at, format="xyz")    
    pos = mol.get_positions()

# get list of initial structures
ats = []
for at_idx in range(batch_size):
    ats.append(read("./howto_batchwise_relaxations_outputs/init_ethanol_{}.xyz".format(at_idx)))

In [None]:
# define structure mask for optimization (True for fixed, False for non-fixed)
n_atoms = len(ats[0].get_atomic_numbers())
single_structure_mask = [False for _ in range(n_atoms)]
# expand mask by number of input structures (fixed atoms are equivalent for all input structures)
mask = single_structure_mask * len(ats)

In [None]:
# run individual structure optimizations as reference
t_start_indiv = time.time()
for at_idx in range(batch_size):

    relax_dir = "howto_batchwise_relaxations_outputs/relax_{}".format(at_idx)
    if os.path.exists(relax_dir):
        shutil.rmtree(relax_dir)
    os.makedirs(relax_dir)

    ase_interface = AseInterface(
        molecule_path="./howto_batchwise_relaxations_outputs/init_ethanol_{}.xyz".format(at_idx),
        working_dir=relax_dir,
        model_file=model_path,
        neighbor_list=nbh_list,
        device=device,
        dtype=torch.float32,
        energy_unit="kcal/mol",
        position_unit="Ang",
        optimizer_class=LBFGS,
    )
    ase_interface.optimize(fmax=0.0005, steps=1000)

t_end_indiv = time.time()
relaxation_time_indiv = t_end_indiv - t_start_indiv