diff --git a/.github/workflows/run_pytest_interfaces.yml b/.github/workflows/run_pytest_interfaces.yml new file mode 100644 index 00000000..6c943821 --- /dev/null +++ b/.github/workflows/run_pytest_interfaces.yml @@ -0,0 +1,39 @@ +name: Run interface related PyTests +env: + COLUMNS: 120 +on: + pull_request: + paths: + - 'matsciml/interfaces/**' + workflow_dispatch: +jobs: + interfaces-pytest: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Create/reuse micromamba env + uses: mamba-org/setup-micromamba@v1 + with: + micromamba-version: '1.5.7-0' + environment-file: conda.yml + init-shell: >- + bash + cache-environment: true + post-cleanup: 'all' + generate-run-shell: true + - name: Install current version of matsciml + run: | + pip install . + shell: micromamba-shell {0} + - name: Install PyTest + run: | + pip install pytest pytest-dependency pytest-pretty + shell: micromamba-shell {0} + - name: Print out environment + run: | + micromamba env export && pip freeze + shell: micromamba-shell {0} + - name: Run pytest in data + run: | + pytest -v -m "not lmdb and not slow and not remote_request" ./matsciml/interfaces + shell: micromamba-shell {0} diff --git a/examples/interfaces/ase_from_pretrained.py b/examples/interfaces/ase_from_pretrained.py new file mode 100644 index 00000000..b1f4f56a --- /dev/null +++ b/examples/interfaces/ase_from_pretrained.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from ase import Atoms, units +from ase.md.verlet import VelocityVerlet + +from matsciml.interfaces.ase import MatSciMLCalculator +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) + +""" +Demonstrates setting up a calculator from a pretrained +`ForceRegressionTask` checkpoint. + +Substitute 'model.ckpt' for the path to a checkpoint file. +""" + +d = 2.9 +L = 10.0 + +atoms = Atoms("C", positions=[[0, L / 2, L / 2]], cell=[d, L, L], pbc=[1, 0, 0]) + +calc = MatSciMLCalculator.from_pretrained_force_regression( + "model.ckpt", + transforms=[ + PeriodicPropertiesTransform(6.0, True), + PointCloudToGraphTransform("pyg"), + ], +) +# set the calculator to matsciml +atoms.calc = calc +# run the simulation for 100 timesteps, with 5 femtosecond timesteps +dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") +dyn.run(100) diff --git a/examples/interfaces/ase_from_scratch.py b/examples/interfaces/ase_from_scratch.py new file mode 100644 index 00000000..8ea21fa4 --- /dev/null +++ b/examples/interfaces/ase_from_scratch.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from ase import Atoms, units +from ase.md.verlet import VelocityVerlet +import pytorch_lightning as pl + +from matsciml.lightning import MatSciMLDataModule +from matsciml.models.base import ForceRegressionTask +from matsciml.models.pyg import EGNN +from matsciml.interfaces.ase import MatSciMLCalculator +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) + +""" +Demonstrates setting up a calculator from a `ForceRegressionTask` +trained from scratch - this is unlikely be the way you would actually +do this, but just demonstrates how the workflow is composed together +in a single file. +""" + +task = ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs={"hidden_dim": 32, "output_dim": 32}, + output_kwargs={"lazy": False, "input_dim": 32, "hidden_dim": 32, "num_hidden": 3}, +) + +transforms = [ + PeriodicPropertiesTransform(6.0, True), + PointCloudToGraphTransform("pyg"), +] + +dm = MatSciMLDataModule.from_devset( + "LiPSDataset", batch_size=8, num_workers=0, dset_kwargs={"transforms": transforms} +) + +# run the training loop +trainer = pl.Trainer( + fast_dev_run=10, + logger=False, + enable_checkpointing=False, +) +trainer.fit(task, datamodule=dm) + +# put it into eval for inference +task = task.eval() + +# get a random frame from LiPS to do the propagation +frame = dm.dataset.__getitem__(52) +graph = frame["graph"] +atoms = Atoms( + positions=graph["pos"].numpy(), + cell=frame["cell"].numpy().squeeze(), + numbers=graph["atomic_numbers"].numpy(), +) + +# instantiate calculator using the trained model +# reuse the same transforms as with the data module +calc = MatSciMLCalculator(task, transforms=transforms) +# set the calculator to matsciml +atoms.calc = calc +# run the simulation for 100 timesteps, with 5 femtosecond timesteps +dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") +dyn.run(100) diff --git a/matsciml/common/relaxation/ase_utils.py b/matsciml/common/relaxation/ase_utils.py deleted file mode 100644 index 7be9a7e8..00000000 --- a/matsciml/common/relaxation/ase_utils.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Copyright (c) Facebook, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - - - -Utilities to interface OCP models/trainers with the Atomic Simulation -Environment (ASE) -""" -from __future__ import annotations - -import copy -import logging -import os - -import torch -import yaml -from ase import Atoms -from ase.calculators.calculator import Calculator -from ase.calculators.singlepoint import SinglePointCalculator as sp -from ase.constraints import FixAtoms - -from matsciml.common.registry import registry -from matsciml.common.utils import radius_graph_pbc, setup_imports, setup_logging -from matsciml.datasets.trajectory_lmdb import data_list_collater -from matsciml.preprocessing import AtomsToGraphs - - -def batch_to_atoms(batch): - n_systems = batch.neighbors.shape[0] - natoms = batch.natoms.tolist() - numbers = torch.split(batch.atomic_numbers, natoms) - fixed = torch.split(batch.fixed, natoms) - forces = torch.split(batch.force, natoms) - positions = torch.split(batch.pos, natoms) - tags = torch.split(batch.tags, natoms) - cells = batch.cell - energies = batch.y.tolist() - - atoms_objects = [] - for idx in range(n_systems): - atoms = Atoms( - numbers=numbers[idx].tolist(), - positions=positions[idx].cpu().detach().numpy(), - tags=tags[idx].tolist(), - cell=cells[idx].cpu().detach().numpy(), - constraint=FixAtoms(mask=fixed[idx].tolist()), - pbc=[True, True, True], - ) - calc = sp( - atoms=atoms, - energy=energies[idx], - forces=forces[idx].cpu().detach().numpy(), - ) - atoms.set_calculator(calc) - atoms_objects.append(atoms) - - return atoms_objects - - -class OCPCalculator(Calculator): - implemented_properties = ["energy", "forces"] - - def __init__(self, config_yml=None, checkpoint=None, cutoff=6, max_neighbors=50): - """ - OCP-ASE Calculator - - Args: - config_yml (str): - Path to yaml config or could be a dictionary. - checkpoint (str): - Path to trained checkpoint. - cutoff (int): - Cutoff radius to be used for data preprocessing. - max_neighbors (int): - Maximum amount of neighbors to store for a given atom. - """ - setup_imports() - setup_logging() - Calculator.__init__(self) - - # Either the config path or the checkpoint path needs to be provided - assert config_yml or checkpoint is not None - - if config_yml is not None: - if isinstance(config_yml, str): - config = yaml.safe_load(open(config_yml)) - - if "includes" in config: - for include in config["includes"]: - # Change the path based on absolute path of config_yml - path = os.path.join(config_yml.split("configs")[0], include) - include_config = yaml.safe_load(open(path)) - config.update(include_config) - else: - config = config_yml - # Only keeps the train data that might have normalizer values - config["dataset"] = config["dataset"][0] - else: - # Loads the config from the checkpoint directly - config = torch.load(checkpoint, map_location=torch.device("cpu"))["config"] - - # Load the trainer based on the dataset used - if config["task"]["dataset"] == "trajectory_lmdb": - config["trainer"] = "forces" - else: - config["trainer"] = "energy" - - config["model_attributes"]["name"] = config.pop("model") - config["model"] = config["model_attributes"] - - # Calculate the edge indices on the fly - config["model"]["otf_graph"] = True - - # Save config so obj can be transported over network (pkl) - self.config = copy.deepcopy(config) - self.config["checkpoint"] = checkpoint - - if "normalizer" not in config: - del config["dataset"]["src"] - config["normalizer"] = config["dataset"] - - self.trainer = registry.get_trainer_class(config.get("trainer", "energy"))( - task=config["task"], - model=config["model"], - dataset=None, - normalizer=config["normalizer"], - optimizer=config["optim"], - identifier="", - slurm=config.get("slurm", {}), - local_rank=config.get("local_rank", 0), - is_debug=config.get("is_debug", True), - cpu=True, - ) - - if checkpoint is not None: - self.load_checkpoint(checkpoint) - - self.a2g = AtomsToGraphs( - max_neigh=max_neighbors, - radius=cutoff, - r_energy=False, - r_forces=False, - r_distances=False, - r_edges=False, - ) - - def load_checkpoint(self, checkpoint_path): - """ - Load existing trained model - - Args: - checkpoint_path: string - Path to trained model - """ - try: - self.trainer.load_checkpoint(checkpoint_path) - except NotImplementedError: - logging.warning("Unable to load checkpoint!") - - def calculate(self, atoms, properties, system_changes): - Calculator.calculate(self, atoms, properties, system_changes) - data_object = self.a2g.convert(atoms) - batch = data_list_collater([data_object], otf_graph=True) - - predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True) - if self.trainer.name == "s2ef": - self.results["energy"] = predictions["energy"].item() - self.results["forces"] = predictions["forces"].cpu().numpy() - - elif self.trainer.name == "is2re": - self.results["energy"] = predictions["energy"].item() diff --git a/matsciml/common/relaxation/ml_relaxation.py b/matsciml/common/relaxation/ml_relaxation.py deleted file mode 100644 index 924fd378..00000000 --- a/matsciml/common/relaxation/ml_relaxation.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Copyright (c) Facebook, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" -from __future__ import annotations - -from pathlib import Path - -import torch - -from matsciml.common.registry import registry -from matsciml.common.relaxation.optimizers.lbfgs_torch import LBFGS, TorchCalc - - -def ml_relax( - batch, - model, - steps, - fmax, - relax_opt, - device="cuda:0", - transform=None, - early_stop_batch=False, -): - """ - Runs ML-based relaxations. - Args: - batch: object - model: object - steps: int - Max number of steps in the structure relaxation. - fmax: float - Structure relaxation terminates when the max force - of the system is no bigger than fmax. - relax_opt: str - Optimizer and corresponding parameters to be used for structure relaxations. - """ - batch = batch[0] - ids = batch.sid - calc = TorchCalc(model, transform) - - # Run ML-based relaxation - traj_dir = relax_opt.get("traj_dir", None) - optimizer = LBFGS( - batch, - calc, - maxstep=relax_opt.get("maxstep", 0.04), - memory=relax_opt["memory"], - damping=relax_opt.get("damping", 1.0), - alpha=relax_opt.get("alpha", 70.0), - device=device, - traj_dir=Path(traj_dir) if traj_dir is not None else None, - traj_names=ids, - early_stop_batch=early_stop_batch, - ) - relaxed_batch = optimizer.run(fmax=fmax, steps=steps) - - return relaxed_batch diff --git a/matsciml/common/relaxation/optimizers/__init__.py b/matsciml/common/relaxation/optimizers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/matsciml/common/relaxation/optimizers/lbfgs_torch.py b/matsciml/common/relaxation/optimizers/lbfgs_torch.py deleted file mode 100644 index 9b69e3a7..00000000 --- a/matsciml/common/relaxation/optimizers/lbfgs_torch.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -Copyright (c) Facebook, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" -from __future__ import annotations - -import logging -from collections import deque -from pathlib import Path - -import ase -import torch -from ase import Atoms -from torch_scatter import scatter - -from matsciml.common.relaxation.ase_utils import batch_to_atoms -from matsciml.common.utils import radius_graph_pbc - - -class LBFGS: - def __init__( - self, - atoms: Atoms, - model, - maxstep=0.01, - memory=100, - damping=0.25, - alpha=100.0, - force_consistent=None, - device="cuda:0", - traj_dir: Path = None, - traj_names=None, - early_stop_batch: bool = False, - ): - self.atoms = atoms - self.model = model - self.maxstep = maxstep - self.memory = memory - self.damping = damping - self.alpha = alpha - self.force_consistent = force_consistent - self.device = device - self.traj_dir = traj_dir - self.traj_names = traj_names - self.early_stop_batch = early_stop_batch - assert not self.traj_dir or ( - traj_dir and len(traj_names) - ), "Trajectory names should be specified to save trajectories" - logging.info("Step Fmax(eV/A)") - - self.model.update_graph(self.atoms) - - def get_forces(self, apply_constraint=True): - energy, forces = self.model.get_forces(self.atoms, apply_constraint) - return energy, forces - - def get_positions(self): - return self.atoms.pos - - def set_positions(self, update, update_mask): - r = self.get_positions() - if not self.early_stop_batch: - update = torch.where(update_mask.unsqueeze(1), update, 0.0) - self.atoms.pos = r + update.to(dtype=torch.float32) - self.model.update_graph(self.atoms) - - def check_convergence(self, iteration, update_mask, forces, force_threshold): - if forces is None: - return False - max_forces_ = scatter( - (forces**2).sum(axis=1).sqrt(), - self.atoms.batch, - reduce="max", - ) - max_forces = max_forces_[self.atoms.batch] - update_mask = torch.logical_and(update_mask, max_forces.ge(force_threshold)) - logging.info( - f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces_.tolist()), - ) - return update_mask - - def run(self, fmax, steps): - s = deque(maxlen=self.memory) - y = deque(maxlen=self.memory) - rho = deque(maxlen=self.memory) - r0 = f0 = e0 = None - H0 = 1.0 / self.alpha - update_mask = torch.ones_like(self.atoms.batch).bool().to(self.device) - - trajectories = None - if self.traj_dir: - self.traj_dir.mkdir(exist_ok=True, parents=True) - trajectories = [ - ase.io.Trajectory(self.traj_dir / f"{name}.traj_tmp", mode="w") - for name in self.traj_names - ] - - iteration = 0 - converged = False - while iteration < steps and not converged: - r0, f0, e0 = self.step(iteration, r0, f0, H0, rho, s, y, update_mask) - iteration += 1 - if trajectories is not None: - self.atoms.y, self.atoms.force = e0, f0 - atoms_objects = batch_to_atoms(self.atoms) - update_mask_ = torch.split(update_mask, self.atoms.natoms.tolist()) - for atm, traj, mask in zip(atoms_objects, trajectories, update_mask_): - if mask[0]: - traj.write(atm) - update_mask = self.check_convergence(iteration, update_mask, f0, fmax) - converged = torch.all(torch.logical_not(update_mask)) - # GPU memory usage as per nvidia-smi seems to gradually build up as - # batches are processed. This releases unoccupied cached memory. - torch.cuda.empty_cache() - - if trajectories is not None: - for traj in trajectories: - traj.close() - for name in self.traj_names: - traj_fl = Path(self.traj_dir / f"{name}.traj_tmp", mode="w") - traj_fl.rename(traj_fl.with_suffix(".traj")) - - self.atoms.y, self.atoms.force = self.get_forces(apply_constraint=False) - return self.atoms - - def step(self, iteration, r0, f0, H0, rho, s, y, update_mask): - def determine_step(dr): - steplengths = torch.norm(dr, dim=1) - longest_steps = scatter(steplengths, self.atoms.batch, reduce="max") - longest_steps = longest_steps[self.atoms.batch] - maxstep = longest_steps.new_tensor(self.maxstep) - scale = (longest_steps + 1e-7).reciprocal() * torch.min( - longest_steps, - maxstep, - ) - dr *= scale.unsqueeze(1) - return dr * self.damping - - e, f = self.get_forces() - f = f.to(self.device, dtype=torch.float64) - r = self.atoms.pos.to(self.device, dtype=torch.float64) - - # Update s, y and rho - if iteration > 0: - s0 = (r - r0).flatten() - y0 = -(f - f0).flatten() - s.append(s0) - y.append(y0) - rho.append(1.0 / torch.dot(y0, s0)) - - loopmax = min(self.memory, iteration) - alpha = f.new_empty(loopmax) - q = -f.flatten() - - for i in range(loopmax - 1, -1, -1): - alpha[i] = rho[i] * torch.dot(s[i], q) - q -= alpha[i] * y[i] - z = H0 * q - for i in range(loopmax): - beta = rho[i] * torch.dot(y[i], z) - z += s[i] * (alpha[i] - beta) - p = -z.reshape((-1, 3)) # descent direction - dr = determine_step(p) - if torch.abs(dr).max() < 1e-7: - # Same configuration again (maybe a restart): - return - self.set_positions(dr, update_mask) - return r, f, e - - -class TorchCalc: - def __init__(self, model, transform=None): - self.model = model - self.transform = transform - - def get_forces(self, atoms, apply_constraint=True): - predictions = self.model.predict(atoms, per_image=False, disable_tqdm=True) - energy = predictions["energy"] - forces = predictions["forces"] - if apply_constraint: - fixed_idx = torch.where(atoms.fixed == 1)[0] - forces[fixed_idx] = 0 - return energy, forces - - def update_graph(self, atoms): - edge_index, cell_offsets, num_neighbors = radius_graph_pbc(atoms, 6, 50) - atoms.edge_index = edge_index - atoms.cell_offsets = cell_offsets - atoms.neighbors = num_neighbors - if self.transform is not None: - atoms = self.transform(atoms) - return atoms diff --git a/matsciml/common/relaxation/__init__.py b/matsciml/interfaces/__init__.py similarity index 100% rename from matsciml/common/relaxation/__init__.py rename to matsciml/interfaces/__init__.py diff --git a/matsciml/interfaces/ase/__init__.py b/matsciml/interfaces/ase/__init__.py new file mode 100644 index 00000000..7a23053a --- /dev/null +++ b/matsciml/interfaces/ase/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from matsciml.interfaces.ase.base import MatSciMLCalculator + +__all__ = ["MatSciMLCalculator"] diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py new file mode 100644 index 00000000..c796ed3c --- /dev/null +++ b/matsciml/interfaces/ase/base.py @@ -0,0 +1,282 @@ +from __future__ import annotations +from pathlib import Path + +from typing import Callable, Literal + +import torch +from ase import Atoms +from ase.calculators.calculator import Calculator +import numpy as np + +from matsciml.common.types import DataDict +from matsciml.models.base import ( + ScalarRegressionTask, + GradFreeForceRegressionTask, + ForceRegressionTask, + MultiTaskLitModule, +) +from matsciml.datasets.transforms.base import AbstractDataTransform + +__all__ = ["MatSciMLCalculator"] + + +def recursive_type_cast( + data_dict: DataDict, + dtype: torch.dtype, + ignore_keys: list[str] = ["atomic_numbers"], + convert_numpy: bool = True, +) -> DataDict: + """ + Recursively cast a dictionary of data into a particular + numeric type. + + This function will only type cast torch tensors; the ``convert_numpy`` + argument will optionally convert NumPy arrays into tensors first, + _then_ perform the type casting. + + Parameters + ---------- + data_dict : DataDict + Dictionary of data to recurse through. + dtype : torch.dtype + Data type to convert to. + ignore_keys : list[str] + Keys to ignore in the process; useful for excluding + casting for certain things like ``atomic_numbers`` + that are intended to be ``torch.long`` from being + erroneously casted to floats. + convery_numpy : bool, default True + If True, converts NumPy arrays into PyTorch tensors + before performing type casting. + + Returns + ------- + DataDict + Data dictionary with type casted results. + """ + for key, value in data_dict.items(): + if ignore_keys and key in ignore_keys: + continue + # optionally convert numpy arrays into torch tensors + # prior to type casting + if isinstance(value, np.ndarray) and convert_numpy: + value = torch.from_numpy(value) + if isinstance(value, dict): + data_dict[key] = recursive_type_cast(value, dtype) + if isinstance(value, torch.Tensor): + data_dict[key] = value.to(dtype) + return data_dict + + +def __checkpoint_conversion_exist_check(ckpt_path: str | Path) -> Path: + """Standardizes and checks for checkpoint path existence.""" + if isinstance(ckpt_path, str): + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + raise FileNotFoundError(f"Checkpoint file not found; passed {ckpt_path}") + return ckpt_path + + +class MatSciMLCalculator(Calculator): + implemented_properties = ["energy", "forces", "stress", "dipole"] + + def __init__( + self, + task_module: ScalarRegressionTask + | GradFreeForceRegressionTask + | ForceRegressionTask + | MultiTaskLitModule, + transforms: list[AbstractDataTransform | Callable] | None = None, + restart=None, + label=None, + atoms: Atoms | None = None, + directory=".", + conversion_factor: float | dict[str, float] = 1.0, + **kwargs, + ): + """ + Initialize an instance of the ``MatSciMLCalculator`` used by ``ase`` + simulations. + + This class essentially acts as an adaptor to a select number of + ``matsciml`` tasks by converting ``Atoms`` data structures into + those expected by ``matsciml`` models, and then extracting the + output of the forward pass into the expected results dictionary + for ``ase``. + + The recommended mode of usage of this class is to use one of the + constructor methods, e.g. ``MatSciMLCalculator.from_pretrained_force_regression``, + to set up the calculator based on one of the supported tasks. + A list of transforms can be passed as well in order to reuse the + same transformation pipeline as the rest of ``matsciml``. + + Examples + --------- + Create from a pretrained ``ForceRegressionTask`` + + >>> calc = MatSciMLCalculator.from_pretrained_force_regression( + "lightning_logs/version_10/checkpoints/epoch=10-step=3000.ckpt", + transforms=[PeriodicPropertiesTransform(6.0), PointCloudToGraphTransform("pyg")] + ) + + Parameters + ---------- + task_module + Instance of a supported ``matsciml`` task. What is 'supported' is + intended to reflect the kinds of modeling tasks, e.g. energy/force + prediction. + transforms : list[AbstractDataTransform | Callable] | None, default None + An optional list of transforms, similar to what is used in the rest + of the ``matsciml`` pipeline. + restart + Argument passed into ``ase`` Calculator base class. + label + Argument passed into ``ase`` Calculator base class. + atoms : Atoms | None, default None + Optional ``Atoms`` object to attach this calculator to. + directory + Argument passed into ``ase`` Calculator base class. + conversion_factor : float | dict[str, float] + Conversion factors to each property, specified as key/value + pairs where keys refer to data in ``self.results`` reported + to ``ase``. If a single ``float`` is passed, we assume that + the conversion is applied to the energy output. Each factor + is multiplied with the result. + """ + super().__init__( + restart, label=label, atoms=atoms, directory=directory, **kwargs + ) + assert isinstance( + task_module, + ( + ForceRegressionTask, + ScalarRegressionTask, + GradFreeForceRegressionTask, + MultiTaskLitModule, + ), + ), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}." + if isinstance(task_module, MultiTaskLitModule): + assert any( + [ + isinstance( + subtask, + ( + ForceRegressionTask, + ScalarRegressionTask, + GradFreeForceRegressionTask, + ), + ) + for subtask in task_module.task_list + ] + ), "Expected at least one subtask to be energy/force predictor." + self.task_module = task_module + self.transforms = transforms + self.conversion_factor = conversion_factor + + @property + def conversion_factor(self) -> dict[str, float]: + return self._conversion_factor + + @conversion_factor.setter + def conversion_factor(self, factor: float | dict[str, float]) -> None: + if isinstance(factor, float): + factor = {"energy": factor} + for key in factor.keys(): + if key not in self.implemented_properties: + raise KeyError( + f"Conversion factor {key} is not in `implemented_properties`." + ) + self._conversion_factor = factor + + @property + def dtype(self) -> torch.dtype | str: + dtype = self.task_module.dtype + return dtype + + def _format_atoms(self, atoms: Atoms) -> DataDict: + data_dict = {} + pos = torch.from_numpy(atoms.get_positions()) + atomic_numbers = torch.LongTensor(atoms.get_atomic_numbers()) + cell = torch.from_numpy(atoms.get_cell(complete=True).array) + # add properties to data dict + data_dict["pos"] = pos + data_dict["atomic_numbers"] = atomic_numbers + data_dict["cell"] = cell + return data_dict + + def _format_pipeline(self, atoms: Atoms) -> DataDict: + """ + Main function that takes an ``ase.Atoms`` object and gets it + ready for matsciml model consumption. + + We call ``_format_atoms`` to get the data in a format that + is similar to what comes out datasets implemented in matsciml, + so that the remainder of the transform pipeline can be used + to obtain nominally the same behavior as you would in the + rest of the pipeline. + """ + # initial formatting to get something akin to dataset outputs + data_dict = self._format_atoms(atoms) + # type cast into the type expected by the model + data_dict = recursive_type_cast( + data_dict, self.dtype, ignore_keys=["atomic_numbers"], convert_numpy=True + ) + # now run through the same transform pipeline as for datasets + if self.transforms: + for transform in self.transforms: + data_dict = transform(data_dict) + return data_dict + + def calculate( + self, + atoms=None, + properties: list[Literal["energy", "forces"]] = ["energy", "forces"], + system_changes=..., + ) -> None: + # retrieve atoms even if not passed + Calculator.calculate(self, atoms) + # get into format ready for matsciml model + data_dict = self._format_pipeline(atoms) + # run the data structure through the model + output = self.task_module(data_dict) + # add outputs to self.results as expected by ase + if "energy" in output: + self.results["energy"] = output["energy"].detach().item() + if "force" in output: + self.results["forces"] = output["force"].detach().numpy() + if "stress" in output: + self.results["stress"] = output["stress"].detach().numpy() + if "dipole" in output: + self.results["dipole"] = output["dipole"].detach().numpy() + if len(self.results) == 0: + raise RuntimeError( + f"No expected properties were written. Output dict: {output}" + ) + # perform optional unit conversions + for key, value in self.conversion_factor.items(): + if key in self.results: + self.results[key] *= value + + @classmethod + def from_pretrained_force_regression( + cls, ckpt_path: str | Path, *args, **kwargs + ) -> MatSciMLCalculator: + ckpt_path = __checkpoint_conversion_exist_check(ckpt_path) + task = ForceRegressionTask.load_from_checkpoint(ckpt_path) + return cls(task, *args, **kwargs) + + @classmethod + def from_pretrained_gradfree_task( + cls, ckpt_path: str | Path, *args, **kwargs + ) -> MatSciMLCalculator: + ckpt_path = __checkpoint_conversion_exist_check(ckpt_path) + task = GradFreeForceRegressionTask.load_from_checkpoint(ckpt_path) + return cls(task, *args, **kwargs) + + @classmethod + def from_pretrained_scalar_task( + cls, ckpt_path: str | Path, *args, **kwargs + ) -> MatSciMLCalculator: + ckpt_path = __checkpoint_conversion_exist_check(ckpt_path) + task = ScalarRegressionTask.load_from_checkpoint(ckpt_path) + return cls(task, *args, **kwargs) diff --git a/matsciml/interfaces/ase/tests/test_ase_calc.py b/matsciml/interfaces/ase/tests/test_ase_calc.py new file mode 100644 index 00000000..fabf3303 --- /dev/null +++ b/matsciml/interfaces/ase/tests/test_ase_calc.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import pytest +import numpy as np +from ase import Atoms, units +from ase.md.verlet import VelocityVerlet + +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) +from matsciml.interfaces.ase import MatSciMLCalculator +from matsciml.models.base import ( + ForceRegressionTask, +) +from matsciml.models.pyg import EGNN + +np.random.seed(21516136) + + +@pytest.fixture +def test_molecule() -> Atoms: + pos = np.random.normal(0.0, 1.0, size=(10, 3)) + atomic_numbers = np.random.randint(1, 100, size=(10,)) + return Atoms(numbers=atomic_numbers, positions=pos) + + +@pytest.fixture +def test_pbc() -> Atoms: + pos = np.random.normal(0.0, 1.0, size=(16, 3)) + atomic_numbers = np.random.randint(1, 100, size=(16,)) + cell = np.eye(3).astype(float) + return Atoms(numbers=atomic_numbers, positions=pos, cell=cell) + + +@pytest.fixture +def pbc_transform() -> list: + return [PeriodicPropertiesTransform(6.0, True), PointCloudToGraphTransform("pyg")] + + +@pytest.fixture +def egnn_config(): + return {"hidden_dim": 32, "output_dim": 32} + + +def test_egnn_energy_forces(egnn_config: dict, test_pbc: Atoms, pbc_transform: list): + """Just get the energy and force out of a ForceRegressionTask.""" + task = ForceRegressionTask( + encoder_class=EGNN, encoder_kwargs=egnn_config, output_kwargs={"hidden_dim": 32} + ) + calc = MatSciMLCalculator(task, transforms=pbc_transform) + atoms = test_pbc.copy() + atoms.calc = calc + energy = atoms.get_potential_energy() + assert np.isfinite(energy) + forces = atoms.get_forces() + assert np.isfinite(forces).all() + + +def test_egnn_dynamics(egnn_config: dict, test_pbc: Atoms, pbc_transform: list): + """Run a few timesteps of MD to test the workflow end-to-end.""" + task = ForceRegressionTask( + encoder_class=EGNN, encoder_kwargs=egnn_config, output_kwargs={"hidden_dim": 32} + ) + calc = MatSciMLCalculator(task, transforms=pbc_transform) + atoms = test_pbc.copy() + atoms.calc = calc + dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") + dyn.run(3) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index e462bfef..820be959 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -37,6 +37,7 @@ "OpenCatalystInference", "IS2REInference", "S2EFInference", + "BaseTaskModule", ] """ @@ -1108,7 +1109,9 @@ def __init__( def _make_output_heads(self) -> nn.ModuleDict: modules = {} for key in self.task_keys: - modules[key] = OutputHead(1, **self.output_kwargs).to(self.device) + modules[key] = OutputHead(1, **self.output_kwargs).to( + self.device, dtype=self.dtype + ) return nn.ModuleDict(modules) def _filter_task_keys( @@ -1288,7 +1291,9 @@ def _compute_losses( def _make_output_heads(self) -> nn.ModuleDict: modules = {} for key in self.task_keys: - modules[key] = OutputHead(**self.output_kwargs[key]).to(self.device) + modules[key] = OutputHead(**self.output_kwargs[key]).to( + self.device, dtype=self.dtype + ) return nn.ModuleDict(modules) def _filter_task_keys( @@ -1445,7 +1450,9 @@ def __init__( def _make_output_heads(self) -> nn.ModuleDict: modules = {} for key in self.task_keys: - modules[key] = OutputHead(1, **self.output_kwargs).to(self.device) + modules[key] = OutputHead(1, **self.output_kwargs).to( + self.device, dtype=self.dtype + ) return nn.ModuleDict(modules) def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None: @@ -1520,7 +1527,11 @@ def __init__( def _make_output_heads(self) -> nn.ModuleDict: # this task only utilizes one output head - modules = {"energy": OutputHead(1, **self.output_kwargs).to(self.device)} + modules = { + "energy": OutputHead(1, **self.output_kwargs).to( + self.device, dtype=self.dtype + ) + } return nn.ModuleDict(modules) def forward( @@ -1601,7 +1612,9 @@ def readout(node_energies: torch.Tensor): ) else: # assumes a batched pyg graph - batch = graph.batch + batch = getattr(graph, "batch", None) + if batch is None: + batch = torch.zeros_like(graph.atomic_numbers) from torch_geometric.utils import scatter def readout(node_energies: torch.Tensor): @@ -1843,7 +1856,11 @@ def __init__( ) def _make_output_heads(self) -> nn.ModuleDict: - modules = {"force": OutputHead(3, **self.output_kwargs).to(self.device)} + modules = { + "force": OutputHead(3, **self.output_kwargs).to( + self.device, dtype=self.dtype + ) + } return nn.ModuleDict(modules) def _get_targets( @@ -1979,7 +1996,11 @@ def __init__( def _make_output_heads(self) -> nn.ModuleDict: # this task only utilizes one output head; 230 possible space groups - modules = {"spacegroup": OutputHead(230, **self.output_kwargs).to(self.device)} + modules = { + "spacegroup": OutputHead(230, **self.output_kwargs).to( + self.device, dtype=self.dtype + ) + } return nn.ModuleDict(modules) def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None: @@ -3030,7 +3051,7 @@ def __init__( def _make_output_heads(self) -> nn.ModuleDict: # make a single output head for noise prediction applied to nodes - denoise = OutputHead(3, **self.output_kwargs).to(self.device) + denoise = OutputHead(3, **self.output_kwargs).to(self.device, dtype=self.dtype) return nn.ModuleDict({"denoise": denoise}) def _filter_task_keys(