In [None]:
import os
import torch
import random
import shutil
import time
import numpy as np
from tqdm import tqdm 
from matplotlib import pyplot as plt
from typing import Optional, List, Union, Dict

import ase
from ase.io import read, write
from ase import Atoms
from ase.optimize.lbfgs import LBFGS
from ase.optimize import QuasiNewton

import schnetpack as spk
from schnetpack import properties
from schnetpack.interfaces.ase_interface import AtomsConverter, AseInterface
from schnetpack.interfaces.ensemble_calculator import EnsembleCalculator
from schnetpack.interfaces.batchwise_optimization import ASEBatchwiseLBFGS, BatchwiseCalculator

In [None]:
class AseInterfaceCustom(AseInterface):
    """
    Interface for ASE calculations (optimization and molecular dynamics)
    """

    def __init__(
        self,
        molecule_path: str,
        working_dir: str,
        model_file: str,
        neighbor_list: spk.transform.Transform,
        energy_key: str = "energy",
        force_key: str = "forces",
        stress_key: Optional[str] = None,
        energy_unit: Union[str, float] = "kcal/mol",
        position_unit: Union[str, float] = "Angstrom",
        device: Union[str, torch.device] = "cpu",
        dtype: torch.dtype = torch.float32,
        converter: AtomsConverter = AtomsConverter,
        optimizer_class: type = QuasiNewton,
        fixed_atoms: Optional[List[int]] = None,
        transforms: Union[
            spk.transform.Transform, List[spk.transform.Transform]
        ] = None,
        additional_inputs: Dict[str, torch.Tensor] = None,
    ):
        
        # Setup directory
        self.working_dir = working_dir
        if not os.path.exists(self.working_dir):
            os.makedirs(self.working_dir)

        # Load the molecule
        self.molecule = read(molecule_path)

        # Apply position constraints
        if fixed_atoms:
            c = FixAtoms(fixed_atoms)
            self.molecule.set_constraint(constraint=c)

        # Set up optimizer
        self.optimizer_class = optimizer_class
        
        converter = AtomsConverter(
            neighbor_list=neighbor_list,
            device=device,
            dtype=dtype,
            transforms=transforms,
            additional_inputs=additional_inputs,
        )

        # Set up calculator
        calculator = EnsembleCalculator(
            model_file=model_file,
            neighbor_list=neighbor_list,
            device=device,
            energy_key=energy_key,
            force_key=force_key,
            stress_key=stress_key,
            energy_unit=energy_unit,
            position_unit=position_unit,
            dtype=dtype,
            #transforms=transforms,
            #additional_inputs=additional_inputs,
            ensemble_average_strategy=None,
        )

        self.molecule.calc = calculator

        self.dynamics = None

In [None]:
model_path_0 = "/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol.model"
model_path_1 = "/home/jonas/Documents/schnetpack/tests/testdata/md_ethanol2.model"

# set device
device = torch.device("cuda")

# define neighbor list
cutoff = 5.0
nbh_list=spk.transform.MatScipyNeighborList(cutoff=cutoff)

In [None]:
input_structure_file = "../../tests/testdata/md_ethanol.xyz"
random.seed(42)

if not os.path.exists('howto_batchwise_relaxations_outputs'):
    os.makedirs('howto_batchwise_relaxations_outputs')

# load initial structure
mol = read(input_structure_file)
pos = mol.get_positions()
# distort the structures and store them
for n in range(pos.shape[0]):
    pos[n] = pos[n] * random.uniform(0.95,1.05)
at = Atoms(positions=pos, numbers=mol.get_atomic_numbers())
write("./howto_batchwise_relaxations_outputs/init_ethanol.xyz", at, format="xyz")

In [None]:
relax_dir = "howto_batchwise_relaxations_outputs/relax"
if os.path.exists(relax_dir):
    shutil.rmtree(relax_dir)
os.makedirs(relax_dir)
    
ase_interface = AseInterfaceCustom(
    molecule_path="./howto_batchwise_relaxations_outputs/init_ethanol.xyz",
    working_dir=relax_dir,
    model_file=[model_path_0, model_path_1],
    neighbor_list=nbh_list,
    device=device,
    dtype=torch.float32,
    energy_unit="kcal/mol",
    position_unit="Ang",
    #optimizer_class=LBFGS,
)
ase_interface.optimize(fmax=0.0005, steps=1000)