From 48f81731af26c8f63a80a13c976c88238df54755 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 11 May 2025 16:34:00 -0700 Subject: [PATCH 1/4] Fix types so pylance's type checker doesn't complain --- .gitignore | 2 ++ torch_sim/state.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 29646ebf..304841eb 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ coverage.xml # env uv.lock + +.vscode/ \ No newline at end of file diff --git a/torch_sim/state.py b/torch_sim/state.py index e5997cf4..c51e7123 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -152,7 +152,7 @@ def n_batches(self) -> int: return torch.unique(self.batch).shape[0] @property - def volume(self) -> torch.Tensor: + def volume(self) -> torch.Tensor | None: """Volume of the system.""" return torch.det(self.cell) if self.pbc else None @@ -226,7 +226,7 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """ return ts.io.state_to_phonopy(self) - def split(self) -> list[Self]: + def split(self) -> list["SimState"]: """Split the SimState into a list of single-batch SimStates. Divides the current state into separate states, each containing a single batch, @@ -237,7 +237,9 @@ def split(self) -> list[Self]: """ return _split_state(self) - def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: + def pop( + self, batch_indices: int | list[int] | slice | torch.Tensor + ) -> list["SimState"]: """Pop off states with the specified batch indices. This method modifies the original state object by removing the specified @@ -282,7 +284,9 @@ def to( """ return state_to_device(self, device, dtype) - def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self: + def __getitem__( + self, batch_indices: int | list[int] | slice | torch.Tensor + ) -> "SimState": """Enable standard Python indexing syntax for slicing batches. Args: From b345a0afc96ba10f5dc4612de53d955ab815746b Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 11 May 2025 16:40:31 -0700 Subject: [PATCH 2/4] cast list of simstates --- torch_sim/state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index c51e7123..b7cd5c1b 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import warnings from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, Self, cast import torch @@ -391,7 +391,7 @@ def _normalize_batch_indices( def state_to_device( state: SimState, device: torch.device | None = None, dtype: torch.dtype | None = None -) -> Self: +) -> SimState: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -868,6 +868,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): + system = cast(list[SimState], system) if not all(state.n_batches == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " From ffde2581db07735e01fd107de218f66ee2b590f7 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 11 May 2025 17:11:06 -0700 Subject: [PATCH 3/4] rename global to per graph --- examples/tutorials/state_tutorial.py | 2 +- tests/test_state.py | 4 +-- torch_sim/state.py | 45 ++++++++++++++++------------ 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 1abf5d28..ca99d1b3 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -258,7 +258,7 @@ print("MDState properties:") scope = infer_property_scope(md_state) -print("Global properties:", scope["global"]) +print("Global properties:", scope["per_graph"]) print("Per-atom properties:", scope["per_atom"]) print("Per-batch properties:", scope["per_batch"]) diff --git a/tests/test_state.py b/tests/test_state.py index 26fec0ea..faf6b904 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -26,7 +26,7 @@ def test_infer_sim_state_property_scope(si_sim_state: SimState) -> None: """Test inference of property scope.""" scope = infer_property_scope(si_sim_state) - assert set(scope["global"]) == {"pbc"} + assert set(scope["per_graph"]) == {"pbc"} assert set(scope["per_atom"]) == {"positions", "masses", "atomic_numbers", "batch"} assert set(scope["per_batch"]) == {"cell"} @@ -40,7 +40,7 @@ def test_infer_md_state_property_scope(si_sim_state: SimState) -> None: energy=torch.zeros((1,)), ) scope = infer_property_scope(state) - assert set(scope["global"]) == {"pbc"} + assert set(scope["per_graph"]) == {"pbc"} assert set(scope["per_atom"]) == { "positions", "masses", diff --git a/torch_sim/state.py b/torch_sim/state.py index b7cd5c1b..a30886d0 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import warnings from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Self, cast +from typing import TYPE_CHECKING, Literal, Self, TypedDict, cast import torch @@ -270,7 +270,7 @@ def pop( def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None - ) -> Self: + ) -> "SimState": """Convert the SimState to a new device and/or data type. Args: @@ -423,10 +423,18 @@ def state_to_device( return type(state)(**attrs) +class Scope(TypedDict): + """Dictionary mapping each scope category to a list of property names""" + + per_graph: list[str] + per_atom: list[str] + per_batch: list[str] + + def infer_property_scope( state: SimState, ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_batch"], list[str]]: +) -> Scope: """Infer whether a property is global, per-atom, or per-batch. Analyzes the shapes of tensor attributes to determine their scope within @@ -441,8 +449,7 @@ def infer_property_scope( - "globalize_warn": Treat ambiguous properties as global with a warning Returns: - dict[Literal["global", "per_atom", "per_batch"], list[str]]: Dictionary mapping - each scope category to a list of property names + Scope: Dictionary mapping each scope category to a list of property names Raises: ValueError: If n_atoms equals n_batches (making scope inference ambiguous) or @@ -458,8 +465,8 @@ def infer_property_scope( "which means shapes cannot be inferred unambiguously." ) - scope = { - "global": [], + scope: Scope = { + "per_graph": [], "per_atom": [], "per_batch": [], } @@ -468,7 +475,7 @@ def infer_property_scope( for attr_name, attr_value in vars(state).items(): # Handle scalar values (global properties) if not isinstance(attr_value, torch.Tensor): - scope["global"].append(attr_name) + scope["per_graph"].append(attr_name) continue # Handle tensor properties based on shape @@ -476,7 +483,7 @@ def infer_property_scope( # Empty tensor case if len(shape) == 0: - scope["global"].append(attr_name) + scope["per_graph"].append(attr_name) # Vector/matrix with first dimension matching number of atoms elif shape[0] == state.n_atoms: scope["per_atom"].append(attr_name) @@ -491,7 +498,7 @@ def infer_property_scope( f"{state.n_batches} (per-batch), or a scalar (global)." ) elif ambiguous_handling in ("globalize", "globalize_warn"): - scope["global"].append(attr_name) + scope["per_graph"].append(attr_name) if ambiguous_handling == "globalize_warn": warnings.warn( @@ -523,11 +530,11 @@ def _get_property_attrs( """ scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - attrs = {"global": {}, "per_atom": {}, "per_batch": {}} + attrs: dict[str, dict] = {"per_graph": {}, "per_atom": {}, "per_batch": {}} # Process global properties - for attr_name in scope["global"]: - attrs["global"][attr_name] = getattr(state, attr_name) + for attr_name in scope["per_graph"]: + attrs["per_graph"][attr_name] = getattr(state, attr_name) # Process per-atom properties for attr_name in scope["per_atom"]: @@ -541,7 +548,7 @@ def _get_property_attrs( def _filter_attrs_by_mask( - attrs: dict[str, dict], + attrs: Scope, atom_mask: torch.Tensor, batch_mask: torch.Tensor, ) -> dict: @@ -550,7 +557,7 @@ def _filter_attrs_by_mask( Selects subsets of attributes based on boolean masks for atoms and batches. Args: - attrs (dict[str, dict]): Dictionary with 'global', 'per_atom', and 'per_batch' + attrs (Scope): Dictionary with 'per_graph', 'per_atom', and 'per_batch' attributes atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) @@ -563,7 +570,7 @@ def _filter_attrs_by_mask( filtered_attrs = {} # Copy global attributes directly - filtered_attrs.update(attrs["global"]) + filtered_attrs.update(attrs["per_graph"]) # Filter per-atom attributes for attr_name, attr_value in attrs["per_atom"].items(): @@ -642,7 +649,7 @@ def _split_state( # Add the split per-batch attributes **{attr_name: split_per_batch[attr_name][i] for attr_name in split_per_batch}, # Add the global attributes - **attrs["global"], + **attrs["per_graph"], } states.append(type(state)(**batch_attrs)) @@ -789,12 +796,12 @@ def concatenate_states( # Get property scopes from the first state to identify # global/per-atom/per-batch properties first_scope = infer_property_scope(first_state) - global_props = set(first_scope["global"]) + per_graph_props = set(first_scope["per_graph"]) per_atom_props = set(first_scope["per_atom"]) per_batch_props = set(first_scope["per_batch"]) # Initialize result with global properties from first state - concatenated = {prop: getattr(first_state, prop) for prop in global_props} + concatenated = {prop: getattr(first_state, prop) for prop in per_graph_props} # Pre-allocate lists for tensors to concatenate per_atom_tensors = {prop: [] for prop in per_atom_props} From e9c8b4e02e6119a1a0f21e98afe996fa08c2b715 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 11 May 2025 17:25:59 -0700 Subject: [PATCH 4/4] more renaming global to per-graph --- examples/tutorials/state_tutorial.py | 4 ++-- torch_sim/models/lennard_jones.py | 2 +- torch_sim/models/morse.py | 2 +- torch_sim/state.py | 12 ++++++------ 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index ca99d1b3..4b615a9d 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -112,7 +112,7 @@ f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_batches} batches" ) -# we can see how the shapes of batchwise, atomwise, and global properties change +# we can see how the shapes of batchwise, atomwise, and per-graph properties change print(f"Positions shape: {multi_state.positions.shape}") print(f"Cell shape: {multi_state.cell.shape}") print(f"PBC: {multi_state.pbc}") @@ -258,7 +258,7 @@ print("MDState properties:") scope = infer_property_scope(md_state) -print("Global properties:", scope["per_graph"]) +print("Per-graph properties:", scope["per_graph"]) print("Per-atom properties:", scope["per_atom"]) print("Per-batch properties:", scope["per_batch"]) diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index d611d3f9..a47bb3d2 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -325,7 +325,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # we always return tensors # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors + # per-graph properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 4b3dbbcd..8b215d1b 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -308,7 +308,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # we always return tensors # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors + # per-graph properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: diff --git a/torch_sim/state.py b/torch_sim/state.py index a30886d0..64013126 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -455,7 +455,7 @@ def infer_property_scope( ValueError: If n_atoms equals n_batches (making scope inference ambiguous) or if ambiguous_handling="error" and an ambiguous property is encountered """ - # TODO: this cannot effectively resolve global properties with + # TODO: this cannot effectively resolve per-graph properties with # length of n_atoms or n_batches, they will be classified incorrectly, # no clear fix @@ -473,7 +473,7 @@ def infer_property_scope( # Iterate through all attributes for attr_name, attr_value in vars(state).items(): - # Handle scalar values (global properties) + # Handle scalar values (per-graph properties) if not isinstance(attr_value, torch.Tensor): scope["per_graph"].append(attr_name) continue @@ -532,7 +532,7 @@ def _get_property_attrs( attrs: dict[str, dict] = {"per_graph": {}, "per_atom": {}, "per_batch": {}} - # Process global properties + # Process per-graph properties for attr_name in scope["per_graph"]: attrs["per_graph"][attr_name] = getattr(state, attr_name) @@ -764,7 +764,7 @@ def concatenate_states( """Concatenate a list of SimStates into a single SimState. Combines multiple states into a single state with multiple batches. - Global properties are taken from the first state, and per-atom and per-batch + per-graph properties are taken from the first state, and per-atom and per-batch properties are concatenated. Args: @@ -794,13 +794,13 @@ def concatenate_states( target_device = device or first_state.device # Get property scopes from the first state to identify - # global/per-atom/per-batch properties + # per-graph/per-atom/per-batch properties first_scope = infer_property_scope(first_state) per_graph_props = set(first_scope["per_graph"]) per_atom_props = set(first_scope["per_atom"]) per_batch_props = set(first_scope["per_batch"]) - # Initialize result with global properties from first state + # Initialize result with per-graph properties from first state concatenated = {prop: getattr(first_state, prop) for prop in per_graph_props} # Pre-allocate lists for tensors to concatenate