diff --git a/tests/test_state.py b/tests/test_state.py index 1657b382..1b293668 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -619,3 +619,34 @@ class DerivedState(SimState): assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str( exc_info.value ) + + +def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: + """Test that SimState.to() doesn't modify the original state.""" + # Store original values + original_positions = si_sim_state.positions.clone() + original_dtype = si_sim_state.dtype + original_device = si_sim_state.device + + # Convert to different dtype + new_state = si_sim_state.to(dtype=torch.float64) + + # Verify original state is unchanged + assert torch.allclose(si_sim_state.positions, original_positions), ( + "Original state was modified!" + ) + assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!" + assert si_sim_state.device == original_device, "Original state device was changed!" + assert si_sim_state is not new_state, "New state is not a different object!" + assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" + + # Test device conversion + if torch.cuda.is_available(): + new_state_gpu = si_sim_state.to(device=torch.device("cuda")) + assert si_sim_state.device == original_device, ( + "Original state device was changed!" + ) + assert new_state_gpu.device.type == "cuda", ( + "New state doesn't have correct device!" + ) + assert si_sim_state is not new_state_gpu, "New state is not a different object!" diff --git a/torch_sim/state.py b/torch_sim/state.py index a04fa5d3..645761d6 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -367,7 +367,7 @@ def to( Returns: SimState: A new SimState with tensors on the specified device and dtype """ - return state_to_device(self, device, dtype) + return _state_to_device(self, device, dtype) def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> Self: """Enable standard Python indexing syntax for slicing batches. @@ -552,7 +552,7 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") -def state_to_device[T: SimState]( +def _state_to_device[T: SimState]( state: T, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> T: """Convert the SimState to a new device and dtype. @@ -856,7 +856,7 @@ def concatenate_states[T: SimState]( # noqa: C901 for state in states: # Move state to target device if needed if state.device != target_device: - state = state_to_device(state, target_device) + state = state.to(target_device) # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): @@ -919,7 +919,7 @@ def initialize_state( # TODO: create a way to pass velocities from pmg and ase if isinstance(system, SimState): - return state_to_device(system, device, dtype) + return system.clone().to(device, dtype) if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): if not all(state.n_systems == 1 for state in system):