From ffdbe3ecd54d4161d5f78fa0d89b118a603cc996 Mon Sep 17 00:00:00 2001 From: Samanvya Tripathi Date: Sun, 19 Oct 2025 18:48:43 -0400 Subject: [PATCH 1/5] Fix state_to_device side effects --- test_fix_demo.py | 118 +++++++++++++++++++++++++++++++++++++++++++++ torch_sim/state.py | 12 +++-- 2 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 test_fix_demo.py diff --git a/test_fix_demo.py b/test_fix_demo.py new file mode 100644 index 00000000..dd8c3272 --- /dev/null +++ b/test_fix_demo.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +""" +Demonstration script showing that the state_to_device side effect fix works correctly. + +This script demonstrates that: +1. SimState.to() now creates a new state without modifying the original +2. concatenate_states no longer modifies input states +3. initialize_state no longer modifies input states +""" + +import torch +import torch_sim as ts + + +def test_state_to_device_fix(): + """Test that SimState.to() doesn't modify the original state.""" + print("Testing SimState.to() side effect fix...") + + # Create a test state + state = ts.SimState( + positions=torch.randn(4, 3), + masses=torch.ones(4), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(4, dtype=torch.long) + ) + + # Store original values + original_positions = state.positions.clone() + original_dtype = state.dtype + + # Convert to different dtype + new_state = state.to(dtype=torch.float64) + + # Verify original state is unchanged + assert torch.allclose(state.positions, original_positions), "Original state was modified!" + assert state.dtype == original_dtype, "Original state dtype was changed!" + assert state is not new_state, "New state is not a different object!" + assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" + + print("[OK] SimState.to() fix works correctly") + + +def test_concatenate_states_fix(): + """Test that concatenate_states doesn't modify input states.""" + print("Testing concatenate_states side effect fix...") + + # Create two test states + state1 = ts.SimState( + positions=torch.randn(4, 3), + masses=torch.ones(4), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(4, dtype=torch.long) + ) + + state2 = ts.SimState( + positions=torch.randn(6, 3), + masses=torch.ones(6), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(6, dtype=torch.long) + ) + + # Store original values + original_positions1 = state1.positions.clone() + original_positions2 = state2.positions.clone() + + # Concatenate states + concatenated = ts.concatenate_states([state1, state2]) + + # Verify original states are unchanged + assert torch.allclose(state1.positions, original_positions1), "State1 was modified!" + assert torch.allclose(state2.positions, original_positions2), "State2 was modified!" + assert concatenated.n_atoms == 10, "Concatenated state has wrong number of atoms!" + + print("[OK] concatenate_states fix works correctly") + + +def test_initialize_state_fix(): + """Test that initialize_state doesn't modify input state.""" + print("Testing initialize_state side effect fix...") + + # Create a test state + original_state = ts.SimState( + positions=torch.randn(4, 3), + masses=torch.ones(4), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.ones(4, dtype=torch.long) + ) + + # Store original values + original_positions = original_state.positions.clone() + original_dtype = original_state.dtype + + # Initialize from existing state + new_state = ts.initialize_state(original_state, torch.device('cpu'), torch.float64) + + # Verify original state is unchanged + assert torch.allclose(original_state.positions, original_positions), "Original state was modified!" + assert original_state.dtype == original_dtype, "Original state dtype was changed!" + assert original_state is not new_state, "New state is not a different object!" + assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" + + print("[OK] initialize_state fix works correctly") + + +if __name__ == "__main__": + print("Demonstrating state_to_device side effect fix (Issue #293)") + print("=" * 60) + + test_state_to_device_fix() + test_concatenate_states_fix() + test_initialize_state_fix() + + print("=" * 60) + print("All tests passed! The fix successfully resolves the side effect issues.") diff --git a/torch_sim/state.py b/torch_sim/state.py index a04fa5d3..7da897f7 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -367,7 +367,7 @@ def to( Returns: SimState: A new SimState with tensors on the specified device and dtype """ - return state_to_device(self, device, dtype) + return _state_to_device(self, device, dtype) def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> Self: """Enable standard Python indexing syntax for slicing batches. @@ -552,7 +552,7 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") -def state_to_device[T: SimState]( +def _state_to_device[T: SimState]( state: T, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> T: """Convert the SimState to a new device and dtype. @@ -573,7 +573,9 @@ def state_to_device[T: SimState]( if dtype is None: dtype = state.dtype - attrs = vars(state) + # Use copy.copy to create a shallow copy of the attributes dict + # This avoids modifying the input state's internal dictionary + attrs = copy.copy(vars(state)) for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) @@ -856,7 +858,7 @@ def concatenate_states[T: SimState]( # noqa: C901 for state in states: # Move state to target device if needed if state.device != target_device: - state = state_to_device(state, target_device) + state = state.to(target_device) # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): @@ -919,7 +921,7 @@ def initialize_state( # TODO: create a way to pass velocities from pmg and ase if isinstance(system, SimState): - return state_to_device(system, device, dtype) + return system.clone().to(device, dtype) if isinstance(system, list | tuple) and all(isinstance(s, SimState) for s in system): if not all(state.n_systems == 1 for state in system): From 4c5483a5512f8f6548af8a19f608556675615661 Mon Sep 17 00:00:00 2001 From: Samanvya Tripathi Date: Sun, 19 Oct 2025 19:01:58 -0400 Subject: [PATCH 2/5] add unit tests for states --- test_fix_demo.py | 118 -------------------------------------------- tests/test_state.py | 72 +++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 118 deletions(-) delete mode 100644 test_fix_demo.py diff --git a/test_fix_demo.py b/test_fix_demo.py deleted file mode 100644 index dd8c3272..00000000 --- a/test_fix_demo.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -""" -Demonstration script showing that the state_to_device side effect fix works correctly. - -This script demonstrates that: -1. SimState.to() now creates a new state without modifying the original -2. concatenate_states no longer modifies input states -3. initialize_state no longer modifies input states -""" - -import torch -import torch_sim as ts - - -def test_state_to_device_fix(): - """Test that SimState.to() doesn't modify the original state.""" - print("Testing SimState.to() side effect fix...") - - # Create a test state - state = ts.SimState( - positions=torch.randn(4, 3), - masses=torch.ones(4), - cell=torch.eye(3).unsqueeze(0), - pbc=True, - atomic_numbers=torch.ones(4, dtype=torch.long) - ) - - # Store original values - original_positions = state.positions.clone() - original_dtype = state.dtype - - # Convert to different dtype - new_state = state.to(dtype=torch.float64) - - # Verify original state is unchanged - assert torch.allclose(state.positions, original_positions), "Original state was modified!" - assert state.dtype == original_dtype, "Original state dtype was changed!" - assert state is not new_state, "New state is not a different object!" - assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" - - print("[OK] SimState.to() fix works correctly") - - -def test_concatenate_states_fix(): - """Test that concatenate_states doesn't modify input states.""" - print("Testing concatenate_states side effect fix...") - - # Create two test states - state1 = ts.SimState( - positions=torch.randn(4, 3), - masses=torch.ones(4), - cell=torch.eye(3).unsqueeze(0), - pbc=True, - atomic_numbers=torch.ones(4, dtype=torch.long) - ) - - state2 = ts.SimState( - positions=torch.randn(6, 3), - masses=torch.ones(6), - cell=torch.eye(3).unsqueeze(0), - pbc=True, - atomic_numbers=torch.ones(6, dtype=torch.long) - ) - - # Store original values - original_positions1 = state1.positions.clone() - original_positions2 = state2.positions.clone() - - # Concatenate states - concatenated = ts.concatenate_states([state1, state2]) - - # Verify original states are unchanged - assert torch.allclose(state1.positions, original_positions1), "State1 was modified!" - assert torch.allclose(state2.positions, original_positions2), "State2 was modified!" - assert concatenated.n_atoms == 10, "Concatenated state has wrong number of atoms!" - - print("[OK] concatenate_states fix works correctly") - - -def test_initialize_state_fix(): - """Test that initialize_state doesn't modify input state.""" - print("Testing initialize_state side effect fix...") - - # Create a test state - original_state = ts.SimState( - positions=torch.randn(4, 3), - masses=torch.ones(4), - cell=torch.eye(3).unsqueeze(0), - pbc=True, - atomic_numbers=torch.ones(4, dtype=torch.long) - ) - - # Store original values - original_positions = original_state.positions.clone() - original_dtype = original_state.dtype - - # Initialize from existing state - new_state = ts.initialize_state(original_state, torch.device('cpu'), torch.float64) - - # Verify original state is unchanged - assert torch.allclose(original_state.positions, original_positions), "Original state was modified!" - assert original_state.dtype == original_dtype, "Original state dtype was changed!" - assert original_state is not new_state, "New state is not a different object!" - assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" - - print("[OK] initialize_state fix works correctly") - - -if __name__ == "__main__": - print("Demonstrating state_to_device side effect fix (Issue #293)") - print("=" * 60) - - test_state_to_device_fix() - test_concatenate_states_fix() - test_initialize_state_fix() - - print("=" * 60) - print("All tests passed! The fix successfully resolves the side effect issues.") diff --git a/tests/test_state.py b/tests/test_state.py index 1657b382..7f40d541 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -619,3 +619,75 @@ class DerivedState(SimState): assert "is not allowed to be of type 'torch.Tensor | None' because torch.cat" in str( exc_info.value ) + + +def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: + """Test that SimState.to() doesn't modify the original state.""" + # Store original values + original_positions = si_sim_state.positions.clone() + original_dtype = si_sim_state.dtype + original_device = si_sim_state.device + + # Convert to different dtype + new_state = si_sim_state.to(dtype=torch.float64) + + # Verify original state is unchanged + assert torch.allclose(si_sim_state.positions, original_positions), "Original state was modified!" + assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!" + assert si_sim_state.device == original_device, "Original state device was changed!" + assert si_sim_state is not new_state, "New state is not a different object!" + assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" + + # Test device conversion + if torch.cuda.is_available(): + new_state_gpu = si_sim_state.to(device=torch.device("cuda")) + assert si_sim_state.device == original_device, "Original state device was changed!" + assert new_state_gpu.device.type == "cuda", "New state doesn't have correct device!" + assert si_sim_state is not new_state_gpu, "New state is not a different object!" + + +def test_concatenate_states_no_side_effects(si_sim_state: SimState) -> None: + """Test that concatenate_states doesn't modify input states.""" + # Create two copies of the state + state1 = si_sim_state.clone() + state2 = si_sim_state.clone() + + # Store original values + original_positions1 = state1.positions.clone() + original_positions2 = state2.positions.clone() + original_dtype1 = state1.dtype + original_dtype2 = state2.dtype + + # Concatenate states + concatenated = ts.concatenate_states([state1, state2]) + + # Verify input states are unchanged + assert torch.allclose(state1.positions, original_positions1), "First input state was modified!" + assert torch.allclose(state2.positions, original_positions2), "Second input state was modified!" + assert state1.dtype == original_dtype1, "First input state dtype was changed!" + assert state2.dtype == original_dtype2, "Second input state dtype was changed!" + + # Verify concatenated state is correct + assert concatenated.n_systems == 2, "Concatenated state has wrong number of systems!" + assert concatenated.positions.shape[0] == state1.positions.shape[0] + state2.positions.shape[0], "Concatenated state has wrong number of atoms!" + + +def test_initialize_state_no_side_effects(si_sim_state: SimState) -> None: + """Test that initialize_state doesn't modify input state.""" + # Store original values + original_positions = si_sim_state.positions.clone() + original_dtype = si_sim_state.dtype + original_device = si_sim_state.device + + # Initialize from existing state + new_state = ts.initialize_state(si_sim_state, torch.device("cpu"), torch.float64) + + # Verify original state is unchanged + assert torch.allclose(si_sim_state.positions, original_positions), "Original state was modified!" + assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!" + assert si_sim_state.device == original_device, "Original state device was changed!" + assert si_sim_state is not new_state, "New state is not a different object!" + + # Verify new state has correct properties + assert new_state.device.type == "cpu", "New state doesn't have correct device!" + assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" \ No newline at end of file From 3611bd1b02fba9a54389b99150e148567cedb07d Mon Sep 17 00:00:00 2001 From: Samanvya Tripathi Date: Sun, 19 Oct 2025 22:11:34 -0400 Subject: [PATCH 3/5] add comment --- tests/test_state.py | 54 +++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 7f40d541..4f23636a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -627,22 +627,28 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: original_positions = si_sim_state.positions.clone() original_dtype = si_sim_state.dtype original_device = si_sim_state.device - + # Convert to different dtype new_state = si_sim_state.to(dtype=torch.float64) - + # Verify original state is unchanged - assert torch.allclose(si_sim_state.positions, original_positions), "Original state was modified!" + assert torch.allclose(si_sim_state.positions, original_positions), ( + "Original state was modified!" + ) assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!" assert si_sim_state.device == original_device, "Original state device was changed!" assert si_sim_state is not new_state, "New state is not a different object!" assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" - + # Test device conversion if torch.cuda.is_available(): new_state_gpu = si_sim_state.to(device=torch.device("cuda")) - assert si_sim_state.device == original_device, "Original state device was changed!" - assert new_state_gpu.device.type == "cuda", "New state doesn't have correct device!" + assert si_sim_state.device == original_device, ( + "Original state device was changed!" + ) + assert new_state_gpu.device.type == "cuda", ( + "New state doesn't have correct device!" + ) assert si_sim_state is not new_state_gpu, "New state is not a different object!" @@ -651,25 +657,32 @@ def test_concatenate_states_no_side_effects(si_sim_state: SimState) -> None: # Create two copies of the state state1 = si_sim_state.clone() state2 = si_sim_state.clone() - + # Store original values original_positions1 = state1.positions.clone() original_positions2 = state2.positions.clone() original_dtype1 = state1.dtype original_dtype2 = state2.dtype - + # Concatenate states concatenated = ts.concatenate_states([state1, state2]) - + # Verify input states are unchanged - assert torch.allclose(state1.positions, original_positions1), "First input state was modified!" - assert torch.allclose(state2.positions, original_positions2), "Second input state was modified!" + assert torch.allclose(state1.positions, original_positions1), ( + "First input state was modified!" + ) + assert torch.allclose(state2.positions, original_positions2), ( + "Second input state was modified!" + ) assert state1.dtype == original_dtype1, "First input state dtype was changed!" assert state2.dtype == original_dtype2, "Second input state dtype was changed!" - + # Verify concatenated state is correct assert concatenated.n_systems == 2, "Concatenated state has wrong number of systems!" - assert concatenated.positions.shape[0] == state1.positions.shape[0] + state2.positions.shape[0], "Concatenated state has wrong number of atoms!" + assert ( + concatenated.positions.shape[0] + == state1.positions.shape[0] + state2.positions.shape[0] + ), "Concatenated state has wrong number of atoms!" def test_initialize_state_no_side_effects(si_sim_state: SimState) -> None: @@ -678,16 +691,19 @@ def test_initialize_state_no_side_effects(si_sim_state: SimState) -> None: original_positions = si_sim_state.positions.clone() original_dtype = si_sim_state.dtype original_device = si_sim_state.device - + # Initialize from existing state new_state = ts.initialize_state(si_sim_state, torch.device("cpu"), torch.float64) - - # Verify original state is unchanged - assert torch.allclose(si_sim_state.positions, original_positions), "Original state was modified!" + + # Verify original state is unchanged, + # i.e. the original state is not modified by the initialize_state function + assert torch.allclose(si_sim_state.positions, original_positions), ( + "Original state was modified!" + ) assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!" assert si_sim_state.device == original_device, "Original state device was changed!" assert si_sim_state is not new_state, "New state is not a different object!" - + # Verify new state has correct properties assert new_state.device.type == "cpu", "New state doesn't have correct device!" - assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" \ No newline at end of file + assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" From bba9b40457c4c65ed12af8d59eb49210eb30a2f4 Mon Sep 17 00:00:00 2001 From: Samanvya Tripathi Date: Wed, 22 Oct 2025 15:04:35 -0400 Subject: [PATCH 4/5] remove downstream behaviour tests --- tests/test_state.py | 57 --------------------------------------------- 1 file changed, 57 deletions(-) diff --git a/tests/test_state.py b/tests/test_state.py index 4f23636a..1b293668 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -650,60 +650,3 @@ def test_state_to_device_no_side_effects(si_sim_state: SimState) -> None: "New state doesn't have correct device!" ) assert si_sim_state is not new_state_gpu, "New state is not a different object!" - - -def test_concatenate_states_no_side_effects(si_sim_state: SimState) -> None: - """Test that concatenate_states doesn't modify input states.""" - # Create two copies of the state - state1 = si_sim_state.clone() - state2 = si_sim_state.clone() - - # Store original values - original_positions1 = state1.positions.clone() - original_positions2 = state2.positions.clone() - original_dtype1 = state1.dtype - original_dtype2 = state2.dtype - - # Concatenate states - concatenated = ts.concatenate_states([state1, state2]) - - # Verify input states are unchanged - assert torch.allclose(state1.positions, original_positions1), ( - "First input state was modified!" - ) - assert torch.allclose(state2.positions, original_positions2), ( - "Second input state was modified!" - ) - assert state1.dtype == original_dtype1, "First input state dtype was changed!" - assert state2.dtype == original_dtype2, "Second input state dtype was changed!" - - # Verify concatenated state is correct - assert concatenated.n_systems == 2, "Concatenated state has wrong number of systems!" - assert ( - concatenated.positions.shape[0] - == state1.positions.shape[0] + state2.positions.shape[0] - ), "Concatenated state has wrong number of atoms!" - - -def test_initialize_state_no_side_effects(si_sim_state: SimState) -> None: - """Test that initialize_state doesn't modify input state.""" - # Store original values - original_positions = si_sim_state.positions.clone() - original_dtype = si_sim_state.dtype - original_device = si_sim_state.device - - # Initialize from existing state - new_state = ts.initialize_state(si_sim_state, torch.device("cpu"), torch.float64) - - # Verify original state is unchanged, - # i.e. the original state is not modified by the initialize_state function - assert torch.allclose(si_sim_state.positions, original_positions), ( - "Original state was modified!" - ) - assert si_sim_state.dtype == original_dtype, "Original state dtype was changed!" - assert si_sim_state.device == original_device, "Original state device was changed!" - assert si_sim_state is not new_state, "New state is not a different object!" - - # Verify new state has correct properties - assert new_state.device.type == "cpu", "New state doesn't have correct device!" - assert new_state.dtype == torch.float64, "New state doesn't have correct dtype!" From f81c3f92a9ea9e739fcbadd10c5eeffec9fc30a6 Mon Sep 17 00:00:00 2001 From: Samanvya Tripathi Date: Wed, 22 Oct 2025 15:04:58 -0400 Subject: [PATCH 5/5] remove copy vars --- torch_sim/state.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 7da897f7..645761d6 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -573,9 +573,7 @@ def _state_to_device[T: SimState]( if dtype is None: dtype = state.dtype - # Use copy.copy to create a shallow copy of the attributes dict - # This avoids modifying the input state's internal dictionary - attrs = copy.copy(vars(state)) + attrs = vars(state) for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device)