Skip to content
Merged
7 changes: 4 additions & 3 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
workflow_dispatch:
push:
branches: [main]
pull_request:

# set GITHUB_TOKEN permissions to allow deployment to GitHub Pages
permissions:
Expand All @@ -13,7 +14,6 @@ permissions:

jobs:
build-docs:
if: github.repository_owner == 'Radical-AI' && github.ref == 'refs/heads/main'
runs-on: ubuntu-latest

steps:
Expand All @@ -35,8 +35,8 @@ jobs:
run: |
uv pip install .[test,docs] --system

- name: Install mace
run: uv pip install mace-torch --system
- name: Install extras for tutorial generation
run: uv pip install ".[graphpes,mace]" --system

- name: Copy tutorials
run: |
Expand All @@ -57,6 +57,7 @@ jobs:
path: ./docs_build

deploy:
if: github.repository_owner == 'Radical-AI' && github.ref == 'refs/heads/main'
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
- { name: mattersim, test_path: "tests/models/test_mattersim.py" }
- { name: orb, test_path: "tests/models/test_orb.py" }
- { name: sevenn, test_path: "tests/models/test_sevennet.py" }
- { name: graphpes, test_path: "tests/models/test_graphpes.py" }
runs-on: ${{ matrix.os }}

steps:
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"html_image",
]

autodoc_mock_imports = ["fairchem", "mace", "mattersim", "orb", "sevennet"]
autodoc_mock_imports = ["fairchem", "mace", "mattersim", "orb", "sevennet", "graphpes"]

# use type hints
autodoc_typehints = "description"
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ versions of the tutorials can also be found in the `torch-sim /examples/tutorial
autobatching_tutorial
low_level_tutorial
hybrid_swap_tutorial
using_graphpes_tutorial
92 changes: 92 additions & 0 deletions examples/tutorials/using_graphpes_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# %% [markdown]
# <details>
# <summary>Dependencies</summary>
# /// script
# dependencies = [
# "graph-pes>=0.0.30",
# "torch==2.5",
# ]
# ///
# </details>


# %% [markdown]
"""
# Integrating TorchSim with `graph-pes`

This brief tutorial demonstrates how to use models trained with the
[graph-pes](https://github.com/mir-group/graph-pes) package to drive
MD simulations and geometry optimizations in TorchSim.

## Step 1: loading a model

As an output of the `graph-pes-train` command, you receive a path
to a `.pt` file containing your trained model. To use this model
with TorchSim, pass the path to this `.pt` file, or the model itself,
to the `GraphPESWrapper` constructor.

Below, we create a dummy TensorNet model with random weights as a demonstration:
"""

# %%
from graph_pes.models import TensorNet, load_model

# if you had a model saved to disk, you could load it like this:
# model = load_model("path/to/model.pt")

# here, we just create a TensorNet model with random weights
model = TensorNet(cutoff=5.0)

print("Number of parameters:", sum(p.numel() for p in model.parameters()))

# %% [markdown]
"""
## Step 2: wrapping the model for use with TorchSim

We provide the `GraphPESWrapper` class to wrap a `graph-pes` model for use with TorchSim.
If you intend to drive simulations that require stresses, you will need to specify the
`compute_stress` argument to `True`.
"""

# %%
from torch_sim.models import GraphPESWrapper

# wrap the model for use with TorchSim
ts_model = GraphPESWrapper(model, compute_stress=False)

# or, alternatively, pass a model path directly:
# ts_model = GraphPESWrapper("path/to/model.pt", compute_stress=False)

# %% [markdown]
"""
## Step 3: driving MD with the model

Now that we have a model, we can drive MD simulations with it. For this, we will use the
`integrate` function.
"""
# %%
from ase.build import molecule
import torch_sim as ts
from load_atoms import view

# NVT at 300K
atoms = molecule("H2O")

final_state = ts.integrate(
system=atoms,
model=ts_model,
integrator=ts.nvt_langevin,
n_steps=50,
temperature=300,
timestep=0.001,
)

final_atoms = final_state.to_atoms()[0]
view(final_atoms, show_bonds=True)

# %% [markdown]
"""
Of course, this is a very simple example. However, you are now equipped to
use any `graph-pes` model that you have trained to drive any of the functionality
exposed by TorchSim!
"""
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ test = [
]
mace = ["mace-torch>=0.3.11"]
mattersim = ["mattersim>=0.1.2"]

orb = [
"orb-models@git+https://github.com/orbital-materials/orb-models#egg=637a98d49cfb494e2491a457d9bbd28311fecf21",
]
sevenn = ["sevenn>=0.11.0"]
graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.11"]
docs = [
"autodoc_pydantic==2.2.0",
"furo==2024.8.6",
Expand Down
214 changes: 214 additions & 0 deletions tests/models/test_graphpes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import pytest
import torch
from ase.build import bulk, molecule

from tests.models.conftest import (
consistency_test_simstate_fixtures,
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)
from torch_sim.io import atoms_to_state
from torch_sim.models.graphpes import GraphPESWrapper


try:
from graph_pes.atomic_graph import AtomicGraph, to_batch
from graph_pes.interfaces import mace_mp
from graph_pes.models import LennardJones, SchNet, TensorNet, ZEmbeddingNequIP
except ImportError:
pytest.skip("graph-pes not installed", allow_module_level=True)


@pytest.fixture
def dtype() -> torch.dtype:
"""Fixture to provide the default dtype for testing."""
return torch.float32


def test_graphpes_isolated(device: torch.device):
# test that the raw model and torch_sim wrapper give the same results
# for an isolated, unbatched structure

water_atoms = molecule("H2O")
water_atoms.center(vacuum=10.0)

gp_model = SchNet(cutoff=5.5)
gp_graph = AtomicGraph.from_ase(water_atoms, cutoff=5.5)
gp_energy = gp_model.predict_energy(gp_graph)

ts_model = GraphPESWrapper(
gp_model,
device=device,
dtype=torch.float32,
compute_forces=True,
compute_stress=False,
)
ts_output = ts_model(atoms_to_state([water_atoms], device, torch.float32))
assert set(ts_output.keys()) == {"energy", "forces"}
assert ts_output["energy"].shape == (1,)

assert gp_energy.item() == pytest.approx(ts_output["energy"].item(), abs=1e-5)


def test_graphpes_periodic(device: torch.device):
# test that the raw model and torch_sim wrapper give the same results
# for a periodic, unbatched structure

bulk_atoms = bulk("Al", "hcp", a=4.05)
assert bulk_atoms.pbc.all()

gp_model = TensorNet(cutoff=5.5)
gp_graph = AtomicGraph.from_ase(bulk_atoms, cutoff=5.5)
gp_forces = gp_model.predict_forces(gp_graph)

ts_model = GraphPESWrapper(
gp_model,
device=device,
dtype=torch.float32,
compute_forces=True,
compute_stress=True,
)
ts_output = ts_model(atoms_to_state([bulk_atoms], device, torch.float32))
assert set(ts_output.keys()) == {"energy", "forces", "stress"}
assert ts_output["energy"].shape == (1,)
assert ts_output["forces"].shape == (len(bulk_atoms), 3)
assert ts_output["stress"].shape == (1, 3, 3)

torch.testing.assert_close(ts_output["forces"].to("cpu"), gp_forces)


def test_batching(device: torch.device):
# test that the raw model and torch_sim wrapper give the same results
# when batching is done via torch_sim's atoms_to_state function

water = molecule("H2O")
methane = molecule("CH4")
systems = [water, methane]
for s in systems:
s.center(vacuum=10.0)

gp_model = SchNet(cutoff=5.5)
gp_graphs = [AtomicGraph.from_ase(s, cutoff=5.5) for s in systems]

gp_energies = gp_model.predict_energy(to_batch(gp_graphs))

ts_model = GraphPESWrapper(
gp_model,
device=device,
dtype=torch.float32,
compute_forces=True,
compute_stress=True,
)
ts_output = ts_model(atoms_to_state(systems, device, torch.float32))

assert set(ts_output.keys()) == {"energy", "forces", "stress"}
assert ts_output["energy"].shape == (2,)
assert ts_output["forces"].shape == (sum(len(s) for s in systems), 3)
assert ts_output["stress"].shape == (2, 3, 3)

assert gp_energies[0].item() == pytest.approx(ts_output["energy"][0].item(), abs=1e-5)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_graphpes_dtype(device: torch.device, dtype: torch.dtype):
water = molecule("H2O")

model = SchNet()

ts_wrapper = GraphPESWrapper(model, device=device, dtype=dtype, compute_stress=False)
ts_output = ts_wrapper(atoms_to_state([water], device, dtype))
assert ts_output["energy"].dtype == dtype
assert ts_output["forces"].dtype == dtype


_nequip_model = ZEmbeddingNequIP()


@pytest.fixture
def ts_nequip_model(device: torch.device, dtype: torch.dtype):
return GraphPESWrapper(
_nequip_model,
device=device,
dtype=dtype,
compute_stress=False,
)


@pytest.fixture
def ase_nequip_calculator(device: torch.device, dtype: torch.dtype):
return _nequip_model.to(device, dtype).ase_calculator(skin=0.0)


test_graphpes_nequip_consistency = make_model_calculator_consistency_test(
test_name="graphpes-nequip",
model_fixture_name="ts_nequip_model",
calculator_fixture_name="ase_nequip_calculator",
sim_state_names=consistency_test_simstate_fixtures,
)

test_graphpes_nequip_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ts_nequip_model",
)


@pytest.fixture
def ts_mace_model(device: torch.device, dtype: torch.dtype):
return GraphPESWrapper(
mace_mp("medium-mpa-0"),
device=device,
dtype=dtype,
compute_stress=False,
)


@pytest.fixture
def ase_mace_calculator(device: torch.device, dtype: torch.dtype):
return mace_mp("medium-mpa-0").to(device, dtype).ase_calculator(skin=0.0)


test_graphpes_mace_consistency = make_model_calculator_consistency_test(
test_name="graphpes-mace",
model_fixture_name="ts_mace_model",
calculator_fixture_name="ase_mace_calculator",
sim_state_names=consistency_test_simstate_fixtures,
# graph-pes passes data directly to the underlying mace-torch model
# from test_mace.py, it seems that these mace-torch models can be
# surprisingly variable in the CI (these tests pass locally on
# MacBooks with no need for high tolerances)
# While investigating, I found that mace-torch model predictions are
# mildly sensitive to the order of items in the neighbourlist - this
# could be the cause of the discrepancies between the ASE calculator
# and the TorchSim wrapper, both here and in test_mace.py
rtol=6e-4,
atol=1e-5,
)

test_graphpes_mace_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ts_mace_model",
)


_lj_model = LennardJones(sigma=0.5)


@pytest.fixture
def ts_lj_model(device: torch.device, dtype: torch.dtype):
return GraphPESWrapper(
_lj_model,
device=device,
dtype=dtype,
compute_stress=False,
)


@pytest.fixture
def ase_lj_calculator(device: torch.device, dtype: torch.dtype):
return _lj_model.to(device, dtype).ase_calculator(skin=0.0)


test_graphpes_lj_consistency = make_model_calculator_consistency_test(
test_name="graphpes-lj",
model_fixture_name="ts_lj_model",
calculator_fixture_name="ase_lj_calculator",
sim_state_names=consistency_test_simstate_fixtures,
)
5 changes: 5 additions & 0 deletions torch_sim/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@
from torch_sim.models.mattersim import MatterSimModel
except ImportError:
pass

try:
from torch_sim.models.graphpes import GraphPESWrapper
except ImportError:
pass
Loading