diff --git a/README.md b/README.md index 643f05c6..1fa32906 100644 --- a/README.md +++ b/README.md @@ -139,4 +139,4 @@ TorchSim is released under an [MIT license](LICENSE). ## Citation -If you use TorchSim in your research, please cite the [arXiv preprint](https://arxiv.org/abs/2508.06628). +If you use TorchSim in your research, please cite our [publication](https://iopscience.iop.org/article/10.1088/3050-287X/ae1799). diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index be7b3914..3c1c5648 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -155,14 +155,12 @@ def __init__( # Load model if provided as path if isinstance(model, str | Path): - # Implement model loading from file - raise NotImplementedError("Loading model from file not implemented yet") - if isinstance(model, torch.nn.Module): - self.model = model + self.model = torch.load(model, map_location=self._device) + elif isinstance(model, torch.nn.Module): + self.model = model.to(self._device) else: raise TypeError("Model must be a path or torch.nn.Module") - self.model = model.to(self._device) self.model = self.model.eval() if self.dtype is not None: @@ -239,7 +237,9 @@ def setup_from_system_idx( dtype=self.dtype, ) - def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # noqa: C901 + def forward( # noqa: C901 + self, state: ts.SimState | StateDict + ) -> dict[str, torch.Tensor]: """Compute energies, forces, and stresses for the given atomic systems. Processes the provided state information and computes energies, forces, and