-
Notifications
You must be signed in to change notification settings - Fork 55
Type trajectory #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type trajectory #244
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ | |
| import copy | ||
| import inspect | ||
| import pathlib | ||
| from collections.abc import Callable | ||
| from collections.abc import Callable, Mapping, Sequence | ||
| from functools import partial | ||
| from typing import TYPE_CHECKING, Any, Literal, Self | ||
|
|
||
|
|
@@ -44,7 +44,7 @@ | |
|
|
||
| if TYPE_CHECKING: | ||
| from ase import Atoms | ||
| from ase.io.trajectory import Trajectory | ||
| from ase.io.trajectory import TrajectoryReader | ||
|
|
||
| _DATA_TYPE_MAP = { | ||
| np.dtype("float32"): tables.Float32Atom(), | ||
|
|
@@ -91,9 +91,18 @@ class TrajectoryReporter: | |
| >>> reporter.close() | ||
| """ | ||
|
|
||
| state_frequency: int | ||
| trajectory_kwargs: dict[str, Any] | ||
| prop_calculators: dict[int, dict[str, Callable]] | ||
| state_kwargs: dict[str, Any] | ||
| metadata: dict[str, str] | None | ||
| shape_warned: bool | ||
| trajectories: list["TorchSimTrajectory"] | ||
| filenames: list[str | pathlib.Path] | None | ||
|
|
||
| def __init__( | ||
| self, | ||
| filenames: str | pathlib.Path | list[str | pathlib.Path] | None, | ||
| filenames: str | pathlib.Path | Sequence[str | pathlib.Path] | None, | ||
| state_frequency: int = 100, | ||
| *, | ||
| prop_calculators: dict[int, dict[str, Callable]] | None = None, | ||
|
|
@@ -137,14 +146,13 @@ def __init__( | |
| self.trajectories = [] | ||
| if filenames is None: | ||
| self.filenames = None | ||
| self.trajectories = [] | ||
| else: | ||
| self.load_new_trajectories(filenames) | ||
|
|
||
| self._add_model_arg_to_prop_calculators() | ||
|
|
||
| def load_new_trajectories( | ||
| self, filenames: str | pathlib.Path | list[str | pathlib.Path] | ||
| self, filenames: str | pathlib.Path | Sequence[str | pathlib.Path] | ||
| ) -> None: | ||
| """Load new trajectories into the reporter. | ||
|
|
||
|
|
@@ -159,7 +167,9 @@ def load_new_trajectories( | |
| """ | ||
| self.finish() | ||
|
|
||
| filenames = [filenames] if not isinstance(filenames, list) else filenames | ||
| filenames = ( | ||
| [filenames] if isinstance(filenames, (str, pathlib.Path)) else list(filenames) | ||
| ) | ||
| self.filenames = [pathlib.Path(filename) for filename in filenames] | ||
| if len(set(self.filenames)) != len(self.filenames): | ||
| raise ValueError("All filenames must be unique.") | ||
|
|
@@ -455,7 +465,7 @@ def _initialize_type_map( | |
|
|
||
| def write_arrays( | ||
| self, | ||
| data: dict[str, np.ndarray | torch.Tensor], | ||
| data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is dict not sufficient to type this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we run into https://mypy.readthedocs.io/en/stable/common_issues.html#variance with mypy torch_sim/trajectory.py:784: note: "dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also need the np.generic because of this error: tests/test_trajectory.py:390: error: Argument 1 to "write_arrays" of "TorchSimTrajectory" has incompatible type "dict[str, ndarray[Any, Any] | generic | Tensor]"; expected "Mapping[str, ndarray[Any, Any] | Tensor]" [arg-type] which allows us to write a np.bool inside test_trajectory.py |
||
| steps: int | list[int], | ||
| ) -> None: | ||
| """Write arrays to the trajectory file. | ||
|
|
@@ -469,8 +479,8 @@ def write_arrays( | |
| file and that the steps are monotonically increasing. | ||
|
|
||
| Args: | ||
| data (dict[str, np.ndarray | torch.Tensor]): Map of array names to numpy | ||
| arrays or torch tensors with shapes [n_frames, ...] | ||
| data (Mapping[str, np.ndarray | np.generic | torch.Tensor]): Map of array | ||
| names to numpy arrays or torch tensors with shapes [n_frames, ...] | ||
| steps (int | list[int]): Step number(s) for the frame(s) being written. | ||
| If steps is an integer, arrays will be treated as single frame data. | ||
|
|
||
|
|
@@ -485,9 +495,12 @@ def write_arrays( | |
| pad_first_dim = False | ||
|
|
||
| for name, array in data.items(): | ||
| # TODO: coerce dtypes to numpy | ||
| # Normalize to numpy arrays | ||
| if isinstance(array, torch.Tensor): | ||
| array = array.cpu().detach().numpy() | ||
| elif not isinstance(array, np.ndarray): | ||
| # Convert numpy scalar (np.generic) or Python scalar to ndarray | ||
| array = np.array(array) | ||
|
|
||
| if pad_first_dim: | ||
| # pad 1st dim of array with 1 | ||
|
|
@@ -715,10 +728,8 @@ def write_state( # noqa: C901 | |
| steps = [steps] | ||
|
|
||
| if isinstance(system_index, int): | ||
| system_index = [system_index] | ||
| sub_states = [state[system_index] for state in state] | ||
| elif system_index is None and torch.unique(state[0].system_idx) == 0: | ||
| system_index = 0 | ||
| sub_states = state | ||
| else: | ||
| raise ValueError( | ||
|
|
@@ -772,7 +783,7 @@ def write_state( # noqa: C901 | |
| # Write all arrays to file | ||
| self.write_arrays(data, steps) | ||
|
|
||
| def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: | ||
| def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]: | ||
| """Get all available state tensors for a given frame. | ||
|
|
||
| Retrieves all state-related arrays (positions, cell, masses, etc.) for a | ||
|
|
@@ -782,7 +793,7 @@ def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: | |
| frame (int): Frame index to retrieve (-1 for last frame) | ||
|
|
||
| Returns: | ||
| dict[str, torch.Tensor]: Map of tensor names to their values | ||
| dict[str, np.ndarray]: Map of array names to their values | ||
|
|
||
| Raises: | ||
| ValueError: If required arrays are missing from trajectory or frame is | ||
|
|
@@ -916,7 +927,7 @@ def get_state( | |
| positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), | ||
| masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), | ||
| cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), | ||
| pbc=arrays.get("pbc", True), | ||
| pbc=bool(arrays.get("pbc", True)), | ||
| atomic_numbers=torch.tensor( | ||
| arrays["atomic_numbers"], device=device, dtype=torch.int | ||
| ), | ||
|
|
@@ -970,7 +981,7 @@ def __len__(self) -> int: | |
| """ | ||
| return self._file.root.data.positions.shape[0] | ||
|
|
||
| def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": | ||
| def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this feels wrong? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you scroll to line 1015, we are opening |
||
| """Convert trajectory to ASE Trajectory format. | ||
|
|
||
| Writes the entire trajectory to a new file in ASE format for compatibility | ||
|
|
@@ -980,7 +991,7 @@ def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": | |
| filename (str | pathlib.Path): Path to the output ASE trajectory file | ||
|
|
||
| Returns: | ||
| ase.io.trajectory.Trajectory: ASE trajectory object | ||
| ase.io.trajectory.TrajectoryReader: ASE trajectory object | ||
|
|
||
| Raises: | ||
| ImportError: If ASE is not installed | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view my comment to Rhys below