From 2d77db5fef7e82c0df40aed6c47a1f66cc4d92f9 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 19:54:08 -0700 Subject: [PATCH 01/12] we can properly identify variables that are tensor | None --- torch_sim/state.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9b..1e9905fd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,6 +6,7 @@ import copy import importlib +import typing import warnings from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal, Self @@ -372,6 +373,32 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all derived states cannot have tensor attributes that can also be + None. This is because torch.concatenate cannot concat between a tensor and a None. + See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + """ + # We need to use get_type_hints to correctly inspect the types + type_hints = typing.get_type_hints(cls) + for attr_name, attr_type in type_hints.items(): + origin = typing.get_origin(attr_type) + + is_union = origin is typing.Union + if not is_union and origin is not None: + # For Python 3.10+ `|` syntax, origin is types.UnionType + # We check by name to be robust against module reloading/patching issues + is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + if is_union: + args = typing.get_args(attr_type) + if torch.Tensor in args and type(None) in args: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be of type 'torch.Tensor | None'. " + "Optional tensor attributes are disallowed in SimState " + "subclasses to prevent concatenation errors." + ) + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients.""" From 6767059d585f98ad59f204174fbee7b65f2a6888 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:07:24 -0700 Subject: [PATCH 02/12] simstate properly handles system_idx_init now --- torch_sim/state.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 1e9905fd..e88bef02 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import typing import warnings -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING, Literal, Self import torch @@ -82,9 +82,10 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor + system_idx_init: InitVar[torch.Tensor | None] - def __post_init__(self) -> None: + def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -108,11 +109,9 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.system_idx is None: - self.cell = self.cell.unsqueeze(0) - - if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_systems, 3, 3)") + if system_idx_init is not None and self.system_idx is None: + # we check if system_idx is none to prevent overriding system_idx^ + self.system_idx = system_idx_init if self.system_idx is None: self.system_idx = torch.zeros( @@ -126,6 +125,12 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.cell.ndim != 3 and self.system_idx is None: + self.cell = self.cell.unsqueeze(0) + + if self.cell.shape[-2:] != (3, 3): + raise ValueError("Cell must have shape (n_systems, 3, 3)") + if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" From a1fc8b5d8c695d52eb9fea530c27dbbf770a1f21 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:10:50 -0700 Subject: [PATCH 03/12] properly set the system idx field init = false --- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 4 ++-- torch_sim/state.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index 8f61bd49..3ead1849 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_idx=system_idx, + system_idx_init=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..91eb5487 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -474,7 +474,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor @@ -965,7 +965,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor diff --git a/torch_sim/state.py b/torch_sim/state.py index e88bef02..a281722f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import typing import warnings -from dataclasses import InitVar, dataclass +from dataclasses import InitVar, dataclass, field from typing import TYPE_CHECKING, Literal, Self import torch @@ -82,7 +82,7 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor + system_idx: torch.Tensor = field(init=False) system_idx_init: InitVar[torch.Tensor | None] def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: From 80cc62aeadf6099c753c9c9d42535234531fb914 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:27:00 -0700 Subject: [PATCH 04/12] ran into problem with using the system_index hack --- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 2 +- torch_sim/state.py | 11 ++++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index 3ead1849..350d2993 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_idx_init=system_idx, + system_index=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 91eb5487..5f16850f 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - system_idx=state.system_idx.clone(), + system_index=state.system_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, diff --git a/torch_sim/state.py b/torch_sim/state.py index a281722f..5722eecd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -83,9 +83,9 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor = field(init=False) - system_idx_init: InitVar[torch.Tensor | None] + system_index: InitVar[torch.Tensor | None] - def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: + def __post_init__(self, system_index: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -109,15 +109,12 @@ def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if system_idx_init is not None and self.system_idx is None: - # we check if system_idx is none to prevent overriding system_idx^ - self.system_idx = system_idx_init - - if self.system_idx is None: + if system_index is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: + self.system_idx = system_index # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. From 5cf6b76fd1cdb1b657ad6d3a67ea983e4493bf2a Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:14:14 -0700 Subject: [PATCH 05/12] init subclass runs properly --- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 2 +- torch_sim/state.py | 86 +++++++++++++++++++++++++++++++++-------- 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index 350d2993..bdfa4067 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_index=system_idx, + init_system_idx=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 5f16850f..5e570893 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - system_index=state.system_idx.clone(), + init_system_idx=state.system_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, diff --git a/torch_sim/state.py b/torch_sim/state.py index 5722eecd..67b40e12 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,10 +6,11 @@ import copy import importlib +import inspect import typing import warnings from dataclasses import InitVar, dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, Self, TypeVar import torch @@ -83,9 +84,9 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor = field(init=False) - system_index: InitVar[torch.Tensor | None] + init_system_idx: InitVar[torch.Tensor | None] - def __post_init__(self, system_index: torch.Tensor | None) -> None: + def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -109,12 +110,12 @@ def __post_init__(self, system_index: torch.Tensor | None) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if system_index is None: + if init_system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: - self.system_idx = system_index + self.system_idx = init_system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. @@ -275,7 +276,7 @@ def clone(self) -> Self: else: attrs[attr_name] = copy.deepcopy(attr_value) - return self.__class__(**attrs) + return construct_state(self, attrs) def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. @@ -382,8 +383,8 @@ def __init_subclass__(cls, **kwargs) -> None: """ # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) - for attr_name, attr_type in type_hints.items(): - origin = typing.get_origin(attr_type) + for attr_name, attr_typehint in type_hints.items(): + origin = typing.get_origin(attr_typehint) is_union = origin is typing.Union if not is_union and origin is not None: @@ -391,14 +392,43 @@ def __init_subclass__(cls, **kwargs) -> None: # We check by name to be robust against module reloading/patching issues is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" if is_union: - args = typing.get_args(attr_type) + args = typing.get_args(attr_typehint) if torch.Tensor in args and type(None) in args: raise TypeError( f"Attribute '{attr_name}' in class '{cls.__name__}' is not " "allowed to be of type 'torch.Tensor | None'. " "Optional tensor attributes are disallowed in SimState " - "subclasses to prevent concatenation errors." + "subclasses to prevent concatenation errors.\n" + "If this attribute will take on a default value in the " + "post_init method, please use an InitVar for that attribute " + "but with a prepended 'init_' to the name. (e.g. init_system_idx)" ) + + # Validate InitVar fields + for attr_name, attr_typehint in cls.__annotations__.items(): + # 1) validate InitVar fields + if type(attr_typehint) is InitVar: + # make sure its prefix is "init_" + if not attr_name.startswith("init_"): + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be an InitVar. It must be prefixed with 'init_'" + ) + # make sure there is a corresponding non-InitVar field + non_init_attr_name = attr_name.removeprefix("init_") + if non_init_attr_name not in type_hints: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be an InitVar. It must have a corresponding " + f"non-InitVar field {non_init_attr_name}" + ) + + # 2) forbid non init vars to have a "init_" prefix + elif attr_name.startswith("init_"): + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to have an 'init_' prefix as it's a non-InitVar field." + ) super().__init_subclass__(**kwargs) @@ -519,7 +549,7 @@ def state_to_device( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) + return construct_state(state, attrs) def infer_property_scope( @@ -748,11 +778,35 @@ def _split_state( # Add the global attributes **attrs["global"], } - states.append(type(state)(**system_attrs)) + states.append(construct_state(state, system_attrs)) return states +SimStateT = TypeVar("SimStateT", bound=SimState) + + +def construct_state( + old_state: SimStateT, + new_state_attrs: dict[str, typing.Any], +) -> SimStateT: + """Construct a new state from an old state and new state parameters.""" + # 1) process the attrs so they are the init params + processed_params = {} + for param in inspect.signature(old_state.__class__).parameters: + if param.startswith("init_"): + # this is an InitVar field + # we need to rename the corresponding field in system_attrs to have + # an "init_" prefix + non_init_attr_name = param.removeprefix("init_") + processed_params[param] = new_state_attrs[non_init_attr_name] + else: + processed_params[param] = new_state_attrs[param] + + # 2) construct the new state + return type(old_state)(**processed_params) + + def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, @@ -799,10 +853,10 @@ def _pop_states( pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) # Create the keep state - keep_state = type(state)(**keep_attrs) + keep_state = construct_state(state, keep_attrs) # Create and split the pop state - pop_state = type(state)(**pop_attrs) + pop_state = construct_state(state, pop_attrs) pop_states = _split_state(pop_state, ambiguous_handling) return keep_state, pop_states @@ -852,7 +906,7 @@ def _slice_state( filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) # Create the sliced state - return type(state)(**filtered_attrs) + return construct_state(state, filtered_attrs) def concatenate_states( @@ -941,7 +995,7 @@ def concatenate_states( concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return state_class(**concatenated) + return construct_state(first_state, concatenated) def initialize_state( From fcb9e7b7e7223937f00d722c520d640c3c88c6d0 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:24:51 -0700 Subject: [PATCH 06/12] revert setting velocities to | None --- torch_sim/optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 5e570893..178a297e 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -670,11 +670,11 @@ class UnitCellFireState(SimState, DeformGradMixin): forces: torch.Tensor energy: torch.Tensor stress: torch.Tensor - velocities: torch.Tensor + velocities: torch.Tensor | None # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor + cell_velocities: torch.Tensor | None cell_forces: torch.Tensor cell_masses: torch.Tensor From fa5c697ada25109187a92c91250a6c8ff1818721 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:38:34 -0700 Subject: [PATCH 07/12] cleanup --- torch_sim/state.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 67b40e12..0cf424b0 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -377,32 +377,8 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) def __init_subclass__(cls, **kwargs) -> None: - """Enforce that all derived states cannot have tensor attributes that can also be - None. This is because torch.concatenate cannot concat between a tensor and a None. - See https://github.com/Radical-AI/torch-sim/pull/219 for more details. - """ - # We need to use get_type_hints to correctly inspect the types + """Enforce that all subclasses have valid InitVar fields.""" type_hints = typing.get_type_hints(cls) - for attr_name, attr_typehint in type_hints.items(): - origin = typing.get_origin(attr_typehint) - - is_union = origin is typing.Union - if not is_union and origin is not None: - # For Python 3.10+ `|` syntax, origin is types.UnionType - # We check by name to be robust against module reloading/patching issues - is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" - if is_union: - args = typing.get_args(attr_typehint) - if torch.Tensor in args and type(None) in args: - raise TypeError( - f"Attribute '{attr_name}' in class '{cls.__name__}' is not " - "allowed to be of type 'torch.Tensor | None'. " - "Optional tensor attributes are disallowed in SimState " - "subclasses to prevent concatenation errors.\n" - "If this attribute will take on a default value in the " - "post_init method, please use an InitVar for that attribute " - "but with a prepended 'init_' to the name. (e.g. init_system_idx)" - ) # Validate InitVar fields for attr_name, attr_typehint in cls.__annotations__.items(): @@ -790,7 +766,9 @@ def construct_state( old_state: SimStateT, new_state_attrs: dict[str, typing.Any], ) -> SimStateT: - """Construct a new state from an old state and new state parameters.""" + """Construct a new state of the same class as the old state with the specified + attrs. + """ # 1) process the attrs so they are the init params processed_params = {} for param in inspect.signature(old_state.__class__).parameters: From 904b2686894363977320a90dc2f7e5b2a6c80ced Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:39:03 -0700 Subject: [PATCH 08/12] more cleanup --- torch_sim/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 178a297e..b045af99 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -474,7 +474,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor + velocities: torch.Tensor | None # FIRE algorithm parameters dt: torch.Tensor From afd7aaa996745a4574908251b70cdce738cb53ff Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:40:13 -0700 Subject: [PATCH 09/12] fix types of firestates --- torch_sim/optimizers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index b045af99..f8877764 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -670,11 +670,11 @@ class UnitCellFireState(SimState, DeformGradMixin): forces: torch.Tensor energy: torch.Tensor stress: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor @@ -965,7 +965,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor + cell_velocities: torch.Tensor | None cell_forces: torch.Tensor cell_masses: torch.Tensor From e5ca217339d8894975c9d5ab1a168a197b60f7dc Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:43:07 -0700 Subject: [PATCH 10/12] document system_idx better --- torch_sim/state.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 0cf424b0..a2b00e1c 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -49,9 +49,11 @@ class SimState: used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) - system_idx (torch.Tensor, optional): Maps each atom index to its system index. - Has shape (n_atoms,), defaults to None, must be unique consecutive - integers starting from 0 + system_idx (torch.Tensor): Maps each atom index to its system index. + Has shape (n_atoms,), must be unique consecutive integers starting from 0. + This attribute is initialized by the init_system_idx parameter in the + constructor. If init_system_idx is not provided, it is initialized to + zeros. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary From 456ad46c164402d47fea5759a94a240413d57698 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:47:46 -0700 Subject: [PATCH 11/12] cleanup _construct_state --- torch_sim/state.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index a2b00e1c..ce026506 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -10,12 +10,12 @@ import typing import warnings from dataclasses import InitVar, dataclass, field -from typing import TYPE_CHECKING, Literal, Self, TypeVar +from typing import TYPE_CHECKING, Literal, Self import torch import torch_sim as ts -from torch_sim.typing import StateLike +from torch_sim.typing import SimStateVar, StateLike if TYPE_CHECKING: @@ -278,7 +278,7 @@ def clone(self) -> Self: else: attrs[attr_name] = copy.deepcopy(attr_value) - return construct_state(self, attrs) + return _construct_state(self, attrs) def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. @@ -527,7 +527,7 @@ def state_to_device( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return construct_state(state, attrs) + return _construct_state(state, attrs) def infer_property_scope( @@ -756,20 +756,27 @@ def _split_state( # Add the global attributes **attrs["global"], } - states.append(construct_state(state, system_attrs)) + states.append(_construct_state(state, system_attrs)) return states -SimStateT = TypeVar("SimStateT", bound=SimState) - - -def construct_state( - old_state: SimStateT, +def _construct_state( + old_state: SimStateVar, new_state_attrs: dict[str, typing.Any], -) -> SimStateT: +) -> SimStateVar: """Construct a new state of the same class as the old state with the specified attrs. + + Args: + old_state (SimStateVar): We will construct a new state of the same class as this + one + new_state_attrs (dict[str, typing.Any]): The attributes to use to construct + the new state + + Returns: + SimStateVar: A new state of the same class as the old state with the specified + attributes """ # 1) process the attrs so they are the init params processed_params = {} @@ -833,10 +840,10 @@ def _pop_states( pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) # Create the keep state - keep_state = construct_state(state, keep_attrs) + keep_state = _construct_state(state, keep_attrs) # Create and split the pop state - pop_state = construct_state(state, pop_attrs) + pop_state = _construct_state(state, pop_attrs) pop_states = _split_state(pop_state, ambiguous_handling) return keep_state, pop_states @@ -886,7 +893,7 @@ def _slice_state( filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) # Create the sliced state - return construct_state(state, filtered_attrs) + return _construct_state(state, filtered_attrs) def concatenate_states( @@ -975,7 +982,7 @@ def concatenate_states( concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return construct_state(first_state, concatenated) + return _construct_state(first_state, concatenated) def initialize_state( From 5efbb7ca983fd1d67c1f66d552670c0f3166ac32 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:49:13 -0700 Subject: [PATCH 12/12] rm state_class --- torch_sim/state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index ce026506..c0bdfbf8 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -924,8 +924,7 @@ def concatenate_states( first_state = states[0] # Ensure all states are of the same class - state_class = type(first_state) - if not all(isinstance(state, state_class) for state in states): + if not all(isinstance(state, type(first_state)) for state in states): raise TypeError("All states must be of the same type") # Use the target device or default to the first state's device