-
Notifications
You must be signed in to change notification settings - Fork 55
add support for graph-pes models
#118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d74da97
working models
jla-gardner e23124b
add tutorial
jla-gardner ff879d3
first set of fixes
jla-gardner e2a4e1b
fix tests
jla-gardner a671f90
change rattling scale
jla-gardner 8fe42bf
use `medium-mpa-0`
jla-gardner 295854e
correct graph-pes calculator instantiation
jla-gardner a03ddaf
reduce bump
jla-gardner ed18944
always build docs but only deploy on main
orionarcher 1a397eb
build docs on pull requests too
orionarcher 4896578
update docs
jla-gardner 459b710
try with normal `atol`
jla-gardner 32b89c8
comment on high required rtol
jla-gardner 2d3b67a
fix docs
jla-gardner 06423dd
update tutorial text
jla-gardner File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # %% [markdown] | ||
| # <details> | ||
| # <summary>Dependencies</summary> | ||
| # /// script | ||
| # dependencies = [ | ||
| # "graph-pes>=0.0.30", | ||
| # "torch==2.5", | ||
| # ] | ||
| # /// | ||
| # </details> | ||
|
|
||
|
|
||
| # %% [markdown] | ||
| """ | ||
| # Integrating TorchSim with `graph-pes` | ||
|
|
||
| This brief tutorial demonstrates how to use models trained with the | ||
| [graph-pes](https://github.com/mir-group/graph-pes) package to drive | ||
| MD simulations and geometry optimizations in TorchSim. | ||
|
|
||
| ## Step 1: loading a model | ||
|
|
||
| As an output of the `graph-pes-train` command, you receive a path | ||
| to a `.pt` file containing your trained model. To use this model | ||
| with TorchSim, pass the path to this `.pt` file, or the model itself, | ||
| to the `GraphPESWrapper` constructor. | ||
|
|
||
| Below, we create a dummy TensorNet model with random weights as a demonstration: | ||
| """ | ||
|
|
||
| # %% | ||
| from graph_pes.models import TensorNet, load_model | ||
|
|
||
| # if you had a model saved to disk, you could load it like this: | ||
| # model = load_model("path/to/model.pt") | ||
|
|
||
| # here, we just create a TensorNet model with random weights | ||
| model = TensorNet(cutoff=5.0) | ||
|
|
||
| print("Number of parameters:", sum(p.numel() for p in model.parameters())) | ||
|
|
||
| # %% [markdown] | ||
| """ | ||
| ## Step 2: wrapping the model for use with TorchSim | ||
|
|
||
| We provide the `GraphPESWrapper` class to wrap a `graph-pes` model for use with TorchSim. | ||
| If you intend to drive simulations that require stresses, you will need to specify the | ||
| `compute_stress` argument to `True`. | ||
| """ | ||
|
|
||
| # %% | ||
| from torch_sim.models import GraphPESWrapper | ||
|
|
||
| # wrap the model for use with TorchSim | ||
| ts_model = GraphPESWrapper(model, compute_stress=False) | ||
|
|
||
| # or, alternatively, pass a model path directly: | ||
| # ts_model = GraphPESWrapper("path/to/model.pt", compute_stress=False) | ||
|
|
||
| # %% [markdown] | ||
| """ | ||
| ## Step 3: driving MD with the model | ||
|
|
||
| Now that we have a model, we can drive MD simulations with it. For this, we will use the | ||
| `integrate` function. | ||
| """ | ||
| # %% | ||
| from ase.build import molecule | ||
| import torch_sim as ts | ||
| from load_atoms import view | ||
|
|
||
| # NVT at 300K | ||
| atoms = molecule("H2O") | ||
|
|
||
| final_state = ts.integrate( | ||
| system=atoms, | ||
| model=ts_model, | ||
| integrator=ts.nvt_langevin, | ||
| n_steps=50, | ||
| temperature=300, | ||
| timestep=0.001, | ||
| ) | ||
|
|
||
| final_atoms = final_state.to_atoms()[0] | ||
| view(final_atoms, show_bonds=True) | ||
|
|
||
| # %% [markdown] | ||
| """ | ||
| Of course, this is a very simple example. However, you are now equipped to | ||
| use any `graph-pes` model that you have trained to drive any of the functionality | ||
| exposed by TorchSim! | ||
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,214 @@ | ||
| import pytest | ||
| import torch | ||
| from ase.build import bulk, molecule | ||
|
|
||
| from tests.models.conftest import ( | ||
| consistency_test_simstate_fixtures, | ||
| make_model_calculator_consistency_test, | ||
| make_validate_model_outputs_test, | ||
| ) | ||
| from torch_sim.io import atoms_to_state | ||
| from torch_sim.models.graphpes import GraphPESWrapper | ||
|
|
||
|
|
||
| try: | ||
| from graph_pes.atomic_graph import AtomicGraph, to_batch | ||
| from graph_pes.interfaces import mace_mp | ||
| from graph_pes.models import LennardJones, SchNet, TensorNet, ZEmbeddingNequIP | ||
| except ImportError: | ||
| pytest.skip("graph-pes not installed", allow_module_level=True) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def dtype() -> torch.dtype: | ||
| """Fixture to provide the default dtype for testing.""" | ||
| return torch.float32 | ||
|
|
||
|
|
||
| def test_graphpes_isolated(device: torch.device): | ||
| # test that the raw model and torch_sim wrapper give the same results | ||
| # for an isolated, unbatched structure | ||
|
|
||
| water_atoms = molecule("H2O") | ||
| water_atoms.center(vacuum=10.0) | ||
|
|
||
| gp_model = SchNet(cutoff=5.5) | ||
| gp_graph = AtomicGraph.from_ase(water_atoms, cutoff=5.5) | ||
| gp_energy = gp_model.predict_energy(gp_graph) | ||
|
|
||
| ts_model = GraphPESWrapper( | ||
| gp_model, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| compute_forces=True, | ||
| compute_stress=False, | ||
| ) | ||
| ts_output = ts_model(atoms_to_state([water_atoms], device, torch.float32)) | ||
| assert set(ts_output.keys()) == {"energy", "forces"} | ||
| assert ts_output["energy"].shape == (1,) | ||
|
|
||
| assert gp_energy.item() == pytest.approx(ts_output["energy"].item(), abs=1e-5) | ||
|
|
||
|
|
||
| def test_graphpes_periodic(device: torch.device): | ||
| # test that the raw model and torch_sim wrapper give the same results | ||
| # for a periodic, unbatched structure | ||
|
|
||
| bulk_atoms = bulk("Al", "hcp", a=4.05) | ||
| assert bulk_atoms.pbc.all() | ||
|
|
||
| gp_model = TensorNet(cutoff=5.5) | ||
| gp_graph = AtomicGraph.from_ase(bulk_atoms, cutoff=5.5) | ||
| gp_forces = gp_model.predict_forces(gp_graph) | ||
|
|
||
| ts_model = GraphPESWrapper( | ||
| gp_model, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| compute_forces=True, | ||
| compute_stress=True, | ||
| ) | ||
| ts_output = ts_model(atoms_to_state([bulk_atoms], device, torch.float32)) | ||
| assert set(ts_output.keys()) == {"energy", "forces", "stress"} | ||
| assert ts_output["energy"].shape == (1,) | ||
| assert ts_output["forces"].shape == (len(bulk_atoms), 3) | ||
| assert ts_output["stress"].shape == (1, 3, 3) | ||
|
|
||
| torch.testing.assert_close(ts_output["forces"].to("cpu"), gp_forces) | ||
|
|
||
|
|
||
| def test_batching(device: torch.device): | ||
| # test that the raw model and torch_sim wrapper give the same results | ||
| # when batching is done via torch_sim's atoms_to_state function | ||
|
|
||
| water = molecule("H2O") | ||
| methane = molecule("CH4") | ||
| systems = [water, methane] | ||
| for s in systems: | ||
| s.center(vacuum=10.0) | ||
|
|
||
| gp_model = SchNet(cutoff=5.5) | ||
| gp_graphs = [AtomicGraph.from_ase(s, cutoff=5.5) for s in systems] | ||
|
|
||
| gp_energies = gp_model.predict_energy(to_batch(gp_graphs)) | ||
|
|
||
| ts_model = GraphPESWrapper( | ||
| gp_model, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| compute_forces=True, | ||
| compute_stress=True, | ||
| ) | ||
| ts_output = ts_model(atoms_to_state(systems, device, torch.float32)) | ||
|
|
||
| assert set(ts_output.keys()) == {"energy", "forces", "stress"} | ||
| assert ts_output["energy"].shape == (2,) | ||
| assert ts_output["forces"].shape == (sum(len(s) for s in systems), 3) | ||
| assert ts_output["stress"].shape == (2, 3, 3) | ||
|
|
||
| assert gp_energies[0].item() == pytest.approx(ts_output["energy"][0].item(), abs=1e-5) | ||
|
|
||
|
|
||
CompRhys marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) | ||
| def test_graphpes_dtype(device: torch.device, dtype: torch.dtype): | ||
| water = molecule("H2O") | ||
|
|
||
| model = SchNet() | ||
|
|
||
| ts_wrapper = GraphPESWrapper(model, device=device, dtype=dtype, compute_stress=False) | ||
| ts_output = ts_wrapper(atoms_to_state([water], device, dtype)) | ||
| assert ts_output["energy"].dtype == dtype | ||
| assert ts_output["forces"].dtype == dtype | ||
|
|
||
|
|
||
| _nequip_model = ZEmbeddingNequIP() | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ts_nequip_model(device: torch.device, dtype: torch.dtype): | ||
| return GraphPESWrapper( | ||
| _nequip_model, | ||
| device=device, | ||
| dtype=dtype, | ||
| compute_stress=False, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ase_nequip_calculator(device: torch.device, dtype: torch.dtype): | ||
| return _nequip_model.to(device, dtype).ase_calculator(skin=0.0) | ||
|
|
||
|
|
||
| test_graphpes_nequip_consistency = make_model_calculator_consistency_test( | ||
| test_name="graphpes-nequip", | ||
| model_fixture_name="ts_nequip_model", | ||
| calculator_fixture_name="ase_nequip_calculator", | ||
| sim_state_names=consistency_test_simstate_fixtures, | ||
| ) | ||
|
|
||
| test_graphpes_nequip_model_outputs = make_validate_model_outputs_test( | ||
| model_fixture_name="ts_nequip_model", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ts_mace_model(device: torch.device, dtype: torch.dtype): | ||
| return GraphPESWrapper( | ||
| mace_mp("medium-mpa-0"), | ||
| device=device, | ||
| dtype=dtype, | ||
| compute_stress=False, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ase_mace_calculator(device: torch.device, dtype: torch.dtype): | ||
| return mace_mp("medium-mpa-0").to(device, dtype).ase_calculator(skin=0.0) | ||
|
|
||
|
|
||
| test_graphpes_mace_consistency = make_model_calculator_consistency_test( | ||
| test_name="graphpes-mace", | ||
| model_fixture_name="ts_mace_model", | ||
| calculator_fixture_name="ase_mace_calculator", | ||
| sim_state_names=consistency_test_simstate_fixtures, | ||
| # graph-pes passes data directly to the underlying mace-torch model | ||
| # from test_mace.py, it seems that these mace-torch models can be | ||
| # surprisingly variable in the CI (these tests pass locally on | ||
| # MacBooks with no need for high tolerances) | ||
| # While investigating, I found that mace-torch model predictions are | ||
| # mildly sensitive to the order of items in the neighbourlist - this | ||
| # could be the cause of the discrepancies between the ASE calculator | ||
| # and the TorchSim wrapper, both here and in test_mace.py | ||
| rtol=6e-4, | ||
| atol=1e-5, | ||
| ) | ||
|
|
||
| test_graphpes_mace_model_outputs = make_validate_model_outputs_test( | ||
| model_fixture_name="ts_mace_model", | ||
| ) | ||
|
|
||
|
|
||
| _lj_model = LennardJones(sigma=0.5) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ts_lj_model(device: torch.device, dtype: torch.dtype): | ||
| return GraphPESWrapper( | ||
| _lj_model, | ||
| device=device, | ||
| dtype=dtype, | ||
| compute_stress=False, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ase_lj_calculator(device: torch.device, dtype: torch.dtype): | ||
| return _lj_model.to(device, dtype).ase_calculator(skin=0.0) | ||
|
|
||
|
|
||
| test_graphpes_lj_consistency = make_model_calculator_consistency_test( | ||
| test_name="graphpes-lj", | ||
| model_fixture_name="ts_lj_model", | ||
| calculator_fixture_name="ase_lj_calculator", | ||
| sim_state_names=consistency_test_simstate_fixtures, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.