diff --git a/.gitignore b/.gitignore index 29646ebf..304841eb 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ coverage.xml # env uv.lock + +.vscode/ \ No newline at end of file diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 1abf5d28..4b615a9d 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -112,7 +112,7 @@ f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_batches} batches" ) -# we can see how the shapes of batchwise, atomwise, and global properties change +# we can see how the shapes of batchwise, atomwise, and per-graph properties change print(f"Positions shape: {multi_state.positions.shape}") print(f"Cell shape: {multi_state.cell.shape}") print(f"PBC: {multi_state.pbc}") @@ -258,7 +258,7 @@ print("MDState properties:") scope = infer_property_scope(md_state) -print("Global properties:", scope["global"]) +print("Per-graph properties:", scope["per_graph"]) print("Per-atom properties:", scope["per_atom"]) print("Per-batch properties:", scope["per_batch"]) diff --git a/tests/test_state.py b/tests/test_state.py index 26fec0ea..faf6b904 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -26,7 +26,7 @@ def test_infer_sim_state_property_scope(si_sim_state: SimState) -> None: """Test inference of property scope.""" scope = infer_property_scope(si_sim_state) - assert set(scope["global"]) == {"pbc"} + assert set(scope["per_graph"]) == {"pbc"} assert set(scope["per_atom"]) == {"positions", "masses", "atomic_numbers", "batch"} assert set(scope["per_batch"]) == {"cell"} @@ -40,7 +40,7 @@ def test_infer_md_state_property_scope(si_sim_state: SimState) -> None: energy=torch.zeros((1,)), ) scope = infer_property_scope(state) - assert set(scope["global"]) == {"pbc"} + assert set(scope["per_graph"]) == {"pbc"} assert set(scope["per_atom"]) == { "positions", "masses", diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index d611d3f9..a47bb3d2 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -325,7 +325,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # we always return tensors # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors + # per-graph properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 4b3dbbcd..8b215d1b 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -308,7 +308,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # we always return tensors # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors + # per-graph properties are returned as shape (..., n) tensors results = {} for key in ("stress", "energy"): if key in properties: diff --git a/torch_sim/state.py b/torch_sim/state.py index e5997cf4..64013126 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import warnings from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, Self, TypedDict, cast import torch @@ -152,7 +152,7 @@ def n_batches(self) -> int: return torch.unique(self.batch).shape[0] @property - def volume(self) -> torch.Tensor: + def volume(self) -> torch.Tensor | None: """Volume of the system.""" return torch.det(self.cell) if self.pbc else None @@ -226,7 +226,7 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """ return ts.io.state_to_phonopy(self) - def split(self) -> list[Self]: + def split(self) -> list["SimState"]: """Split the SimState into a list of single-batch SimStates. Divides the current state into separate states, each containing a single batch, @@ -237,7 +237,9 @@ def split(self) -> list[Self]: """ return _split_state(self) - def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: + def pop( + self, batch_indices: int | list[int] | slice | torch.Tensor + ) -> list["SimState"]: """Pop off states with the specified batch indices. This method modifies the original state object by removing the specified @@ -268,7 +270,7 @@ def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Sel def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None - ) -> Self: + ) -> "SimState": """Convert the SimState to a new device and/or data type. Args: @@ -282,7 +284,9 @@ def to( """ return state_to_device(self, device, dtype) - def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self: + def __getitem__( + self, batch_indices: int | list[int] | slice | torch.Tensor + ) -> "SimState": """Enable standard Python indexing syntax for slicing batches. Args: @@ -387,7 +391,7 @@ def _normalize_batch_indices( def state_to_device( state: SimState, device: torch.device | None = None, dtype: torch.dtype | None = None -) -> Self: +) -> SimState: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -419,10 +423,18 @@ def state_to_device( return type(state)(**attrs) +class Scope(TypedDict): + """Dictionary mapping each scope category to a list of property names""" + + per_graph: list[str] + per_atom: list[str] + per_batch: list[str] + + def infer_property_scope( state: SimState, ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_batch"], list[str]]: +) -> Scope: """Infer whether a property is global, per-atom, or per-batch. Analyzes the shapes of tensor attributes to determine their scope within @@ -437,14 +449,13 @@ def infer_property_scope( - "globalize_warn": Treat ambiguous properties as global with a warning Returns: - dict[Literal["global", "per_atom", "per_batch"], list[str]]: Dictionary mapping - each scope category to a list of property names + Scope: Dictionary mapping each scope category to a list of property names Raises: ValueError: If n_atoms equals n_batches (making scope inference ambiguous) or if ambiguous_handling="error" and an ambiguous property is encountered """ - # TODO: this cannot effectively resolve global properties with + # TODO: this cannot effectively resolve per-graph properties with # length of n_atoms or n_batches, they will be classified incorrectly, # no clear fix @@ -454,17 +465,17 @@ def infer_property_scope( "which means shapes cannot be inferred unambiguously." ) - scope = { - "global": [], + scope: Scope = { + "per_graph": [], "per_atom": [], "per_batch": [], } # Iterate through all attributes for attr_name, attr_value in vars(state).items(): - # Handle scalar values (global properties) + # Handle scalar values (per-graph properties) if not isinstance(attr_value, torch.Tensor): - scope["global"].append(attr_name) + scope["per_graph"].append(attr_name) continue # Handle tensor properties based on shape @@ -472,7 +483,7 @@ def infer_property_scope( # Empty tensor case if len(shape) == 0: - scope["global"].append(attr_name) + scope["per_graph"].append(attr_name) # Vector/matrix with first dimension matching number of atoms elif shape[0] == state.n_atoms: scope["per_atom"].append(attr_name) @@ -487,7 +498,7 @@ def infer_property_scope( f"{state.n_batches} (per-batch), or a scalar (global)." ) elif ambiguous_handling in ("globalize", "globalize_warn"): - scope["global"].append(attr_name) + scope["per_graph"].append(attr_name) if ambiguous_handling == "globalize_warn": warnings.warn( @@ -519,11 +530,11 @@ def _get_property_attrs( """ scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - attrs = {"global": {}, "per_atom": {}, "per_batch": {}} + attrs: dict[str, dict] = {"per_graph": {}, "per_atom": {}, "per_batch": {}} - # Process global properties - for attr_name in scope["global"]: - attrs["global"][attr_name] = getattr(state, attr_name) + # Process per-graph properties + for attr_name in scope["per_graph"]: + attrs["per_graph"][attr_name] = getattr(state, attr_name) # Process per-atom properties for attr_name in scope["per_atom"]: @@ -537,7 +548,7 @@ def _get_property_attrs( def _filter_attrs_by_mask( - attrs: dict[str, dict], + attrs: Scope, atom_mask: torch.Tensor, batch_mask: torch.Tensor, ) -> dict: @@ -546,7 +557,7 @@ def _filter_attrs_by_mask( Selects subsets of attributes based on boolean masks for atoms and batches. Args: - attrs (dict[str, dict]): Dictionary with 'global', 'per_atom', and 'per_batch' + attrs (Scope): Dictionary with 'per_graph', 'per_atom', and 'per_batch' attributes atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) @@ -559,7 +570,7 @@ def _filter_attrs_by_mask( filtered_attrs = {} # Copy global attributes directly - filtered_attrs.update(attrs["global"]) + filtered_attrs.update(attrs["per_graph"]) # Filter per-atom attributes for attr_name, attr_value in attrs["per_atom"].items(): @@ -638,7 +649,7 @@ def _split_state( # Add the split per-batch attributes **{attr_name: split_per_batch[attr_name][i] for attr_name in split_per_batch}, # Add the global attributes - **attrs["global"], + **attrs["per_graph"], } states.append(type(state)(**batch_attrs)) @@ -753,7 +764,7 @@ def concatenate_states( """Concatenate a list of SimStates into a single SimState. Combines multiple states into a single state with multiple batches. - Global properties are taken from the first state, and per-atom and per-batch + per-graph properties are taken from the first state, and per-atom and per-batch properties are concatenated. Args: @@ -783,14 +794,14 @@ def concatenate_states( target_device = device or first_state.device # Get property scopes from the first state to identify - # global/per-atom/per-batch properties + # per-graph/per-atom/per-batch properties first_scope = infer_property_scope(first_state) - global_props = set(first_scope["global"]) + per_graph_props = set(first_scope["per_graph"]) per_atom_props = set(first_scope["per_atom"]) per_batch_props = set(first_scope["per_batch"]) - # Initialize result with global properties from first state - concatenated = {prop: getattr(first_state, prop) for prop in global_props} + # Initialize result with per-graph properties from first state + concatenated = {prop: getattr(first_state, prop) for prop in per_graph_props} # Pre-allocate lists for tensors to concatenate per_atom_tensors = {prop: [] for prop in per_atom_props} @@ -864,6 +875,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): + system = cast(list[SimState], system) if not all(state.n_batches == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, "