diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 67d9de1f..47206e10 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -260,7 +260,7 @@ def test_data_type_conversions(test_file: Path) -> None: rng = np.random.default_rng(seed=0) # Test data with different types - test_data = { + test_data: dict[str, np.ndarray | torch.Tensor] = { # NumPy arrays "np_float64": rng.random((10, 3)).astype(np.float64), "np_float32": rng.random((10, 3)).astype(np.float32), @@ -376,7 +376,7 @@ def test_scalar_dtype_handling(test_file: Path) -> None: mode="w", ) - scalar_data = { + scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = { "float64_scalar": np.float64(1.0), "float32_scalar": np.float32(1.0), "int64_scalar": np.int64(1), diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 83c341bf..4f556ce0 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -22,7 +22,7 @@ from collections.abc import Callable, Iterator from itertools import chain -from typing import Any, get_args +from typing import Any, get_args, overload import torch @@ -570,6 +570,14 @@ def load_states( return self.max_memory_scaler + @overload + def next_batch(self, *, return_indices: bool = False) -> SimState | None: ... + + @overload + def next_batch( + self, *, return_indices: bool = True + ) -> tuple[SimState, list[int]] | None: ... + def next_batch( self, *, return_indices: bool = False ) -> SimState | tuple[SimState, list[int]] | None: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 14164375..5a11f86c 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -108,7 +108,7 @@ def _configure_batches_iterator( return batches -def integrate( +def integrate( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -146,14 +146,19 @@ def integrate( """ unit_system = UnitSystem.metal # create a list of temperatures - temps = temperature if hasattr(temperature, "__iter__") else [temperature] * n_steps + if hasattr(temperature, "__iter__"): + temps = torch.tensor(temperature, dtype=model.dtype, device=model.device) + else: + temps = torch.tensor( + [temperature] * n_steps, dtype=model.dtype, device=model.device + ) if len(temps) != n_steps: raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}") # initialize the state state: SimState = initialize_state(system, model.device, model.dtype) dtype, device = state.dtype, state.device - kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature + kTs = temps * unit_system.temperature init_fn, update_fn = integrator( model=model, kT=kTs[0], diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index fb170754..2d299772 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -30,7 +30,7 @@ import copy import inspect import pathlib -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TYPE_CHECKING, Any, Literal, Self @@ -44,7 +44,7 @@ if TYPE_CHECKING: from ase import Atoms - from ase.io.trajectory import Trajectory + from ase.io.trajectory import TrajectoryReader _DATA_TYPE_MAP = { np.dtype("float32"): tables.Float32Atom(), @@ -91,9 +91,18 @@ class TrajectoryReporter: >>> reporter.close() """ + state_frequency: int + trajectory_kwargs: dict[str, Any] + prop_calculators: dict[int, dict[str, Callable]] + state_kwargs: dict[str, Any] + metadata: dict[str, str] | None + shape_warned: bool + trajectories: list["TorchSimTrajectory"] + filenames: list[str | pathlib.Path] | None + def __init__( self, - filenames: str | pathlib.Path | list[str | pathlib.Path] | None, + filenames: str | pathlib.Path | Sequence[str | pathlib.Path] | None, state_frequency: int = 100, *, prop_calculators: dict[int, dict[str, Callable]] | None = None, @@ -137,14 +146,13 @@ def __init__( self.trajectories = [] if filenames is None: self.filenames = None - self.trajectories = [] else: self.load_new_trajectories(filenames) self._add_model_arg_to_prop_calculators() def load_new_trajectories( - self, filenames: str | pathlib.Path | list[str | pathlib.Path] + self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] ) -> None: """Load new trajectories into the reporter. @@ -159,7 +167,9 @@ def load_new_trajectories( """ self.finish() - filenames = [filenames] if not isinstance(filenames, list) else filenames + filenames = ( + [filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames) + ) self.filenames = [pathlib.Path(filename) for filename in filenames] if len(set(self.filenames)) != len(self.filenames): raise ValueError("All filenames must be unique.") @@ -455,7 +465,7 @@ def _initialize_type_map( def write_arrays( self, - data: dict[str, np.ndarray | torch.Tensor], + data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]", steps: int | list[int], ) -> None: """Write arrays to the trajectory file. @@ -469,8 +479,8 @@ def write_arrays( file and that the steps are monotonically increasing. Args: - data (dict[str, np.ndarray | torch.Tensor]): Map of array names to numpy - arrays or torch tensors with shapes [n_frames, ...] + data (Mapping[str, np.ndarray | np.generic | torch.Tensor]): Map of array + names to numpy arrays or torch tensors with shapes [n_frames, ...] steps (int | list[int]): Step number(s) for the frame(s) being written. If steps is an integer, arrays will be treated as single frame data. @@ -485,9 +495,12 @@ def write_arrays( pad_first_dim = False for name, array in data.items(): - # TODO: coerce dtypes to numpy + # Normalize to numpy arrays if isinstance(array, torch.Tensor): array = array.cpu().detach().numpy() + elif not isinstance(array, np.ndarray): + # Convert numpy scalar (np.generic) or Python scalar to ndarray + array = np.array(array) if pad_first_dim: # pad 1st dim of array with 1 @@ -715,10 +728,8 @@ def write_state( # noqa: C901 steps = [steps] if isinstance(system_index, int): - system_index = [system_index] sub_states = [state[system_index] for state in state] elif system_index is None and torch.unique(state[0].system_idx) == 0: - system_index = 0 sub_states = state else: raise ValueError( @@ -772,7 +783,7 @@ def write_state( # noqa: C901 # Write all arrays to file self.write_arrays(data, steps) - def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: + def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]: """Get all available state tensors for a given frame. Retrieves all state-related arrays (positions, cell, masses, etc.) for a @@ -782,7 +793,7 @@ def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: frame (int): Frame index to retrieve (-1 for last frame) Returns: - dict[str, torch.Tensor]: Map of tensor names to their values + dict[str, np.ndarray]: Map of array names to their values Raises: ValueError: If required arrays are missing from trajectory or frame is @@ -916,7 +927,7 @@ def get_state( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=arrays.get("pbc", True), + pbc=bool(arrays.get("pbc", True)), atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), @@ -970,7 +981,7 @@ def __len__(self) -> int: """ return self._file.root.data.positions.shape[0] - def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": + def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader": """Convert trajectory to ASE Trajectory format. Writes the entire trajectory to a new file in ASE format for compatibility @@ -980,7 +991,7 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": filename (str | pathlib.Path): Path to the output ASE trajectory file Returns: - ase.io.trajectory.Trajectory: ASE trajectory object + ase.io.trajectory.TrajectoryReader: ASE trajectory object Raises: ImportError: If ASE is not installed