Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,4 @@ TorchSim is released under an [MIT license](LICENSE).

## Citation

If you use TorchSim in your research, please cite the [arXiv preprint](https://arxiv.org/abs/2508.06628).
If you use TorchSim in your research, please cite our [publication](https://iopscience.iop.org/article/10.1088/3050-287X/ae1799).
12 changes: 6 additions & 6 deletions torch_sim/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,12 @@ def __init__(

# Load model if provided as path
if isinstance(model, str | Path):
# Implement model loading from file
raise NotImplementedError("Loading model from file not implemented yet")
if isinstance(model, torch.nn.Module):
self.model = model
self.model = torch.load(model, map_location=self._device)
elif isinstance(model, torch.nn.Module):
self.model = model.to(self._device)
else:
raise TypeError("Model must be a path or torch.nn.Module")

self.model = model.to(self._device)
self.model = self.model.eval()

if self.dtype is not None:
Expand Down Expand Up @@ -239,7 +237,9 @@ def setup_from_system_idx(
dtype=self.dtype,
)

def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # noqa: C901
def forward( # noqa: C901
self, state: ts.SimState | StateDict
) -> dict[str, torch.Tensor]:
"""Compute energies, forces, and stresses for the given atomic systems.

Processes the provided state information and computes energies, forces, and
Expand Down
Loading