-
Notifications
You must be signed in to change notification settings - Fork 55
Type trajectory #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type trajectory #244
Conversation
WalkthroughAdds typing enhancements in tests and autobatching. Updates runners.integrate to construct temperatures directly as tensors on the model’s device/dtype and simplifies kT computation. Expands torch_sim.trajectory with multi-file support, broader input types (Mapping/Sequence), new public attributes and property, numpy-centric array handling, and adjusted type annotations. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Client
participant Reporter as TrajectoryReporter
participant Traj* as TorchSimTrajectory (per file)
Note over Client,Reporter: Multi-file trajectory initialization and writing
Client->>Reporter: __init__(filenames | None, state_frequency, ..., trajectory_kwargs)
alt filenames provided
Reporter->>Reporter: load_new_trajectories(filenames)
loop for each filename
Reporter->>Reporter: finish() (finalize prior, if any)
Reporter->>Traj*: construct TorchSimTrajectory(filename, metadata, trajectory_kwargs)
Reporter->>Reporter: store in trajectories[]
end
else no filenames
Note over Reporter: trajectories stays empty
end
Note right of Reporter: array_registry proxies first Traj*
loop Integration steps (state_frequency gating)
Client->>Reporter: write state/props
Reporter->>Traj*: write_arrays(data, steps)
Note over Traj*: Coerce Mapping values to numpy (torch.Tensor -> CPU numpy)
end
sequenceDiagram
autonumber
participant Caller
participant Runner as integrate(...)
participant Model
Caller->>Runner: integrate(model, n_steps, temperature, unit_system, ...)
Runner->>Model: query dtype/device
alt temperature is iterable
Runner->>Runner: temps = torch.tensor(temperature, dtype/device)
else temperature is scalar
Runner->>Runner: temps = torch.full([n_steps], scalar, dtype/device)
end
Runner->>Runner: kTs = temps * unit_system.temperature
Note over Runner: Remaining integration loop unchanged
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks (2 passed, 1 inconclusive)❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
e1a4c51 to
3767ce5
Compare
apply diff to runners wip fix trajectory fix more types fix types try to fix next batch type fix trajectories revert some changes revert more changes make trajectory file good fix lints
3767ce5 to
c27cd45
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
torch_sim/trajectory.py (2)
786-834: System selection is ignored; frames vs. systems mixed upwrite_state builds arrays from the outer “state” list even when a specific system_index is selected, and mixes per-frame vs per-system shapes (e.g., cell). This will write wrong data and/or corrupt steps alignment when multiple systems are present.
Apply this refactor to consistently operate on the selected per-frame sub-states and fix shapes:
- # we wrap - if isinstance(state, SimState): - state = [state] - if isinstance(steps, int): - steps = [steps] - - if isinstance(system_index, int): - sub_states = [state[system_index] for state in state] - elif system_index is None and torch.unique(state[0].system_idx) == 0: - sub_states = state - else: - raise ValueError( - "System index must be specified if there are multiple systems" - ) + # Normalize frames and steps + frames = [state] if isinstance(state, SimState) else list(state) + steps = [steps] if isinstance(steps, int) else list(steps) + + # Select per-frame single-system states + if isinstance(system_index, int): + sub_states = [frame[system_index] for frame in frames] + elif frames[0].n_systems == 1: + sub_states = frames + else: + raise ValueError("system_index must be specified if there are multiple systems") @@ - if len(sub_states) != len(steps): + if len(sub_states) != len(steps): raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}") @@ - data = { - "positions": torch.stack([s.positions for s in state]), - } + data = {"positions": torch.stack([s.positions for s in sub_states])} @@ - for array_name, should_save in optional_arrays.items(): + for array_name, should_save in optional_arrays.items(): if should_save: - if not hasattr(state[0], array_name): + if not hasattr(sub_states[0], array_name): raise ValueError( f"{array_name.capitalize()} can only be saved " f"if included in the state being reported." ) - data[array_name] = torch.stack([getattr(s, array_name) for s in state]) + data[array_name] = torch.stack([getattr(s, array_name) for s in sub_states]) @@ - if variable_cell: - data["cell"] = torch.cat([s.cell for s in state]) - elif "cell" not in self.array_registry: # Save cell only for first frame - # we but cell in list because it doesn't need to be padded - self.write_arrays({"cell": state[0].cell}, [0]) + if variable_cell: + data["cell"] = torch.cat([s.cell for s in sub_states]) # [n_frames, 3, 3] + elif "cell" not in self.array_registry: + self.write_arrays({"cell": sub_states[0].cell}, [0]) @@ - if variable_masses: - data["masses"] = torch.stack([s.masses for s in state]) + if variable_masses: + data["masses"] = torch.stack([s.masses for s in sub_states]) elif "masses" not in self.array_registry: # Save masses only for first frame - self.write_arrays({"masses": state[0].masses}, 0) + self.write_arrays({"masses": sub_states[0].masses}, 0) @@ - if variable_atomic_numbers: - data["atomic_numbers"] = torch.stack([s.atomic_numbers for s in state]) + if variable_atomic_numbers: + data["atomic_numbers"] = torch.stack([s.atomic_numbers for s in sub_states]) elif "atomic_numbers" not in self.array_registry: # Save atomic numbers only for first frame - self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0) + self.write_arrays({"atomic_numbers": sub_states[0].atomic_numbers}, 0) @@ - if "pbc" not in self.array_registry: - self.write_arrays({"pbc": np.array(state[0].pbc)}, 0) + if "pbc" not in self.array_registry: + self.write_arrays({"pbc": np.array(sub_states[0].pbc)}, 0)This also replaces the fragile torch.unique(...) check with frames[0].n_systems == 1.
923-934: Casting pbc via bool(...) breaks when pbc is a 3-vectorbool(np.ndarray([...])) raises an “ambiguous truth value” error. Convert arrays with all(...) or pass through scalars unchanged.
- arrays = self._get_state_arrays(frame) + arrays = self._get_state_arrays(frame) + # Normalize pbc to a plain bool (True only if all periodic) + pbc_val = arrays.get("pbc", True) + pbc_bool = bool(pbc_val) if np.isscalar(pbc_val) else bool(np.all(pbc_val)) @@ - return SimState( + return SimState( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=bool(arrays.get("pbc", True)), + pbc=pbc_bool, atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), )
🧹 Nitpick comments (9)
torch_sim/runners.py (4)
111-111: Remove unused noqa for C901Ruff flagged this as unused. Drop the directive or enable the rule; otherwise it’s noise.
-def integrate( # noqa: C901 +def integrate(
149-154: Harden iterable detection and avoid Python list materializationUsing hasattr(..., "iter") will treat strings/bytes as iterables and builds a Python list for the scalar branch. Prefer explicit checks and torch-native allocation.
Apply this diff, and import Sequence at the top if not already:
- if hasattr(temperature, "__iter__"): - temps = torch.tensor(temperature, dtype=model.dtype, device=model.device) - else: - temps = torch.tensor( - [temperature] * n_steps, dtype=model.dtype, device=model.device - ) + if isinstance(temperature, torch.Tensor): + temps = temperature.to(dtype=model.dtype, device=model.device).flatten() + elif isinstance(temperature, (list, tuple)): + temps = torch.as_tensor(temperature, dtype=model.dtype, device=model.device) + else: + temps = torch.full( + (n_steps,), float(temperature), dtype=model.dtype, device=model.device + )Additionally, the docstring says “float | ArrayLike” while the type hints say “float | list | torch.Tensor”. Align those for consistency.
156-156: Shorten/simplify the ValueError message (TRY003)The long f-string triggers TRY003. Make it concise.
- if len(temps) != n_steps: - raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}") + if len(temps) != n_steps: + raise ValueError(f"len(temps)={len(temps):,} must equal n_steps={n_steps:,}")
161-166: Confirm unit handling for kTMultiplying a torch tensor by unit_system.temperature assumes it’s a plain scalar compatible with the model dtype/device. If it can be a quantity/wrapper, cast once to the right dtype on the device to avoid dtype/device surprises.
- kTs = temps * unit_system.temperature + kTs = temps * float(unit_system.temperature)torch_sim/autobatching.py (2)
573-580: Tighten overloads with Literal for better type narrowingUsing Literal improves callers’ type inference.
-@overload -def next_batch(self, *, return_indices: bool = False) -> SimState | None: ... +@overload +def next_batch(self, *, return_indices: Literal[False] = False) -> SimState | None: ... -@overload -def next_batch( - self, *, return_indices: bool = True -) -> tuple[SimState, list[int]] | None: ... +@overload +def next_batch( + self, *, return_indices: Literal[True] +) -> tuple[SimState, list[int]] | None: ...(Remember to import Literal if not already.)
581-625: Docstring example is inconsistent with this APIThe example references parameters from a different method signature. Update to a simple pull-loop using return_indices when desired.
- Example:: - - # Get batches one by one - all_converged_state, convergence = [], None - while (result := batcher.next_batch(state, convergence))[0] is not None: - state, converged_states = result - all_converged_states.extend(converged_states) - - evolve_batch(state) - convergence = convergence_criterion(state) - else: - all_converged_states.extend(result[1]) + Example:: + + # Get batches one by one + while (batch := batcher.next_batch()) is not None: + process_batch(batch) + + # Or with indices + while (res := batcher.next_batch(return_indices=True)) is not None: + batch, idx = res + process_batch(batch)torch_sim/trajectory.py (1)
155-173: Sequence handling for filenames is good; minor nit on uniqueness checkConverting to Path and deduplicating via set(...) is fine. If you want to guard against differing string representations (e.g., "./a.h5" vs "a.h5"), consider resolving paths before checking uniqueness.
- self.filenames = [pathlib.Path(filename) for filename in filenames] + self.filenames = [pathlib.Path(filename).resolve() for filename in filenames]tests/test_trajectory.py (2)
263-263: Prefer Mapping and align union with write_arrays contractUse a covariant Mapping and include np.generic to mirror the accepted input types; avoids overspecifying mutability and matches the API.
- test_data: dict[str, np.ndarray | torch.Tensor] = { + test_data: Mapping[str, np.ndarray | np.generic | torch.Tensor] = {Add the required import:
-from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Mapping
379-379: Unify scalar_data typing with Mapping and shared unionMatch the container/type union used above for consistency and API parity.
- scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = { + scalar_data: Mapping[str, np.ndarray | np.generic | torch.Tensor] = {
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tests/test_trajectory.py(2 hunks)torch_sim/autobatching.py(2 hunks)torch_sim/runners.py(2 hunks)torch_sim/trajectory.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
torch_sim/autobatching.py (2)
torch_sim/state.py (1)
SimState(29-493)tests/test_autobatching.py (3)
test_binning_auto_batcher_with_indices(215-241)test_binning_auto_batcher(127-165)test_in_flight_auto_batcher(307-352)
torch_sim/runners.py (2)
torch_sim/state.py (4)
dtype(185-187)device(180-182)SimState(29-493)initialize_state(919-987)torch_sim/models/interface.py (4)
dtype(90-92)dtype(95-99)device(78-80)device(83-87)
torch_sim/trajectory.py (1)
tests/test_trajectory.py (2)
trajectory(42-46)prop_calculators(523-532)
🪛 Ruff (0.12.2)
torch_sim/runners.py
111-111: Unused noqa directive (non-enabled: C901)
Remove unused noqa directive
(RUF100)
156-156: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
torch_sim/trajectory.py (1)
469-504: Input normalization looks goodAccepting Mapping and auto-normalizing torch Tensors and numpy scalars to ndarray is a solid improvement for writer ergonomics.
| def write_arrays( | ||
| self, | ||
| data: dict[str, np.ndarray | torch.Tensor], | ||
| data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is dict not sufficient to type this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we run into https://mypy.readthedocs.io/en/stable/common_issues.html#variance with mypy
torch_sim/trajectory.py:784: note: "dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
torch_sim/trajectory.py:784: note: Consider using "Mapping" instead, which is covariant in the value type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need the np.generic because of this error:
tests/test_trajectory.py:390: error: Argument 1 to "write_arrays" of "TorchSimTrajectory" has incompatible type "dict[str, ndarray[Any, Any] | generic | Tensor]"; expected "Mapping[str, ndarray[Any, Any] | Tensor]" [arg-type]
which allows us to write a np.bool inside test_trajectory.py
scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = {
"float64_scalar": np.float64(1.0),
"float32_scalar": np.float32(1.0),
"int64_scalar": np.int64(1),
"int32_scalar": np.int32(1),
"bool_scalar": np.bool_(True), # noqa: FBT003
"torch_float_scalar": torch.tensor(1.0, dtype=torch.float64),
"torch_int_scalar": torch.tensor(1, dtype=torch.int64),
"torch_bool_scalar": torch.tensor(data=True),
}
traj.write_arrays(scalar_data, steps=0)
| return self._file.root.data.positions.shape[0] | ||
|
|
||
| def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory": | ||
| def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you scroll to line 1015, we are opening Trajectory in read mode return Trajectory(filename) # Reopen in read mode, which is a TrajectoryReader object
| if TYPE_CHECKING: | ||
| from ase import Atoms | ||
| from ase.io.trajectory import Trajectory | ||
| from ase.io.trajectory import TrajectoryReader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view my comment to Rhys below
Summary
Checklist
Before a pull request can be merged, the following items must be checked:
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit installto install the hooks which will check your code before each commit.Summary by CodeRabbit
New Features
Improvements
Tests