Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node embedding projections for ForceRegressionTask #204

Merged
merged 12 commits into from
May 3, 2024
Merged
61 changes: 54 additions & 7 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Iterable
from contextlib import ExitStack, nullcontext
from pathlib import Path
from typing import Any, ContextManager, Dict, List, Optional, Type, Union
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, Union
from warnings import warn

import pytorch_lightning as pl
Expand Down Expand Up @@ -1501,6 +1501,7 @@ def __init__(
loss_func: type[nn.Module] | nn.Module = nn.L1Loss,
task_keys: list[str] | None = None,
output_kwargs: dict[str, Any] = {},
embedding_reduction_type: str = "sum",
**kwargs,
) -> None:
super().__init__(
Expand All @@ -1510,6 +1511,7 @@ def __init__(
loss_func,
task_keys,
output_kwargs,
embedding_reduction_type=embedding_reduction_type,
**kwargs,
)
self.save_hyperparameters(ignore=["encoder", "loss_func"])
Expand Down Expand Up @@ -1543,6 +1545,7 @@ def forward(
fa_rot = getattr(graph, "fa_rot", None)
fa_pos = getattr(graph, "fa_pos", None)
else:
graph = None
# assume point cloud otherwise
pos: torch.Tensor = batch.get("pos")
# no frame averaging architecture yet for point clouds
Expand All @@ -1569,7 +1572,9 @@ def forward(
else:
embeddings = self.encoder(batch)
natoms = batch.get("natoms", None)
outputs = self.process_embedding(embeddings, pos, fa_rot, fa_pos, natoms)
outputs = self.process_embedding(
embeddings, pos, fa_rot, fa_pos, natoms, graph
)
return outputs

def process_embedding(
Expand All @@ -1579,13 +1584,47 @@ def process_embedding(
fa_rot: None | torch.Tensor = None,
fa_pos: None | torch.Tensor = None,
natoms: None | torch.Tensor = None,
graph: None | AbstractGraph = None,
) -> dict[str, torch.Tensor]:
outputs = {}
# compute node-level contributions to the energy
node_energies = self.output_heads["energy"](embeddings.point_embedding)
# figure out how we're going to reduce node level energies
# depending on the representation and/or the graph framework
if graph is not None:
if isinstance(graph, dgl.DGLGraph):
graph.ndata["node_energies"] = node_energies

def readout(node_energies: torch.Tensor):
return dgl.readout_nodes(
graph, "node_energies", op=self.embedding_reduction_type
)
else:
# assumes a batched pyg graph
batch = graph.batch
from torch_geometric.utils import scatter

def readout(node_energies: torch.Tensor):
return scatter(
node_energies,
batch,
dim=-2,
reduce=self.embedding_reduction_type,
)
else:

def readout(node_energies: torch.Tensor):
return reduce(
node_energies, "b ... d -> b ()", self.embedding_reduction_type
)

def energy_and_force(
pos: torch.Tensor, system_embedding: torch.Tensor
pos: torch.Tensor, node_energies: torch.Tensor, readout: Callable
) -> tuple[torch.Tensor, torch.Tensor]:
energy = self.output_heads["energy"](system_embedding)
# we sum over points and keep dimension as 1
energy = readout(node_energies)
if energy.ndim == 1:
energy.unsqueeze(-1)
# now use autograd for force calculation
force = (
-1
Expand All @@ -1598,14 +1637,17 @@ def energy_and_force(
)
return energy, force

# not using frame averaging
if fa_pos is None:
energy, force = energy_and_force(pos, embeddings.system_embedding)
energy, force = energy_and_force(pos, node_energies, readout)
else:
energy = []
force = []
for idx, pos in enumerate(fa_pos):
frame_embedding = embeddings.system_embedding[:, idx, :]
frame_energy, frame_force = energy_and_force(pos, frame_embedding)
frame_embedding = node_energies[:, idx, :]
frame_energy, frame_force = energy_and_force(
pos, frame_embedding, readout
)
force.append(frame_force)
energy.append(frame_energy)

Expand Down Expand Up @@ -1634,12 +1676,17 @@ def energy_and_force(
self.embedding_reduction_type,
d=3,
)
# this may not do anything if we aren't frame averaging
# since the reduction is also done in the energy_and_force call
outputs["energy"] = reduce(
energy,
"b ... d -> b d",
self.embedding_reduction_type,
d=1,
)
# this ensures that we get a scalar value for every node
# representing the energy contribution
outputs["node_energies"] = node_energies
return outputs

def _get_targets(
Expand Down
37 changes: 36 additions & 1 deletion matsciml/models/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
PointCloudToGraphTransform,
PeriodicPropertiesTransform,
NoisyPositions,
FrameAveraging,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models import PLEGNNBackbone
from matsciml.models import PLEGNNBackbone, FAENet
from matsciml.models.base import (
ForceRegressionTask,
GradFreeForceRegressionTask,
Expand Down Expand Up @@ -56,6 +57,18 @@ def egnn_config():
return {"encoder_class": PLEGNNBackbone, "encoder_kwargs": model_args}


@pytest.fixture
def faenet_config():
model_args = {
"average_frame_embeddings": False,
"pred_as_dict": False,
"hidden_channels": 128,
"out_dim": 128,
"tag_hidden_channels": 0,
}
return {"encoder_class": FAENet, "encoder_kwargs": model_args}


def test_force_regression(egnn_config):
devset = MatSciMLDataModule.from_devset(
"S2EFDataset",
Expand All @@ -77,6 +90,28 @@ def test_force_regression(egnn_config):
assert f"train_{key}" in trainer.logged_metrics


def test_fa_force_regression(faenet_config):
devset = MatSciMLDataModule.from_devset(
"S2EFDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(6.0, True),
PointCloudToGraphTransform(
"pyg",
node_keys=["pos", "force", "atomic_numbers"],
),
FrameAveraging(frame_averaging="3D", fa_method="stochastic"),
],
},
)
task = ForceRegressionTask(**faenet_config)
trainer = pl.Trainer(max_steps=5, logger=False, enable_checkpointing=False)
trainer.fit(task, datamodule=devset)
# make sure losses are tracked
for key in ["energy", "force"]:
assert f"train_{key}" in trainer.logged_metrics


def test_gradfree_force_regression(egnn_config):
devset = MatSciMLDataModule.from_devset(
"S2EFDataset",
Expand Down
Loading