From c965377f5839d337cc98586d45d8d83db9f74af0 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 16 May 2024 15:10:43 -0700 Subject: [PATCH 01/19] feat: added a base method for extracting multitask results --- matsciml/interfaces/ase/multitask.py | 83 ++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 matsciml/interfaces/ase/multitask.py diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py new file mode 100644 index 00000000..3b142f30 --- /dev/null +++ b/matsciml/interfaces/ase/multitask.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from abc import abstractmethod, ABC + +import torch +import numpy as np + +from matsciml.models.base import ( + MultiTaskLitModule, +) +from matsciml.common.types import DataDict + + +__task_property_mapping__ = { + "ScalarRegressionTask": ["energy", "dipole"], + "ForceRegressionTask": ["energy", "forces"], + "GradFreeForceRegressionTask": ["forces"], +} + + +class AbstractStrategy(ABC): + @abstractmethod + def merge_outputs(self, *args, **kwargs) -> dict[str, float | np.ndarray]: ... + + def parse_outputs( + self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs + ) -> dict[str, dict[str, float | torch.Tensor]]: + """ + Map the task results into their appropriate fields. + + Expected output looks like: + {"IS2REDataset": {"energy": ..., "forces": ...}, ...} + + Parameters + ---------- + output_dict : DataDict + Multitask/multidata output from the ``MultiTaskLitModule`` + forward pass. + task : MultiTaskLitModule + Instance of the task module. This allows access to the + ``task.task_map``, which tells us which dataset/subtask + is mapped together. + + Returns + ------- + dict[str, dict[str, float | torch.Tensor]] + Dictionary mapping of results per dataset. The subdicts + correspond to the extracted outputs, per subtask (e.g. + energy/force from the IS2REDataset head). + + Raises + ------ + RuntimeError: + When no subresults are returned for a dataset that is + expected to have something on the basis that a task + _should_ produce something, e.g. ``ForceRegressionTask`` + should yield energy/force, and if it doesn't produce + anything, something is wrong. + """ + results = {} + # loop over the task map + for dset_name in task.task_map.keys(): + for subtask_name, subtask in task.task_map[dset_name].items(): + sub_results = {} + pos_fields = __task_property_mapping__.get(subtask, None) + if pos_fields is None: + continue + else: + for key in pos_fields: + output = output_dict[dset_name][subtask_name].detach() + if key == "energy": + output = output.item() + sub_results[key] = output + if len(sub_results) == 0: + raise RuntimeError( + f"Expected {subtask_name} to have {pos_fields} but got nothing." + ) + results[dset_name] = sub_results + return results + + +class AverageTasks(AbstractStrategy): + def merge_outputs(self): ... From aed05b2199869040e54cf7cbf173eabc08fa4bbe Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 08:07:35 -0700 Subject: [PATCH 02/19] refactor: setting signature for merge output function --- matsciml/interfaces/ase/multitask.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index 3b142f30..fd5012cb 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -20,7 +20,9 @@ class AbstractStrategy(ABC): @abstractmethod - def merge_outputs(self, *args, **kwargs) -> dict[str, float | np.ndarray]: ... + def merge_outputs( + self, outputs: dict[str, dict[str, float | torch.Tensor]], *args, **kwargs + ) -> dict[str, float | np.ndarray]: ... def parse_outputs( self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs @@ -80,4 +82,6 @@ def parse_outputs( class AverageTasks(AbstractStrategy): - def merge_outputs(self): ... + def merge_outputs( + self, outputs: dict[str, dict[str, float | torch.Tensor]], *args, **kwargs + ) -> dict[str, float | np.ndarray]: ... From 0f3066dc14dda1b8451ec58c280a16889452e76a Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 08:39:15 -0700 Subject: [PATCH 03/19] refactor: returning per-key results --- matsciml/interfaces/ase/multitask.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index fd5012cb..ab455cc9 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -26,7 +26,10 @@ def merge_outputs( def parse_outputs( self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs - ) -> dict[str, dict[str, float | torch.Tensor]]: + ) -> tuple[ + dict[str, dict[str, float | torch.Tensor]], + dict[str, list[float | torch.Tensor]], + ]: """ Map the task results into their appropriate fields. @@ -49,6 +52,10 @@ def parse_outputs( Dictionary mapping of results per dataset. The subdicts correspond to the extracted outputs, per subtask (e.g. energy/force from the IS2REDataset head). + dict[str, list[float | torch.Tensor]] + For convenience, this provides the same data without + differentiating between datasets, and instead, sorts + them by the property name (e.g. {"energy": [...]}). Raises ------ @@ -60,6 +67,7 @@ def parse_outputs( anything, something is wrong. """ results = {} + per_key_results = {} # loop over the task map for dset_name in task.task_map.keys(): for subtask_name, subtask in task.task_map[dset_name].items(): @@ -73,15 +81,21 @@ def parse_outputs( if key == "energy": output = output.item() sub_results[key] = output + # add to per_key_results as another sorting + if key not in per_key_results: + per_key_results[key] = [] + per_key_results[key].append(output) if len(sub_results) == 0: raise RuntimeError( f"Expected {subtask_name} to have {pos_fields} but got nothing." ) results[dset_name] = sub_results - return results + return results, per_key_results class AverageTasks(AbstractStrategy): def merge_outputs( self, outputs: dict[str, dict[str, float | torch.Tensor]], *args, **kwargs - ) -> dict[str, float | np.ndarray]: ... + ) -> dict[str, float | np.ndarray]: + for dset, results in outputs.items(): + ... From a0cb969f9b784ddd7c4071a9fc1201c5b743f55d Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 08:55:42 -0700 Subject: [PATCH 04/19] refactor: adding abstract run and __call__ --- matsciml/interfaces/ase/multitask.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index ab455cc9..c92f6f91 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -21,7 +21,11 @@ class AbstractStrategy(ABC): @abstractmethod def merge_outputs( - self, outputs: dict[str, dict[str, float | torch.Tensor]], *args, **kwargs + self, + outputs: dict[str, dict[str, float | torch.Tensor]] + | dict[str, list[float | torch.Tensor]], + *args, + **kwargs, ) -> dict[str, float | np.ndarray]: ... def parse_outputs( @@ -92,6 +96,14 @@ def parse_outputs( results[dset_name] = sub_results return results, per_key_results + @abstractmethod + def run(self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs): ... + + def __call__( + self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs + ) -> dict[str, float | np.ndarray]: + return self.run(output_dict, task, *args, **kwargs) + class AverageTasks(AbstractStrategy): def merge_outputs( From 907a4034348478721662e119cb5923c7ec05c94b Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 08:56:10 -0700 Subject: [PATCH 05/19] feat: added merge method for averaging method --- matsciml/interfaces/ase/multitask.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index c92f6f91..69f21a78 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -107,7 +107,18 @@ def __call__( class AverageTasks(AbstractStrategy): def merge_outputs( - self, outputs: dict[str, dict[str, float | torch.Tensor]], *args, **kwargs + self, outputs: dict[str, list[float | torch.Tensor]], *args, **kwargs ) -> dict[str, float | np.ndarray]: - for dset, results in outputs.items(): - ... + joined_results = {} + for key, results in outputs.items(): + if isinstance(results[0], float): + merged_results = sum(results) / len(results) + elif isinstance(results[0], torch.Tensor): + results = torch.stack(results, dim=0) + merged_results = results.mean(dim=0).numpy() + else: + raise TypeError( + f"Only floats and tensors are supported for merging; got {type(results[0])} for key {key}." + ) + joined_results[key] = merged_results + return joined_results From d542fa30a290dd78f5bd40ba17a44d6b71325e02 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 08:58:01 -0700 Subject: [PATCH 06/19] feat: added run call for averaging strategy --- matsciml/interfaces/ase/multitask.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index 69f21a78..35398f54 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -122,3 +122,10 @@ def merge_outputs( ) joined_results[key] = merged_results return joined_results + + def run( + self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs + ) -> dict[str, float | np.ndarray]: + _, per_key_results = self.parse_outputs(output_dict, task) + aggregated_results = self.merge_outputs(per_key_results) + return aggregated_results From db53652bc4ad83aa317fc0cc9b00b97d9c0e1deb Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 09:10:36 -0700 Subject: [PATCH 07/19] feat: defining __all__ in multitask strategies --- matsciml/interfaces/ase/multitask.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index 35398f54..a91a54cc 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -18,6 +18,9 @@ } +__all__ = ["AverageTasks"] + + class AbstractStrategy(ABC): @abstractmethod def merge_outputs( From 8c0ed4018b00ca875a9ce12a73db37f7f57ae9cb Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 09:13:54 -0700 Subject: [PATCH 08/19] refactor: adding multi task strategy interface in calculator --- matsciml/interfaces/ase/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index c796ed3c..c6b957e6 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -16,6 +16,7 @@ MultiTaskLitModule, ) from matsciml.datasets.transforms.base import AbstractDataTransform +from matsciml.interfaces.ase import multitask as mt __all__ = ["MatSciMLCalculator"] @@ -92,6 +93,7 @@ def __init__( atoms: Atoms | None = None, directory=".", conversion_factor: float | dict[str, float] = 1.0, + multitask_strategy: str | Callable | mt.AbstractStrategy = "AverageTasks", **kwargs, ): """ @@ -172,6 +174,14 @@ def __init__( self.task_module = task_module self.transforms = transforms self.conversion_factor = conversion_factor + if isinstance(multitask_strategy, str): + cls_name = getattr(mt, multitask_strategy, None) + if cls_name is None: + raise NameError( + f"Invalid multitask strategy name; supported strategies are {mt.__all__}" + ) + multitask_strategy = cls_name() + self.multitask_strategy = multitask_strategy @property def conversion_factor(self) -> dict[str, float]: From 17aa62772c4d71f09b7d82d676a987a781a5274c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 09:15:42 -0700 Subject: [PATCH 09/19] refactor: adding multi task strategy application to calculate --- matsciml/interfaces/ase/base.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index c6b957e6..46c5388c 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -249,19 +249,24 @@ def calculate( data_dict = self._format_pipeline(atoms) # run the data structure through the model output = self.task_module(data_dict) - # add outputs to self.results as expected by ase - if "energy" in output: - self.results["energy"] = output["energy"].detach().item() - if "force" in output: - self.results["forces"] = output["force"].detach().numpy() - if "stress" in output: - self.results["stress"] = output["stress"].detach().numpy() - if "dipole" in output: - self.results["dipole"] = output["dipole"].detach().numpy() - if len(self.results) == 0: - raise RuntimeError( - f"No expected properties were written. Output dict: {output}" - ) + # use a more complicated parser for multitasks + if isinstance(self.task_module, MultiTaskLitModule): + results = self.multitask_strategy(output, self.task_module) + self.results = results + else: + # add outputs to self.results as expected by ase + if "energy" in output: + self.results["energy"] = output["energy"].detach().item() + if "force" in output: + self.results["forces"] = output["force"].detach().numpy() + if "stress" in output: + self.results["stress"] = output["stress"].detach().numpy() + if "dipole" in output: + self.results["dipole"] = output["dipole"].detach().numpy() + if len(self.results) == 0: + raise RuntimeError( + f"No expected properties were written. Output dict: {output}" + ) # perform optional unit conversions for key, value in self.conversion_factor.items(): if key in self.results: From c9827733555a81d3b2bb1d36430b281714bcae44 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 10:41:58 -0700 Subject: [PATCH 10/19] test: added unit tests for multi task aggregations Signed-off-by: Lee, Kin Long Kelvin --- .../interfaces/ase/tests/test_multi_task.py | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 matsciml/interfaces/ase/tests/test_multi_task.py diff --git a/matsciml/interfaces/ase/tests/test_multi_task.py b/matsciml/interfaces/ase/tests/test_multi_task.py new file mode 100644 index 00000000..004496ea --- /dev/null +++ b/matsciml/interfaces/ase/tests/test_multi_task.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import pytest +import torch + +from matsciml.models.pyg import EGNN +from matsciml.models.base import ( + MultiTaskLitModule, + ScalarRegressionTask, + ForceRegressionTask, +) +from matsciml.interfaces.ase import multitask as mt + + +@pytest.fixture +def egnn_args(): + return {"hidden_dim": 32, "output_dim": 32} + + +@pytest.fixture +def single_data_multi_task_combo(egnn_args): + output = { + "IS2REDataset": { + "ScalarRegressionTask": {"energy": torch.rand(1, 1)}, + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + }, + } + } + task = MultiTaskLitModule( + ( + "IS2REDataset", + ScalarRegressionTask( + encoder_class=EGNN, encoder_kwargs=egnn_args, task_keys=["energy"] + ), + ), + ( + "IS2REDataset", + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "input_dim": 32}, + ), + ), + ) + return output, task + + +@pytest.fixture +def multi_data_multi_task_combo(egnn_args): + output = { + "IS2REDataset": { + "ScalarRegressionTask": {"energy": torch.rand(1, 1)}, + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + }, + }, + "S2EFDataset": { + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + } + }, + "AlexandriaDataset": { + "ForceRegressionTask": { + "energy": torch.rand(1, 1), + "force": torch.rand(32, 3), + } + }, + } + task = MultiTaskLitModule( + ( + "IS2REDataset", + ScalarRegressionTask( + encoder_class=EGNN, encoder_kwargs=egnn_args, task_keys=["energy"] + ), + ), + ( + "IS2REDataset", + ForceRegressionTask(encoder_class=EGNN, encoder_kwargs=egnn_args), + ), + ( + "S2EFDataset", + ForceRegressionTask(encoder_class=EGNN, encoder_kwargs=egnn_args), + ), + ( + "AlexandriaDataset", + ForceRegressionTask(encoder_class=EGNN, encoder_kwargs=egnn_args), + ), + ) + return output, task + + +def test_average_single_data(single_data_multi_task_combo): + # unpack the fixtrure + output, task = single_data_multi_task_combo + strat = mt.AverageTasks() + # test the parsing + _, parsed_output = strat.parse_outputs(output, task) + agg_results = strat.merge_outputs(parsed_output) + end = strat(output, task) + assert end + assert agg_results + + +def test_average_multi_data(multi_data_multi_task_combo): + # unpack the fixtrure + output, task = multi_data_multi_task_combo + strat = mt.AverageTasks() + # test the parsing + _, parsed_output = strat.parse_outputs(output, task) + agg_results = strat.merge_outputs(parsed_output) + end = strat(output, task) + assert end + assert agg_results From e5cdb8e3b5e81b28c7c73edb764d65e59f0b34cb Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 10:47:54 -0700 Subject: [PATCH 11/19] test: added tests to check force output shape --- matsciml/interfaces/ase/tests/test_multi_task.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/matsciml/interfaces/ase/tests/test_multi_task.py b/matsciml/interfaces/ase/tests/test_multi_task.py index 004496ea..b41adc4b 100644 --- a/matsciml/interfaces/ase/tests/test_multi_task.py +++ b/matsciml/interfaces/ase/tests/test_multi_task.py @@ -103,6 +103,9 @@ def test_average_single_data(single_data_multi_task_combo): end = strat(output, task) assert end assert agg_results + for key in ["energy", "forces"]: + assert key in end, f"{key} was missing from agg results" + assert end["forces"].shape == (32, 3) def test_average_multi_data(multi_data_multi_task_combo): @@ -115,3 +118,6 @@ def test_average_multi_data(multi_data_multi_task_combo): end = strat(output, task) assert end assert agg_results + for key in ["energy", "forces"]: + assert key in end, f"{key} was missing from agg results" + assert end["forces"].shape == (32, 3) From 951c124c98a2bb43d3c2d8c590e783861114b8de Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 10:48:20 -0700 Subject: [PATCH 12/19] refactor: added temporary step to ensure force key consistency --- matsciml/interfaces/ase/multitask.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index a91a54cc..aa8f2ebf 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -105,7 +105,11 @@ def run(self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs): def __call__( self, output_dict: DataDict, task: MultiTaskLitModule, *args, **kwargs ) -> dict[str, float | np.ndarray]: - return self.run(output_dict, task, *args, **kwargs) + aggregated_results = self.run(output_dict, task, *args, **kwargs) + # TODO: homogenize keys so we don't have to do stuff like this :P + if "force" in aggregated_results: + aggregated_results["forces"] = aggregated_results["force"] + return aggregated_results class AverageTasks(AbstractStrategy): From 3bcc4464d7a2715bd1231d1690cb56756e3ba710 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 10:49:18 -0700 Subject: [PATCH 13/19] refactor: making multitask output keys refer to task class names --- matsciml/interfaces/ase/multitask.py | 18 +++++++++++++----- matsciml/models/base.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/matsciml/interfaces/ase/multitask.py b/matsciml/interfaces/ase/multitask.py index aa8f2ebf..6a8bbef1 100644 --- a/matsciml/interfaces/ase/multitask.py +++ b/matsciml/interfaces/ase/multitask.py @@ -13,8 +13,8 @@ __task_property_mapping__ = { "ScalarRegressionTask": ["energy", "dipole"], - "ForceRegressionTask": ["energy", "forces"], - "GradFreeForceRegressionTask": ["forces"], + "ForceRegressionTask": ["energy", "force"], + "GradFreeForceRegressionTask": ["force"], } @@ -79,14 +79,22 @@ def parse_outputs( for dset_name in task.task_map.keys(): for subtask_name, subtask in task.task_map[dset_name].items(): sub_results = {} - pos_fields = __task_property_mapping__.get(subtask, None) + pos_fields = __task_property_mapping__.get(subtask_name, None) if pos_fields is None: continue else: for key in pos_fields: - output = output_dict[dset_name][subtask_name].detach() + output = output_dict[dset_name][subtask_name].get(key, None) + # this means the task _can_ output the key but was + # not included in the actual training task keys + if output is None: + continue + if isinstance(output, torch.Tensor): + output = output.detach() if key == "energy": - output = output.item() + # squeeze is applied just in case we have too many + # extra dimensions + output = output.squeeze().item() sub_results[key] = output # add to per_key_results as another sorting if key not in per_key_results: diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 820be959..967fedb6 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2107,7 +2107,7 @@ def __init__( if index != 0: task.encoder = self.encoder # nest the task based on its category - task_map[dset_name][task.__task__] = task + task_map[dset_name][task.__class__.__name__] = task # add dataset names to determine forward logic dset_names.add(dset_name) # save hyperparameters from subtasks From b0d161efab62c27dc776d92f111d0f781a5ebb45 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 17 May 2024 15:33:29 -0700 Subject: [PATCH 14/19] test: updating test to make things work Signed-off-by: Lee, Kin Long Kelvin --- .../interfaces/ase/tests/test_multi_task.py | 60 +++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/matsciml/interfaces/ase/tests/test_multi_task.py b/matsciml/interfaces/ase/tests/test_multi_task.py index b41adc4b..a08bc617 100644 --- a/matsciml/interfaces/ase/tests/test_multi_task.py +++ b/matsciml/interfaces/ase/tests/test_multi_task.py @@ -2,6 +2,9 @@ import pytest import torch +import numpy as np +from ase import Atoms, units +from ase.md import VelocityVerlet from matsciml.models.pyg import EGNN from matsciml.models.base import ( @@ -9,7 +12,25 @@ ScalarRegressionTask, ForceRegressionTask, ) +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) from matsciml.interfaces.ase import multitask as mt +from matsciml.interfaces.ase import MatSciMLCalculator + + +@pytest.fixture +def test_pbc() -> Atoms: + pos = np.random.normal(0.0, 1.0, size=(16, 3)) * 10.0 + atomic_numbers = np.random.randint(1, 100, size=(16,)) + cell = np.eye(3).astype(float) + return Atoms(numbers=atomic_numbers, positions=pos, cell=cell) + + +@pytest.fixture +def pbc_transform() -> list: + return [PeriodicPropertiesTransform(6.0, True), PointCloudToGraphTransform("pyg")] @pytest.fixture @@ -74,20 +95,35 @@ def multi_data_multi_task_combo(egnn_args): ( "IS2REDataset", ScalarRegressionTask( - encoder_class=EGNN, encoder_kwargs=egnn_args, task_keys=["energy"] + encoder_class=EGNN, + encoder_kwargs=egnn_args, + task_keys=["energy"], + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, ), ), ( "IS2REDataset", - ForceRegressionTask(encoder_class=EGNN, encoder_kwargs=egnn_args), + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), ), ( "S2EFDataset", - ForceRegressionTask(encoder_class=EGNN, encoder_kwargs=egnn_args), + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), ), ( "AlexandriaDataset", - ForceRegressionTask(encoder_class=EGNN, encoder_kwargs=egnn_args), + ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs=egnn_args, + output_kwargs={"lazy": False, "hidden_dim": 32, "input_dim": 32}, + ), ), ) return output, task @@ -121,3 +157,19 @@ def test_average_multi_data(multi_data_multi_task_combo): for key in ["energy", "forces"]: assert key in end, f"{key} was missing from agg results" assert end["forces"].shape == (32, 3) + + +def test_calc_multi_data( + multi_data_multi_task_combo, test_pbc: Atoms, pbc_transform: list +): + output, task = multi_data_multi_task_combo + strat = mt.AverageTasks() + calc = MatSciMLCalculator(task, multitask_strategy=strat, transforms=pbc_transform) + atoms = test_pbc.copy() + atoms.calc = calc + energy = atoms.get_potential_energy() + assert np.isfinite(energy) + forces = atoms.get_forces() + assert np.isfinite(forces).all() + dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") + dyn.run(3) From c7c25ff073859aa789d1fab9be7bc091ff8fa7d2 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 24 May 2024 08:04:32 -0700 Subject: [PATCH 15/19] fix: correcting graph key retrieval from batch This isn't actually originally part of the PR scope, but apparently this typo never triggered as an issue until now! --- matsciml/models/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 967fedb6..ca981f9a 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2390,17 +2390,17 @@ def _toggle_input_grads( input_keys = list(self.input_grad_keys.values()).pop(0) for key in input_keys: # set require grad for both point cloud and graph tensors - if "graph" in data: - g = data.get("g") + if "graph" in batch: + g = batch.get("graph") if isinstance(g, dgl.DGLGraph): if key in g.ndata: - data["graph"].ndata[key].requires_grad_(True) + batch["graph"].ndata[key].requires_grad_(True) else: # assume it's a PyG graph if key in g: getattr(g, key).requires_grad_(True) - if key in data: - target = data.get(key) + if key in batch: + target = batch.get(key) # for tensors just set them directly if isinstance(target, torch.Tensor): target.requires_grad_(True) From 4e87c5d010e69a6d2f1f48ba54c4b3783e4dc49b Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 24 May 2024 08:15:29 -0700 Subject: [PATCH 16/19] refactor: writing a dedicated method for multitask ase inference This substantially simplifies the workflow, albeit adds yet another method to multitask modules. The new method simply passes input data into the encoder, and maps it to every single subtask regardless, instead of requiring that the batch shares the dataset keys. --- matsciml/interfaces/ase/base.py | 5 ++-- matsciml/models/base.py | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index 46c5388c..94071047 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -248,12 +248,13 @@ def calculate( # get into format ready for matsciml model data_dict = self._format_pipeline(atoms) # run the data structure through the model - output = self.task_module(data_dict) - # use a more complicated parser for multitasks if isinstance(self.task_module, MultiTaskLitModule): + output = self.task_module.ase_calculate(data_dict) + # use a more complicated parser for multitasks results = self.multitask_strategy(output, self.task_module) self.results = results else: + output = self.task_module(data_dict) # add outputs to self.results as expected by ase if "energy" in output: self.results["energy"] = output["energy"].detach().item() diff --git a/matsciml/models/base.py b/matsciml/models/base.py index ca981f9a..ae925213 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2469,6 +2469,53 @@ def forward( results[task_type] = subtask(batch) return results + def ase_calculate(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]: + """ + Currently "specialized" function that runs a set of data through + every single output head, ignoring the nominal dataset/subtask + unique mapping. + + This is designed for ASE usage primarily, but ostensibly could be + used as _the_ inference call for a multitask module. Basically, + when the input data doesn't come from the same "datasets" used + for initialization/training, and we want to provide a "mixture of + experts" response. + + TODO: this could potentially be used as a template to redesign + the forward call to substantially simplify the multitask mapping. + + Parameters + ---------- + batch + Input data dictionary, which should correspond to a formatted + ase.Atoms sample. + + Returns + ------- + dict[str, dict[str, torch.Tensor]] + Nested results dictionary, following a dataset/subtask structure. + For example, {'IS2REDataset': {'ForceRegressionTask': ..., 'ScalarRegressionTask': ...}} + """ + results = {} + _grads = getattr( + self, + "needs_dynamic_grads", + False, + ) # default to not needing grads + with dynamic_gradients_context(_grads, self.has_rnn): + # this function switches of `requires_grad_` for input tensors that need them + self._toggle_input_grads(batch) + batch["embedding"] = self.encoder(batch) + # now loop through every dataset/output head pair + for dset_name, subtask_name in self.dataset_task_pairs: + subtask = self.task_map[dset_name][subtask_name] + output = subtask(batch) + # now add it to the rest of the results + if dset_name not in results: + results[dset_name] = {} + results[dset_name][subtask_name] = output + return results + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: """ This callback is used to dynamically initialize output heads. From 2f8894c1c49c2a49f441f6dc328dd0482ce0e2ae Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 24 May 2024 08:45:43 -0700 Subject: [PATCH 17/19] fix: correcting graph get retrieval --- matsciml/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index ae925213..ff69986c 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2365,7 +2365,7 @@ def _toggle_input_grads( for key in input_keys: # set require grad for both point cloud and graph tensors if "graph" in data: - g = data.get("g") + g = data.get("graph") if isinstance(g, dgl.DGLGraph): if key in g.ndata: data["graph"].ndata[key].requires_grad_(True) From 6e54dadfb60249bb8f7023657804651ba7aeae20 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 24 May 2024 08:56:26 -0700 Subject: [PATCH 18/19] fix: patches input grad toggling based on incoming batch This adjusts the logic, albeit maybe inconsistent with the rest of multitask, where we check the incoming batch for dataset names at the top level to determine if it's a multidata batch, instead of relying on the model expectations. This fixes the ase calculate behavior, which would have been mismatched since the module is inherently multidata but the incoming batch is not. --- matsciml/models/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index ff69986c..54e86edd 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2357,7 +2357,9 @@ def _toggle_input_grads( """ need_grad_keys = getattr(self, "input_grad_keys", None) if need_grad_keys is not None: - if self.is_multidata: + # we determine if it's multidata based on the incoming batch + # as it should have dataset in its key + if any(["Dataset" in key for key in batch.keys()]): # if this is a multidataset task, loop over each dataset # and enable gradients for the inputs that need them for dset_name, data in batch.items(): From 22c5f36a0ba3efa475cb2f233ec6b1405ad31390 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 28 May 2024 07:54:24 -0700 Subject: [PATCH 19/19] script: added pretrained example from multitask Signed-off-by: Lee, Kin Long Kelvin --- .../ase_multitask_from_pretrained.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 examples/interfaces/ase_multitask_from_pretrained.py diff --git a/examples/interfaces/ase_multitask_from_pretrained.py b/examples/interfaces/ase_multitask_from_pretrained.py new file mode 100644 index 00000000..8ca5889e --- /dev/null +++ b/examples/interfaces/ase_multitask_from_pretrained.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from ase import Atoms, units +from ase.md.verlet import VelocityVerlet + +from matsciml.interfaces.ase import MatSciMLCalculator +from matsciml.interfaces.ase.multitask import AverageTasks +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) + +""" +Demonstrates setting up a calculator from a pretrained +multitask/multidata module, using an averaging strategy to +merge output heads. + +As an example, if we trained force regression on multiple datasets +simultaneously, we would average the outputs from each "dataset", +similar to an ensemble prediction without any special weighting. + +Substitute 'model.ckpt' for the path to a checkpoint file. +""" + +d = 2.9 +L = 10.0 + +atoms = Atoms("C", positions=[[0, L / 2, L / 2]], cell=[d, L, L], pbc=[1, 0, 0]) + +calc = MatSciMLCalculator.from_pretrained_force_regression( + "model.ckpt", + transforms=[ + PeriodicPropertiesTransform(6.0, True), + PointCloudToGraphTransform("pyg"), + ], + multitask_strategy=AverageTasks(), # also can be specified as a string +) +# set the calculator to matsciml +atoms.calc = calc +# run the simulation for 100 timesteps, with 5 femtosecond timesteps +dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") +dyn.run(100)