Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class HybridSwapMCState(ts.SwapMCState, MDState):

swap_state = ts.swap_mc_init(state=md_state, model=model)
hybrid_state = HybridSwapMCState(
**vars(md_state),
**md_state.attributes,
last_permutation=torch.arange(
md_state.n_atoms, device=md_state.device, dtype=torch.long
),
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/hybrid_swap_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class HybridSwapMCState(SwapMCState, MDState):

# Create hybrid state combining both
hybrid_state = HybridSwapMCState(
**vars(md_state),
**md_state.attributes,
last_permutation=torch.arange(
md_state.n_atoms, device=md_state.device, dtype=torch.long
),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizer_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def optim_data() -> dict:

def test_optim_state_init(sim_state: SimState, optim_data: dict) -> None:
"""Test OptimState initialization."""
state = OptimState(**vars(sim_state), **optim_data)
state = OptimState(**sim_state.attributes, **optim_data)
assert torch.equal(state.forces, optim_data["forces"])
assert torch.equal(state.energy, optim_data["energy"])
assert torch.equal(state.stress, optim_data["stress"])
Expand All @@ -51,7 +51,7 @@ def test_fire_state_custom_values(sim_state: SimState, optim_data: dict) -> None
"n_pos": torch.tensor([5], dtype=torch.int32),
}

state = FireState(**vars(sim_state), **optim_data, **fire_data)
state = FireState(**sim_state.attributes, **optim_data, **fire_data)

assert torch.equal(state.velocities, fire_data["velocities"])
assert torch.equal(state.dt, fire_data["dt"])
Expand Down
35 changes: 27 additions & 8 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ def volume(self) -> torch.Tensor:
"""Volume of the system."""
return torch.det(self.cell)

@property
def attributes(self) -> dict[str, torch.Tensor]:
"""Get all public attributes of the state."""
return {
attr: getattr(self, attr)
for attr in self._atom_attributes
| self._system_attributes
| self._global_attributes
}

@property
def column_vector_cell(self) -> torch.Tensor:
"""Unit cell following the column vector convention."""
Expand Down Expand Up @@ -244,7 +254,7 @@ def clone(self) -> Self:
SimState: A new SimState object with the same properties as the original
"""
attrs = {}
for attr_name, attr_value in vars(self).items():
for attr_name, attr_value in self.attributes.items():
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.clone()
else:
Expand Down Expand Up @@ -278,7 +288,7 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self:
"""
# Copy all attributes from the source state
attrs = {}
for attr_name, attr_value in vars(state).items():
for attr_name, attr_value in state.attributes.items():
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.clone()
else:
Expand Down Expand Up @@ -348,7 +358,7 @@ def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Se
modified_state, popped_states = _pop_states(self, system_indices)

# Update all attributes of self with the modified state's attributes
for attr_name, attr_value in vars(modified_state).items():
for attr_name, attr_value in modified_state.attributes.items():
setattr(self, attr_name, attr_value)

return popped_states
Expand Down Expand Up @@ -443,12 +453,15 @@ def _assert_all_attributes_have_defined_scope(cls) -> None:
if hasattr(parent_cls, "__annotations__"):
all_annotations.update(parent_cls.__annotations__)

attributes_to_check = set(vars(cls)) | set(all_annotations)
# Get class namespace attributes (methods, properties, class vars with values)
class_namespace = vars(cls)
attributes_to_check = set(class_namespace.keys()) | set(all_annotations.keys())

for attr_name in attributes_to_check:
is_special_attribute = attr_name.startswith("__")
is_property = attr_name in vars(cls) and isinstance(
vars(cls).get(attr_name), property
is_private_attribute = attr_name.startswith("_") and not is_special_attribute
is_property = attr_name in class_namespace and isinstance(
class_namespace.get(attr_name), property
)
is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name))
is_class_variable = (
Expand All @@ -457,7 +470,13 @@ def _assert_all_attributes_have_defined_scope(cls) -> None:
typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar
)

if is_special_attribute or is_property or is_method or is_class_variable:
if (
is_special_attribute
or is_private_attribute
or is_property
or is_method
or is_class_variable
):
continue

if attr_name not in all_defined_attributes:
Expand Down Expand Up @@ -573,7 +592,7 @@ def _state_to_device[T: SimState](
if dtype is None:
dtype = state.dtype

attrs = vars(state)
attrs = state.attributes
for attr_name, attr_value in attrs.items():
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.to(device=device)
Expand Down
Loading