Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ test = [
"pytest-cov>=6",
"pytest>=8",
]
io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"]
mace = ["mace-torch>=0.3.12"]
mattersim = ["mattersim>=0.1.2"]
metatensor = ["metatensor-torch>=0.7,<0.8", "metatrain[pet]>=2025.4"]
Expand All @@ -65,6 +66,7 @@ docs = [
"sphinx==8.1.3",
"sphinx_design==0.6.1",
]

[project.urls]
Repo = "https://github.com/radical-ai/torch-sim"

Expand Down
65 changes: 65 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import sys
from typing import Any

import pytest
Expand Down Expand Up @@ -256,3 +257,67 @@ def test_state_round_trip(
# since both use their own isotope masses based on species,
# not the ones in the state
assert torch.allclose(sim_state.masses, round_trip_state.masses)


def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setitem(sys.modules, "ase", None)
monkeypatch.setitem(sys.modules, "ase.data", None)

with pytest.raises(
ImportError, match="ASE is required for state_to_atoms conversion"
):
ts.io.state_to_atoms(None)


def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setitem(sys.modules, "phonopy", None)
monkeypatch.setitem(sys.modules, "phonopy.structure", None)
monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None)

with pytest.raises(
ImportError, match="Phonopy is required for state_to_phonopy conversion"
):
ts.io.state_to_phonopy(None)


def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setitem(sys.modules, "pymatgen", None)
monkeypatch.setitem(sys.modules, "pymatgen.core", None)
monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None)

with pytest.raises(
ImportError, match="Pymatgen is required for state_to_structures conversion"
):
ts.io.state_to_structures(None)


def test_atoms_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setitem(sys.modules, "ase", None)
monkeypatch.setitem(sys.modules, "ase.data", None)

with pytest.raises(
ImportError, match="ASE is required for atoms_to_state conversion"
):
ts.io.atoms_to_state(None, None, None)


def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setitem(sys.modules, "phonopy", None)
monkeypatch.setitem(sys.modules, "phonopy.structure", None)
monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None)

with pytest.raises(
ImportError, match="Phonopy is required for phonopy_to_state conversion"
):
ts.io.phonopy_to_state(None, None, None)


def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setitem(sys.modules, "pymatgen", None)
monkeypatch.setitem(sys.modules, "pymatgen.core", None)
monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None)

with pytest.raises(
ImportError, match="Pymatgen is required for structures_to_state conversion"
):
ts.io.structures_to_state(None, None, None)
45 changes: 45 additions & 0 deletions tests/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from collections.abc import Callable, Generator
from pathlib import Path

Expand Down Expand Up @@ -796,3 +797,47 @@ def energy_calculator(state: SimState, model: torch.nn.Module) -> torch.Tensor:
np.testing.assert_allclose(batch_props["energy"], file_energy)

trajectory.close()


def test_get_atoms_importerror(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
# Simulate missing ase
monkeypatch.setitem(sys.modules, "ase", None)

traj = TorchSimTrajectory(tmp_path / "dummy.h5", mode="w")
# Write minimal data so get_atoms can be called
state = SimState(
positions=torch.zeros(1, 3),
masses=torch.ones(1),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.ones(1, dtype=torch.int),
)
traj.write_state(state, steps=0)

with pytest.raises(ImportError, match="ASE is required to convert to ASE Atoms"):
traj.get_atoms(0)
traj.close()


def test_write_ase_trajectory_importerror(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
# Simulate missing ase.io.trajectory
monkeypatch.setitem(sys.modules, "ase", None)
monkeypatch.setitem(sys.modules, "ase.io", None)
monkeypatch.setitem(sys.modules, "ase.io.trajectory", None)

traj = TorchSimTrajectory(tmp_path / "dummy.h5", mode="w")
# Write minimal data so write_ase_trajectory can be called
state = SimState(
positions=torch.zeros(1, 3),
masses=torch.ones(1),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.ones(1, dtype=torch.int),
)
traj.write_state(state, steps=0)

with pytest.raises(ImportError, match="ASE is required to convert to ASE trajectory"):
traj.write_ase_trajectory(tmp_path / "dummy.traj")
traj.close()
10 changes: 4 additions & 6 deletions torch_sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def state_to_structures(state: "SimState") -> list["Structure"]:
from pymatgen.core.periodic_table import Element
except ImportError:
raise ImportError(
"Pymatgen is required for state_to_structure conversion"
"Pymatgen is required for state_to_structures conversion"
) from None

# Convert tensors to numpy arrays on CPU
Expand Down Expand Up @@ -148,9 +148,7 @@ def state_to_phonopy(state: "SimState") -> list["PhonopyAtoms"]:
from ase.data import chemical_symbols
from phonopy.structure.atoms import PhonopyAtoms
except ImportError:
raise ImportError(
"Phonopy is required for state_to_phonopy_atoms conversion"
) from None
raise ImportError("Phonopy is required for state_to_phonopy conversion") from None

# Convert tensors to numpy arrays on CPU
positions = state.positions.detach().cpu().numpy()
Expand Down Expand Up @@ -209,7 +207,7 @@ def atoms_to_state(
try:
from ase import Atoms
except ImportError:
raise ImportError("ASE is required for state_to_atoms conversion") from None
raise ImportError("ASE is required for atoms_to_state conversion") from None

atoms_list = [atoms] if isinstance(atoms, Atoms) else atoms

Expand Down Expand Up @@ -280,7 +278,7 @@ def structures_to_state(
from pymatgen.core import Structure
except ImportError:
raise ImportError(
"Pymatgen is required for state_to_structure conversion"
"Pymatgen is required for structures_to_state conversion"
) from None

struct_list = [structure] if isinstance(structure, Structure) else structure
Expand Down
17 changes: 13 additions & 4 deletions torch_sim/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import pathlib
from collections.abc import Callable
from functools import partial
from typing import Any, Literal, Self
from typing import TYPE_CHECKING, Any, Literal, Self

import numpy as np
import tables
Expand All @@ -41,6 +41,10 @@
from torch_sim.state import SimState


if TYPE_CHECKING:
from ase import Atoms
from ase.io.trajectory import Trajectory

Check warning on line 46 in torch_sim/trajectory.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/trajectory.py#L45-L46

Added lines #L45 - L46 were not covered by tests

_DATA_TYPE_MAP = {
np.dtype("float32"): tables.Float32Atom(),
np.dtype("float64"): tables.Float64Atom(),
Expand Down Expand Up @@ -848,7 +852,7 @@
validate_proximity=False,
)

def get_atoms(self, frame: int = -1) -> Any:
def get_atoms(self, frame: int = -1) -> "Atoms":
"""Get an ASE Atoms object for a given frame.

Converts the state at the specified frame to an ASE Atoms object
Expand All @@ -863,7 +867,12 @@
Raises:
ImportError: If ASE is not installed
"""
from ase import Atoms
try:
from ase import Atoms
except ImportError:
raise ImportError(
"ASE is required to convert to ASE Atoms. Run `pip install ase`"
) from None

arrays = self._get_state_arrays(frame)

Expand Down Expand Up @@ -959,7 +968,7 @@
"""
return self._file.root.data.positions.shape[0]

def write_ase_trajectory(self, filename: str | pathlib.Path) -> Any:
def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory":
"""Convert trajectory to ASE Trajectory format.

Writes the entire trajectory to a new file in ASE format for compatibility
Expand Down
Loading