diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 0c48c8d3..a2f4394a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -33,10 +33,10 @@ jobs: - name: Install dependencies run: | - uv pip install .[test,docs] --system + uv pip install ".[test,docs]" --system - name: Install extras for tutorial generation - run: uv pip install ".[graphpes,mace]" --system + run: uv pip install ".[graphpes,mace,metatensor]" --system - name: Copy tutorials run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 32f28cff..9a21e889 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -43,7 +43,8 @@ jobs: --ignore=tests/models/test_mace.py \ --ignore=tests/models/test_fairchem.py \ --ignore=tests/models/test_orb.py \ - --ignore=tests/models/test_sevennet.py + --ignore=tests/models/test_sevennet.py \ + --ignore=tests/models/test_metatensor.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 @@ -64,6 +65,7 @@ jobs: - { name: mace, test_path: "tests/models/test_mace.py" } - { name: mace, test_path: "tests/test_elastic.py" } - { name: mattersim, test_path: "tests/models/test_mattersim.py" } + - { name: metatensor, test_path: "tests/models/test_metatensor.py" } - { name: orb, test_path: "tests/models/test_orb.py" } - { name: sevenn, test_path: "tests/models/test_sevennet.py" } - { name: graphpes, test_path: "tests/models/test_graphpes.py" } @@ -112,7 +114,6 @@ jobs: if: ${{ matrix.model.name != 'fairchem' }} run: uv pip install -e .[test,${{ matrix.model.name }}] --resolution=${{ matrix.version.resolution }} --system - - name: Run Tests with Coverage run: | pytest --cov=torch_sim --cov-report=xml ${{ matrix.model.test_path }} diff --git a/README.md b/README.md index 9d2ac766..64ee92d7 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ era. By rewriting the core primitives of atomistic simulation in Pytorch, it all orders of magnitude acceleration of popular machine learning potentials. * Automatic batching and GPU memory management allowing significant simulation speedup -* Support for MACE, Fairchem, and SevenNet MLIP models with more in progress +* Support for MACE, Fairchem, SevenNet, ORB, MatterSim and metatensor MLIP models * Support for classical lennard jones, morse, and soft-sphere potentials * Molecular dynamics integration schemes like NVE, NVT Langevin, and NPT Langevin * Relaxation of atomic positions and cell with gradient descent and FIRE diff --git a/docs/conf.py b/docs/conf.py index 23221250..bd46a1d6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,7 +62,15 @@ "html_image", ] -autodoc_mock_imports = ["fairchem", "mace", "mattersim", "orb", "sevennet", "graphpes"] +autodoc_mock_imports = [ + "fairchem", + "mace", + "mattersim", + "metatensor", + "orb", + "sevennet", + "graphpes", +] # use type hints autodoc_typehints = "description" diff --git a/docs/index.md b/docs/index.md index a129fe9a..eede93c5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,7 +30,7 @@ about/contributing about/license ``` -# torch_sim documentation +# TorchSim documentation **Date**: {sub-ref}`today` @@ -50,7 +50,7 @@ TorchSim is a next-generation open-source atomistic simulation engine for the ML :class-header: bg-light **User Guide** 🚀 ^^^ -The user guide provides in-depth information and tutorials for using *torch_sim*. +The user guide provides in-depth information and tutorials for using *TorchSim*. ::: :::{grid-item-card} @@ -59,7 +59,7 @@ The user guide provides in-depth information and tutorials for using *torch_sim* :class-header: bg-light **API reference** 📖 ^^^ -The reference guide contains a detailed description of the *torch_sim* API. It +The reference guide contains a detailed description of the *TorchSim* API. It assumes that you have an understanding of the key concepts. ::: diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 745c08ad..3517bff5 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -19,3 +19,4 @@ versions of the tutorials can also be found in the `torch-sim /examples/tutorial low_level_tutorial hybrid_swap_tutorial using_graphpes_tutorial + metatensor_tutorial diff --git a/examples/tutorials/metatensor_tutorial.py b/examples/tutorials/metatensor_tutorial.py new file mode 100644 index 00000000..d2afe47e --- /dev/null +++ b/examples/tutorials/metatensor_tutorial.py @@ -0,0 +1,73 @@ +# %% [markdown] +#
+# Dependencies +# /// script +# dependencies = [ +# "metatrain[pet] >=2025.4", +# "metatensor-torch >=0.7,<0.8" +# ] +# /// +#
+ + +# %% [markdown] +""" +# Using the PET-MAD model with metatensor + +This tutorial explains how to use the PET-MAD model (https://arxiv.org/abs/2503.14118) +via TorchSim's metatensor interface. + +## Loading the model + +Loading the model is simple: you simply need to specify the model name (in this case +"pet-mad"), as shown below. All other arguments are optional: for example, you could +specify the device. (If the device is not specified, like in this case, the optimal +device is chosen automatically.) +""" + +# %% +from torch_sim.models import MetatensorModel + +model = MetatensorModel("pet-mad") + +# %% [markdown] +""" +## Using the model to run a molecular dynamics simulations + +Once the model is loaded, you can use it just like any other TorchSim model to run +simulations. Here, we show how to run a simple MD simulation consisting of an initial +NVT equilibration run followed by an NVE run. +""" +# %% +from ase.build import bulk +import torch_sim as ts + +atoms = bulk("Si", "diamond", a=5.43, cubic=True) + +equilibrated_state = ts.integrate( + system=atoms, + model=model, + integrator=ts.nvt_langevin, + n_steps=100, + temperature=300, # K + timestep=0.001, # ps +) + +final_state = ts.integrate( + system=equilibrated_state, + model=model, + integrator=ts.nve, + n_steps=100, + temperature=300, # K + timestep=0.001, # ps +) + +# %% [markdown] +""" +## Further steps + +Of course, in reality, you would want to run the simulation for much longer, probably +save trajectories, and much more. However, this is all you need to get started with +metatensor and PET-MAD. For more details on how to use TorchSim, you can refer to the +other tutorials in this section. +""" diff --git a/pyproject.toml b/pyproject.toml index 38dd59ed..8558f845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ test = [ ] mace = ["mace-torch>=0.3.11"] mattersim = ["mattersim>=0.1.2"] +metatensor = ["metatensor-torch >=0.7,<0.8", "metatrain[pet] >=2025.4"] orb = [ "orb-models@git+https://github.com/orbital-materials/orb-models#egg=637a98d49cfb494e2491a457d9bbd28311fecf21", ] diff --git a/tests/models/test_metatensor.py b/tests/models/test_metatensor.py new file mode 100644 index 00000000..806aacd4 --- /dev/null +++ b/tests/models/test_metatensor.py @@ -0,0 +1,65 @@ +import pytest +import torch + +from tests.models.conftest import ( + consistency_test_simstate_fixtures, + make_model_calculator_consistency_test, + make_validate_model_outputs_test, +) + + +try: + from metatensor.torch.atomistic import ase_calculator + from metatrain.utils.io import load_model + + from torch_sim.models.metatensor import MetatensorModel +except ImportError: + pytest.skip("Metatensor not installed", allow_module_level=True) + + +@pytest.fixture +def dtype() -> torch.dtype: + """Fixture to provide the default dtype for testing.""" + return torch.float32 + + +@pytest.fixture +def metatensor_calculator(device: torch.device): + """Load a pretrained metatensor model for testing.""" + return ase_calculator.MetatensorCalculator( + model=load_model( + "https://huggingface.co/lab-cosmo/pet-mad/resolve/main/models/pet-mad-latest.ckpt" + ).export(), + device=device, + ) + + +@pytest.fixture +def metatensor_model(device: torch.device) -> MetatensorModel: + """Create an MetatensorModel wrapper for the pretrained model.""" + return MetatensorModel( + model="pet-mad", + device=device, + ) + + +def test_metatensor_initialization(device: torch.device) -> None: + """Test that the metatensor model initializes correctly.""" + model = MetatensorModel( + model="pet-mad", + device=device, + ) + assert model.device == device + assert model.dtype == torch.float32 + + +test_mattersim_consistency = make_model_calculator_consistency_test( + test_name="metatensor", + model_fixture_name="metatensor_model", + calculator_fixture_name="metatensor_calculator", + sim_state_names=consistency_test_simstate_fixtures, +) + +test_mattersim_model_outputs = make_validate_model_outputs_test( + model_fixture_name="metatensor_model", +) diff --git a/torch_sim/models/__init__.py b/torch_sim/models/__init__.py index 28675a71..924284bc 100644 --- a/torch_sim/models/__init__.py +++ b/torch_sim/models/__init__.py @@ -36,3 +36,8 @@ from torch_sim.models.graphpes import GraphPESWrapper except ImportError: pass + +try: + from torch_sim.models.metatensor import MetatensorModel +except ImportError: + pass diff --git a/torch_sim/models/metatensor.py b/torch_sim/models/metatensor.py new file mode 100644 index 00000000..18065cc0 --- /dev/null +++ b/torch_sim/models/metatensor.py @@ -0,0 +1,287 @@ +"""Wrapper for metatensor-based models in TorchSim. + +This module provides a TorchSim wrapper of metatensor models for computing +energies, forces, and stresses for atomistic systems, including batched computations +for multiple systems simultaneously. + +The MetatensorModel class adapts metatensor models to the ModelInterface protocol, +allowing them to be used within the broader torch_sim simulation framework. + +Notes: + This module depends on the metatensor-torch package. +""" + +from pathlib import Path + +import torch +import vesin.torch.metatensor + +from torch_sim.models.interface import ModelInterface +from torch_sim.state import SimState, StateDict + + +try: + from metatensor.torch.atomistic import ( + ModelEvaluationOptions, + ModelOutput, + System, + load_atomistic_model, + ) + from metatrain.utils.io import load_model + +except ImportError: + + class MetatensorModel(torch.nn.Module, ModelInterface): + """Metatensor model wrapper for torch_sim. + + This class is a placeholder for the MetatensorModel class. + It raises an ImportError if metatensor is not installed. + """ + + def __init__(self, *args, **kwargs) -> None: # noqa: ARG002 + """Dummy constructor.""" + raise ImportError("metatensor must be installed to use MetatensorModel.") + + +class MetatensorModel(torch.nn.Module, ModelInterface): + """Computes energies for a list of systems using a metatensor model. + + This class wraps a metatensor model to compute energies, forces, and stresses for + atomic systems within the TorchSim framework. It supports batched calculations + for multiple systems and handles the necessary transformations between + TorchSim's data structures and metatensor's expected inputs. + + Attributes: + ... + """ + + def __init__( + self, + model: str | Path | None = None, + extensions_path: str | Path | None = None, + device: torch.device | str | None = None, + *, + check_consistency: bool = False, + compute_forces: bool = True, + compute_stress: bool = True, + ) -> None: + """Initialize the metatensor model for energy, force and stress calculations. + + Sets up a metatensor model for energy, force, and stress calculations within + the TorchSim framework. The model can be initialized with atomic numbers + and batch indices, or these can be provided during the forward pass. + + Args: + model (str | Path | None): Path to the metatensor model file or a + pre-defined model name. Currently only "pet-mad" + (https://arxiv.org/abs/2503.14118) is supported as a pre-defined model. + If None, defaults to "pet-mad". + extensions_path (str | Path | None): Optional, path to the folder containing + compiled extensions for the model. + device (torch.device | None): Device on which to run the model. If None, + defaults to "cuda" if available, otherwise "cpu". + check_consistency (bool): Whether to perform various consistency checks + during model evaluation. This should only be used in case of anomalous + behavior, as it can hurt performance significantly. + compute_forces (bool): Whether to compute forces. + compute_stress (bool): Whether to compute stresses. + + Raises: + TypeError: If model is neither a path nor "pet-mad". + """ + super().__init__() + + if model is None: + raise ValueError( + "A model path, or the name of a pre-defined model, must be provided. " + 'Currently only "pet-mad" is available as a pre-defined model.' + ) + + if model == "pet-mad": + path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/main/models/pet-mad-latest.ckpt" + self._model = load_model(path).export() + elif model.endswith(".ckpt"): + path = model + self._model = load_model(path).export() + elif model.endswith(".pt"): + path = model + self._model = load_atomistic_model(path, extensions_path) + else: + raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".') + + if "energy" not in self._model.capabilities().outputs: + raise ValueError( + "This model does not support energy predictions. " + "The model must have an `energy` output to be used in torch-sim." + ) + + self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") + if isinstance(self._device, str): + self._device = torch.device(self._device) + if self._device.type not in self._model.capabilities().supported_devices: + raise ValueError( + f"Model does not support device {self._device}. Supported devices: " + f"{self._model.capabilities().supported_devices}. You might want to " + f"set the `device` argument to a supported device." + ) + + self._dtype = getattr(torch, self._model.capabilities().dtype) + self._model.to(self._device) + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" # for the majority of models + self._check_consistency = check_consistency + self._requested_neighbor_lists = self._model.requested_neighbor_lists() + self._evaluation_options = ModelEvaluationOptions( + length_unit="angstrom", + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + per_atom=False, + ) + }, + ) + + def forward( # noqa: C901, PLR0915 + self, + state: SimState | StateDict, + ) -> dict[str, torch.Tensor]: + """Compute energies, forces, and stresses for the given atomic systems. + + Processes the provided state information and computes energies, forces, and + stresses using the underlying metatensor model. Handles batched calculations for + multiple systems as well as constructing the necessary neighbor lists. + + Args: + state (SimState | StateDict): State object containing positions, cell, + and other system information. Can be either a SimState object or a + dictionary with the relevant fields. + + Returns: + dict[str, torch.Tensor]: Dictionary containing: + - 'energy': System energies with shape [n_systems] + - 'forces': Atomic forces with shape [n_atoms, 3] if compute_forces=True + - 'stress': System stresses with shape [n_systems, 3, 3] if + compute_stress=True + """ + # Extract required data from input + if isinstance(state, dict): + state = SimState(**state, masses=torch.ones_like(state["positions"])) + + # Input validation is already done inside the forward method of the + # MetatensorAtomisticModel class, so we don't need to do it again here. + + atomic_numbers = state.atomic_numbers + cell = state.row_vector_cell + positions = state.positions + pbc = state.pbc + + # Check dtype (metatensor models require a specific input dtype) + if positions.dtype != self._dtype: + raise TypeError( + f"Positions dtype {positions.dtype} does not match model dtype " + f"{self._dtype}" + ) + + # Compared to other models, metatensor models have two peculiarities: + # - different structures are fed to the models separately as a list of System + # objects, and not as a single graph-like batch + # - the model does not compute forces and stresses itself, but rather the + # caller code needs to call torch.autograd.grad or similar to compute them + # from the energy output + + # Process each system separately + systems: list[System] = [] + strains = [] + for b in range(len(cell)): + system_mask = state.batch == b + system_positions = positions[system_mask] + system_cell = cell[b] + system_pbc = torch.tensor( + [pbc, pbc, pbc], device=self._device, dtype=torch.bool + ) + system_atomic_numbers = atomic_numbers[system_mask] + + # Create a System object for this system + if self._compute_forces: + system_positions.requires_grad_() + if self._compute_stress: + strain = torch.eye( + 3, device=self._device, dtype=self._dtype, requires_grad=True + ) + system_positions = system_positions @ strain + system_cell = system_cell @ strain + + systems.append( + System( + positions=system_positions, + types=system_atomic_numbers, + cell=system_cell, + pbc=system_pbc, + ) + ) + + # Calculate the required neighbor list(s) for all the systems + vesin.torch.metatensor.compute_requested_neighbors( + systems, system_length_unit="Angstrom", model=self._model + ) + + # Get model output + model_outputs = self._model( + systems=systems, + options=self._evaluation_options, + check_consistency=self._check_consistency, + ) + + results = {} + results["energy"] = model_outputs["energy"].block().values.detach().squeeze(-1) + + # Compute forces and/or stresses if requested + tensors_for_autograd = [] + if self._compute_forces: + for system in systems: + tensors_for_autograd.append(system.positions) # noqa: PERF401 + if self._compute_stress: + for strain in strains: + tensors_for_autograd.append(strain) # noqa: PERF402 + + if self._compute_forces or self._compute_stress: + derivatives = torch.autograd.grad( + outputs=model_outputs["energy"].block().values, + inputs=tensors_for_autograd, + grad_outputs=torch.ones_like(model_outputs["energy"].block().values), + ) + else: + derivatives = [] + + results_by_system: dict[str, list[torch.Tensor]] = {} + if self._compute_forces and self._compute_stress: + results_by_system["forces"] = [-d for d in derivatives[: len(systems)]] + results_by_system["stress"] = [ + d / torch.abs(torch.det(system.cell.detach())) + for d, system in zip(derivatives[len(systems) :], systems, strict=False) + ] + elif self._compute_forces: + results_by_system["forces"] = [-d for d in derivatives] + elif self._compute_stress: + results_by_system["stress"] = [ + d / torch.abs(torch.det(system.cell.detach())) + for d, system in zip(derivatives, systems, strict=False) + ] + else: + pass + + # Concatenate/stack forces and stresses + if self._compute_forces: + if len(results_by_system["forces"]) > 0: + results["forces"] = torch.cat(results_by_system["forces"]) + else: + results["forces"] = torch.empty_like(positions) + if self._compute_stress: + if len(results_by_system["stress"]) > 0: + results["stress"] = torch.stack(results_by_system["stress"]) + else: + results["stress"] = torch.empty_like(cell) + + return results