diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 0bab9757..0c48c8d3 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -4,6 +4,7 @@ on: workflow_dispatch: push: branches: [main] + pull_request: # set GITHUB_TOKEN permissions to allow deployment to GitHub Pages permissions: @@ -13,7 +14,6 @@ permissions: jobs: build-docs: - if: github.repository_owner == 'Radical-AI' && github.ref == 'refs/heads/main' runs-on: ubuntu-latest steps: @@ -35,8 +35,8 @@ jobs: run: | uv pip install .[test,docs] --system - - name: Install mace - run: uv pip install mace-torch --system + - name: Install extras for tutorial generation + run: uv pip install ".[graphpes,mace]" --system - name: Copy tutorials run: | @@ -57,6 +57,7 @@ jobs: path: ./docs_build deploy: + if: github.repository_owner == 'Radical-AI' && github.ref == 'refs/heads/main' environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eeb94034..db059fe9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,6 +66,7 @@ jobs: - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } + - { name: graphpes, test_path: "tests/models/test_graphpes.py" } runs-on: ${{ matrix.os }} steps: diff --git a/docs/conf.py b/docs/conf.py index f80954ef..23221250 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,7 +62,7 @@ "html_image", ] -autodoc_mock_imports = ["fairchem", "mace", "mattersim", "orb", "sevennet"] +autodoc_mock_imports = ["fairchem", "mace", "mattersim", "orb", "sevennet", "graphpes"] # use type hints autodoc_typehints = "description" diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 53bd7eb2..745c08ad 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -18,3 +18,4 @@ versions of the tutorials can also be found in the `torch-sim /examples/tutorial autobatching_tutorial low_level_tutorial hybrid_swap_tutorial + using_graphpes_tutorial diff --git a/examples/tutorials/using_graphpes_tutorial.py b/examples/tutorials/using_graphpes_tutorial.py new file mode 100644 index 00000000..bbc6b5db --- /dev/null +++ b/examples/tutorials/using_graphpes_tutorial.py @@ -0,0 +1,92 @@ +# %% [markdown] +#
+# Dependencies +# /// script +# dependencies = [ +# "graph-pes>=0.0.30", +# "torch==2.5", +# ] +# /// +#
+ + +# %% [markdown] +""" +# Integrating TorchSim with `graph-pes` + +This brief tutorial demonstrates how to use models trained with the +[graph-pes](https://github.com/mir-group/graph-pes) package to drive +MD simulations and geometry optimizations in TorchSim. + +## Step 1: loading a model + +As an output of the `graph-pes-train` command, you receive a path +to a `.pt` file containing your trained model. To use this model +with TorchSim, pass the path to this `.pt` file, or the model itself, +to the `GraphPESWrapper` constructor. + +Below, we create a dummy TensorNet model with random weights as a demonstration: +""" + +# %% +from graph_pes.models import TensorNet, load_model + +# if you had a model saved to disk, you could load it like this: +# model = load_model("path/to/model.pt") + +# here, we just create a TensorNet model with random weights +model = TensorNet(cutoff=5.0) + +print("Number of parameters:", sum(p.numel() for p in model.parameters())) + +# %% [markdown] +""" +## Step 2: wrapping the model for use with TorchSim + +We provide the `GraphPESWrapper` class to wrap a `graph-pes` model for use with TorchSim. +If you intend to drive simulations that require stresses, you will need to specify the +`compute_stress` argument to `True`. +""" + +# %% +from torch_sim.models import GraphPESWrapper + +# wrap the model for use with TorchSim +ts_model = GraphPESWrapper(model, compute_stress=False) + +# or, alternatively, pass a model path directly: +# ts_model = GraphPESWrapper("path/to/model.pt", compute_stress=False) + +# %% [markdown] +""" +## Step 3: driving MD with the model + +Now that we have a model, we can drive MD simulations with it. For this, we will use the +`integrate` function. +""" +# %% +from ase.build import molecule +import torch_sim as ts +from load_atoms import view + +# NVT at 300K +atoms = molecule("H2O") + +final_state = ts.integrate( + system=atoms, + model=ts_model, + integrator=ts.nvt_langevin, + n_steps=50, + temperature=300, + timestep=0.001, +) + +final_atoms = final_state.to_atoms()[0] +view(final_atoms, show_bonds=True) + +# %% [markdown] +""" +Of course, this is a very simple example. However, you are now equipped to +use any `graph-pes` model that you have trained to drive any of the functionality +exposed by TorchSim! +""" diff --git a/pyproject.toml b/pyproject.toml index ac2b64c6..38dd59ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,11 +46,11 @@ test = [ ] mace = ["mace-torch>=0.3.11"] mattersim = ["mattersim>=0.1.2"] - orb = [ "orb-models@git+https://github.com/orbital-materials/orb-models#egg=637a98d49cfb494e2491a457d9bbd28311fecf21", ] sevenn = ["sevenn>=0.11.0"] +graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.11"] docs = [ "autodoc_pydantic==2.2.0", "furo==2024.8.6", diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py new file mode 100644 index 00000000..cd32eff2 --- /dev/null +++ b/tests/models/test_graphpes.py @@ -0,0 +1,214 @@ +import pytest +import torch +from ase.build import bulk, molecule + +from tests.models.conftest import ( + consistency_test_simstate_fixtures, + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) +from torch_sim.io import atoms_to_state +from torch_sim.models.graphpes import GraphPESWrapper + + +try: + from graph_pes.atomic_graph import AtomicGraph, to_batch + from graph_pes.interfaces import mace_mp + from graph_pes.models import LennardJones, SchNet, TensorNet, ZEmbeddingNequIP +except ImportError: + pytest.skip("graph-pes not installed", allow_module_level=True) + + +@pytest.fixture +def dtype() -> torch.dtype: + """Fixture to provide the default dtype for testing.""" + return torch.float32 + + +def test_graphpes_isolated(device: torch.device): + # test that the raw model and torch_sim wrapper give the same results + # for an isolated, unbatched structure + + water_atoms = molecule("H2O") + water_atoms.center(vacuum=10.0) + + gp_model = SchNet(cutoff=5.5) + gp_graph = AtomicGraph.from_ase(water_atoms, cutoff=5.5) + gp_energy = gp_model.predict_energy(gp_graph) + + ts_model = GraphPESWrapper( + gp_model, + device=device, + dtype=torch.float32, + compute_forces=True, + compute_stress=False, + ) + ts_output = ts_model(atoms_to_state([water_atoms], device, torch.float32)) + assert set(ts_output.keys()) == {"energy", "forces"} + assert ts_output["energy"].shape == (1,) + + assert gp_energy.item() == pytest.approx(ts_output["energy"].item(), abs=1e-5) + + +def test_graphpes_periodic(device: torch.device): + # test that the raw model and torch_sim wrapper give the same results + # for a periodic, unbatched structure + + bulk_atoms = bulk("Al", "hcp", a=4.05) + assert bulk_atoms.pbc.all() + + gp_model = TensorNet(cutoff=5.5) + gp_graph = AtomicGraph.from_ase(bulk_atoms, cutoff=5.5) + gp_forces = gp_model.predict_forces(gp_graph) + + ts_model = GraphPESWrapper( + gp_model, + device=device, + dtype=torch.float32, + compute_forces=True, + compute_stress=True, + ) + ts_output = ts_model(atoms_to_state([bulk_atoms], device, torch.float32)) + assert set(ts_output.keys()) == {"energy", "forces", "stress"} + assert ts_output["energy"].shape == (1,) + assert ts_output["forces"].shape == (len(bulk_atoms), 3) + assert ts_output["stress"].shape == (1, 3, 3) + + torch.testing.assert_close(ts_output["forces"].to("cpu"), gp_forces) + + +def test_batching(device: torch.device): + # test that the raw model and torch_sim wrapper give the same results + # when batching is done via torch_sim's atoms_to_state function + + water = molecule("H2O") + methane = molecule("CH4") + systems = [water, methane] + for s in systems: + s.center(vacuum=10.0) + + gp_model = SchNet(cutoff=5.5) + gp_graphs = [AtomicGraph.from_ase(s, cutoff=5.5) for s in systems] + + gp_energies = gp_model.predict_energy(to_batch(gp_graphs)) + + ts_model = GraphPESWrapper( + gp_model, + device=device, + dtype=torch.float32, + compute_forces=True, + compute_stress=True, + ) + ts_output = ts_model(atoms_to_state(systems, device, torch.float32)) + + assert set(ts_output.keys()) == {"energy", "forces", "stress"} + assert ts_output["energy"].shape == (2,) + assert ts_output["forces"].shape == (sum(len(s) for s in systems), 3) + assert ts_output["stress"].shape == (2, 3, 3) + + assert gp_energies[0].item() == pytest.approx(ts_output["energy"][0].item(), abs=1e-5) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_graphpes_dtype(device: torch.device, dtype: torch.dtype): + water = molecule("H2O") + + model = SchNet() + + ts_wrapper = GraphPESWrapper(model, device=device, dtype=dtype, compute_stress=False) + ts_output = ts_wrapper(atoms_to_state([water], device, dtype)) + assert ts_output["energy"].dtype == dtype + assert ts_output["forces"].dtype == dtype + + +_nequip_model = ZEmbeddingNequIP() + + +@pytest.fixture +def ts_nequip_model(device: torch.device, dtype: torch.dtype): + return GraphPESWrapper( + _nequip_model, + device=device, + dtype=dtype, + compute_stress=False, + ) + + +@pytest.fixture +def ase_nequip_calculator(device: torch.device, dtype: torch.dtype): + return _nequip_model.to(device, dtype).ase_calculator(skin=0.0) + + +test_graphpes_nequip_consistency = make_model_calculator_consistency_test( + test_name="graphpes-nequip", + model_fixture_name="ts_nequip_model", + calculator_fixture_name="ase_nequip_calculator", + sim_state_names=consistency_test_simstate_fixtures, +) + +test_graphpes_nequip_model_outputs = make_validate_model_outputs_test( + model_fixture_name="ts_nequip_model", +) + + +@pytest.fixture +def ts_mace_model(device: torch.device, dtype: torch.dtype): + return GraphPESWrapper( + mace_mp("medium-mpa-0"), + device=device, + dtype=dtype, + compute_stress=False, + ) + + +@pytest.fixture +def ase_mace_calculator(device: torch.device, dtype: torch.dtype): + return mace_mp("medium-mpa-0").to(device, dtype).ase_calculator(skin=0.0) + + +test_graphpes_mace_consistency = make_model_calculator_consistency_test( + test_name="graphpes-mace", + model_fixture_name="ts_mace_model", + calculator_fixture_name="ase_mace_calculator", + sim_state_names=consistency_test_simstate_fixtures, + # graph-pes passes data directly to the underlying mace-torch model + # from test_mace.py, it seems that these mace-torch models can be + # surprisingly variable in the CI (these tests pass locally on + # MacBooks with no need for high tolerances) + # While investigating, I found that mace-torch model predictions are + # mildly sensitive to the order of items in the neighbourlist - this + # could be the cause of the discrepancies between the ASE calculator + # and the TorchSim wrapper, both here and in test_mace.py + rtol=6e-4, + atol=1e-5, +) + +test_graphpes_mace_model_outputs = make_validate_model_outputs_test( + model_fixture_name="ts_mace_model", +) + + +_lj_model = LennardJones(sigma=0.5) + + +@pytest.fixture +def ts_lj_model(device: torch.device, dtype: torch.dtype): + return GraphPESWrapper( + _lj_model, + device=device, + dtype=dtype, + compute_stress=False, + ) + + +@pytest.fixture +def ase_lj_calculator(device: torch.device, dtype: torch.dtype): + return _lj_model.to(device, dtype).ase_calculator(skin=0.0) + + +test_graphpes_lj_consistency = make_model_calculator_consistency_test( + test_name="graphpes-lj", + model_fixture_name="ts_lj_model", + calculator_fixture_name="ase_lj_calculator", + sim_state_names=consistency_test_simstate_fixtures, +) diff --git a/torch_sim/models/__init__.py b/torch_sim/models/__init__.py index 9c0a10ad..28675a71 100644 --- a/torch_sim/models/__init__.py +++ b/torch_sim/models/__init__.py @@ -31,3 +31,8 @@ from torch_sim.models.mattersim import MatterSimModel except ImportError: pass + +try: + from torch_sim.models.graphpes import GraphPESWrapper +except ImportError: + pass diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py new file mode 100644 index 00000000..6f6ed5b5 --- /dev/null +++ b/torch_sim/models/graphpes.py @@ -0,0 +1,185 @@ +"""An interface for using arbitrary GraphPESModels in torch_sim. + +This module provides a TorchSim wrapper of the GraphPES models for computing +energies, forces, and stresses of atomistic systems. It serves as a wrapper around +the graph_pes library, integrating it with the torch_sim framework to enable seamless +simulation of atomistic systems with machine learning potentials. + +The GraphPESWrapper class adapts GraphPESModels to the ModelInterface protocol, +allowing them to be used within the broader torch_sim simulation framework. + +Notes: + This implementation requires graph_pes to be installed and accessible. + It supports various model configurations through model instances or model paths. +""" + +import typing +from pathlib import Path + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import vesin_nl_ts +from torch_sim.state import SimState, StateDict + + +try: + from graph_pes import AtomicGraph, GraphPESModel + from graph_pes.atomic_graph import PropertyKey, to_batch + from graph_pes.models import load_model + +except ImportError: + PropertyKey = str + + class GraphPESWrapper(torch.nn.Module, ModelInterface): # type: ignore[reportRedeclaration] + """GraphPESModel wrapper for torch_sim. + + This class is a placeholder for the GraphPESWrapper class. + It raises an ImportError if graph_pes is not installed. + """ + + def __init__(self, *_args: typing.Any, **_kwargs: typing.Any) -> None: # noqa: D107 + raise ImportError("graph_pes must be installed to use this model.") + + class AtomicGraph: # type: ignore[reportRedeclaration] # noqa: D101 + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: # noqa: D107,ARG002 + raise ImportError("graph_pes must be installed to use this model.") + + class GraphPESModel(torch.nn.Module): # type: ignore[reportRedeclaration] # noqa: D101 + pass + + +def state_to_atomic_graph(state: SimState, cutoff: torch.Tensor) -> AtomicGraph: + """Convert a SimState object into an AtomicGraph object. + + Args: + state: SimState object containing atomic positions, cell, and atomic numbers + cutoff: Cutoff radius for the neighbor list + + Returns: + AtomicGraph object representing the batched structures + """ + graphs = [] + + for i in range(state.n_batches): + batch_mask = state.batch == i + R = state.positions[batch_mask] + Z = state.atomic_numbers[batch_mask] + cell = state.row_vector_cell[i] + nl, shifts = vesin_nl_ts( + R, + cell, + state.pbc, + # graph-pes models internally trim the neighbour list to the + # model's cutoff value. To ensure no strange edge effects whereby + # edges that are exactly `cutoff` long are included/excluded, + # we bump this up slightly here + cutoff + 1e-5, + ) + + graphs.append( + AtomicGraph( + Z=Z.long(), + R=R, + cell=cell, + neighbour_list=nl.long(), + neighbour_cell_offsets=shifts, + properties={}, + cutoff=cutoff.item(), + other={}, + ) + ) + + return to_batch(graphs) + + +class GraphPESWrapper(torch.nn.Module, ModelInterface): + """Wrapper for GraphPESModel in TorchSim. + + This class provides a TorchSim wrapper around GraphPESModel instances, + allowing them to be used within the broader torch_sim simulation framework. + + The graph-pes package allows for the training of existing model architectures, + including SchNet, PaiNN, MACE, NequIP, TensorNet, EDDP and more. + You can use any of these, as well as your own custom architectures, with this wrapper. + See the the graph-pes repo for more details: https://github.com/jla-gardner/graph-pes + + Args: + model: GraphPESModel instance, or a path to a model file + device: Device to run the model on + dtype: Data type for the model + compute_forces: Whether to compute forces + compute_stress: Whether to compute stress + + Example: + >>> from torch_sim.models import GraphPESWrapper + >>> from graph_pes.models import load_model + >>> model = load_model("path/to/model.pt") + >>> wrapper = GraphPESWrapper(model) + >>> state = SimState( + ... positions=torch.randn(10, 3), + ... cell=torch.eye(3), + ... atomic_numbers=torch.randint(1, 104, (10,)), + ... ) + >>> wrapper(state) + """ + + def __init__( + self, + model: GraphPESModel | str | Path, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + *, + compute_forces: bool = True, + compute_stress: bool = True, + ) -> None: + """Initialize the GraphPESWrapper. + + Args: + model: GraphPESModel instance, or a path to a model file + device: Device to run the model on + dtype: Data type for the model + compute_forces: Whether to compute forces + compute_stress: Whether to compute stress + """ + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + + _model = typing.cast( + "GraphPESModel", + ( + model if isinstance(model, GraphPESModel) else load_model(model) # type: ignore[arg-type] + ), + ) + self._gp_model = _model.to(device=self._device, dtype=self._dtype) + + self._compute_forces = compute_forces + self._compute_stress = compute_stress + + self._properties: list[PropertyKey] = ["energy"] + if self._compute_forces: + self._properties.append("forces") + if self._compute_stress: + self._properties.append("stress") + + if self._gp_model.cutoff.item() < 0.5: + self._memory_scales_with = "n_atoms" + + def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: + """Forward pass for the GraphPESWrapper. + + Args: + state: SimState object containing atomic positions, cell, and atomic numbers + + Returns: + Dictionary containing the computed energies, forces, and stresses + (where applicable) + """ + if not isinstance(state, SimState): + state = SimState(**state) # type: ignore[arg-type] + + atomic_graph = state_to_atomic_graph(state, self._gp_model.cutoff) + return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value]