diff --git a/examples/interfaces/ase_multitask_from_pretrained.py b/examples/interfaces/ase_multitask_from_pretrained.py new file mode 100644 index 00000000..8ca5889e --- /dev/null +++ b/examples/interfaces/ase_multitask_from_pretrained.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from ase import Atoms, units +from ase.md.verlet import VelocityVerlet + +from matsciml.interfaces.ase import MatSciMLCalculator +from matsciml.interfaces.ase.multitask import AverageTasks +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) + +""" +Demonstrates setting up a calculator from a pretrained +multitask/multidata module, using an averaging strategy to +merge output heads. + +As an example, if we trained force regression on multiple datasets +simultaneously, we would average the outputs from each "dataset", +similar to an ensemble prediction without any special weighting. + +Substitute 'model.ckpt' for the path to a checkpoint file. +""" + +d = 2.9 +L = 10.0 + +atoms = Atoms("C", positions=[[0, L / 2, L / 2]], cell=[d, L, L], pbc=[1, 0, 0]) + +calc = MatSciMLCalculator.from_pretrained_force_regression( + "model.ckpt", + transforms=[ + PeriodicPropertiesTransform(6.0, True), + PointCloudToGraphTransform("pyg"), + ], + multitask_strategy=AverageTasks(), # also can be specified as a string +) +# set the calculator to matsciml +atoms.calc = calc +# run the simulation for 100 timesteps, with 5 femtosecond timesteps +dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") +dyn.run(100) diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index 415eba3c..518ca9ee 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -16,6 +16,7 @@ MultiTaskLitModule, ) from matsciml.datasets.transforms.base import AbstractDataTransform +from matsciml.interfaces.ase import multitask as mt __all__ = ["MatSciMLCalculator"] @@ -92,6 +93,7 @@ def __init__( atoms: Atoms | None = None, directory=".", conversion_factor: float | dict[str, float] = 1.0, + multitask_strategy: str | Callable | mt.AbstractStrategy = "AverageTasks", **kwargs, ): """ @@ -172,6 +174,14 @@ def __init__( self.task_module = task_module self.transforms = transforms self.conversion_factor = conversion_factor + if isinstance(multitask_strategy, str): + cls_name = getattr(mt, multitask_strategy, None) + if cls_name is None: + raise NameError( + f"Invalid multitask strategy name; supported strategies are {mt.__all__}" + ) + multitask_strategy = cls_name() + self.multitask_strategy = multitask_strategy @property def conversion_factor(self) -> dict[str, float]: @@ -238,20 +248,26 @@ def calculate( # get into format ready for matsciml model data_dict = self._format_pipeline(atoms) # run the data structure through the model - output = self.task_module(data_dict) - # add outputs to self.results as expected by ase - if "energy" in output: - self.results["energy"] = output["energy"].detach().item() - if "force" in output: - self.results["forces"] = output["force"].detach().numpy() - if "stress" in output: - self.results["stress"] = output["stress"].detach().numpy() - if "dipole" in output: - self.results["dipole"] = output["dipole"].detach().numpy() - if len(self.results) == 0: - raise RuntimeError( - f"No expected properties were written. Output dict: {output}" - ) + if isinstance(self.task_module, MultiTaskLitModule): + output = self.task_module.ase_calculate(data_dict) + # use a more complicated parser for multitasks + results = self.multitask_strategy(output, self.task_module) + self.results = results + else: + output = self.task_module(data_dict) + # add outputs to self.results as expected by ase + if "energy" in output: + self.results["energy"] = output["energy"].detach().item() + if "force" in output: + self.results["forces"] = output["force"].detach().numpy() + if "stress" in output: + self.results["stress"] = output["stress"].detach().numpy() + if "dipole" in output: + self.results["dipole"] = output["dipole"].detach().numpy() + if len(self.results) == 0: + raise RuntimeError( + f"No expected properties were written. Output dict: {output}" + ) # perform optional unit conversions for key, value in self.conversion_factor.items(): if key in self.results: diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py new file mode 100644 index 00000000..6a8bbef1 --- /dev/null +++ b/matsciml/interfaces/ase/multitask.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from abc import abstractmethod, ABC + +import torch +import numpy as np + +from matsciml.models.base import ( + MultiTaskLitModule, +) +from matsciml.common.types import DataDict + + +__task_property_mapping__ = { + "ScalarRegressionTask": ["energy", "dipole"], + "ForceRegressionTask": ["energy", "force"], + "GradFreeForceRegressionTask": ["force"], +} + + +__all__ = ["AverageTasks"] + + +class AbstractStrategy(ABC): + @abstractmethod + def merge_outputs( + self, + outputs: dict[str, dict[str, float | torch.Tensor]] + | dict[str, list[float | torch.Tensor]], + *args, + **kwargs, + ) -> dict[str, float | np.ndarray]: ... + + def parse_outputs( + self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs + ) -> tuple[ + dict[str, dict[str, float | torch.Tensor]], + dict[str, list[float | torch.Tensor]], + ]: + """ + Map the task results into their appropriate fields. + + Expected output looks like: + {"IS2REDataset": {"energy": ..., "forces": ...}, ...} + + Parameters + ---------- + output_dict : DataDict + Multitask/multidata output from the ``MultiTaskLitModule`` + forward pass. + task : MultiTaskLitModule + Instance of the task module. This allows access to the + ``task.task_map``, which tells us which dataset/subtask + is mapped together. + + Returns + ------- + dict[str, dict[str, float | torch.Tensor]] + Dictionary mapping of results per dataset. The subdicts + correspond to the extracted outputs, per subtask (e.g. + energy/force from the IS2REDataset head). + dict[str, list[float | torch.Tensor]] + For convenience, this provides the same data without + differentiating between datasets, and instead, sorts + them by the property name (e.g. {"energy": [...]}). + + Raises + ------ + RuntimeError: + When no subresults are returned for a dataset that is + expected to have something on the basis that a task + _should_ produce something, e.g. ``ForceRegressionTask`` + should yield energy/force, and if it doesn't produce + anything, something is wrong. + """ + results = {} + per_key_results = {} + # loop over the task map + for dset_name in task.task_map.keys(): + for subtask_name, subtask in task.task_map[dset_name].items(): + sub_results = {} + pos_fields = __task_property_mapping__.get(subtask_name, None) + if pos_fields is None: + continue + else: + for key in pos_fields: + output = output_dict[dset_name][subtask_name].get(key, None) + # this means the task _can_ output the key but was + # not included in the actual training task keys + if output is None: + continue + if isinstance(output, torch.Tensor): + output = output.detach() + if key == "energy": + # squeeze is applied just in case we have too many + # extra dimensions + output = output.squeeze().item() + sub_results[key] = output + # add to per_key_results as another sorting + if key not in per_key_results: + per_key_results[key] = [] + per_key_results[key].append(output) + if len(sub_results) == 0: + raise RuntimeError( + f"Expected {subtask_name} to have {pos_fields} but got nothing." + ) + results[dset_name] = sub_results + return results, per_key_results + + @abstractmethod + def run(self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs): ... + + def __call__( + self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs + ) -> dict[str, float | np.ndarray]: + aggregated_results = self.run(output_dict, task, *args, **kwargs) + # TODO: homogenize keys so we don't have to do stuff like this :P + if "force" in aggregated_results: + aggregated_results["forces"] = aggregated_results["force"] + return aggregated_results + + +class AverageTasks(AbstractStrategy): + def merge_outputs( + self, outputs: dict[str, list[float | torch.Tensor]], *args, **kwargs + ) -> dict[str, float | np.ndarray]: + joined_results = {} + for key, results in outputs.items(): + if isinstance(results[0], float): + merged_results = sum(results) / len(results) + elif isinstance(results[0], torch.Tensor): + results = torch.stack(results, dim=0) + merged_results = results.mean(dim=0).numpy() + else: + raise TypeError( + f"Only floats and tensors are supported for merging; got {type(results[0])} for key {key}." + ) + joined_results[key] = merged_results + return joined_results + + def run( + self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs + ) -> dict[str, float | np.ndarray]: + _, per_key_results = self.parse_outputs(output_dict, task) + aggregated_results = self.merge_outputs(per_key_results) + return aggregated_results diff --git a/matsciml/interfaces/ase/tests/test_multi_task.py b/matsciml/interfaces/ase/tests/test_multi_task.py new file mode 100644 index 00000000..a08bc617 --- /dev/null +++ b/matsciml/interfaces/ase/tests/test_multi_task.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import pytest +import torch +import numpy as np +from ase import Atoms, units +from ase.md import VelocityVerlet + +from matsciml.models.pyg import EGNN +from matsciml.models.base import ( + MultiTaskLitModule, + ScalarRegressionTask, + ForceRegressionTask, +) +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) +from matsciml.interfaces.ase import multitask as mt +from matsciml.interfaces.ase import MatSciMLCalculator + + +@pytest.fixture +def test_pbc() -> Atoms: + pos = np.random.normal(0.0, 1.0, size=(16, 3)) * 10.0 + atomic_numbers = np.random.randint(1, 100, size=(16,)) + cell = np.eye(3).astype(float) + return Atoms(numbers=atomic_numbers, positions=pos, cell=cell) + + +@pytest.fixture +def pbc_transform() -> list: + return [PeriodicPropertiesTransform(6.0, True), PointCloudToGraphTransform("pyg")] + + +@pytest.fixture +def egnn_args(): + return {"hidden_dim": 32, "output_dim": 32} + + +@pytest.fixture +def single_data_multi_task_combo(egnn_args): + output = { + "IS2REDataset": { + "ScalarRegressionTask": {"energy": torch.rand(1, 1)}, + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + }, + } + } + task = MultiTaskLitModule( + ( + "IS2REDataset", + ScalarRegressionTask( + encoder_class=EGNN, encoder_kwargs=egnn_args, task_keys=["energy"] + ), + ), + ( + "IS2REDataset", + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "input_dim": 32}, + ), + ), + ) + return output, task + + +@pytest.fixture +def multi_data_multi_task_combo(egnn_args): + output = { + "IS2REDataset": { + "ScalarRegressionTask": {"energy": torch.rand(1, 1)}, + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + }, + }, + "S2EFDataset": { + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + } + }, + "AlexandriaDataset": { + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + } + }, + } + task = MultiTaskLitModule( + ( + "IS2REDataset", + ScalarRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + task_keys=["energy"], + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), + ), + ( + "IS2REDataset", + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), + ), + ( + "S2EFDataset", + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), + ), + ( + "AlexandriaDataset", + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), + ), + ) + return output, task + + +def test_average_single_data(single_data_multi_task_combo): + # unpack the fixtrure + output, task = single_data_multi_task_combo + strat = mt.AverageTasks() + # test the parsing + _, parsed_output = strat.parse_outputs(output, task) + agg_results = strat.merge_outputs(parsed_output) + end = strat(output, task) + assert end + assert agg_results + for key in ["energy", "forces"]: + assert key in end, f"{key} was missing from agg results" + assert end["forces"].shape == (32, 3) + + +def test_average_multi_data(multi_data_multi_task_combo): + # unpack the fixtrure + output, task = multi_data_multi_task_combo + strat = mt.AverageTasks() + # test the parsing + _, parsed_output = strat.parse_outputs(output, task) + agg_results = strat.merge_outputs(parsed_output) + end = strat(output, task) + assert end + assert agg_results + for key in ["energy", "forces"]: + assert key in end, f"{key} was missing from agg results" + assert end["forces"].shape == (32, 3) + + +def test_calc_multi_data( + multi_data_multi_task_combo, test_pbc: Atoms, pbc_transform: list +): + output, task = multi_data_multi_task_combo + strat = mt.AverageTasks() + calc = MatSciMLCalculator(task, multitask_strategy=strat, transforms=pbc_transform) + atoms = test_pbc.copy() + atoms.calc = calc + energy = atoms.get_potential_energy() + assert np.isfinite(energy) + forces = atoms.get_forces() + assert np.isfinite(forces).all() + dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") + dyn.run(3) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 820be959..54e86edd 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2107,7 +2107,7 @@ def __init__( if index != 0: task.encoder = self.encoder # nest the task based on its category - task_map[dset_name][task.__task__] = task + task_map[dset_name][task.__class__.__name__] = task # add dataset names to determine forward logic dset_names.add(dset_name) # save hyperparameters from subtasks @@ -2357,7 +2357,9 @@ def _toggle_input_grads( """ need_grad_keys = getattr(self, "input_grad_keys", None) if need_grad_keys is not None: - if self.is_multidata: + # we determine if it's multidata based on the incoming batch + # as it should have dataset in its key + if any(["Dataset" in key for key in batch.keys()]): # if this is a multidataset task, loop over each dataset # and enable gradients for the inputs that need them for dset_name, data in batch.items(): @@ -2365,7 +2367,7 @@ def _toggle_input_grads( for key in input_keys: # set require grad for both point cloud and graph tensors if "graph" in data: - g = data.get("g") + g = data.get("graph") if isinstance(g, dgl.DGLGraph): if key in g.ndata: data["graph"].ndata[key].requires_grad_(True) @@ -2390,17 +2392,17 @@ def _toggle_input_grads( input_keys = list(self.input_grad_keys.values()).pop(0) for key in input_keys: # set require grad for both point cloud and graph tensors - if "graph" in data: - g = data.get("g") + if "graph" in batch: + g = batch.get("graph") if isinstance(g, dgl.DGLGraph): if key in g.ndata: - data["graph"].ndata[key].requires_grad_(True) + batch["graph"].ndata[key].requires_grad_(True) else: # assume it's a PyG graph if key in g: getattr(g, key).requires_grad_(True) - if key in data: - target = data.get(key) + if key in batch: + target = batch.get(key) # for tensors just set them directly if isinstance(target, torch.Tensor): target.requires_grad_(True) @@ -2469,6 +2471,53 @@ def forward( results[task_type] = subtask(batch) return results + def ase_calculate(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]: + """ + Currently "specialized" function that runs a set of data through + every single output head, ignoring the nominal dataset/subtask + unique mapping. + + This is designed for ASE usage primarily, but ostensibly could be + used as _the_ inference call for a multitask module. Basically, + when the input data doesn't come from the same "datasets" used + for initialization/training, and we want to provide a "mixture of + experts" response. + + TODO: this could potentially be used as a template to redesign + the forward call to substantially simplify the multitask mapping. + + Parameters + ---------- + batch + Input data dictionary, which should correspond to a formatted + ase.Atoms sample. + + Returns + ------- + dict[str, dict[str, torch.Tensor]] + Nested results dictionary, following a dataset/subtask structure. + For example, {'IS2REDataset': {'ForceRegressionTask': ..., 'ScalarRegressionTask': ...}} + """ + results = {} + _grads = getattr( + self, + "needs_dynamic_grads", + False, + ) # default to not needing grads + with dynamic_gradients_context(_grads, self.has_rnn): + # this function switches of `requires_grad_` for input tensors that need them + self._toggle_input_grads(batch) + batch["embedding"] = self.encoder(batch) + # now loop through every dataset/output head pair + for dset_name, subtask_name in self.dataset_task_pairs: + subtask = self.task_map[dset_name][subtask_name] + output = subtask(batch) + # now add it to the rest of the results + if dset_name not in results: + results[dset_name] = {} + results[dset_name][subtask_name] = output + return results + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: """ This callback is used to dynamically initialize output heads.