Skip to content

Conversation

@curtischong
Copy link
Collaborator

@curtischong curtischong commented Sep 9, 2025

Summary

  • Feature 1
  • Fix 1

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run pip install pre-commit && pre-commit install to install the hooks which will check your code before each commit.

Summary by CodeRabbit

  • New Features

    • Write to multiple trajectory files in a single run.
    • Accept multiple filenames and flexible path/sequence inputs.
    • Expose a unified array registry for trajectory data.
  • Improvements

    • Trajectory writing now accepts NumPy arrays/scalars and PyTorch tensors, converting automatically as needed.
    • Consistent temperature handling on the correct device/dtype during integration.
    • Support for configurable metadata and state/trajectory options when creating trajectories.
  • Tests

    • Added typing annotations to test data (no runtime changes).

@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Sep 9, 2025
@coderabbitai
Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds typing enhancements in tests and autobatching. Updates runners.integrate to construct temperatures directly as tensors on the model’s device/dtype and simplifies kT computation. Expands torch_sim.trajectory with multi-file support, broader input types (Mapping/Sequence), new public attributes and property, numpy-centric array handling, and adjusted type annotations.

Changes

Cohort / File(s) Summary of changes
Typing refinements (tests)
tests/test_trajectory.py
Added local type annotations for test dictionaries, including np.ndarray, np.generic, and torch.Tensor variants; no runtime changes.
API typing overloads (autobatching)
torch_sim/autobatching.py
Added typing.overload variants for BinningAutoBatcher.next_batch to express return type when return_indices is True/False; imported overload; no logic changes.
Temperature tensorization in runner
torch_sim/runners.py
Construct temps as torch.Tensor on model’s dtype/device for both scalar and iterable inputs; compute kTs by direct multiplication; added noqa comment; control flow otherwise unchanged.
Trajectory multi-file and API expansion
torch_sim/trajectory.py
Introduced Sequence/Mapping in public APIs; TrajectoryReporter now manages multiple TorchSimTrajectory instances; added public attributes and array_registry property; load_new_trajectories accepts multiple filenames and finalizes previous; write_paths normalize inputs to numpy (including torch.Tensor to CPU numpy); _get_state_arrays returns numpy; adjusted TYPE_CHECKING and return annotations (TrajectoryReader).

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Client
    participant Reporter as TrajectoryReporter
    participant Traj* as TorchSimTrajectory (per file)
    Note over Client,Reporter: Multi-file trajectory initialization and writing

    Client->>Reporter: __init__(filenames | None, state_frequency, ..., trajectory_kwargs)
    alt filenames provided
        Reporter->>Reporter: load_new_trajectories(filenames)
        loop for each filename
            Reporter->>Reporter: finish() (finalize prior, if any)
            Reporter->>Traj*: construct TorchSimTrajectory(filename, metadata, trajectory_kwargs)
            Reporter->>Reporter: store in trajectories[]
        end
    else no filenames
        Note over Reporter: trajectories stays empty
    end

    Note right of Reporter: array_registry proxies first Traj*

    loop Integration steps (state_frequency gating)
        Client->>Reporter: write state/props
        Reporter->>Traj*: write_arrays(data, steps)
        Note over Traj*: Coerce Mapping values to numpy (torch.Tensor -> CPU numpy)
    end
Loading
sequenceDiagram
    autonumber
    participant Caller
    participant Runner as integrate(...)
    participant Model

    Caller->>Runner: integrate(model, n_steps, temperature, unit_system, ...)
    Runner->>Model: query dtype/device
    alt temperature is iterable
        Runner->>Runner: temps = torch.tensor(temperature, dtype/device)
    else temperature is scalar
        Runner->>Runner: temps = torch.full([n_steps], scalar, dtype/device)
    end
    Runner->>Runner: kTs = temps * unit_system.temperature
    Note over Runner: Remaining integration loop unchanged
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks (2 passed, 1 inconclusive)

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title Check ❓ Inconclusive The title “Type trajectory” is generic and fails to clearly communicate the primary change of adding comprehensive type annotations and API enhancements to the trajectory modules and related tests. Rename the pull request to a more specific title such as “Add type annotations and API updates to trajectory modules” to clearly convey the scope and intent of the changes.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.

Poem

I thump my paws on fresh new trails,
Splitting files like carrot bales.
Temps now torchy, warm and bright,
kT rolls on in tidy light.
With maps and queues I hop with glee—
Many paths, one memory. 🥕✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch type-trajectory

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

apply diff to runners

wip fix trajectory

fix more types

fix types

 try to fix next batch type

fix trajectories

revert some changes

revert more changes

make trajectory file good

fix lints
@curtischong curtischong marked this pull request as ready for review September 10, 2025 03:17
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
torch_sim/trajectory.py (2)

786-834: System selection is ignored; frames vs. systems mixed up

write_state builds arrays from the outer “state” list even when a specific system_index is selected, and mixes per-frame vs per-system shapes (e.g., cell). This will write wrong data and/or corrupt steps alignment when multiple systems are present.

Apply this refactor to consistently operate on the selected per-frame sub-states and fix shapes:

-        # we wrap
-        if isinstance(state, SimState):
-            state = [state]
-        if isinstance(steps, int):
-            steps = [steps]
-
-        if isinstance(system_index, int):
-            sub_states = [state[system_index] for state in state]
-        elif system_index is None and torch.unique(state[0].system_idx) == 0:
-            sub_states = state
-        else:
-            raise ValueError(
-                "System index must be specified if there are multiple systems"
-            )
+        # Normalize frames and steps
+        frames = [state] if isinstance(state, SimState) else list(state)
+        steps = [steps] if isinstance(steps, int) else list(steps)
+
+        # Select per-frame single-system states
+        if isinstance(system_index, int):
+            sub_states = [frame[system_index] for frame in frames]
+        elif frames[0].n_systems == 1:
+            sub_states = frames
+        else:
+            raise ValueError("system_index must be specified if there are multiple systems")
@@
-        if len(sub_states) != len(steps):
+        if len(sub_states) != len(steps):
             raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}")
@@
-        data = {
-            "positions": torch.stack([s.positions for s in state]),
-        }
+        data = {"positions": torch.stack([s.positions for s in sub_states])}
@@
-        for array_name, should_save in optional_arrays.items():
+        for array_name, should_save in optional_arrays.items():
             if should_save:
-                if not hasattr(state[0], array_name):
+                if not hasattr(sub_states[0], array_name):
                     raise ValueError(
                         f"{array_name.capitalize()} can only be saved "
                         f"if included in the state being reported."
                     )
-                data[array_name] = torch.stack([getattr(s, array_name) for s in state])
+                data[array_name] = torch.stack([getattr(s, array_name) for s in sub_states])
@@
-        if variable_cell:
-            data["cell"] = torch.cat([s.cell for s in state])
-        elif "cell" not in self.array_registry:  # Save cell only for first frame
-            # we but cell in list because it doesn't need to be padded
-            self.write_arrays({"cell": state[0].cell}, [0])
+        if variable_cell:
+            data["cell"] = torch.cat([s.cell for s in sub_states])  # [n_frames, 3, 3]
+        elif "cell" not in self.array_registry:
+            self.write_arrays({"cell": sub_states[0].cell}, [0])
@@
-        if variable_masses:
-            data["masses"] = torch.stack([s.masses for s in state])
+        if variable_masses:
+            data["masses"] = torch.stack([s.masses for s in sub_states])
         elif "masses" not in self.array_registry:  # Save masses only for first frame
-            self.write_arrays({"masses": state[0].masses}, 0)
+            self.write_arrays({"masses": sub_states[0].masses}, 0)
@@
-        if variable_atomic_numbers:
-            data["atomic_numbers"] = torch.stack([s.atomic_numbers for s in state])
+        if variable_atomic_numbers:
+            data["atomic_numbers"] = torch.stack([s.atomic_numbers for s in sub_states])
         elif "atomic_numbers" not in self.array_registry:
             # Save atomic numbers only for first frame
-            self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0)
+            self.write_arrays({"atomic_numbers": sub_states[0].atomic_numbers}, 0)
@@
-        if "pbc" not in self.array_registry:
-            self.write_arrays({"pbc": np.array(state[0].pbc)}, 0)
+        if "pbc" not in self.array_registry:
+            self.write_arrays({"pbc": np.array(sub_states[0].pbc)}, 0)

This also replaces the fragile torch.unique(...) check with frames[0].n_systems == 1.


923-934: Casting pbc via bool(...) breaks when pbc is a 3-vector

bool(np.ndarray([...])) raises an “ambiguous truth value” error. Convert arrays with all(...) or pass through scalars unchanged.

-        arrays = self._get_state_arrays(frame)
+        arrays = self._get_state_arrays(frame)
+        # Normalize pbc to a plain bool (True only if all periodic)
+        pbc_val = arrays.get("pbc", True)
+        pbc_bool = bool(pbc_val) if np.isscalar(pbc_val) else bool(np.all(pbc_val))
@@
-        return SimState(
+        return SimState(
             positions=torch.tensor(arrays["positions"], device=device, dtype=dtype),
             masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype),
             cell=torch.tensor(arrays["cell"], device=device, dtype=dtype),
-            pbc=bool(arrays.get("pbc", True)),
+            pbc=pbc_bool,
             atomic_numbers=torch.tensor(
                 arrays["atomic_numbers"], device=device, dtype=torch.int
             ),
         )
🧹 Nitpick comments (9)
torch_sim/runners.py (4)

111-111: Remove unused noqa for C901

Ruff flagged this as unused. Drop the directive or enable the rule; otherwise it’s noise.

-def integrate(  # noqa: C901
+def integrate(

149-154: Harden iterable detection and avoid Python list materialization

Using hasattr(..., "iter") will treat strings/bytes as iterables and builds a Python list for the scalar branch. Prefer explicit checks and torch-native allocation.

Apply this diff, and import Sequence at the top if not already:

-    if hasattr(temperature, "__iter__"):
-        temps = torch.tensor(temperature, dtype=model.dtype, device=model.device)
-    else:
-        temps = torch.tensor(
-            [temperature] * n_steps, dtype=model.dtype, device=model.device
-        )
+    if isinstance(temperature, torch.Tensor):
+        temps = temperature.to(dtype=model.dtype, device=model.device).flatten()
+    elif isinstance(temperature, (list, tuple)):
+        temps = torch.as_tensor(temperature, dtype=model.dtype, device=model.device)
+    else:
+        temps = torch.full(
+            (n_steps,), float(temperature), dtype=model.dtype, device=model.device
+        )

Additionally, the docstring says “float | ArrayLike” while the type hints say “float | list | torch.Tensor”. Align those for consistency.


156-156: Shorten/simplify the ValueError message (TRY003)

The long f-string triggers TRY003. Make it concise.

-    if len(temps) != n_steps:
-        raise ValueError(f"{len(temps)=:,}. It must equal n_steps = {n_steps=:,}")
+    if len(temps) != n_steps:
+        raise ValueError(f"len(temps)={len(temps):,} must equal n_steps={n_steps:,}")

161-166: Confirm unit handling for kT

Multiplying a torch tensor by unit_system.temperature assumes it’s a plain scalar compatible with the model dtype/device. If it can be a quantity/wrapper, cast once to the right dtype on the device to avoid dtype/device surprises.

-    kTs = temps * unit_system.temperature
+    kTs = temps * float(unit_system.temperature)
torch_sim/autobatching.py (2)

573-580: Tighten overloads with Literal for better type narrowing

Using Literal improves callers’ type inference.

-@overload
-def next_batch(self, *, return_indices: bool = False) -> SimState | None: ...
+@overload
+def next_batch(self, *, return_indices: Literal[False] = False) -> SimState | None: ...
 
-@overload
-def next_batch(
-    self, *, return_indices: bool = True
-) -> tuple[SimState, list[int]] | None: ...
+@overload
+def next_batch(
+    self, *, return_indices: Literal[True]
+) -> tuple[SimState, list[int]] | None: ...

(Remember to import Literal if not already.)


581-625: Docstring example is inconsistent with this API

The example references parameters from a different method signature. Update to a simple pull-loop using return_indices when desired.

-        Example::
-
-            # Get batches one by one
-            all_converged_state, convergence = [], None
-            while (result := batcher.next_batch(state, convergence))[0] is not None:
-                state, converged_states = result
-                all_converged_states.extend(converged_states)
-
-                evolve_batch(state)
-                convergence = convergence_criterion(state)
-            else:
-                all_converged_states.extend(result[1])
+        Example::
+
+            # Get batches one by one
+            while (batch := batcher.next_batch()) is not None:
+                process_batch(batch)
+
+            # Or with indices
+            while (res := batcher.next_batch(return_indices=True)) is not None:
+                batch, idx = res
+                process_batch(batch)
torch_sim/trajectory.py (1)

155-173: Sequence handling for filenames is good; minor nit on uniqueness check

Converting to Path and deduplicating via set(...) is fine. If you want to guard against differing string representations (e.g., "./a.h5" vs "a.h5"), consider resolving paths before checking uniqueness.

-        self.filenames = [pathlib.Path(filename) for filename in filenames]
+        self.filenames = [pathlib.Path(filename).resolve() for filename in filenames]
tests/test_trajectory.py (2)

263-263: Prefer Mapping and align union with write_arrays contract

Use a covariant Mapping and include np.generic to mirror the accepted input types; avoids overspecifying mutability and matches the API.

-    test_data: dict[str, np.ndarray | torch.Tensor] = {
+    test_data: Mapping[str, np.ndarray | np.generic | torch.Tensor] = {

Add the required import:

-from collections.abc import Callable, Generator
+from collections.abc import Callable, Generator, Mapping

379-379: Unify scalar_data typing with Mapping and shared union

Match the container/type union used above for consistency and API parity.

-    scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = {
+    scalar_data: Mapping[str, np.ndarray | np.generic | torch.Tensor] = {
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 390c071 and c27cd45.

📒 Files selected for processing (4)
  • tests/test_trajectory.py (2 hunks)
  • torch_sim/autobatching.py (2 hunks)
  • torch_sim/runners.py (2 hunks)
  • torch_sim/trajectory.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
torch_sim/autobatching.py (2)
torch_sim/state.py (1)
  • SimState (29-493)
tests/test_autobatching.py (3)
  • test_binning_auto_batcher_with_indices (215-241)
  • test_binning_auto_batcher (127-165)
  • test_in_flight_auto_batcher (307-352)
torch_sim/runners.py (2)
torch_sim/state.py (4)
  • dtype (185-187)
  • device (180-182)
  • SimState (29-493)
  • initialize_state (919-987)
torch_sim/models/interface.py (4)
  • dtype (90-92)
  • dtype (95-99)
  • device (78-80)
  • device (83-87)
torch_sim/trajectory.py (1)
tests/test_trajectory.py (2)
  • trajectory (42-46)
  • prop_calculators (523-532)
🪛 Ruff (0.12.2)
torch_sim/runners.py

111-111: Unused noqa directive (non-enabled: C901)

Remove unused noqa directive

(RUF100)


156-156: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
torch_sim/trajectory.py (1)

469-504: Input normalization looks good

Accepting Mapping and auto-normalizing torch Tensors and numpy scalars to ndarray is a solid improvement for writer ergonomics.

def write_arrays(
self,
data: dict[str, np.ndarray | torch.Tensor],
data: "Mapping[str, np.ndarray | np.generic | torch.Tensor]",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is dict not sufficient to type this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we run into https://mypy.readthedocs.io/en/stable/common_issues.html#variance with mypy

torch_sim/trajectory.py:784: note: "dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
torch_sim/trajectory.py:784: note: Consider using "Mapping" instead, which is covariant in the value type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need the np.generic because of this error:

tests/test_trajectory.py:390: error: Argument 1 to "write_arrays" of "TorchSimTrajectory" has incompatible type "dict[str, ndarray[Any, Any] | generic | Tensor]"; expected "Mapping[str, ndarray[Any, Any] | Tensor]" [arg-type]

which allows us to write a np.bool inside test_trajectory.py

scalar_data: dict[str, np.ndarray | np.generic | torch.Tensor] = {
    "float64_scalar": np.float64(1.0),
    "float32_scalar": np.float32(1.0),
    "int64_scalar": np.int64(1),
    "int32_scalar": np.int32(1),
    "bool_scalar": np.bool_(True),  # noqa: FBT003
    "torch_float_scalar": torch.tensor(1.0, dtype=torch.float64),
    "torch_int_scalar": torch.tensor(1, dtype=torch.int64),
    "torch_bool_scalar": torch.tensor(data=True),
}

traj.write_arrays(scalar_data, steps=0)

return self._file.root.data.positions.shape[0]

def write_ase_trajectory(self, filename: str | pathlib.Path) -> "Trajectory":
def write_ase_trajectory(self, filename: str | pathlib.Path) -> "TrajectoryReader":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels wrong?

Copy link
Collaborator Author

@curtischong curtischong Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you scroll to line 1015, we are opening Trajectory in read mode return Trajectory(filename) # Reopen in read mode, which is a TrajectoryReader object

if TYPE_CHECKING:
from ase import Atoms
from ase.io.trajectory import Trajectory
from ase.io.trajectory import TrajectoryReader
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view my comment to Rhys below

@CompRhys CompRhys removed the cla-signed Contributor license agreement signed label Sep 30, 2025
@CompRhys CompRhys merged commit c6552b3 into main Sep 30, 2025
89 checks passed
@CompRhys CompRhys deleted the type-trajectory branch September 30, 2025 22:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants