From c27cd45dd392099cd151ed8dfc36697eed4609be Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Tue, 9 Sep 2025 22:30:04 -0400 Subject: [PATCH] move trajectory changes apply diff to runners wip fix trajectory fix more types fix types try to fix next batch type fix trajectories revert some changes revert more changes make trajectory file good fix lints --- tests/test_trajectory.py | 4 ++-- torch_sim/autobatching.py | 10 ++++++++- torch_sim/runners.py | 11 +++++++--- torch_sim/trajectory.py | 45 ++++++++++++++++++++++++--------------- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 67d9de1f..47206e10 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -260,7 +260,7 @@ def test_data_type_conversions(test_file: Path) -> None: rng = np.random.default_rng(seed=0) # Test data with different types - test_data = { + test_data: dict[str, np.ndarray | torch.Tensor] = { # NumPy arrays "np_float64": rng.random((10, 3)).astype(np.float64), "np_float32": rng.random((10, 3)).astype(np.float32), @@ -376,7 +376,7 @@ def test_scalar_dtype_handling(test_file: Path) -> None: mode="w", ) - scalar_data = { + scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = { "float64_scalar": np.float64(1.0), "float32_scalar": np.float32(1.0), "int64_scalar": np.int64(1), diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index a691de36..b611792f 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -22,7 +22,7 @@ from collections.abc import Callable, Iterator from itertools import chain -from typing import Any, get_args +from typing import Any, get_args, overload import torch @@ -570,6 +570,14 @@ def load_states( return self.max_memory_scaler + @overload + def next_batch(self, *, return_indices: bool = False) -> SimState | None: ... + + @overload + def next_batch( + self, *, return_indices: bool = True + ) -> tuple[SimState, list[int]] | None: ... + def next_batch( self, *, return_indices: bool = False ) -> SimState | tuple[SimState, list[int]] | None: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 14164375..5a11f86c 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -108,7 +108,7 @@ def _configure_batches_iterator( return batches -def integrate( +def integrate( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -146,14 +146,19 @@ def integrate( """ unit_system = UnitSystem.metal # create a list of temperatures - temps = temperature if hasattr(temperature, "__iter__") else [temperature] * n_steps + if hasattr(temperature, "__iter__"): + temps = torch.tensor(temperature, dtype=model.dtype, device=model.device) + else: + temps = torch.tensor( + [temperature] * n_steps, dtype=model.dtype, device=model.device + ) if len(temps) != n_steps: raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}") # initialize the state state: SimState = initialize_state(system, model.device, model.dtype) dtype, device = state.dtype, state.device - kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature + kTs = temps * unit_system.temperature init_fn, update_fn = integrator( model=model, kT=kTs[0], diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index fb170754..2d299772 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -30,7 +30,7 @@ import copy import inspect import pathlib -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from functools import partial from typing import TYPE_CHECKING, Any, Literal, Self @@ -44,7 +44,7 @@ if TYPE_CHECKING: from ase import Atoms - from ase.io.trajectory import Trajectory + from ase.io.trajectory import TrajectoryReader _DATA_TYPE_MAP = { np.dtype("float32"): tables.Float32Atom(), @@ -91,9 +91,18 @@ class TrajectoryReporter: >>> reporter.close() """ + state_frequency: int + trajectory_kwargs: dict[str, Any] + prop_calculators: dict[int, dict[str, Callable]] + state_kwargs: dict[str, Any] + metadata: dict[str, str] | None + shape_warned: bool + trajectories: list["TorchSimTrajectory"] + filenames: list[str | pathlib.Path] | None + def __init__( self, - filenames: str | pathlib.Path | list[str | pathlib.Path] | None, + filenames: str | pathlib.Path | Sequence[str | pathlib.Path] | None, state_frequency: int = 100, *, prop_calculators: dict[int, dict[str, Callable]] | None = None, @@ -137,14 +146,13 @@ def __init__( self.trajectories = [] if filenames is None: self.filenames = None - self.trajectories = [] else: self.load_new_trajectories(filenames) self._add_model_arg_to_prop_calculators() def load_new_trajectories( - self, filenames: str | pathlib.Path | list[str | pathlib.Path] + self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] ) -> None: """Load new trajectories into the reporter. @@ -159,7 +167,9 @@ def load_new_trajectories( """ self.finish() - filenames = [filenames] if not isinstance(filenames, list) else filenames + filenames = ( + [filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames) + ) self.filenames = [pathlib.Path(filename) for filename in filenames] if len(set(self.filenames)) != len(self.filenames): raise ValueError("All filenames must be unique.") @@ -455,7 +465,7 @@ def _initialize_type_map( def write_arrays( self, - data: dict[str, np.ndarray | torch.Tensor], + data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]", steps: int | list[int], ) -> None: """Write arrays to the trajectory file. @@ -469,8 +479,8 @@ def write_arrays( file and that the steps are monotonically increasing. Args: - data (dict[str, np.ndarray | torch.Tensor]): Map of array names to numpy - arrays or torch tensors with shapes [n_frames, ...] + data (Mapping[str, np.ndarray | np.generic | torch.Tensor]): Map of array + names to numpy arrays or torch tensors with shapes [n_frames, ...] steps (int | list[int]): Step number(s) for the frame(s) being written. If steps is an integer, arrays will be treated as single frame data. @@ -485,9 +495,12 @@ def write_arrays( pad_first_dim = False for name, array in data.items(): - # TODO: coerce dtypes to numpy + # Normalize to numpy arrays if isinstance(array, torch.Tensor): array = array.cpu().detach().numpy() + elif not isinstance(array, np.ndarray): + # Convert numpy scalar (np.generic) or Python scalar to ndarray + array = np.array(array) if pad_first_dim: # pad 1st dim of array with 1 @@ -715,10 +728,8 @@ def write_state( # noqa: C901 steps = [steps] if isinstance(system_index, int): - system_index = [system_index] sub_states = [state[system_index] for state in state] elif system_index is None and torch.unique(state[0].system_idx) == 0: - system_index = 0 sub_states = state else: raise ValueError( @@ -772,7 +783,7 @@ def write_state( # noqa: C901 # Write all arrays to file self.write_arrays(data, steps) - def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: + def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]: """Get all available state tensors for a given frame. Retrieves all state-related arrays (positions, cell, masses, etc.) for a @@ -782,7 +793,7 @@ def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: frame (int): Frame index to retrieve (-1 for last frame) Returns: - dict[str, torch.Tensor]: Map of tensor names to their values + dict[str, np.ndarray]: Map of array names to their values Raises: ValueError: If required arrays are missing from trajectory or frame is @@ -916,7 +927,7 @@ def get_state( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=arrays.get("pbc", True), + pbc=bool(arrays.get("pbc", True)), atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), @@ -970,7 +981,7 @@ def __len__(self) -> int: """ return self._file.root.data.positions.shape[0] - def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": + def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader": """Convert trajectory to ASE Trajectory format. Writes the entire trajectory to a new file in ASE format for compatibility @@ -980,7 +991,7 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": filename (str | pathlib.Path): Path to the output ASE trajectory file Returns: - ase.io.trajectory.Trajectory: ASE trajectory object + ase.io.trajectory.TrajectoryReader: ASE trajectory object Raises: ImportError: If ASE is not installed