Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_data_type_conversions(test_file: Path) -> None:

rng = np.random.default_rng(seed=0)
# Test data with different types
test_data = {
test_data: dict[str, np.ndarray | torch.Tensor] = {
# NumPy arrays
"np_float64": rng.random((10, 3)).astype(np.float64),
"np_float32": rng.random((10, 3)).astype(np.float32),
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_scalar_dtype_handling(test_file: Path) -> None:
mode="w",
)

scalar_data = {
scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = {
"float64_scalar": np.float64(1.0),
"float32_scalar": np.float32(1.0),
"int64_scalar": np.int64(1),
Expand Down
10 changes: 9 additions & 1 deletion torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from collections.abc import Callable, Iterator
from itertools import chain
from typing import Any, get_args
from typing import Any, get_args, overload

import torch

Expand Down Expand Up @@ -570,6 +570,14 @@ def load_states(

return self.max_memory_scaler

@overload
def next_batch(self, *, return_indices: bool = False) -> SimState | None: ...

@overload
def next_batch(
self, *, return_indices: bool = True
) -> tuple[SimState, list[int]] | None: ...

def next_batch(
self, *, return_indices: bool = False
) -> SimState | tuple[SimState, list[int]] | None:
Expand Down
11 changes: 8 additions & 3 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _configure_batches_iterator(
return batches


def integrate(
def integrate( # noqa: C901
system: StateLike,
model: ModelInterface,
*,
Expand Down Expand Up @@ -146,14 +146,19 @@ def integrate(
"""
unit_system = UnitSystem.metal
# create a list of temperatures
temps = temperature if hasattr(temperature, "__iter__") else [temperature] * n_steps
if hasattr(temperature, "__iter__"):
temps = torch.tensor(temperature, dtype=model.dtype, device=model.device)
else:
temps = torch.tensor(
[temperature] * n_steps, dtype=model.dtype, device=model.device
)
if len(temps) != n_steps:
raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}")

# initialize the state
state: SimState = initialize_state(system, model.device, model.dtype)
dtype, device = state.dtype, state.device
kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature
kTs = temps * unit_system.temperature
init_fn, update_fn = integrator(
model=model,
kT=kTs[0],
Expand Down
45 changes: 28 additions & 17 deletions torch_sim/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import copy
import inspect
import pathlib
from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, Self

Expand All @@ -44,7 +44,7 @@

if TYPE_CHECKING:
from ase import Atoms
from ase.io.trajectory import Trajectory
from ase.io.trajectory import TrajectoryReader
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view my comment to Rhys below


_DATA_TYPE_MAP = {
np.dtype("float32"): tables.Float32Atom(),
Expand Down Expand Up @@ -91,9 +91,18 @@ class TrajectoryReporter:
>>> reporter.close()
"""

state_frequency: int
trajectory_kwargs: dict[str, Any]
prop_calculators: dict[int, dict[str, Callable]]
state_kwargs: dict[str, Any]
metadata: dict[str, str] | None
shape_warned: bool
trajectories: list["TorchSimTrajectory"]
filenames: list[str | pathlib.Path] | None

def __init__(
self,
filenames: str | pathlib.Path | list[str | pathlib.Path] | None,
filenames: str | pathlib.Path | Sequence[str | pathlib.Path] | None,
state_frequency: int = 100,
*,
prop_calculators: dict[int, dict[str, Callable]] | None = None,
Expand Down Expand Up @@ -137,14 +146,13 @@ def __init__(
self.trajectories = []
if filenames is None:
self.filenames = None
self.trajectories = []
else:
self.load_new_trajectories(filenames)

self._add_model_arg_to_prop_calculators()

def load_new_trajectories(
self, filenames: str | pathlib.Path | list[str | pathlib.Path]
self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path]
) -> None:
"""Load new trajectories into the reporter.

Expand All @@ -159,7 +167,9 @@ def load_new_trajectories(
"""
self.finish()

filenames = [filenames] if not isinstance(filenames, list) else filenames
filenames = (
[filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames)
)
self.filenames = [pathlib.Path(filename) for filename in filenames]
if len(set(self.filenames)) != len(self.filenames):
raise ValueError("All filenames must be unique.")
Expand Down Expand Up @@ -455,7 +465,7 @@ def _initialize_type_map(

def write_arrays(
self,
data: dict[str, np.ndarray | torch.Tensor],
data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is dict not sufficient to type this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we run into https://mypy.readthedocs.io/en/stable/common_issues.html#variance with mypy

torch_sim/trajectory.py:784: note: "dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
torch_sim/trajectory.py:784: note: Consider using "Mapping" instead, which is covariant in the value type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need the np.generic because of this error:

tests/test_trajectory.py:390: error: Argument 1 to "write_arrays" of "TorchSimTrajectory" has incompatible type "dict[str, ndarray[Any, Any] | generic | Tensor]"; expected "Mapping[str, ndarray[Any, Any] | Tensor]" [arg-type]

which allows us to write a np.bool inside test_trajectory.py

scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = {
    "float64_scalar": np.float64(1.0),
    "float32_scalar": np.float32(1.0),
    "int64_scalar": np.int64(1),
    "int32_scalar": np.int32(1),
    "bool_scalar": np.bool_(True),  # noqa: FBT003
    "torch_float_scalar": torch.tensor(1.0, dtype=torch.float64),
    "torch_int_scalar": torch.tensor(1, dtype=torch.int64),
    "torch_bool_scalar": torch.tensor(data=True),
}

traj.write_arrays(scalar_data, steps=0)

steps: int | list[int],
) -> None:
"""Write arrays to the trajectory file.
Expand All @@ -469,8 +479,8 @@ def write_arrays(
file and that the steps are monotonically increasing.

Args:
data (dict[str, np.ndarray | torch.Tensor]): Map of array names to numpy
arrays or torch tensors with shapes [n_frames, ...]
data (Mapping[str, np.ndarray | np.generic | torch.Tensor]): Map of array
names to numpy arrays or torch tensors with shapes [n_frames, ...]
steps (int | list[int]): Step number(s) for the frame(s) being written.
If steps is an integer, arrays will be treated as single frame data.

Expand All @@ -485,9 +495,12 @@ def write_arrays(
pad_first_dim = False

for name, array in data.items():
# TODO: coerce dtypes to numpy
# Normalize to numpy arrays
if isinstance(array, torch.Tensor):
array = array.cpu().detach().numpy()
elif not isinstance(array, np.ndarray):
# Convert numpy scalar (np.generic) or Python scalar to ndarray
array = np.array(array)

if pad_first_dim:
# pad 1st dim of array with 1
Expand Down Expand Up @@ -715,10 +728,8 @@ def write_state( # noqa: C901
steps = [steps]

if isinstance(system_index, int):
system_index = [system_index]
sub_states = [state[system_index] for state in state]
elif system_index is None and torch.unique(state[0].system_idx) == 0:
system_index = 0
sub_states = state
else:
raise ValueError(
Expand Down Expand Up @@ -772,7 +783,7 @@ def write_state( # noqa: C901
# Write all arrays to file
self.write_arrays(data, steps)

def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]:
def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]:
"""Get all available state tensors for a given frame.

Retrieves all state-related arrays (positions, cell, masses, etc.) for a
Expand All @@ -782,7 +793,7 @@ def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]:
frame (int): Frame index to retrieve (-1 for last frame)

Returns:
dict[str, torch.Tensor]: Map of tensor names to their values
dict[str, np.ndarray]: Map of array names to their values

Raises:
ValueError: If required arrays are missing from trajectory or frame is
Expand Down Expand Up @@ -916,7 +927,7 @@ def get_state(
positions=torch.tensor(arrays["positions"], device=device, dtype=dtype),
masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype),
cell=torch.tensor(arrays["cell"], device=device, dtype=dtype),
pbc=arrays.get("pbc", True),
pbc=bool(arrays.get("pbc", True)),
atomic_numbers=torch.tensor(
arrays["atomic_numbers"], device=device, dtype=torch.int
),
Expand Down Expand Up @@ -970,7 +981,7 @@ def __len__(self) -> int:
"""
return self._file.root.data.positions.shape[0]

def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory":
def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels wrong?

Copy link
Collaborator Author

@curtischong curtischong Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you scroll to line 1015, we are opening Trajectory in read mode return Trajectory(filename) # Reopen in read mode, which is a TrajectoryReader object

"""Convert trajectory to ASE Trajectory format.

Writes the entire trajectory to a new file in ASE format for compatibility
Expand All @@ -980,7 +991,7 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory":
filename (str | pathlib.Path): Path to the output ASE trajectory file

Returns:
ase.io.trajectory.Trajectory: ASE trajectory object
ase.io.trajectory.TrajectoryReader: ASE trajectory object

Raises:
ImportError: If ASE is not installed
Expand Down