Skip to content

Commit

Permalink
Merge pull request #232 from laserkelvin/predict-task-method
Browse files Browse the repository at this point in the history
General predict method for tasks
  • Loading branch information
laserkelvin committed Jun 3, 2024
2 parents 75ec524 + ca3b8a8 commit af6c5c1
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
3 changes: 1 addition & 2 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,12 @@ def calculate(
# get into format ready for matsciml model
data_dict = self._format_pipeline(atoms)
# run the data structure through the model
output = self.task_module.predict(data_dict)
if isinstance(self.task_module, MultiTaskLitModule):
output = self.task_module.ase_calculate(data_dict)
# use a more complicated parser for multitasks
results = self.multitask_strategy(output, self.task_module)
self.results = results
else:
output = self.task_module(data_dict)
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
Expand Down
85 changes: 73 additions & 12 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,38 @@ def _make_normalizers(self) -> dict[str, Normalizer]:
normalizers[key] = Normalizer(mean=mean, std=std, device=self.device)
return normalizers

def predict(self, batch: BatchDict) -> dict[str, torch.Tensor]:
"""
Implements what is effectively the 'inference' logic of the task,
where run the forward pass on a batch of samples, and if normalizers
were used for training, we also apply the inverse operation to get
values in the right scale.
Not to be confused with `predict_step`, which is used by Lightning as
part of the prediction workflow. Since there is no one-size-fits-all
inference workflow we can define, this provides a convenient function
for users to call as a replacement.
Parameters
----------
batch : BatchDict
Batch of samples to pass to the model.
Returns
-------
dict[str, torch.Tensor]
Output dictionary as provided by the forward pass, but if
normalizers are available for a given task, we apply the
inverse norm on the value.
"""
outputs = self(batch)
if self.uses_normalizers:
for key in self.task_keys:
if key in self.normalizers:
# apply the inverse transform if provided
outputs[key] = self.normalizers[key].denorm(outputs[key])
return outputs

@classmethod
def from_pretrained_encoder(cls, task_ckpt_path: str | Path, **kwargs):
"""
Expand Down Expand Up @@ -1706,6 +1738,36 @@ def energy_and_force(
outputs["node_energies"] = node_energies
return outputs

def predict(self, batch: BatchDict) -> dict[str, torch.Tensor]:
"""
Similar to the base method, but we make two minor modifications to
the denormalization logic as we want to potentially apply the same
energy normalization rescaling to the forces and node-level energies.
Parameters
----------
batch : BatchDict
Batch of samples to evaluate on.
Returns
-------
dict[str, torch.Tensor]
Output dictionary as provided by the forward call. For this task in
particular, we may also apply the energy rescaling to forces and
node energies if separate keys for them are not provided.
"""
output = super().predict(batch)
# for forces, in the event that a dedicated normalizer wasn't provided
# but we have an energy normalizer, we apply the same factors to the force
if self.uses_normalizers:
if "force" not in self.normalizers and "energy" in self.normalizers:
output["force"] = self.normalizers["energy"].denorm(output["force"])
if "node_energies" not in self.normalizers and "energy" in self.normalizers:
output["node_energies"] = self.normalizers["energy"].denorm(
output["node_energies"]
)
return output

def _get_targets(
self,
batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]],
Expand Down Expand Up @@ -2471,20 +2533,18 @@ def forward(
results[task_type] = subtask(batch)
return results

def ase_calculate(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
def predict(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
"""
Currently "specialized" function that runs a set of data through
every single output head, ignoring the nominal dataset/subtask
unique mapping.
Similar logic to the `BaseTaskModule.predict` method, but implemented
for the multitask setting.
This is designed for ASE usage primarily, but ostensibly could be
used as _the_ inference call for a multitask module. Basically,
when the input data doesn't come from the same "datasets" used
for initialization/training, and we want to provide a "mixture of
experts" response.
The workflow is a linear combination of the two: we run the joint
embedder once, and then subsequently rely on the `predict` method
for each subtask to get outputs at their expected scales.
TODO: this could potentially be used as a template to redesign
the forward call to substantially simplify the multitask mapping.
This method also behaves a little differently from the other multitask
operations, as it runs a set of data through every single output head,
ignoring the nominal dataset/subtask unique mapping.
Parameters
----------
Expand All @@ -2511,7 +2571,8 @@ def ase_calculate(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]:
# now loop through every dataset/output head pair
for dset_name, subtask_name in self.dataset_task_pairs:
subtask = self.task_map[dset_name][subtask_name]
output = subtask(batch)
# use the predict method to get rescaled outputs
output = subtask.predict(batch)
# now add it to the rest of the results
if dset_name not in results:
results[dset_name] = {}
Expand Down

0 comments on commit af6c5c1

Please sign in to comment.