Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,34 @@ class DerivedState(SimState):
assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str(
exc_info.value
)


def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None:
"""Test that SimState.to() doesn't modify the original state."""
# Store original values
original_positions = si_sim_state.positions.clone()
original_dtype = si_sim_state.dtype
original_device = si_sim_state.device

# Convert to different dtype
new_state = si_sim_state.to(dtype=torch.float64)

# Verify original state is unchanged
assert torch.allclose(si_sim_state.positions, original_positions), (
"Original state was modified!"
)
assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!"
assert si_sim_state.device == original_device, "Original state device was changed!"
assert si_sim_state is not new_state, "New state is not a different object!"
assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!"

# Test device conversion
if torch.cuda.is_available():
new_state_gpu = si_sim_state.to(device=torch.device("cuda"))
assert si_sim_state.device == original_device, (
"Original state device was changed!"
)
assert new_state_gpu.device.type == "cuda", (
"New state doesn't have correct device!"
)
assert si_sim_state is not new_state_gpu, "New state is not a different object!"
8 changes: 4 additions & 4 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def to(
Returns:
SimState: A new SimState with tensors on the specified device and dtype
"""
return state_to_device(self, device, dtype)
return _state_to_device(self, device, dtype)

def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> Self:
"""Enable standard Python indexing syntax for slicing batches.
Expand Down Expand Up @@ -552,7 +552,7 @@ def _normalize_system_indices(
raise TypeError(f"Unsupported index type: {type(system_indices)}")


def state_to_device[T: SimState](
def _state_to_device[T: SimState](
state: T, device: torch.device | None = None, dtype: torch.dtype | None = None
) -> T:
"""Convert the SimState to a new device and dtype.
Expand Down Expand Up @@ -856,7 +856,7 @@ def concatenate_states[T: SimState]( # noqa: C901
for state in states:
# Move state to target device if needed
if state.device != target_device:
state = state_to_device(state, target_device)
state = state.to(target_device)

# Collect per-atom properties
for prop, val in get_attrs_for_scope(state, "per-atom"):
Expand Down Expand Up @@ -919,7 +919,7 @@ def initialize_state(
# TODO: create a way to pass velocities from pmg and ase

if isinstance(system, SimState):
return state_to_device(system, device, dtype)
return system.clone().to(device, dtype)

if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system):
if not all(state.n_systems == 1 for state in system):
Expand Down
Loading