Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frame Averaging Energy Reduction Bugfix #205

Merged
8 changes: 6 additions & 2 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def energy_and_force(
pos, frame_embedding, readout
)
force.append(frame_force)
energy.append(frame_energy)
energy.append(frame_energy.unsqueeze(-1))

# check to see if we are frame averaging
if fa_rot is not None:
Expand All @@ -1667,8 +1667,12 @@ def energy_and_force(
force[frame_idx].view(-1, 1, 3).bmm(repeat_rot.transpose(1, 2))
)
all_forces.append(rotated_forces)
# combine all the force data into a single tensor
# combine all the force and energy data into a single tensor
laserkelvin marked this conversation as resolved.
Show resolved Hide resolved
# using frame averaging, the expected shapes after concatenation are:
# force - [num positions, num frames, 3]
# energy - [batch size, num frames, 1]
force = torch.cat(all_forces, dim=1)
energy = torch.cat(energy, dim=1)
# reduce outputs to what are expected shapes
outputs["force"] = reduce(
force,
Expand Down
8 changes: 6 additions & 2 deletions matsciml/models/pyg/faenet/faenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,19 @@ def read_batch(self, batch: BatchDict) -> DataDict:
DataDict
Input data for FAENet as a dictionary.
"""

data = {"graph": batch.get("graph")}
graph = batch.get("graph")
for key in ["edge_feats", "graph_feats"]:
data[key] = getattr(graph, key, None)
pos: torch.Tensor = getattr(graph, "pos")
data["pos"] = pos
data["graph"].cell = batch["cell"]
data["graph"].natoms = batch["natoms"].squeeze(-1).to(torch.int32)
if "natoms" not in batch:
laserkelvin marked this conversation as resolved.
Show resolved Hide resolved
_, natoms = torch.unique(graph.batch, return_counts=True)
data["graph"].natoms = natoms
else:
data["graph"].natoms = batch["natoms"].squeeze(-1).to(torch.int32)

edge_index, cell_offsets, neighbors = radius_graph_pbc(
data["graph"],
self.cutoff,
Expand Down
132 changes: 132 additions & 0 deletions matsciml/models/pyg/tests/test_faenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import pytest
import torch
import pytorch_lightning as pl

# this import is not used, but ensures that the registry is updated
from matsciml.common.registry import registry
from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
PointCloudToGraphTransform,
FrameAveraging,
)
from matsciml.lightning import MatSciMLDataModule
from matsciml.models.pyg import FAENet
from matsciml.models.base import ForceRegressionTask


@pytest.fixture
def faenet_architecture() -> FAENet:
"""
Fixture for a nominal FAENet architecture.

Some lightweight (but realistic) hyperparameters are
used to test data flowing through the model.

Returns
-------
FAENet
Concrete FAENet object
"""
faenet_kwargs = {
"average_frame_embeddings": False, # set to false for use with FA transform
"pred_as_dict": False,
"hidden_dim": 128,
"out_dim": 128,
"tag_hidden_channels": 0,
}
model = FAENet(**faenet_kwargs)
return model


# here we filter out datasets from the registry that don't make sense
ignore_dset = ["Multi", "M3G", "PyG", "Cdvae", "SyntheticPointGroupDataset"]
filtered_list = list(
filter(
lambda x: all([target_str not in x for target_str in ignore_dset]),
registry.__entries__["datasets"].keys(),
),
)


@pytest.mark.parametrize(
"dset_class_name",
filtered_list,
)
def test_model_forward_nograd(dset_class_name: str, faenet_architecture: FAENet):
# these are necessary for the model to work as intended
"""
This test checks model ``forward`` compatibility with datasets.

The test is parameterized to run on all datasets in the registry
that have *not* been filtered out; this list should be sparse,
as the idea is to maximize coverage and we can just ignore failing
combinations if they do not make sense and we can at least be
aware of them.

Parameters
----------
dset_class_name : str
Name of the dataset class to retrieve
faenet_architecture : FAENet
Concrete FAENet object with some parameters
"""
transforms = [
PeriodicPropertiesTransform(cutoff_radius=6.0),
PointCloudToGraphTransform(
"pyg",
node_keys=["pos", "atomic_numbers"],
),
FrameAveraging(frame_averaging="3D", fa_method="stochastic"),
]
dm = MatSciMLDataModule.from_devset(
dset_class_name,
batch_size=8,
dset_kwargs={"transforms": transforms},
)
# dummy initialization
dm.setup("fit")
loader = dm.train_dataloader()
batch = next(iter(loader))
# run the model without gradient tracking
with torch.no_grad():
embeddings = faenet_architecture(batch)
# returns embeddings, and runs numerical checks
for z in [embeddings.system_embedding, embeddings.point_embedding]:
assert torch.isreal(z).all()
assert ~torch.isnan(z).all() # check there are no NaNs
assert torch.isfinite(z).all()
assert torch.all(torch.abs(z) <= 1000) # ensure reasonable values


def test_force_regression(faenet_architecture):
devset = MatSciMLDataModule.from_devset(
"S2EFDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(cutoff_radius=6.0, adaptive_cutoff=True),
PointCloudToGraphTransform(
"pyg",
node_keys=["pos", "atomic_numbers"],
),
FrameAveraging(frame_averaging="3D", fa_method="stochastic"),
],
},
)
task = ForceRegressionTask(
faenet_architecture,
)
trainer = pl.Trainer(
max_steps=5, logger=False, enable_checkpointing=False, accelerator="cpu"
)
trainer.fit(task, datamodule=devset)
# make sure losses are tracked
for key in ["energy", "force"]:
assert f"train_{key}" in trainer.logged_metrics

loader = devset.train_dataloader()
batch = next(iter(loader))
outputs = task(batch)
assert outputs["energy"].size(0) == batch["natoms"].size(0)
assert outputs["force"].size(0) == sum(batch["natoms"]).item()
Loading