From 5cfd5526a45daa17cd22a54435a82e0230564605 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 27 Jul 2025 03:25:01 +0000 Subject: [PATCH 01/32] use hardcoded attr names add test for the new pre-defined attributes wip make fire simstate attributes predefined rename features to attributes define attribute scope for fire state consolidate attribute definitions and add it to npt rm infer_property_scope import __init_subclass__ to enforce that all attributes are specified find callable attributes do not check some user attributes filter for @properties __init_subclass__ doesn't catch for static state since it's a grandchild state added documentation for get_attrs_for_scope more examples for the other scopes cleaner documentation rename batch to system more docs cleanup test for running the init subclass more tests define more scope for the integrators split state handles none properties a bit better --- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 1 + examples/tutorials/state_tutorial.py | 27 +- pyproject.toml | 1 + tests/test_state.py | 73 +++-- torch_sim/integrators/md.py | 3 + torch_sim/integrators/npt.py | 27 ++ torch_sim/integrators/nvt.py | 6 + torch_sim/monte_carlo.py | 3 + torch_sim/optimizers.py | 52 ++++ torch_sim/runners.py | 3 + torch_sim/state.py | 280 +++++++----------- 11 files changed, 272 insertions(+), 204 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 6fcb50e2..ab64a2da 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -76,6 +76,7 @@ class HybridSwapMCState(MDState): """ last_permutation: torch.Tensor + _system_attributes = (*MDState._system_attributes, "last_permutation") # noqa: SLF001 nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT) diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 0d3eca96..3c705262 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -79,16 +79,28 @@ the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. -You can use the `infer_property_scope` function to analyze a state's properties. This +For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's +name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`: + +_atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx") +_system_attributes = ("cell",) +_global_attributes = ("pbc",) + +You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This is mostly used internally but can be useful for debugging. """ # %% -from torch_sim.state import infer_property_scope +from torch_sim.state import get_attrs_for_scope -scope = infer_property_scope(si_state) -print(scope) +# loop through each attribute: +for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"): + print(f"per-atom attribute: {attr_name}") + print(f"value: {attr_value}") +# or access the attributes via a dict: +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501 +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) # %% [markdown] """ @@ -257,10 +269,9 @@ ) print("MDState properties:") -scope = infer_property_scope(md_state) -print("Global properties:", scope["global"]) -print("Per-atom properties:", scope["per_atom"]) -print("Per-system properties:", scope["per_system"]) +print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom"))) +print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) +print("Global attributes:", dict(get_attrs_for_scope(si_state, "global"))) # %% [markdown] diff --git a/pyproject.toml b/pyproject.toml index 723edafb..f59e7ed3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ test = [ "pymatgen>=2024.11.3", "pytest-cov>=6", "pytest>=8", + "pytest-xdist>=3.8.0", ] io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"] mace = ["mace-torch>=0.3.12"] diff --git a/tests/test_state.py b/tests/test_state.py index 81109bf3..ea19bec8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,7 +13,7 @@ _pop_states, _slice_state, concatenate_states, - infer_property_scope, + get_attrs_for_scope, initialize_state, ) @@ -24,38 +24,59 @@ from pymatgen.core import Structure -def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - scope = infer_property_scope(si_sim_state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { +def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None: + """Test getting attributes for a scope.""" + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) + assert set(per_atom_attrs.keys()) == { "positions", "masses", "atomic_numbers", "system_idx", } - assert set(scope["per_system"]) == {"cell"} + per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) + assert set(per_system_attrs.keys()) == {"cell"} + global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) + assert set(global_attrs.keys()) == {"pbc"} -def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - state = MDState( - **asdict(si_sim_state), - momenta=torch.randn_like(si_sim_state.positions), - forces=torch.randn_like(si_sim_state.positions), - energy=torch.zeros((1,)), - ) - scope = infer_property_scope(state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { - "positions", - "masses", - "atomic_numbers", - "system_idx", - "forces", - "momenta", - } - assert set(scope["per_system"]) == {"cell", "energy"} +def test_all_attributes_must_be_specified_in_scopes() -> None: + """Test that an error is raised when we forget to specify the scope + for an attribute in a child SimState class.""" + with pytest.raises(TypeError) as excinfo: + + class ChildState(SimState): + attribute_specified_in_scopes: bool + attribute_not_specified_in_scopes: bool + + _atom_attributes = ( + *SimState._atom_attributes, # noqa: SLF001 + "attribute_specified_in_scopes", + ) + + assert "attribute_not_specified_in_scopes" in str(excinfo.value) + assert "attribute_specified_in_scopes" not in str(excinfo.value) + + +def test_no_duplicate_attributes_in_scopes() -> None: + """Test that no attributes are specified in multiple scopes.""" + + # Capture the exception information using "as excinfo" + with pytest.raises(TypeError) as excinfo: + + class ChildState(SimState): + duplicated_attribute: bool + + _system_attributes = ( + *SimState._atom_attributes, # noqa: SLF001 + "duplicated_attribute", + ) + _global_attributes = ( + *SimState._global_attributes, # noqa: SLF001 + "duplicated_attribute", + ) + + assert "are declared multiple times" in str(excinfo.value) + assert "duplicated_attribute" in str(excinfo.value) def test_slice_substate( diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 069f62eb..ae367456 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -41,6 +41,9 @@ class MDState(SimState): energy: torch.Tensor forces: torch.Tensor + _atom_attributes = (*SimState._atom_attributes, "momenta", "forces") # noqa: SLF001 + _system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001 + @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index e1ac1d39..c3c3d2af 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -15,6 +15,7 @@ construct_nose_hoover_chain, ) from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import md_atom_attributes from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -67,6 +68,17 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor + _atom_attributes = md_atom_attributes + _system_attributes = ( + *SimState._system_attributes, # noqa: SLF001 + "stress", + "cell_positions", + "cell_velocities", + "cell_masses", + "reference_cell", + "energy", + ) + @property def momenta(self) -> torch.Tensor: """Calculate momenta from velocities and masses.""" @@ -867,6 +879,21 @@ class NPTNoseHooverState(MDState): barostat: NoseHooverChain barostat_fns: NoseHooverChainFns + _system_attributes = ( + *MDState._system_attributes, # noqa: SLF001 + "reference_cell", + "cell_position", + "cell_momentum", + "cell_mass", + ) + _global_attributes = ( + *MDState._global_attributes, # noqa: SLF001 + "thermostat", + "barostat", + "thermostat_fns", + "barostat_fns", + ) + @property def velocities(self) -> torch.Tensor: """Calculate particle velocities from momenta and masses. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8309afa1..69922b76 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -266,6 +266,12 @@ class NVTNoseHooverState(MDState): chain: NoseHooverChain _chain_fns: NoseHooverChainFns + _global_attributes = ( + *MDState._global_attributes, # noqa: SLF001 + "chain", + "_chain_fns", + ) + @property def velocities(self) -> torch.Tensor: """Velocities calculated from momenta and masses with shape diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 64aad3c7..1184f10d 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -36,6 +36,9 @@ class SwapMCState(SimState): energy: torch.Tensor last_permutation: torch.Tensor + _atom_attributes = (*SimState._atom_attributes, "last_permutation") # noqa: SLF001 + _system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001 + def generate_swaps( state: SimState, generator: torch.Generator | None = None diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 9a41cd8c..9e9a16e8 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -33,6 +33,28 @@ MdFlavor = Literal["vv_fire", "ase_fire"] vv_fire_key, ase_fire_key = get_args(MdFlavor) +md_atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001 +_fire_system_attributes = ( + *SimState._system_attributes, # noqa: SLF001 + "energy", + "stress", + "cell_positions", + "cell_velocities", + "cell_forces", + "cell_masses", + "reference_cell", + "cell_factor", + "pressure", + "dt", + "alpha", + "n_pos", +) +_fire_global_attributes = ( + *SimState._global_attributes, # noqa: SLF001 + "hydrostatic_strain", + "constant_volume", +) + @dataclass class GDState(SimState): @@ -56,6 +78,9 @@ class GDState(SimState): forces: torch.Tensor energy: torch.Tensor + _atom_attributes = (*SimState._atom_attributes, "forces") # noqa: SLF001 + _system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001 + def gradient_descent( model: ModelInterface, *, lr: torch.Tensor | float = 0.01 @@ -195,6 +220,22 @@ class UnitCellGDState(GDState, DeformGradMixin): cell_forces: torch.Tensor cell_masses: torch.Tensor + _system_attributes = ( + *GDState._system_attributes, # noqa: SLF001 + "cell_forces", + "pressure", + "stress", + "cell_positions", + "cell_factor", + "cell_masses", + ) + _global_attributes = ( + *GDState._global_attributes, # noqa: SLF001 + "reference_cell", + "hydrostatic_strain", + "constant_volume", + ) + def unit_cell_gradient_descent( # noqa: PLR0915, C901 model: ModelInterface, @@ -482,6 +523,9 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = md_atom_attributes + _system_attributes = (*SimState._system_attributes, "energy", "dt", "alpha", "n_pos") # noqa: SLF001 + def fire( model: ModelInterface, @@ -693,6 +737,10 @@ class UnitCellFireState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes + def unit_cell_fire( model: ModelInterface, @@ -981,6 +1029,10 @@ class FrechetCellFIREState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor + _atom_attributes = md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes + def frechet_cell_fire( model: ModelInterface, diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8f2917da..89401459 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -541,6 +541,9 @@ class StaticState(type(state)): forces: torch.Tensor stress: torch.Tensor + _atom_attributes = (*state._atom_attributes, "forces") # noqa: SLF001 + _system_attributes = (*state._system_attributes, "energy", "stress") # noqa: SLF001 + all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames diff --git a/torch_sim/state.py b/torch_sim/state.py index d2ec8351..9f9b5f67 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,8 +8,10 @@ import importlib import typing import warnings -from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal, Self, cast +from collections import defaultdict +from collections.abc import Generator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Self, cast import torch @@ -83,6 +85,10 @@ class SimState: atomic_numbers: torch.Tensor system_idx: torch.Tensor + _atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx") + _system_attributes = ("cell",) + _global_attributes = ("pbc",) + def __init__( self, positions: torch.Tensor, @@ -405,6 +411,9 @@ def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + + Also enforce all of child classes's attributes are specified in _atom_attributes, + _system_attributes, or _global_attributes. """ # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) @@ -426,6 +435,50 @@ def __init_subclass__(cls, **kwargs) -> None: "the tensor with dummy values and track the 'None' case." ) + all_defined_attributes = ( + cls._atom_attributes + cls._system_attributes + cls._global_attributes + ) + # 1) assert that no attribute is defined twice in all_defined_attributes + duplicates = [ + x for x in all_defined_attributes if all_defined_attributes.count(x) > 1 + ] + if duplicates: + raise TypeError( + f"Attributes {duplicates} are declared multiple times in {cls.__name__} " + "in _atom_attributes, _system_attributes, or _global_attributes" + ) + + # 2) assert that all attributes are defined in all_defined_attributes + all_annotations = {} + for c in reversed(cls.mro()): + if hasattr(c, "__annotations__"): + all_annotations.update(c.__annotations__) + + attributes_to_check = set(vars(cls).keys()) | set(all_annotations.keys()) + + for attr_name in attributes_to_check: + if ( + attr_name.startswith("__") + or (hasattr(cls, attr_name) and callable(getattr(cls, attr_name))) + or ( + attr_name in vars(cls) + and isinstance(vars(cls).get(attr_name), property) + ) + ): + continue + + if attr_name in [ + "_atom_attributes", + "_system_attributes", + "_global_attributes", + ]: + continue + + if attr_name not in all_defined_attributes: + raise TypeError( + f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " + "of _atom_attributes, _system_attributes, or _global_attributes" + ) super().__init_subclass__(**kwargs) @@ -543,125 +596,32 @@ def state_to_device( return type(state)(**attrs) -def infer_property_scope( - state: SimState, - ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_system"], list[str]]: - """Infer whether a property is global, per-atom, or per-system. - - Analyzes the shapes of tensor attributes to determine their scope within - the atomistic system representation. - - Args: - state (SimState): The state to analyze - ambiguous_handling ("error" | "globalize" | "globalize_warn"): How to - handle properties with ambiguous scope. Options: - - "error": Raise an error for ambiguous properties - - "globalize": Treat ambiguous properties as global - - "globalize_warn": Treat ambiguous properties as global with a warning - - Returns: - dict[Literal["global", "per_atom", "per_system"], list[str]]: Map of scope - category to list of property names - - Raises: - ValueError: If n_atoms equals n_systems (making scope inference ambiguous) or - if ambiguous_handling="error" and an ambiguous property is encountered - """ - # TODO: this cannot effectively resolve global properties with - # length of n_atoms or n_systems, they will be classified incorrectly, - # no clear fix - - if state.n_atoms == state.n_systems: - raise ValueError( - f"n_atoms ({state.n_atoms}) and n_systems ({state.n_systems}) are equal, " - "which means shapes cannot be inferred unambiguously." - ) - - scope = { - "global": [], - "per_atom": [], - "per_system": [], - } - - # Iterate through all attributes - for attr_name, attr_value in vars(state).items(): - # Handle scalar values (global properties) - if not isinstance(attr_value, torch.Tensor): - scope["global"].append(attr_name) - continue - - # Handle tensor properties based on shape - shape = attr_value.shape - - # Empty tensor case - if len(shape) == 0: - scope["global"].append(attr_name) - # Vector/matrix with first dimension matching number of atoms - elif shape[0] == state.n_atoms: - scope["per_atom"].append(attr_name) - # Tensor with first dimension matching number of systems - elif shape[0] == state.n_systems: - scope["per_system"].append(attr_name) - # Any other shape is ambiguous - elif ambiguous_handling == "error": - raise ValueError( - f"Cannot categorize property '{attr_name}' with shape {shape}. " - f"Expected first dimension to be either {state.n_atoms} (per-atom) or " - f"{state.n_systems} (per-system), or a scalar (global)." - ) - elif ambiguous_handling in ("globalize", "globalize_warn"): - scope["global"].append(attr_name) - - if ambiguous_handling == "globalize_warn": - warnings.warn( - f"Property '{attr_name}' with shape {shape} is ambiguous, " - "treating as global. This may lead to unexpected behavior " - "and suggests the State is not being used as intended.", - stacklevel=1, - ) - - return scope - - -def _get_property_attrs( - state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error" -) -> dict[str, dict]: - """Get global, per-atom, and per-system attributes from a state. - - Categorizes all attributes of the state based on their scope - (global, per-atom, or per-system). +def get_attrs_for_scope( + state: SimState, scope: Literal["per-atom", "per-system", "global"] +) -> Generator[tuple[str, Any], None, None]: + """Get attributes for a given scope. Args: - state (SimState): The state to extract attributes from - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties + state (SimState): The state to get attributes for + scope (Literal["per-atom", "per-system", "global"]): The scope to get + attributes for Returns: - dict[str, dict]: Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ - scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - - attrs = {"global": {}, "per_atom": {}, "per_system": {}} - - # Process global properties - for attr_name in scope["global"]: - attrs["global"][attr_name] = getattr(state, attr_name) - - # Process per-atom properties - for attr_name in scope["per_atom"]: - attrs["per_atom"][attr_name] = getattr(state, attr_name) - - # Process per-system properties - for attr_name in scope["per_system"]: - attrs["per_system"][attr_name] = getattr(state, attr_name) - - return attrs + match scope: + case "per-atom": + attr_names = state._atom_attributes # noqa: SLF001 + case "per-system": + attr_names = state._system_attributes # noqa: SLF001 + case "global": + attr_names = state._global_attributes # noqa: SLF001 + for attr_name in attr_names: + yield attr_name, getattr(state, attr_name) def _filter_attrs_by_mask( - attrs: dict[str, dict], + state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, ) -> dict: @@ -670,8 +630,7 @@ def _filter_attrs_by_mask( Selects subsets of attributes based on boolean masks for atoms and systems. Args: - attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + state (SimState): The state to filter atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) system_mask (torch.Tensor): Boolean mask for systems to include with shape @@ -680,13 +639,11 @@ def _filter_attrs_by_mask( Returns: dict: Filtered attributes with appropriate handling for each scope """ - filtered_attrs = {} - # Copy global attributes directly - filtered_attrs.update(attrs["global"]) + filtered_attrs = dict(get_attrs_for_scope(state, "global")) # Filter per-atom attributes - for attr_name, attr_value in attrs["per_atom"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] @@ -710,7 +667,7 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs @@ -718,7 +675,6 @@ def _filter_attrs_by_mask( def _split_state( state: SimStateVar, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> list[SimStateVar]: """Split a SimState into a list of states, each containing a single system. @@ -727,33 +683,38 @@ def _split_state( Args: state (SimState): The SimState to split - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: list[SimState]: A list of SimState objects, each containing a single system """ - attrs = _get_property_attrs(state, ambiguous_handling) system_sizes = torch.bincount(state.system_idx).tolist() + n_systems = len(system_sizes) - # Split per-atom attributes by system - split_per_atom = {} - for attr_name, attr_value in attrs["per_atom"].items(): - if attr_name == "system_idx": - continue - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + def split_attr( + attr_value: torch.Tensor | None, split_sizes: list[int] + ) -> list[torch.Tensor | None]: + return ( + [None] * n_systems + if attr_value is None + else torch.split(attr_value, split_sizes) + ) - # Split per-system attributes into individual elements - split_per_system = {} - for attr_name, attr_value in attrs["per_system"].items(): - split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) + split_per_atom = { + name: split_attr(value, system_sizes) + for name, value in get_attrs_for_scope(state, "per-atom") + if name != "system_idx" + } + split_per_system = { + name: split_attr(value, 1) + for name, value in get_attrs_for_scope(state, "per-system") + } + + global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system states = [] - for i in range(state.n_systems): + for i in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( @@ -767,7 +728,7 @@ def _split_state( for attr_name in split_per_system }, # Add the global attributes - **attrs["global"], + **global_attrs, } states.append(type(state)(**system_attrs)) @@ -777,7 +738,6 @@ def _split_state( def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. @@ -786,10 +746,6 @@ def _pop_states( Args: state (SimState): The SimState to modify pop_indices (list[int] | torch.Tensor): The system indices to extract and remove - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: tuple[SimState, list[SimState]]: A tuple containing: @@ -805,8 +761,6 @@ def _pop_states( if isinstance(pop_indices, list): pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to keep and pop system_range = torch.arange(state.n_systems, device=state.device) pop_system_mask = torch.isin(system_range, pop_indices) @@ -816,15 +770,15 @@ def _pop_states( keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states - keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_system_mask) - pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) + keep_attrs = _filter_attrs_by_mask(state, keep_atom_mask, keep_system_mask) + pop_attrs = _filter_attrs_by_mask(state, pop_atom_mask, pop_system_mask) # Create the keep state keep_state = type(state)(**keep_attrs) # Create and split the pop state pop_state = type(state)(**pop_attrs) - pop_states = _split_state(pop_state, ambiguous_handling) + pop_states = _split_state(pop_state) return keep_state, pop_states @@ -832,7 +786,6 @@ def _pop_states( def _slice_state( state: SimStateVar, system_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> SimStateVar: """Slice a substate from the SimState containing only the specified system indices. @@ -843,10 +796,6 @@ def _slice_state( state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the sliced state - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: SimState: A new SimState object containing only the specified systems @@ -862,15 +811,13 @@ def _slice_state( if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to include system_range = torch.arange(state.n_systems, device=state.device) system_mask = torch.isin(system_range, system_indices) atom_mask = torch.isin(state.system_idx, system_indices) # Filter attributes - filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) + filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) # Create the sliced state return type(state)(**filtered_attrs) @@ -911,19 +858,12 @@ def concatenate_states( # Use the target device or default to the first state's device target_device = device or first_state.device - # Get property scopes from the first state to identify - # global/per-atom/per-system properties - first_scope = infer_property_scope(first_state) - global_props = set(first_scope["global"]) - per_atom_props = set(first_scope["per_atom"]) - per_system_props = set(first_scope["per_system"]) - # Initialize result with global properties from first state - concatenated = {prop: getattr(first_state, prop) for prop in global_props} + concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate - per_atom_tensors = {prop: [] for prop in per_atom_props} - per_system_tensors = {prop: [] for prop in per_system_props} + per_atom_tensors = defaultdict(list) + per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 @@ -934,14 +874,14 @@ def concatenate_states( state = state_to_device(state, target_device) # Collect per-atom properties - for prop in per_atom_props: + for prop, val in get_attrs_for_scope(state, "per-atom"): # if hasattr(state, prop): - per_atom_tensors[prop].append(getattr(state, prop)) + per_atom_tensors[prop].append(val) # Collect per-system properties - for prop in per_system_props: + for prop, val in get_attrs_for_scope(state, "per-system"): # if hasattr(state, prop): - per_system_tensors[prop].append(getattr(state, prop)) + per_system_tensors[prop].append(val) # Update system indices num_systems = state.n_systems From 4781bc47489b39e60f09066a097043598c1f5a73 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 3 Aug 2025 14:49:25 -0700 Subject: [PATCH 02/32] fix bad rebase --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 9f9b5f67..feb45d35 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -10,7 +10,7 @@ import warnings from collections import defaultdict from collections.abc import Generator -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Self, cast import torch From c4f8ee0fc8e17755f57d51cc1071b9493d36f7d5 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Tue, 5 Aug 2025 18:12:48 -0700 Subject: [PATCH 03/32] add last_permutation to system attributes --- examples/tutorials/hybrid_swap_tutorial.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index d8adbbf8..0513e512 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -37,6 +37,7 @@ import torch import torch_sim as ts from mace.calculators.foundations_models import mace_mp +from torch_sim.integrators.md import MDState from torch_sim.models.mace import MaceModel # Initialize the mace model @@ -104,6 +105,7 @@ class HybridSwapMCState(ts.integrators.MDState): """ last_permutation: torch.Tensor + _system_attributes = (*MDState._system_attributes, "last_permutation") # noqa: SLF001 # %% [markdown] From 5204370ac85337edcaa70f688e247350e85bb96a Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 19:31:14 -0700 Subject: [PATCH 04/32] cleanup init subclass --- torch_sim/state.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index feb45d35..6ab34803 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -210,7 +210,7 @@ def n_atoms_per_batch(self) -> torch.Tensor: return self.n_atoms_per_system @property - def batch(self) -> torch.Tensor | None: + def batch(self) -> torch.Tensor: """System indices. deprecated:: @@ -415,6 +415,12 @@ def __init_subclass__(cls, **kwargs) -> None: Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. """ + cls._assert_no_tensor_attributes_can_be_none() + cls._assert_all_attributes_have_defined_scope() + super().__init_subclass__(**kwargs) + + @classmethod + def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) for attr_name, attr_typehint in type_hints.items(): @@ -435,6 +441,8 @@ def __init_subclass__(cls, **kwargs) -> None: "the tensor with dummy values and track the 'None' case." ) + @classmethod + def _assert_all_attributes_have_defined_scope(cls) -> None: all_defined_attributes = ( cls._atom_attributes + cls._system_attributes + cls._global_attributes ) @@ -450,28 +458,25 @@ def __init_subclass__(cls, **kwargs) -> None: # 2) assert that all attributes are defined in all_defined_attributes all_annotations = {} - for c in reversed(cls.mro()): + for c in cls.mro(): if hasattr(c, "__annotations__"): all_annotations.update(c.__annotations__) attributes_to_check = set(vars(cls).keys()) | set(all_annotations.keys()) for attr_name in attributes_to_check: - if ( - attr_name.startswith("__") - or (hasattr(cls, attr_name) and callable(getattr(cls, attr_name))) - or ( - attr_name in vars(cls) - and isinstance(vars(cls).get(attr_name), property) - ) - ): - continue - - if attr_name in [ + is_special_attribute = attr_name.startswith("__") + is_property = attr_name in vars(cls) and isinstance( + vars(cls).get(attr_name), property + ) + is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) + is_scope_list = attr_name in [ "_atom_attributes", "_system_attributes", "_global_attributes", - ]: + ] + + if is_special_attribute or is_property or is_method or is_scope_list: continue if attr_name not in all_defined_attributes: @@ -479,7 +484,6 @@ def __init_subclass__(cls, **kwargs) -> None: f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " "of _atom_attributes, _system_attributes, or _global_attributes" ) - super().__init_subclass__(**kwargs) class DeformGradMixin: From b403fa3011ace5a743c305b550d567328b652502 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 19:38:50 -0700 Subject: [PATCH 05/32] define scope for deformgradmixin attributes --- tests/test_state.py | 5 +++++ torch_sim/optimizers.py | 2 ++ torch_sim/state.py | 2 ++ 3 files changed, 9 insertions(+) diff --git a/tests/test_state.py b/tests/test_state.py index ea19bec8..4bf33283 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -518,6 +518,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None: class DeformState(SimState, DeformGradMixin): """Test class that combines SimState with DeformGradMixin.""" + _system_attributes = ( + *SimState._system_attributes, # noqa: SLF001 + *DeformGradMixin._system_attributes, # noqa: SLF001 + ) + def __init__( self, *args, diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 9e9a16e8..795c351b 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -36,6 +36,7 @@ md_atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001 _fire_system_attributes = ( *SimState._system_attributes, # noqa: SLF001 + *DeformGradMixin._system_attributes, # noqa: SLF001 "energy", "stress", "cell_positions", @@ -222,6 +223,7 @@ class UnitCellGDState(GDState, DeformGradMixin): _system_attributes = ( *GDState._system_attributes, # noqa: SLF001 + *DeformGradMixin._system_attributes, # noqa: SLF001 "cell_forces", "pressure", "stress", diff --git a/torch_sim/state.py b/torch_sim/state.py index 6ab34803..e7a5ce3c 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -492,6 +492,8 @@ class DeformGradMixin: reference_cell: torch.Tensor row_vector_cell: torch.Tensor + _system_attributes = ("reference_cell", "row_vector_cell") + @property def reference_row_vector_cell(self) -> torch.Tensor: """Get the original unit cell in terms of row vectors.""" From a9027004a50274de1fec65bdccc37b8d575819f1 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 19:48:26 -0700 Subject: [PATCH 06/32] remove duplicate definitions of reference_cell. also fix scope definition --- torch_sim/optimizers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 795c351b..cfd8746d 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -209,7 +209,6 @@ class UnitCellGDState(GDState, DeformGradMixin): """ # Required attributes not in BatchedGDState - reference_cell: torch.Tensor cell_factor: torch.Tensor hydrostatic_strain: bool constant_volume: bool @@ -233,7 +232,6 @@ class UnitCellGDState(GDState, DeformGradMixin): ) _global_attributes = ( *GDState._global_attributes, # noqa: SLF001 - "reference_cell", "hydrostatic_strain", "constant_volume", ) @@ -728,7 +726,6 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_masses: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -1014,7 +1011,6 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool From 9e71a89ba4f5278791b913b4c7f68762a38116fb Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 19:55:44 -0700 Subject: [PATCH 07/32] fix _fire_system_attributes declaration --- torch_sim/optimizers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index cfd8746d..c5dc33d3 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -43,7 +43,6 @@ "cell_velocities", "cell_forces", "cell_masses", - "reference_cell", "cell_factor", "pressure", "dt", From 8110f43db04149ff4a22be02b25d27f65a741ef6 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 20:10:29 -0700 Subject: [PATCH 08/32] add back reference cell attr to be defined --- torch_sim/optimizers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index c5dc33d3..59b0fb5c 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -208,6 +208,7 @@ class UnitCellGDState(GDState, DeformGradMixin): """ # Required attributes not in BatchedGDState + reference_cell: torch.Tensor cell_factor: torch.Tensor hydrostatic_strain: bool constant_volume: bool @@ -725,6 +726,7 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_masses: torch.Tensor # Optimization-specific attributes + reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -1010,6 +1012,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes + reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool From f2be537e6e496ff9285c5eb071aec8bae2adf007 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 8 Aug 2025 20:22:40 -0700 Subject: [PATCH 09/32] manually add row_vector_cell to each reference_cell --- torch_sim/optimizers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 59b0fb5c..2d72dc3a 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -209,6 +209,7 @@ class UnitCellGDState(GDState, DeformGradMixin): # Required attributes not in BatchedGDState reference_cell: torch.Tensor + row_vector_cell: torch.Tensor cell_factor: torch.Tensor hydrostatic_strain: bool constant_volume: bool @@ -724,9 +725,10 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor + reference_cell: torch.Tensor + row_vector_cell: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -1013,6 +1015,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Optimization-specific attributes reference_cell: torch.Tensor + row_vector_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool From 2012d74000e4da71a1c578669672d9779579318d Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 07:58:52 -0700 Subject: [PATCH 10/32] use kwonly to hopefully fix default arg issues --- torch_sim/optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 2d72dc3a..4e83e4d7 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -480,7 +480,7 @@ def gd_step( return gd_init, gd_step -@dataclass +@dataclass(kw_only=True) class FireState(SimState): """State information for batched FIRE optimization. @@ -957,7 +957,7 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -@dataclass +@dataclass(kw_only=True) class FrechetCellFIREState(SimState, DeformGradMixin): """State class for batched FIRE optimization with Frechet cell derivatives. From ba31ea13055dc89902de7e2b6c81ba6d22c15ffa Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 11:17:55 -0700 Subject: [PATCH 11/32] more kwargs --- torch_sim/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 4e83e4d7..5b696a56 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -174,7 +174,7 @@ def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: return gd_init, gd_step -@dataclass +@dataclass(kw_only=True) class UnitCellGDState(GDState, DeformGradMixin): """State class for batched gradient descent optimization with unit cell. From 6768f4e166c00cf9894cbe503c106b4539601b89 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 11:40:17 -0700 Subject: [PATCH 12/32] make more states kwonly --- torch_sim/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 5b696a56..7499c39b 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -664,7 +664,7 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) -@dataclass +@dataclass(kw_only=True) class UnitCellFireState(SimState, DeformGradMixin): """State information for batched FIRE optimization with unit cell degrees of freedom. From 4b832941b76ab4aa89c67af18a1471dc76279015 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 14:53:08 -0700 Subject: [PATCH 13/32] try different params --- torch_sim/optimizers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 7499c39b..ab422e5c 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1014,8 +1014,8 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor - row_vector_cell: torch.Tensor + # reference_cell: torch.Tensor + # row_vector_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -1234,6 +1234,9 @@ def fire_init( cell_velocities=torch.full( cell_forces.shape, torch.nan, device=device, dtype=dtype ), + row_vector_cell=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes From dc860c3b42b5a58d1e4b2f370066b93193ffe9f9 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 16:00:32 -0700 Subject: [PATCH 14/32] make the parent class a dataclass so attributes just propagate down --- torch_sim/state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_sim/state.py b/torch_sim/state.py index e7a5ce3c..7ded6038 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -486,6 +486,7 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: ) +@dataclass class DeformGradMixin: """Mixin for states that support deformation gradients.""" From 099e38c509554cea7ddc1b608ad504ef022598ff Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 9 Aug 2025 21:13:58 -0400 Subject: [PATCH 15/32] fix: row_vector_cell is just an alias to cell.mT not an attribute --- torch_sim/optimizers.py | 8 +------- torch_sim/state.py | 3 +-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index ab422e5c..27ce069a 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -209,7 +209,6 @@ class UnitCellGDState(GDState, DeformGradMixin): # Required attributes not in BatchedGDState reference_cell: torch.Tensor - row_vector_cell: torch.Tensor cell_factor: torch.Tensor hydrostatic_strain: bool constant_volume: bool @@ -726,7 +725,6 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_forces: torch.Tensor cell_masses: torch.Tensor reference_cell: torch.Tensor - row_vector_cell: torch.Tensor # Optimization-specific attributes cell_factor: torch.Tensor @@ -1014,8 +1012,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes - # reference_cell: torch.Tensor - # row_vector_cell: torch.Tensor + reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -1234,9 +1231,6 @@ def fire_init( cell_velocities=torch.full( cell_forces.shape, torch.nan, device=device, dtype=dtype ), - row_vector_cell=torch.full( - cell_forces.shape, torch.nan, device=device, dtype=dtype - ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes diff --git a/torch_sim/state.py b/torch_sim/state.py index 7ded6038..22fcad23 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -491,9 +491,8 @@ class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor - row_vector_cell: torch.Tensor - _system_attributes = ("reference_cell", "row_vector_cell") + _system_attributes = ("reference_cell",) @property def reference_row_vector_cell(self) -> torch.Tensor: From e9c9aaed86764b24a759c905d01a2ddfb9b6223a Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 17:34:15 -0700 Subject: [PATCH 16/32] make deform kwonly true --- torch_sim/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 22fcad23..305b3001 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -486,7 +486,7 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: ) -@dataclass +@dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" From c23ae3a6a67c1f417d0bf438973448f9b07bf3f5 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 18:22:01 -0700 Subject: [PATCH 17/32] splitting logic now just uses torch.split --- torch_sim/state.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 305b3001..9e791ed5 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -697,22 +697,13 @@ def _split_state( system_sizes = torch.bincount(state.system_idx).tolist() n_systems = len(system_sizes) - def split_attr( - attr_value: torch.Tensor | None, split_sizes: list[int] - ) -> list[torch.Tensor | None]: - return ( - [None] * n_systems - if attr_value is None - else torch.split(attr_value, split_sizes) - ) - split_per_atom = { - name: split_attr(value, system_sizes) + name: torch.split(value, system_sizes) for name, value in get_attrs_for_scope(state, "per-atom") if name != "system_idx" } split_per_system = { - name: split_attr(value, 1) + name: torch.split(value, 1) for name, value in get_attrs_for_scope(state, "per-system") } From aa29f81513f47a1a8bdf281ab7ce7318f320ab30 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 18:26:28 -0700 Subject: [PATCH 18/32] revert to old method of split per atom and system (more readable) --- torch_sim/state.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 9e791ed5..dc77cae4 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -697,15 +697,14 @@ def _split_state( system_sizes = torch.bincount(state.system_idx).tolist() n_systems = len(system_sizes) - split_per_atom = { - name: torch.split(value, system_sizes) - for name, value in get_attrs_for_scope(state, "per-atom") - if name != "system_idx" - } - split_per_system = { - name: torch.split(value, 1) - for name, value in get_attrs_for_scope(state, "per-system") - } + split_per_atom = {} + for name, value in get_attrs_for_scope(state, "per-atom"): + if name != "system_idx": + split_per_atom[name] = torch.split(value, system_sizes, dim=0) + + split_per_system = {} + for name, value in get_attrs_for_scope(state, "per-system"): + split_per_system[name] = torch.split(value, 1, dim=0) global_attrs = dict(get_attrs_for_scope(state, "global")) From af09df4f49e83c2787b18a9488fa3ccb35909221 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 18:38:28 -0700 Subject: [PATCH 19/32] see if it works if I remove reference_cell: torch.Tensor --- torch_sim/optimizers.py | 1 - torch_sim/state.py | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 27ce069a..8281b533 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -724,7 +724,6 @@ class UnitCellFireState(SimState, DeformGradMixin): cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor - reference_cell: torch.Tensor # Optimization-specific attributes cell_factor: torch.Tensor diff --git a/torch_sim/state.py b/torch_sim/state.py index dc77cae4..787a76f1 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -695,21 +695,21 @@ def _split_state( system """ system_sizes = torch.bincount(state.system_idx).tolist() - n_systems = len(system_sizes) split_per_atom = {} - for name, value in get_attrs_for_scope(state, "per-atom"): - if name != "system_idx": - split_per_atom[name] = torch.split(value, system_sizes, dim=0) + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): + if attr_name != "system_idx": + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} - for name, value in get_attrs_for_scope(state, "per-system"): - split_per_system[name] = torch.split(value, 1, dim=0) + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system states = [] + n_systems = len(system_sizes) for i in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system From 0e130d83e48142851ec7920f9bb31a8e9c1a7f4e Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 9 Aug 2025 19:13:57 -0700 Subject: [PATCH 20/32] coderabbit comments --- examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- torch_sim/runners.py | 8 ++++++-- torch_sim/state.py | 2 ++ 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index ab64a2da..ce4bd73b 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -76,7 +76,7 @@ class HybridSwapMCState(MDState): """ last_permutation: torch.Tensor - _system_attributes = (*MDState._system_attributes, "last_permutation") # noqa: SLF001 + _atom_attributes = (*MDState._atom_attributes, "last_permutation") # noqa: SLF001 nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 0513e512..bdfa547c 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -105,7 +105,7 @@ class HybridSwapMCState(ts.integrators.MDState): """ last_permutation: torch.Tensor - _system_attributes = (*MDState._system_attributes, "last_permutation") # noqa: SLF001 + _atom_attributes = (*MDState._atom_attributes, "last_permutation") # noqa: SLF001 # %% [markdown] diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 89401459..70159dd3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -567,8 +567,12 @@ class StaticState(type(state)): sub_state = StaticState( **vars(sub_state), energy=model_outputs["energy"], - forces=model_outputs["forces"] if model.compute_forces else None, - stress=model_outputs["stress"] if model.compute_stress else None, + forces=model_outputs["forces"] + if model.compute_forces + else torch.full_like(sub_state.positions, fill_value=float("nan")), + stress=model_outputs["stress"] + if model.compute_stress + else torch.full_like(sub_state.cell, fill_value=float("nan")), ) props = trajectory_reporter.report(sub_state, 0, model=model) diff --git a/torch_sim/state.py b/torch_sim/state.py index 787a76f1..eb15352a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -622,6 +622,8 @@ def get_attrs_for_scope( attr_names = state._system_attributes # noqa: SLF001 case "global": attr_names = state._global_attributes # noqa: SLF001 + case _: + raise ValueError(f"Unknown scope: {scope!r}") for attr_name in attr_names: yield attr_name, getattr(state, attr_name) From 3d89fb3655e8a8a7caaa23a4e26fa85b68df7896 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 13:55:01 -0700 Subject: [PATCH 21/32] rm xdist and remove dependency between integrators and optimizers --- pyproject.toml | 1 - torch_sim/integrators/npt.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f59e7ed3..723edafb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ test = [ "pymatgen>=2024.11.3", "pytest-cov>=6", "pytest>=8", - "pytest-xdist>=3.8.0", ] io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"] mace = ["mace-torch>=0.3.12"] diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index c3c3d2af..eea744c9 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -15,7 +15,6 @@ construct_nose_hoover_chain, ) from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers import md_atom_attributes from torch_sim.quantities import calc_kinetic_energy from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -68,7 +67,7 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes = md_atom_attributes + _atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001 _system_attributes = ( *SimState._system_attributes, # noqa: SLF001 "stress", From 9cbcd878debe21b9224eb277b9766bcd9469b29a Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 14:45:59 -0700 Subject: [PATCH 22/32] swap from tuple to set --- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 5 +- examples/tutorials/hybrid_swap_tutorial.py | 5 +- tests/test_state.py | 21 ++--- torch_sim/integrators/md.py | 9 +- torch_sim/integrators/npt.py | 39 ++++---- torch_sim/integrators/nvt.py | 8 +- torch_sim/monte_carlo.py | 9 +- torch_sim/optimizers.py | 91 ++++++++++--------- torch_sim/runners.py | 10 +- torch_sim/state.py | 25 +++-- 10 files changed, 129 insertions(+), 93 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index ce4bd73b..75639caa 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -8,6 +8,7 @@ # /// from dataclasses import dataclass +from typing import ClassVar import torch from mace.calculators.foundations_models import mace_mp @@ -76,7 +77,9 @@ class HybridSwapMCState(MDState): """ last_permutation: torch.Tensor - _atom_attributes = (*MDState._atom_attributes, "last_permutation") # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = ( + MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index bdfa547c..67794774 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -34,6 +34,7 @@ """ # %% +from typing import ClassVar import torch import torch_sim as ts from mace.calculators.foundations_models import mace_mp @@ -105,7 +106,9 @@ class HybridSwapMCState(ts.integrators.MDState): """ last_permutation: torch.Tensor - _atom_attributes = (*MDState._atom_attributes, "last_permutation") # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = ( + MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) # %% [markdown] diff --git a/tests/test_state.py b/tests/test_state.py index 4bf33283..5975cfc2 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -48,9 +48,8 @@ class ChildState(SimState): attribute_specified_in_scopes: bool attribute_not_specified_in_scopes: bool - _atom_attributes = ( - *SimState._atom_attributes, # noqa: SLF001 - "attribute_specified_in_scopes", + _atom_attributes: typing.ClassVar[set[str]] = ( + SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 ) assert "attribute_not_specified_in_scopes" in str(excinfo.value) @@ -66,13 +65,11 @@ def test_no_duplicate_attributes_in_scopes() -> None: class ChildState(SimState): duplicated_attribute: bool - _system_attributes = ( - *SimState._atom_attributes, # noqa: SLF001 - "duplicated_attribute", + _system_attributes: typing.ClassVar[set[str]] = ( + SimState._atom_attributes | {"duplicated_attribute"} # noqa: SLF001 ) - _global_attributes = ( - *SimState._global_attributes, # noqa: SLF001 - "duplicated_attribute", + _global_attributes: typing.ClassVar[set[str]] = ( + SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 ) assert "are declared multiple times" in str(excinfo.value) @@ -518,9 +515,9 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None: class DeformState(SimState, DeformGradMixin): """Test class that combines SimState with DeformGradMixin.""" - _system_attributes = ( - *SimState._system_attributes, # noqa: SLF001 - *DeformGradMixin._system_attributes, # noqa: SLF001 + _system_attributes: typing.ClassVar[set[str]] = ( + SimState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 ) def __init__( diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index ae367456..ff04f899 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -2,6 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass +from typing import ClassVar import torch @@ -41,8 +42,12 @@ class MDState(SimState): energy: torch.Tensor forces: torch.Tensor - _atom_attributes = (*SimState._atom_attributes, "momenta", "forces") # noqa: SLF001 - _system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = ( + SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001 + ) + _system_attributes: ClassVar[set[str]] = ( + SimState._system_attributes | {"energy"} # noqa: SLF001 + ) @property def velocities(self) -> torch.Tensor: diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index eea744c9..da6093cd 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -2,7 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, ClassVar import torch @@ -67,16 +67,17 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001 - _system_attributes = ( - *SimState._system_attributes, # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = ( + SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 + ) + _system_attributes: ClassVar[set[str]] = SimState._system_attributes | { # noqa: SLF001 "stress", "cell_positions", "cell_velocities", "cell_masses", "reference_cell", "energy", - ) + } @property def momenta(self) -> torch.Tensor: @@ -878,19 +879,23 @@ class NPTNoseHooverState(MDState): barostat: NoseHooverChain barostat_fns: NoseHooverChainFns - _system_attributes = ( - *MDState._system_attributes, # noqa: SLF001 - "reference_cell", - "cell_position", - "cell_momentum", - "cell_mass", + _system_attributes: ClassVar[set[str]] = ( + MDState._system_attributes # noqa: SLF001 + | { + "reference_cell", + "cell_position", + "cell_momentum", + "cell_mass", + } ) - _global_attributes = ( - *MDState._global_attributes, # noqa: SLF001 - "thermostat", - "barostat", - "thermostat_fns", - "barostat_fns", + _global_attributes: ClassVar[set[str]] = ( + MDState._global_attributes # noqa: SLF001 + | { + "thermostat", + "barostat", + "thermostat_fns", + "barostat_fns", + } ) @property diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 69922b76..3540b715 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -2,7 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, ClassVar import torch @@ -266,10 +266,8 @@ class NVTNoseHooverState(MDState): chain: NoseHooverChain _chain_fns: NoseHooverChainFns - _global_attributes = ( - *MDState._global_attributes, # noqa: SLF001 - "chain", - "_chain_fns", + _global_attributes: ClassVar[set[str]] = ( + MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001 ) @property diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 1184f10d..656a52d3 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -12,6 +12,7 @@ from collections.abc import Callable from dataclasses import dataclass +from typing import ClassVar import torch @@ -36,8 +37,12 @@ class SwapMCState(SimState): energy: torch.Tensor last_permutation: torch.Tensor - _atom_attributes = (*SimState._atom_attributes, "last_permutation") # noqa: SLF001 - _system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = ( + SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ) + _system_attributes: ClassVar[set[str]] = ( + SimState._system_attributes | {"energy"} # noqa: SLF001 + ) def generate_swaps( diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 8281b533..9632fe7b 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -20,7 +20,7 @@ import functools from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Literal, get_args +from typing import Any, ClassVar, Literal, get_args import torch @@ -33,27 +33,28 @@ MdFlavor = Literal["vv_fire", "ase_fire"] vv_fire_key, ase_fire_key = get_args(MdFlavor) -md_atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001 +md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 _fire_system_attributes = ( - *SimState._system_attributes, # noqa: SLF001 - *DeformGradMixin._system_attributes, # noqa: SLF001 - "energy", - "stress", - "cell_positions", - "cell_velocities", - "cell_forces", - "cell_masses", - "cell_factor", - "pressure", - "dt", - "alpha", - "n_pos", + SimState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + | { + "energy", + "stress", + "cell_positions", + "cell_velocities", + "cell_forces", + "cell_masses", + "cell_factor", + "pressure", + "dt", + "alpha", + "n_pos", + } ) -_fire_global_attributes = ( - *SimState._global_attributes, # noqa: SLF001 +_fire_global_attributes = SimState._global_attributes | { # noqa: SLF001 "hydrostatic_strain", "constant_volume", -) +} @dataclass @@ -78,8 +79,8 @@ class GDState(SimState): forces: torch.Tensor energy: torch.Tensor - _atom_attributes = (*SimState._atom_attributes, "forces") # noqa: SLF001 - _system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001 + _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 def gradient_descent( @@ -220,20 +221,20 @@ class UnitCellGDState(GDState, DeformGradMixin): cell_forces: torch.Tensor cell_masses: torch.Tensor - _system_attributes = ( - *GDState._system_attributes, # noqa: SLF001 - *DeformGradMixin._system_attributes, # noqa: SLF001 - "cell_forces", - "pressure", - "stress", - "cell_positions", - "cell_factor", - "cell_masses", + _system_attributes: ClassVar[set[str]] = ( + GDState._system_attributes # noqa: SLF001 + | DeformGradMixin._system_attributes # noqa: SLF001 + | { + "cell_forces", + "pressure", + "stress", + "cell_positions", + "cell_factor", + "cell_masses", + } ) - _global_attributes = ( - *GDState._global_attributes, # noqa: SLF001 - "hydrostatic_strain", - "constant_volume", + _global_attributes: ClassVar[set[str]] = ( + GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001 ) @@ -523,8 +524,16 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes = md_atom_attributes - _system_attributes = (*SimState._system_attributes, "energy", "dt", "alpha", "n_pos") # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = md_atom_attributes + _system_attributes: ClassVar[set[str]] = ( + SimState._system_attributes # noqa: SLF001 + | { + "energy", + "dt", + "alpha", + "n_pos", + } + ) def fire( @@ -736,9 +745,9 @@ class UnitCellFireState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes = md_atom_attributes - _system_attributes = _fire_system_attributes - _global_attributes = _fire_global_attributes + _atom_attributes: ClassVar[set[str]] = md_atom_attributes + _system_attributes: ClassVar[set[str]] = _fire_system_attributes + _global_attributes: ClassVar[set[str]] = _fire_global_attributes def unit_cell_fire( @@ -1028,9 +1037,9 @@ class FrechetCellFIREState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes = md_atom_attributes - _system_attributes = _fire_system_attributes - _global_attributes = _fire_global_attributes + _atom_attributes: ClassVar[set[str]] = md_atom_attributes + _system_attributes: ClassVar[set[str]] = _fire_system_attributes + _global_attributes: ClassVar[set[str]] = _fire_global_attributes def frechet_cell_fire( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 70159dd3..8e6c107a 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass from itertools import chain -from typing import Any +from typing import Any, ClassVar import torch from tqdm import tqdm @@ -541,8 +541,12 @@ class StaticState(type(state)): forces: torch.Tensor stress: torch.Tensor - _atom_attributes = (*state._atom_attributes, "forces") # noqa: SLF001 - _system_attributes = (*state._system_attributes, "energy", "stress") # noqa: SLF001 + _atom_attributes: ClassVar[set[str]] = ( + state._atom_attributes | {"forces"} # noqa: SLF001 + ) + _system_attributes: ClassVar[set[str]] = ( + state._system_attributes | {"energy", "stress"} # noqa: SLF001 + ) all_props: list[dict[str, torch.Tensor]] = [] og_filenames = trajectory_reporter.filenames diff --git a/torch_sim/state.py b/torch_sim/state.py index eb15352a..91248a0b 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -11,7 +11,7 @@ from collections import defaultdict from collections.abc import Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast import torch @@ -85,9 +85,14 @@ class SimState: atomic_numbers: torch.Tensor system_idx: torch.Tensor - _atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx") - _system_attributes = ("cell",) - _global_attributes = ("pbc",) + _atom_attributes: ClassVar[set[str]] = { + "positions", + "masses", + "atomic_numbers", + "system_idx", + } + _system_attributes: ClassVar[set[str]] = {"cell"} + _global_attributes: ClassVar[set[str]] = {"pbc"} def __init__( self, @@ -444,12 +449,14 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: @classmethod def _assert_all_attributes_have_defined_scope(cls) -> None: all_defined_attributes = ( - cls._atom_attributes + cls._system_attributes + cls._global_attributes + cls._atom_attributes | cls._system_attributes | cls._global_attributes ) # 1) assert that no attribute is defined twice in all_defined_attributes - duplicates = [ - x for x in all_defined_attributes if all_defined_attributes.count(x) > 1 - ] + duplicates = ( + (cls._atom_attributes & cls._system_attributes) + | (cls._atom_attributes & cls._global_attributes) + | (cls._system_attributes & cls._global_attributes) + ) if duplicates: raise TypeError( f"Attributes {duplicates} are declared multiple times in {cls.__name__} " @@ -492,7 +499,7 @@ class DeformGradMixin: reference_cell: torch.Tensor - _system_attributes = ("reference_cell",) + _system_attributes: ClassVar[set[str]] = {"reference_cell"} @property def reference_row_vector_cell(self) -> torch.Tensor: From 5e181edc87e6d63febfc94a4f576c0d22ebb9544 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 18:22:45 -0700 Subject: [PATCH 23/32] rm classvar annotation from all simstate --- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 3 +-- examples/tutorials/hybrid_swap_tutorial.py | 2 +- tests/test_state.py | 2 +- torch_sim/integrators/md.py | 5 ++--- torch_sim/integrators/npt.py | 10 ++++----- torch_sim/integrators/nvt.py | 4 ++-- torch_sim/monte_carlo.py | 9 ++------ torch_sim/optimizers.py | 22 +++++++++---------- torch_sim/runners.py | 6 ++--- torch_sim/state.py | 12 +++++----- 10 files changed, 34 insertions(+), 41 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 75639caa..fb3f7983 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -8,7 +8,6 @@ # /// from dataclasses import dataclass -from typing import ClassVar import torch from mace.calculators.foundations_models import mace_mp @@ -77,7 +76,7 @@ class HybridSwapMCState(MDState): """ last_permutation: torch.Tensor - _atom_attributes: ClassVar[set[str]] = ( + _atom_attributes = ( MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 ) diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 67794774..08c3914b 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -106,7 +106,7 @@ class HybridSwapMCState(ts.integrators.MDState): """ last_permutation: torch.Tensor - _atom_attributes: ClassVar[set[str]] = ( + _atom_attributes = ( MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 ) diff --git a/tests/test_state.py b/tests/test_state.py index 5975cfc2..4914278e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -66,7 +66,7 @@ class ChildState(SimState): duplicated_attribute: bool _system_attributes: typing.ClassVar[set[str]] = ( - SimState._atom_attributes | {"duplicated_attribute"} # noqa: SLF001 + SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 ) _global_attributes: typing.ClassVar[set[str]] = ( SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index ff04f899..490e3528 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -2,7 +2,6 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import ClassVar import torch @@ -42,10 +41,10 @@ class MDState(SimState): energy: torch.Tensor forces: torch.Tensor - _atom_attributes: ClassVar[set[str]] = ( + _atom_attributes = ( SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001 ) - _system_attributes: ClassVar[set[str]] = ( + _system_attributes = ( SimState._system_attributes | {"energy"} # noqa: SLF001 ) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index da6093cd..f5f23c5c 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -2,7 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any import torch @@ -67,10 +67,10 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes: ClassVar[set[str]] = ( + _atom_attributes = ( SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 ) - _system_attributes: ClassVar[set[str]] = SimState._system_attributes | { # noqa: SLF001 + _system_attributes = SimState._system_attributes | { # noqa: SLF001 "stress", "cell_positions", "cell_velocities", @@ -879,7 +879,7 @@ class NPTNoseHooverState(MDState): barostat: NoseHooverChain barostat_fns: NoseHooverChainFns - _system_attributes: ClassVar[set[str]] = ( + _system_attributes = ( MDState._system_attributes # noqa: SLF001 | { "reference_cell", @@ -888,7 +888,7 @@ class NPTNoseHooverState(MDState): "cell_mass", } ) - _global_attributes: ClassVar[set[str]] = ( + _global_attributes = ( MDState._global_attributes # noqa: SLF001 | { "thermostat", diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 3540b715..18f0ae15 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -2,7 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any import torch @@ -266,7 +266,7 @@ class NVTNoseHooverState(MDState): chain: NoseHooverChain _chain_fns: NoseHooverChainFns - _global_attributes: ClassVar[set[str]] = ( + _global_attributes = ( MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001 ) diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 656a52d3..be2d99a8 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -12,7 +12,6 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import ClassVar import torch @@ -37,12 +36,8 @@ class SwapMCState(SimState): energy: torch.Tensor last_permutation: torch.Tensor - _atom_attributes: ClassVar[set[str]] = ( - SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 - ) - _system_attributes: ClassVar[set[str]] = ( - SimState._system_attributes | {"energy"} # noqa: SLF001 - ) + _atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001 def generate_swaps( diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 9632fe7b..30a6c814 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -20,7 +20,7 @@ import functools from collections.abc import Callable from dataclasses import dataclass -from typing import Any, ClassVar, Literal, get_args +from typing import Any, Literal, get_args import torch @@ -221,7 +221,7 @@ class UnitCellGDState(GDState, DeformGradMixin): cell_forces: torch.Tensor cell_masses: torch.Tensor - _system_attributes: ClassVar[set[str]] = ( + _system_attributes = ( GDState._system_attributes # noqa: SLF001 | DeformGradMixin._system_attributes # noqa: SLF001 | { @@ -233,7 +233,7 @@ class UnitCellGDState(GDState, DeformGradMixin): "cell_masses", } ) - _global_attributes: ClassVar[set[str]] = ( + _global_attributes = ( GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001 ) @@ -524,8 +524,8 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes: ClassVar[set[str]] = md_atom_attributes - _system_attributes: ClassVar[set[str]] = ( + _atom_attributes = md_atom_attributes + _system_attributes = ( SimState._system_attributes # noqa: SLF001 | { "energy", @@ -745,9 +745,9 @@ class UnitCellFireState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes: ClassVar[set[str]] = md_atom_attributes - _system_attributes: ClassVar[set[str]] = _fire_system_attributes - _global_attributes: ClassVar[set[str]] = _fire_global_attributes + _atom_attributes = md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes def unit_cell_fire( @@ -1037,9 +1037,9 @@ class FrechetCellFIREState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes: ClassVar[set[str]] = md_atom_attributes - _system_attributes: ClassVar[set[str]] = _fire_system_attributes - _global_attributes: ClassVar[set[str]] = _fire_global_attributes + _atom_attributes = md_atom_attributes + _system_attributes = _fire_system_attributes + _global_attributes = _fire_global_attributes def frechet_cell_fire( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8e6c107a..14164375 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass from itertools import chain -from typing import Any, ClassVar +from typing import Any import torch from tqdm import tqdm @@ -541,10 +541,10 @@ class StaticState(type(state)): forces: torch.Tensor stress: torch.Tensor - _atom_attributes: ClassVar[set[str]] = ( + _atom_attributes = ( state._atom_attributes | {"forces"} # noqa: SLF001 ) - _system_attributes: ClassVar[set[str]] = ( + _system_attributes = ( state._system_attributes | {"energy", "stress"} # noqa: SLF001 ) diff --git a/torch_sim/state.py b/torch_sim/state.py index 91248a0b..912d1606 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -477,13 +477,13 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: vars(cls).get(attr_name), property ) is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) - is_scope_list = attr_name in [ - "_atom_attributes", - "_system_attributes", - "_global_attributes", - ] + is_class_variable = ( + # Note: _atom_attributes, _system_attributes, and _global_attributes + # are all class variables + typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar + ) - if is_special_attribute or is_property or is_method or is_scope_list: + if is_special_attribute or is_property or is_method or is_class_variable: continue if attr_name not in all_defined_attributes: From 50fa99dc16f3982761be1cb53598f41000d8a483 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 20:15:00 -0700 Subject: [PATCH 24/32] make md_atom_attributes private --- torch_sim/optimizers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 30a6c814..739496a0 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -33,7 +33,7 @@ MdFlavor = Literal["vv_fire", "ase_fire"] vv_fire_key, ase_fire_key = get_args(MdFlavor) -md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 +_md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 _fire_system_attributes = ( SimState._system_attributes # noqa: SLF001 | DeformGradMixin._system_attributes # noqa: SLF001 @@ -524,7 +524,7 @@ class FireState(SimState): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes = md_atom_attributes + _atom_attributes = _md_atom_attributes _system_attributes = ( SimState._system_attributes # noqa: SLF001 | { @@ -745,7 +745,7 @@ class UnitCellFireState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes = md_atom_attributes + _atom_attributes = _md_atom_attributes _system_attributes = _fire_system_attributes _global_attributes = _fire_global_attributes @@ -1037,7 +1037,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): alpha: torch.Tensor n_pos: torch.Tensor - _atom_attributes = md_atom_attributes + _atom_attributes = _md_atom_attributes _system_attributes = _fire_system_attributes _global_attributes = _fire_global_attributes From aee26df61f42d8060b386e99eb82e20ee3f8c4a2 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 20:43:38 -0700 Subject: [PATCH 25/32] fix isinstance type check --- torch_sim/optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 739496a0..2464098b 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1321,7 +1321,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions[nan_velocities] ) if is_cell_optimization: - if not isinstance(state, AnyFireCellState): + if not isinstance(state, get_args(AnyFireCellState)): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) @@ -1543,7 +1543,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) forces = state.forces if is_cell_optimization: - if not isinstance(state, AnyFireCellState): + if not isinstance(state, get_args(AnyFireCellState)): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) From 971b7bdd2baf937103228ed274ec777c7e60dcde Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 20:45:42 -0700 Subject: [PATCH 26/32] fix docs --- examples/tutorials/state_tutorial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 3c705262..5cd43b17 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -82,9 +82,9 @@ For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`: -_atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx") -_system_attributes = ("cell",) -_global_attributes = ("pbc",) +_atom_attributes = {"positions", "masses", "atomic_numbers", "system_idx"} +_system_attributes = {"cell"} +_global_attributes = {"pbc"} You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This is mostly used internally but can be useful for debugging. From 318e9cfe9e8c7d53f115a9779ed96ad69280f88f Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 20:47:32 -0700 Subject: [PATCH 27/32] cleanup test type annotation --- tests/test_state.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 4914278e..67a757fe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -48,7 +48,7 @@ class ChildState(SimState): attribute_specified_in_scopes: bool attribute_not_specified_in_scopes: bool - _atom_attributes: typing.ClassVar[set[str]] = ( + _atom_attributes = ( SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 ) @@ -65,12 +65,8 @@ def test_no_duplicate_attributes_in_scopes() -> None: class ChildState(SimState): duplicated_attribute: bool - _system_attributes: typing.ClassVar[set[str]] = ( - SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 - ) - _global_attributes: typing.ClassVar[set[str]] = ( - SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 - ) + _system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 + _global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 assert "are declared multiple times" in str(excinfo.value) assert "duplicated_attribute" in str(excinfo.value) @@ -515,7 +511,7 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None: class DeformState(SimState, DeformGradMixin): """Test class that combines SimState with DeformGradMixin.""" - _system_attributes: typing.ClassVar[set[str]] = ( + _system_attributes = ( SimState._system_attributes # noqa: SLF001 | DeformGradMixin._system_attributes # noqa: SLF001 ) From d8f67ffba57bb5eda940658065610c5b49815739 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 20:51:39 -0700 Subject: [PATCH 28/32] add back row_vector_cell to deformgrad mixin --- torch_sim/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 912d1606..2b89af48 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -498,8 +498,9 @@ class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor + row_vector_cell: torch.Tensor - _system_attributes: ClassVar[set[str]] = {"reference_cell"} + _system_attributes: ClassVar[set[str]] = {"reference_cell", "row_vector_cell"} @property def reference_row_vector_cell(self) -> torch.Tensor: From 6f7439b5237491698c43295ee5f77758df216d40 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sun, 10 Aug 2025 21:06:21 -0700 Subject: [PATCH 29/32] clone row_vector_cell --- torch_sim/optimizers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 2464098b..3cd89274 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -385,6 +385,7 @@ def gd_init( cell=state.cell, pbc=state.pbc, reference_cell=state.cell.clone(), + row_vector_cell=state.row_vector_cell.clone(), cell_factor=cell_factor, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume, @@ -936,6 +937,7 @@ def fire_init( cell_masses=cell_masses, # Optimization attributes reference_cell=state.cell.clone(), + row_vector_cell=state.row_vector_cell.clone(), cell_factor=cell_factor, pressure=pressure, dt=dt_start, @@ -1020,7 +1022,6 @@ class FrechetCellFIREState(SimState, DeformGradMixin): stress: torch.Tensor # Optimization-specific attributes - reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool @@ -1243,6 +1244,7 @@ def fire_init( cell_masses=cell_masses, # Optimization attributes reference_cell=state.cell.clone(), + row_vector_cell=state.row_vector_cell.clone(), cell_factor=cell_factor, pressure=pressure, dt=dt_start, From 77bdc90a4633047337b45ebd1d1128f4ce503033 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Mon, 11 Aug 2025 06:24:54 -0700 Subject: [PATCH 30/32] rm clone of row_vector_cell --- torch_sim/optimizers.py | 3 --- torch_sim/state.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 3cd89274..a88d40d9 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -385,7 +385,6 @@ def gd_init( cell=state.cell, pbc=state.pbc, reference_cell=state.cell.clone(), - row_vector_cell=state.row_vector_cell.clone(), cell_factor=cell_factor, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume, @@ -937,7 +936,6 @@ def fire_init( cell_masses=cell_masses, # Optimization attributes reference_cell=state.cell.clone(), - row_vector_cell=state.row_vector_cell.clone(), cell_factor=cell_factor, pressure=pressure, dt=dt_start, @@ -1244,7 +1242,6 @@ def fire_init( cell_masses=cell_masses, # Optimization attributes reference_cell=state.cell.clone(), - row_vector_cell=state.row_vector_cell.clone(), cell_factor=cell_factor, pressure=pressure, dt=dt_start, diff --git a/torch_sim/state.py b/torch_sim/state.py index 2b89af48..912d1606 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -498,9 +498,8 @@ class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor - row_vector_cell: torch.Tensor - _system_attributes: ClassVar[set[str]] = {"reference_cell", "row_vector_cell"} + _system_attributes: ClassVar[set[str]] = {"reference_cell"} @property def reference_row_vector_cell(self) -> torch.Tensor: From 81f830b9039dcce83c49a2e3fae7cd377a3254e0 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Mon, 11 Aug 2025 07:17:49 -0700 Subject: [PATCH 31/32] try to define row_vector_cell for the typechecker --- torch_sim/state.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_sim/state.py b/torch_sim/state.py index 912d1606..2ff4d0ae 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -501,6 +501,11 @@ class DeformGradMixin: _system_attributes: ClassVar[set[str]] = {"reference_cell"} + if TYPE_CHECKING: + # define this under a TYPE_CHECKING block to avoid it being included in the + # dataclass __init__ during runtime + row_vector_cell: torch.Tensor + @property def reference_row_vector_cell(self) -> torch.Tensor: """Get the original unit cell in terms of row vectors.""" From 7ba046b43baf294683006ca5804224b39c1cad4c Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Mon, 11 Aug 2025 08:02:33 -0700 Subject: [PATCH 32/32] skip processing system_idx as recommended by coderabbit --- torch_sim/state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 2ff4d0ae..fa898e1c 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -885,12 +885,13 @@ def concatenate_states( # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): - # if hasattr(state, prop): + if prop == "system_idx": + # skip system_idx, it will be handled below + continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): - # if hasattr(state, prop): per_system_tensors[prop].append(val) # Update system indices