From 95014796dc07ba17da38cf600a5a7154ba18ef9b Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 31 Oct 2025 13:41:11 -0400 Subject: [PATCH 1/2] replace vars(state) with state.attributes --- .../scripts/3_Dynamics/3.10_Hybrid_swap_mc.py | 2 +- examples/tutorials/hybrid_swap_tutorial.py | 2 +- tests/test_optimizer_states.py | 4 +- torch_sim/state.py | 37 +++++++++++++++---- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 331d69a1..016c4556 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -80,7 +80,7 @@ class HybridSwapMCState(ts.SwapMCState, MDState): swap_state = ts.swap_mc_init(state=md_state, model=model) hybrid_state = HybridSwapMCState( - **vars(md_state), + **md_state.attributes, last_permutation=torch.arange( md_state.n_atoms, device=md_state.device, dtype=torch.long ), diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index c41a7f0b..8344c6b4 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -129,7 +129,7 @@ class HybridSwapMCState(SwapMCState, MDState): # Create hybrid state combining both hybrid_state = HybridSwapMCState( - **vars(md_state), + **md_state.attributes, last_permutation=torch.arange( md_state.n_atoms, device=md_state.device, dtype=torch.long ), diff --git a/tests/test_optimizer_states.py b/tests/test_optimizer_states.py index bdd002e4..eee40875 100644 --- a/tests/test_optimizer_states.py +++ b/tests/test_optimizer_states.py @@ -34,7 +34,7 @@ def optim_data() -> dict: def test_optim_state_init(sim_state: SimState, optim_data: dict) -> None: """Test OptimState initialization.""" - state = OptimState(**vars(sim_state), **optim_data) + state = OptimState(**sim_state.attributes, **optim_data) assert torch.equal(state.forces, optim_data["forces"]) assert torch.equal(state.energy, optim_data["energy"]) assert torch.equal(state.stress, optim_data["stress"]) @@ -51,7 +51,7 @@ def test_fire_state_custom_values(sim_state: SimState, optim_data: dict) -> None "n_pos": torch.tensor([5], dtype=torch.int32), } - state = FireState(**vars(sim_state), **optim_data, **fire_data) + state = FireState(**sim_state.attributes, **optim_data, **fire_data) assert torch.equal(state.velocities, fire_data["velocities"]) assert torch.equal(state.dt, fire_data["dt"]) diff --git a/torch_sim/state.py b/torch_sim/state.py index 645761d6..2538dcc2 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -206,6 +206,16 @@ def volume(self) -> torch.Tensor: """Volume of the system.""" return torch.det(self.cell) + @property + def attributes(self) -> dict[str, torch.Tensor]: + """Get all public attributes of the state.""" + return { + attr: getattr(self, attr) + for attr in self._atom_attributes + | self._system_attributes + | self._global_attributes + } + @property def column_vector_cell(self) -> torch.Tensor: """Unit cell following the column vector convention.""" @@ -244,7 +254,7 @@ def clone(self) -> Self: SimState: A new SimState object with the same properties as the original """ attrs = {} - for attr_name, attr_value in vars(self).items(): + for attr_name, attr_value in self.attributes.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.clone() else: @@ -278,7 +288,7 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: """ # Copy all attributes from the source state attrs = {} - for attr_name, attr_value in vars(state).items(): + for attr_name, attr_value in state.attributes.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.clone() else: @@ -348,7 +358,7 @@ def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Se modified_state, popped_states = _pop_states(self, system_indices) # Update all attributes of self with the modified state's attributes - for attr_name, attr_value in vars(modified_state).items(): + for attr_name, attr_value in modified_state.attributes.items(): setattr(self, attr_name, attr_value) return popped_states @@ -443,12 +453,17 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: if hasattr(parent_cls, "__annotations__"): all_annotations.update(parent_cls.__annotations__) - attributes_to_check = set(vars(cls)) | set(all_annotations) + # Get class namespace attributes (methods, properties, class vars with values) + class_namespace = vars(cls) + attributes_to_check = set(class_namespace.keys()) | set(all_annotations.keys()) for attr_name in attributes_to_check: is_special_attribute = attr_name.startswith("__") - is_property = attr_name in vars(cls) and isinstance( - vars(cls).get(attr_name), property + is_private_attribute = attr_name.startswith("_") and not attr_name.startswith( + "__" + ) + is_property = attr_name in class_namespace and isinstance( + class_namespace.get(attr_name), property ) is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) is_class_variable = ( @@ -457,7 +472,13 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar ) - if is_special_attribute or is_property or is_method or is_class_variable: + if ( + is_special_attribute + or is_private_attribute + or is_property + or is_method + or is_class_variable + ): continue if attr_name not in all_defined_attributes: @@ -573,7 +594,7 @@ def _state_to_device[T: SimState]( if dtype is None: dtype = state.dtype - attrs = vars(state) + attrs = state.attributes for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) From f1b5827cdc167c67c80bed55b193c96a352750ae Mon Sep 17 00:00:00 2001 From: orionarcher Date: Fri, 31 Oct 2025 13:49:36 -0400 Subject: [PATCH 2/2] nit --- torch_sim/state.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 2538dcc2..94881810 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -459,9 +459,7 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: for attr_name in attributes_to_check: is_special_attribute = attr_name.startswith("__") - is_private_attribute = attr_name.startswith("_") and not attr_name.startswith( - "__" - ) + is_private_attribute = attr_name.startswith("_") and not is_special_attribute is_property = attr_name in class_namespace and isinstance( class_namespace.get(attr_name), property )