diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 6fcb50e2..fb3f7983 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -76,6 +76,9 @@ class HybridSwapMCState(MDState): """ last_permutation: torch.Tensor + _atom_attributes = ( + MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index d8adbbf8..08c3914b 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -34,9 +34,11 @@ """ # %% +from typing import ClassVar import torch import torch_sim as ts from mace.calculators.foundations_models import mace_mp +from torch_sim.integrators.md import MDState from torch_sim.models.mace import MaceModel # Initialize the mace model @@ -104,6 +106,9 @@ class HybridSwapMCState(ts.integrators.MDState): """ last_permutation: torch.Tensor + _atom_attributes = ( + MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) # %% [markdown] diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 0d3eca96..5cd43b17 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -79,16 +79,28 @@ the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. -You can use the `infer_property_scope` function to analyze a state's properties. This +For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's +name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`: + +_atom_attributes = {"positions", "masses", "atomic_numbers", "system_idx"} +_system_attributes = {"cell"} +_global_attributes = {"pbc"} + +You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This is mostly used internally but can be useful for debugging. """ # %% -from torch_sim.state import infer_property_scope +from torch_sim.state import get_attrs_for_scope -scope = infer_property_scope(si_state) -print(scope) +# loop through each attribute: +for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"): + print(f"per-atom attribute: {attr_name}") + print(f"value: {attr_value}") +# or access the attributes via a dict: +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501 +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) # %% [markdown] """ @@ -257,10 +269,9 @@ ) print("MDState properties:") -scope = infer_property_scope(md_state) -print("Global properties:", scope["global"]) -print("Per-atom properties:", scope["per_atom"]) -print("Per-system properties:", scope["per_system"]) +print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom"))) +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) # %% [markdown] diff --git a/tests/test_state.py b/tests/test_state.py index 81109bf3..67a757fe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,7 +13,7 @@ _pop_states, _slice_state, concatenate_states, - infer_property_scope, + get_attrs_for_scope, initialize_state, ) @@ -24,38 +24,52 @@ from pymatgen.core import Structure -def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - scope = infer_property_scope(si_sim_state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { +def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None: + """Test getting attributes for a scope.""" + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) + assert set(per_atom_attrs.keys()) == { "positions", "masses", "atomic_numbers", "system_idx", } - assert set(scope["per_system"]) == {"cell"} + per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) + assert set(per_system_attrs.keys()) == {"cell"} + global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) + assert set(global_attrs.keys()) == {"pbc"} -def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - state = MDState( - **asdict(si_sim_state), - momenta=torch.randn_like(si_sim_state.positions), - forces=torch.randn_like(si_sim_state.positions), - energy=torch.zeros((1,)), - ) - scope = infer_property_scope(state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { - "positions", - "masses", - "atomic_numbers", - "system_idx", - "forces", - "momenta", - } - assert set(scope["per_system"]) == {"cell", "energy"} +def test_all_attributes_must_be_specified_in_scopes() -> None: + """Test that an error is raised when we forget to specify the scope + for an attribute in a child SimState class.""" + with pytest.raises(TypeError) as excinfo: + + class ChildState(SimState): + attribute_specified_in_scopes: bool + attribute_not_specified_in_scopes: bool + + _atom_attributes = ( + SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 + ) + + assert "attribute_not_specified_in_scopes" in str(excinfo.value) + assert "attribute_specified_in_scopes" not in str(excinfo.value) + + +def test_no_duplicate_attributes_in_scopes() -> None: + """Test that no attributes are specified in multiple scopes.""" + + # Capture the exception information using "as excinfo" + with pytest.raises(TypeError) as excinfo: + + class ChildState(SimState): + duplicated_attribute: bool + + _system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 + _global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 + + assert "are declared multiple times" in str(excinfo.value) + assert "duplicated_attribute" in str(excinfo.value) def test_slice_substate( @@ -497,6 +511,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None: class DeformState(SimState, DeformGradMixin): """Test class that combines SimState with DeformGradMixin.""" + _system_attributes = ( + SimState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + ) + def __init__( self, *args, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 069f62eb..490e3528 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -41,6 +41,13 @@ class MDState(SimState): energy: torch.Tensor forces: torch.Tensor + _atom_attributes = ( + SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001 + ) + _system_attributes = ( + SimState._system_attributes | {"energy"} # noqa: SLF001 + ) + @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index e1ac1d39..f5f23c5c 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -67,6 +67,18 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor + _atom_attributes = ( + SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 + ) + _system_attributes = SimState._system_attributes | { # noqa: SLF001 + "stress", + "cell_positions", + "cell_velocities", + "cell_masses", + "reference_cell", + "energy", + } + @property def momenta(self) -> torch.Tensor: """Calculate momenta from velocities and masses.""" @@ -867,6 +879,25 @@ class NPTNoseHooverState(MDState): barostat: NoseHooverChain barostat_fns: NoseHooverChainFns + _system_attributes = ( + MDState._system_attributes # noqa: SLF001 + | { + "reference_cell", + "cell_position", + "cell_momentum", + "cell_mass", + } + ) + _global_attributes = ( + MDState._global_attributes # noqa: SLF001 + | { + "thermostat", + "barostat", + "thermostat_fns", + "barostat_fns", + } + ) + @property def velocities(self) -> torch.Tensor: """Calculate particle velocities from momenta and masses. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8309afa1..18f0ae15 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -266,6 +266,10 @@ class NVTNoseHooverState(MDState): chain: NoseHooverChain _chain_fns: NoseHooverChainFns + _global_attributes = ( + MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001 + ) + @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 64aad3c7..be2d99a8 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -36,6 +36,9 @@ class SwapMCState(SimState): energy: torch.Tensor last_permutation: torch.Tensor + _atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 + def generate_swaps( state: SimState, generator: torch.Generator | None = None diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 9a41cd8c..a88d40d9 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -33,6 +33,29 @@ MdFlavor = Literal["vv_fire", "ase_fire"] vv_fire_key, ase_fire_key = get_args(MdFlavor) +_md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 +_fire_system_attributes = ( + SimState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + | { + "energy", + "stress", + "cell_positions", + "cell_velocities", + "cell_forces", + "cell_masses", + "cell_factor", + "pressure", + "dt", + "alpha", + "n_pos", + } +) +_fire_global_attributes = SimState._global_attributes | { # noqa: SLF001 + "hydrostatic_strain", + "constant_volume", +} + @dataclass class GDState(SimState): @@ -56,6 +79,9 @@ class GDState(SimState): forces: torch.Tensor energy: torch.Tensor + _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 + def gradient_descent( model: ModelInterface, *, lr: torch.Tensor | float = 0.01 @@ -149,7 +175,7 @@ def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: return gd_init, gd_step -@dataclass +@dataclass(kw_only=True) class UnitCellGDState(GDState, DeformGradMixin): """State class for batched gradient descent optimization with unit cell. @@ -195,6 +221,22 @@ class UnitCellGDState(GDState, DeformGradMixin): cell_forces: torch.Tensor cell_masses: torch.Tensor + _system_attributes = ( + GDState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + | { + "cell_forces", + "pressure", + "stress", + "cell_positions", + "cell_factor", + "cell_masses", + } + ) + _global_attributes = ( + GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001 + ) + def unit_cell_gradient_descent( # noqa: PLR0915, C901 model: ModelInterface, @@ -438,7 +480,7 @@ def gd_step( return gd_init, gd_step -@dataclass +@dataclass(kw_only=True) class FireState(SimState): """State information for batched FIRE optimization. @@ -482,6 +524,17 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = _md_atom_attributes + _system_attributes = ( + SimState._system_attributes # noqa: SLF001 + | { + "energy", + "dt", + "alpha", + "n_pos", + } + ) + def fire( model: ModelInterface, @@ -619,7 +672,7 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -@dataclass +@dataclass(kw_only=True) class UnitCellFireState(SimState, DeformGradMixin): """State information for batched FIRE optimization with unit cell degrees of freedom. @@ -682,7 +735,6 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_masses: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -693,6 +745,10 @@ class UnitCellFireState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = _md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes + def unit_cell_fire( model: ModelInterface, @@ -907,7 +963,7 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -@dataclass +@dataclass(kw_only=True) class FrechetCellFIREState(SimState, DeformGradMixin): """State class for batched FIRE optimization with Frechet cell derivatives. @@ -964,7 +1020,6 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -981,6 +1036,10 @@ class FrechetCellFIREState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = _md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes + def frechet_cell_fire( model: ModelInterface, @@ -1261,7 +1320,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions[nan_velocities] ) if is_cell_optimization: - if not isinstance(state, AnyFireCellState): + if not isinstance(state, get_args(AnyFireCellState)): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) @@ -1483,7 +1542,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) forces = state.forces if is_cell_optimization: - if not isinstance(state, AnyFireCellState): + if not isinstance(state, get_args(AnyFireCellState)): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8f2917da..14164375 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -541,6 +541,13 @@ class StaticState(type(state)): forces: torch.Tensor stress: torch.Tensor + _atom_attributes = ( + state._atom_attributes | {"forces"} # noqa: SLF001 + ) + _system_attributes = ( + state._system_attributes | {"energy", "stress"} # noqa: SLF001 + ) + all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames @@ -564,8 +571,12 @@ class StaticState(type(state)): sub_state = StaticState( **vars(sub_state), energy=model_outputs["energy"], - forces=model_outputs["forces"] if model.compute_forces else None, - stress=model_outputs["stress"] if model.compute_stress else None, + forces=model_outputs["forces"] + if model.compute_forces + else torch.full_like(sub_state.positions, fill_value=float("nan")), + stress=model_outputs["stress"] + if model.compute_stress + else torch.full_like(sub_state.cell, fill_value=float("nan")), ) props = trajectory_reporter.report(sub_state, 0, model=model) diff --git a/torch_sim/state.py b/torch_sim/state.py index d2ec8351..fa898e1c 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,8 +8,10 @@ import importlib import typing import warnings +from collections import defaultdict +from collections.abc import Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast import torch @@ -83,6 +85,15 @@ class SimState: atomic_numbers: torch.Tensor system_idx: torch.Tensor + _atom_attributes: ClassVar[set[str]] = { + "positions", + "masses", + "atomic_numbers", + "system_idx", + } + _system_attributes: ClassVar[set[str]] = {"cell"} + _global_attributes: ClassVar[set[str]] = {"pbc"} + def __init__( self, positions: torch.Tensor, @@ -204,7 +215,7 @@ def n_atoms_per_batch(self) -> torch.Tensor: return self.n_atoms_per_system @property - def batch(self) -> torch.Tensor | None: + def batch(self) -> torch.Tensor: """System indices. deprecated:: @@ -405,7 +416,16 @@ def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + + Also enforce all of child classes's attributes are specified in _atom_attributes, + _system_attributes, or _global_attributes. """ + cls._assert_no_tensor_attributes_can_be_none() + cls._assert_all_attributes_have_defined_scope() + super().__init_subclass__(**kwargs) + + @classmethod + def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) for attr_name, attr_typehint in type_hints.items(): @@ -426,14 +446,65 @@ def __init_subclass__(cls, **kwargs) -> None: "the tensor with dummy values and track the 'None' case." ) - super().__init_subclass__(**kwargs) + @classmethod + def _assert_all_attributes_have_defined_scope(cls) -> None: + all_defined_attributes = ( + cls._atom_attributes | cls._system_attributes | cls._global_attributes + ) + # 1) assert that no attribute is defined twice in all_defined_attributes + duplicates = ( + (cls._atom_attributes & cls._system_attributes) + | (cls._atom_attributes & cls._global_attributes) + | (cls._system_attributes & cls._global_attributes) + ) + if duplicates: + raise TypeError( + f"Attributes {duplicates} are declared multiple times in {cls.__name__} " + "in _atom_attributes, _system_attributes, or _global_attributes" + ) + + # 2) assert that all attributes are defined in all_defined_attributes + all_annotations = {} + for c in cls.mro(): + if hasattr(c, "__annotations__"): + all_annotations.update(c.__annotations__) + + attributes_to_check = set(vars(cls).keys()) | set(all_annotations.keys()) + + for attr_name in attributes_to_check: + is_special_attribute = attr_name.startswith("__") + is_property = attr_name in vars(cls) and isinstance( + vars(cls).get(attr_name), property + ) + is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) + is_class_variable = ( + # Note: _atom_attributes, _system_attributes, and _global_attributes + # are all class variables + typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar + ) + if is_special_attribute or is_property or is_method or is_class_variable: + continue + if attr_name not in all_defined_attributes: + raise TypeError( + f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " + "of _atom_attributes, _system_attributes, or _global_attributes" + ) + + +@dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor - row_vector_cell: torch.Tensor + + _system_attributes: ClassVar[set[str]] = {"reference_cell"} + + if TYPE_CHECKING: + # define this under a TYPE_CHECKING block to avoid it being included in the + # dataclass __init__ during runtime + row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: @@ -543,125 +614,34 @@ def state_to_device( return type(state)(**attrs) -def infer_property_scope( - state: SimState, - ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_system"], list[str]]: - """Infer whether a property is global, per-atom, or per-system. - - Analyzes the shapes of tensor attributes to determine their scope within - the atomistic system representation. - - Args: - state (SimState): The state to analyze - ambiguous_handling ("error" | "globalize" | "globalize_warn"): How to - handle properties with ambiguous scope. Options: - - "error": Raise an error for ambiguous properties - - "globalize": Treat ambiguous properties as global - - "globalize_warn": Treat ambiguous properties as global with a warning - - Returns: - dict[Literal["global", "per_atom", "per_system"], list[str]]: Map of scope - category to list of property names - - Raises: - ValueError: If n_atoms equals n_systems (making scope inference ambiguous) or - if ambiguous_handling="error" and an ambiguous property is encountered - """ - # TODO: this cannot effectively resolve global properties with - # length of n_atoms or n_systems, they will be classified incorrectly, - # no clear fix - - if state.n_atoms == state.n_systems: - raise ValueError( - f"n_atoms ({state.n_atoms}) and n_systems ({state.n_systems}) are equal, " - "which means shapes cannot be inferred unambiguously." - ) - - scope = { - "global": [], - "per_atom": [], - "per_system": [], - } - - # Iterate through all attributes - for attr_name, attr_value in vars(state).items(): - # Handle scalar values (global properties) - if not isinstance(attr_value, torch.Tensor): - scope["global"].append(attr_name) - continue - - # Handle tensor properties based on shape - shape = attr_value.shape - - # Empty tensor case - if len(shape) == 0: - scope["global"].append(attr_name) - # Vector/matrix with first dimension matching number of atoms - elif shape[0] == state.n_atoms: - scope["per_atom"].append(attr_name) - # Tensor with first dimension matching number of systems - elif shape[0] == state.n_systems: - scope["per_system"].append(attr_name) - # Any other shape is ambiguous - elif ambiguous_handling == "error": - raise ValueError( - f"Cannot categorize property '{attr_name}' with shape {shape}. " - f"Expected first dimension to be either {state.n_atoms} (per-atom) or " - f"{state.n_systems} (per-system), or a scalar (global)." - ) - elif ambiguous_handling in ("globalize", "globalize_warn"): - scope["global"].append(attr_name) - - if ambiguous_handling == "globalize_warn": - warnings.warn( - f"Property '{attr_name}' with shape {shape} is ambiguous, " - "treating as global. This may lead to unexpected behavior " - "and suggests the State is not being used as intended.", - stacklevel=1, - ) - - return scope - - -def _get_property_attrs( - state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error" -) -> dict[str, dict]: - """Get global, per-atom, and per-system attributes from a state. - - Categorizes all attributes of the state based on their scope - (global, per-atom, or per-system). +def get_attrs_for_scope( + state: SimState, scope: Literal["per-atom", "per-system", "global"] +) -> Generator[tuple[str, Any], None, None]: + """Get attributes for a given scope. Args: - state (SimState): The state to extract attributes from - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties + state (SimState): The state to get attributes for + scope (Literal["per-atom", "per-system", "global"]): The scope to get + attributes for Returns: - dict[str, dict]: Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ - scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - - attrs = {"global": {}, "per_atom": {}, "per_system": {}} - - # Process global properties - for attr_name in scope["global"]: - attrs["global"][attr_name] = getattr(state, attr_name) - - # Process per-atom properties - for attr_name in scope["per_atom"]: - attrs["per_atom"][attr_name] = getattr(state, attr_name) - - # Process per-system properties - for attr_name in scope["per_system"]: - attrs["per_system"][attr_name] = getattr(state, attr_name) - - return attrs + match scope: + case "per-atom": + attr_names = state._atom_attributes # noqa: SLF001 + case "per-system": + attr_names = state._system_attributes # noqa: SLF001 + case "global": + attr_names = state._global_attributes # noqa: SLF001 + case _: + raise ValueError(f"Unknown scope: {scope!r}") + for attr_name in attr_names: + yield attr_name, getattr(state, attr_name) def _filter_attrs_by_mask( - attrs: dict[str, dict], + state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, ) -> dict: @@ -670,8 +650,7 @@ def _filter_attrs_by_mask( Selects subsets of attributes based on boolean masks for atoms and systems. Args: - attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + state (SimState): The state to filter atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) system_mask (torch.Tensor): Boolean mask for systems to include with shape @@ -680,13 +659,11 @@ def _filter_attrs_by_mask( Returns: dict: Filtered attributes with appropriate handling for each scope """ - filtered_attrs = {} - # Copy global attributes directly - filtered_attrs.update(attrs["global"]) + filtered_attrs = dict(get_attrs_for_scope(state, "global")) # Filter per-atom attributes - for attr_name, attr_value in attrs["per_atom"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] @@ -710,7 +687,7 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs @@ -718,7 +695,6 @@ def _filter_attrs_by_mask( def _split_state( state: SimStateVar, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> list[SimStateVar]: """Split a SimState into a list of states, each containing a single system. @@ -727,33 +703,28 @@ def _split_state( Args: state (SimState): The SimState to split - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: list[SimState]: A list of SimState objects, each containing a single system """ - attrs = _get_property_attrs(state, ambiguous_handling) system_sizes = torch.bincount(state.system_idx).tolist() - # Split per-atom attributes by system split_per_atom = {} - for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "system_idx": - continue - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): + if attr_name != "system_idx": + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - # Split per-system attributes into individual elements split_per_system = {} - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) + global_attrs = dict(get_attrs_for_scope(state, "global")) + # Create a state for each system states = [] - for i in range(state.n_systems): + n_systems = len(system_sizes) + for i in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( @@ -767,7 +738,7 @@ def _split_state( for attr_name in split_per_system }, # Add the global attributes - **attrs["global"], + **global_attrs, } states.append(type(state)(**system_attrs)) @@ -777,7 +748,6 @@ def _split_state( def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. @@ -786,10 +756,6 @@ def _pop_states( Args: state (SimState): The SimState to modify pop_indices (list[int] | torch.Tensor): The system indices to extract and remove - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: tuple[SimState, list[SimState]]: A tuple containing: @@ -805,8 +771,6 @@ def _pop_states( if isinstance(pop_indices, list): pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to keep and pop system_range = torch.arange(state.n_systems, device=state.device) pop_system_mask = torch.isin(system_range, pop_indices) @@ -816,15 +780,15 @@ def _pop_states( keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states - keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_system_mask) - pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) + keep_attrs = _filter_attrs_by_mask(state, keep_atom_mask, keep_system_mask) + pop_attrs = _filter_attrs_by_mask(state, pop_atom_mask, pop_system_mask) # Create the keep state keep_state = type(state)(**keep_attrs) # Create and split the pop state pop_state = type(state)(**pop_attrs) - pop_states = _split_state(pop_state, ambiguous_handling) + pop_states = _split_state(pop_state) return keep_state, pop_states @@ -832,7 +796,6 @@ def _pop_states( def _slice_state( state: SimStateVar, system_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> SimStateVar: """Slice a substate from the SimState containing only the specified system indices. @@ -843,10 +806,6 @@ def _slice_state( state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the sliced state - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: SimState: A new SimState object containing only the specified systems @@ -862,15 +821,13 @@ def _slice_state( if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to include system_range = torch.arange(state.n_systems, device=state.device) system_mask = torch.isin(system_range, system_indices) atom_mask = torch.isin(state.system_idx, system_indices) # Filter attributes - filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) + filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) # Create the sliced state return type(state)(**filtered_attrs) @@ -911,19 +868,12 @@ def concatenate_states( # Use the target device or default to the first state's device target_device = device or first_state.device - # Get property scopes from the first state to identify - # global/per-atom/per-system properties - first_scope = infer_property_scope(first_state) - global_props = set(first_scope["global"]) - per_atom_props = set(first_scope["per_atom"]) - per_system_props = set(first_scope["per_system"]) - # Initialize result with global properties from first state - concatenated = {prop: getattr(first_state, prop) for prop in global_props} + concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate - per_atom_tensors = {prop: [] for prop in per_atom_props} - per_system_tensors = {prop: [] for prop in per_system_props} + per_atom_tensors = defaultdict(list) + per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 @@ -934,14 +884,15 @@ def concatenate_states( state = state_to_device(state, target_device) # Collect per-atom properties - for prop in per_atom_props: - # if hasattr(state, prop): - per_atom_tensors[prop].append(getattr(state, prop)) + for prop, val in get_attrs_for_scope(state, "per-atom"): + if prop == "system_idx": + # skip system_idx, it will be handled below + continue + per_atom_tensors[prop].append(val) # Collect per-system properties - for prop in per_system_props: - # if hasattr(state, prop): - per_system_tensors[prop].append(getattr(state, prop)) + for prop, val in get_attrs_for_scope(state, "per-system"): + per_system_tensors[prop].append(val) # Update system indices num_systems = state.n_systems