From a1b9c7560bbed1f42b018427732892d919e6ae1d Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 25 Oct 2025 11:07:04 -0700 Subject: [PATCH 1/6] rm pbc wrap general --- tests/test_transforms.py | 141 +-------------------------------------- torch_sim/transforms.py | 43 ------------ 2 files changed, 3 insertions(+), 181 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3b76145e..6d7a750d 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -62,138 +62,6 @@ def test_inverse_box_single_element() -> None: assert torch.allclose(ft.inverse_box(x), torch.tensor(0.5)) -def test_pbc_wrap_general_orthorhombic() -> None: - """Test periodic boundary wrapping with orthorhombic cell. - - Tests wrapping of positions in a simple cubic/orthorhombic cell where - the lattice vectors are aligned with coordinate axes. This is the simplest - case where the lattice matrix is diagonal. - """ - # Simple cubic cell with length 2.0 - lattice = torch.eye(3) * 2.0 - - # Test positions outside box in various directions - positions = torch.tensor( - [ - [2.5, 0.5, 0.5], # Beyond +x face - [-0.5, 0.5, 0.5], # Beyond -x face - [0.5, 2.5, 0.5], # Beyond +y face - [0.5, 0.5, -2.5], # Beyond -z face - ] - ) - - expected = torch.tensor( - [ - [0.5, 0.5, 0.5], # Wrapped to +x face - [1.5, 0.5, 0.5], # Wrapped to -x face - [0.5, 0.5, 0.5], # Wrapped to +y face - [0.5, 0.5, 1.5], # Wrapped to -z face - ] - ) - - wrapped = ft.pbc_wrap_general(positions, lattice) - assert torch.allclose(wrapped, expected) - - -@pytest.mark.parametrize( - ("cell", "shift"), - [ - # Cubic cell, integer shift [1, 1, 1] - (torch.eye(3, dtype=torch.float64) * 2.0, [1, 1, 1]), - # Triclinic cell, integer shift [1, 1, 1] - (([[2.0, 0.0, 0.0], [0.5, 2.0, 0.0], [0.0, 0.3, 2.0]]), [1, 1, 1]), - # Triclinic cell, integer shift [-1, 2, 0] - (([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-1, 2, 0]), - # triclinic, all negative shift - (([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-2, -1, -3]), - # cubic, large mixed shift - (torch.eye(3, dtype=torch.float64) * 2.0, [5, 0, -10]), - # highly tilted cell - (([[1.3, 0.9, 0.8], [0.0, 1.0, 0.9], [0.0, 0.0, 1.0]]), [1, -2, 3]), - # Left-handed cell - (([[2.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 2.0]]), [1, 1, 1]), - ], -) -def test_pbc_wrap_general_param(cell: torch.Tensor, shift: torch.Tensor) -> None: - """Test periodic boundary wrapping for various cells and integer shifts.""" - cell = torch.as_tensor(cell, dtype=torch.float64) - shift = torch.as_tensor(shift, dtype=torch.float64) - base_frac = torch.tensor([[0.25, 0.5, 0.75]], dtype=torch.float64) - base_cart = base_frac @ cell.T - shifted_cart = base_cart + (shift @ cell.T) - wrapped = ft.pbc_wrap_general(shifted_cart, cell) - torch.testing.assert_close(wrapped, base_cart, rtol=1e-6, atol=1e-6) - - -def test_pbc_wrap_general_edge_case() -> None: - """Test periodic boundary wrapping at cell boundaries. - - Verifies correct handling of positions exactly on cell boundaries, - which should be wrapped to zero rather than one to maintain consistency. - """ - lattice = torch.eye(2) * 2.0 - positions = torch.tensor( - [ - [2.0, 1.0], # On +x boundary - [1.0, 2.0], # On +y boundary - [2.0, 2.0], # On corner - ] - ) - - expected = torch.tensor([[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]]) - - wrapped = ft.pbc_wrap_general(positions, lattice) - assert torch.allclose(wrapped, expected) - - -def test_pbc_wrap_general_invalid_inputs() -> None: - """Test error handling for invalid inputs. - - Verifies that appropriate errors are raised for: - - Non-floating point tensors - - Non-square lattice matrix - - Mismatched dimensions between positions and lattice - """ - # Test integer tensors - with pytest.raises(TypeError): - ft.pbc_wrap_general(torch.ones(3, dtype=torch.int64), torch.eye(3)) - - # Test non-square lattice - with pytest.raises(ValueError): - ft.pbc_wrap_general(torch.ones(3), torch.ones(3, 2)) - - # Test dimension mismatch - with pytest.raises(ValueError): - ft.pbc_wrap_general(torch.ones(4), torch.eye(3)) - - -def test_pbc_wrap_general_batch() -> None: - """Test periodic boundary wrapping with batched positions. - - Verifies that the function correctly handles batched position inputs - while using a single lattice definition. - """ - lattice = torch.eye(3) * 2.0 - - # Batch of positions with shape (2, 4, 3) - positions = torch.tensor( - [ - [[2.5, 0.5, 0.5], [0.5, 2.5, 0.5], [0.5, 0.5, 2.5], [2.5, 2.5, 2.5]], - [[3.5, 1.5, 1.5], [-0.5, 1.5, 1.5], [1.5, -0.5, 1.5], [1.5, 1.5, -0.5]], - ] - ) - - expected = torch.tensor( - [ - [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]], - [[1.5, 1.5, 1.5], [1.5, 1.5, 1.5], [1.5, 1.5, 1.5], [1.5, 1.5, 1.5]], - ] - ) - - wrapped = ft.pbc_wrap_general(positions, lattice) - assert torch.allclose(wrapped, expected) - - @pytest.mark.parametrize( "pbc", [[True, True, True], [True, True, False], [False, False, False], True, False] ) @@ -318,15 +186,12 @@ def test_pbc_wrap_batched_triclinic() -> None: ) batch = torch.tensor([0, 1], device=DEVICE) - # Stack the cells for batched processing - cell = torch.stack([cell1, cell2]) - # Apply wrapping wrapped = ft.pbc_wrap_batched(positions, cell=cell, system_idx=batch) - # Calculate expected result for first atom (using original algorithm for verification) - expected1 = ft.pbc_wrap_general(positions[0:1], cell1) - expected2 = ft.pbc_wrap_general(positions[1:2], cell2) + # Calculate expected results by wrapping each system independently + expected1 = ft.wrap_positions(positions[0:1], cell1) + expected2 = ft.wrap_positions(positions[1:2], cell2) # Verify results match the expected values assert torch.allclose(wrapped[0:1], expected1, atol=1e-6) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 1b2c416b..4200cf9b 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -110,49 +110,6 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor: raise ValueError(f"Box must be either: a scalar, a vector, or a matrix. Found {box}.") -def pbc_wrap_general( - positions: torch.Tensor, lattice_vectors: torch.Tensor -) -> torch.Tensor: - """Apply periodic boundary conditions using lattice - vector transformation method. - - This implementation follows the general matrix-based approach for - periodic boundary conditions in arbitrary triclinic cells: - 1. Transform positions to fractional coordinates using B = A^(-1) - 2. Wrap fractional coordinates to [0,1) using modulo - 3. Transform back to real space using A - - Args: - positions (torch.Tensor): Tensor of shape (..., d) - containing particle positions in real space. - lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing - lattice vectors as columns (A matrix in the equations). - - Returns: - torch.Tensor: Wrapped positions in real space with same shape as input positions. - """ - # Validate inputs - if not torch.is_floating_point(positions) or not torch.is_floating_point( - lattice_vectors - ): - raise TypeError("Positions and lattice vectors must be floating point tensors.") - - if lattice_vectors.ndim != 2 or lattice_vectors.shape[0] != lattice_vectors.shape[1]: - raise ValueError("Lattice vectors must be a square matrix.") - - if positions.shape[-1] != lattice_vectors.shape[0]: - raise ValueError("Position dimensionality must match lattice vectors.") - - # Transform to fractional coordinates: f = Br - frac_coords = positions @ torch.linalg.inv(lattice_vectors).T - - # Wrap to reference cell [0,1) using modulo - wrapped_frac = frac_coords % 1.0 - - # Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row - return wrapped_frac @ lattice_vectors.T - - def pbc_wrap_batched( positions: torch.Tensor, cell: torch.Tensor, system_idx: torch.Tensor ) -> torch.Tensor: From 648cc66fd60346a2ee371a5caa30364d2d3fcf0b Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 25 Oct 2025 12:35:23 -0700 Subject: [PATCH 2/6] fix test --- tests/test_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6d7a750d..de868694 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -190,8 +190,8 @@ def test_pbc_wrap_batched_triclinic() -> None: wrapped = ft.pbc_wrap_batched(positions, cell=cell, system_idx=batch) # Calculate expected results by wrapping each system independently - expected1 = ft.wrap_positions(positions[0:1], cell1) - expected2 = ft.wrap_positions(positions[1:2], cell2) + expected1 = ft.wrap_positions(positions[0:1], cell1.T) + expected2 = ft.wrap_positions(positions[1:2], cell2.T) # Verify results match the expected values assert torch.allclose(wrapped[0:1], expected1, atol=1e-6) From 7385ff383200f22dc72c9deb90cc75589cd03e21 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 25 Oct 2025 12:39:28 -0700 Subject: [PATCH 3/6] add back function but make it deprecated --- tests/test_transforms.py | 132 +++++++++++++++++++++++++++++++++++++++ torch_sim/transforms.py | 24 +++++++ 2 files changed, 156 insertions(+) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index de868694..067abad8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -62,6 +62,138 @@ def test_inverse_box_single_element() -> None: assert torch.allclose(ft.inverse_box(x), torch.tensor(0.5)) +def test_pbc_wrap_general_orthorhombic() -> None: + """Test periodic boundary wrapping with orthorhombic cell. + + Tests wrapping of positions in a simple cubic/orthorhombic cell where + the lattice vectors are aligned with coordinate axes. This is the simplest + case where the lattice matrix is diagonal. + """ + # Simple cubic cell with length 2.0 + lattice = torch.eye(3) * 2.0 + + # Test positions outside box in various directions + positions = torch.tensor( + [ + [2.5, 0.5, 0.5], # Beyond +x face + [-0.5, 0.5, 0.5], # Beyond -x face + [0.5, 2.5, 0.5], # Beyond +y face + [0.5, 0.5, -2.5], # Beyond -z face + ] + ) + + expected = torch.tensor( + [ + [0.5, 0.5, 0.5], # Wrapped to +x face + [1.5, 0.5, 0.5], # Wrapped to -x face + [0.5, 0.5, 0.5], # Wrapped to +y face + [0.5, 0.5, 1.5], # Wrapped to -z face + ] + ) + + wrapped = ft.pbc_wrap_general(positions, lattice) + assert torch.allclose(wrapped, expected) + + +@pytest.mark.parametrize( + ("cell", "shift"), + [ + # Cubic cell, integer shift [1, 1, 1] + (torch.eye(3, dtype=torch.float64) * 2.0, [1, 1, 1]), + # Triclinic cell, integer shift [1, 1, 1] + (([[2.0, 0.0, 0.0], [0.5, 2.0, 0.0], [0.0, 0.3, 2.0]]), [1, 1, 1]), + # Triclinic cell, integer shift [-1, 2, 0] + (([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-1, 2, 0]), + # triclinic, all negative shift + (([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-2, -1, -3]), + # cubic, large mixed shift + (torch.eye(3, dtype=torch.float64) * 2.0, [5, 0, -10]), + # highly tilted cell + (([[1.3, 0.9, 0.8], [0.0, 1.0, 0.9], [0.0, 0.0, 1.0]]), [1, -2, 3]), + # Left-handed cell + (([[2.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 2.0]]), [1, 1, 1]), + ], +) +def test_pbc_wrap_general_param(cell: torch.Tensor, shift: torch.Tensor) -> None: + """Test periodic boundary wrapping for various cells and integer shifts.""" + cell = torch.as_tensor(cell, dtype=torch.float64) + shift = torch.as_tensor(shift, dtype=torch.float64) + base_frac = torch.tensor([[0.25, 0.5, 0.75]], dtype=torch.float64) + base_cart = base_frac @ cell.T + shifted_cart = base_cart + (shift @ cell.T) + wrapped = ft.pbc_wrap_general(shifted_cart, cell) + torch.testing.assert_close(wrapped, base_cart, rtol=1e-6, atol=1e-6) + + +def test_pbc_wrap_general_edge_case() -> None: + """Test periodic boundary wrapping at cell boundaries. + + Verifies correct handling of positions exactly on cell boundaries, + which should be wrapped to zero rather than one to maintain consistency. + """ + lattice = torch.eye(2) * 2.0 + positions = torch.tensor( + [ + [2.0, 1.0], # On +x boundary + [1.0, 2.0], # On +y boundary + [2.0, 2.0], # On corner + ] + ) + + expected = torch.tensor([[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]]) + + wrapped = ft.pbc_wrap_general(positions, lattice) + assert torch.allclose(wrapped, expected) + + +def test_pbc_wrap_general_invalid_inputs() -> None: + """Test error handling for invalid inputs. + + Verifies that appropriate errors are raised for: + - Non-floating point tensors + - Non-square lattice matrix + - Mismatched dimensions between positions and lattice + """ + # Test integer tensors + with pytest.raises(TypeError): + ft.pbc_wrap_general(torch.ones(3, dtype=torch.int64), torch.eye(3)) + + # Test non-square lattice + with pytest.raises(ValueError): + ft.pbc_wrap_general(torch.ones(3), torch.ones(3, 2)) + + # Test dimension mismatch + with pytest.raises(ValueError): + ft.pbc_wrap_general(torch.ones(4), torch.eye(3)) + + +def test_pbc_wrap_general_batch() -> None: + """Test periodic boundary wrapping with batched positions. + + Verifies that the function correctly handles batched position inputs + while using a single lattice definition. + """ + lattice = torch.eye(3) * 2.0 + + # Batch of positions with shape (2, 4, 3) + positions = torch.tensor( + [ + [[2.5, 0.5, 0.5], [0.5, 2.5, 0.5], [0.5, 0.5, 2.5], [2.5, 2.5, 2.5]], + [[3.5, 1.5, 1.5], [-0.5, 1.5, 1.5], [1.5, -0.5, 1.5], [1.5, 1.5, -0.5]], + ] + ) + + expected = torch.tensor( + [ + [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]], + [[1.5, 1.5, 1.5], [1.5, 1.5, 1.5], [1.5, 1.5, 1.5], [1.5, 1.5, 1.5]], + ] + ) + + wrapped = ft.pbc_wrap_general(positions, lattice) + assert torch.allclose(wrapped, expected) + + @pytest.mark.parametrize( "pbc", [[True, True, True], [True, True, False], [False, False, False], True, False] ) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 4200cf9b..3193a3ab 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -7,6 +7,7 @@ from collections.abc import Callable, Iterable from functools import wraps +from typing_extensions import deprecated import torch from torch.types import _dtype @@ -110,6 +111,29 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor: raise ValueError(f"Box must be either: a scalar, a vector, or a matrix. Found {box}.") +@deprecated("Use wrap_positions instead") +def pbc_wrap_general(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """Apply periodic boundary conditions using lattice + vector transformation method. + + This implementation follows the general matrix-based approach for + periodic boundary conditions in arbitrary triclinic cells: + 1. Transform positions to fractional coordinates using B = A^(-1) + 2. Wrap fractional coordinates to [0,1) using modulo + 3. Transform back to real space using A + + Args: + positions (torch.Tensor): Tensor of shape (..., d) + containing particle positions in real space. + lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing + lattice vectors as columns (A matrix in the equations). + + Returns: + torch.Tensor: Wrapped positions in real space with same shape as input positions. + """ + return wrap_positions(positions, cell.T) + + def pbc_wrap_batched( positions: torch.Tensor, cell: torch.Tensor, system_idx: torch.Tensor ) -> torch.Tensor: From 70eef576abb1680a6b2c691fc210c6a39e803dad Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 25 Oct 2025 12:41:21 -0700 Subject: [PATCH 4/6] simplify function --- torch_sim/transforms.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 3193a3ab..0b678ce5 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -131,7 +131,14 @@ def pbc_wrap_general(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tenso Returns: torch.Tensor: Wrapped positions in real space with same shape as input positions. """ - return wrap_positions(positions, cell.T) + return wrap_positions( + positions, + cell.T, + pbc=True, + center=(0.0, 0.0, 0.0), + pretty_translation=False, + eps=0.0, + ) def pbc_wrap_batched( From 9f539a872b6302df0fc2d9d67427959790316579 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 25 Oct 2025 12:44:38 -0700 Subject: [PATCH 5/6] add back original implementation --- torch_sim/transforms.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 0b678ce5..90b8f907 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -112,7 +112,9 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor: @deprecated("Use wrap_positions instead") -def pbc_wrap_general(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: +def pbc_wrap_general( + positions: torch.Tensor, lattice_vectors: torch.Tensor +) -> torch.Tensor: """Apply periodic boundary conditions using lattice vector transformation method. @@ -131,14 +133,26 @@ def pbc_wrap_general(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tenso Returns: torch.Tensor: Wrapped positions in real space with same shape as input positions. """ - return wrap_positions( - positions, - cell.T, - pbc=True, - center=(0.0, 0.0, 0.0), - pretty_translation=False, - eps=0.0, - ) + # Validate inputs + if not torch.is_floating_point(positions) or not torch.is_floating_point( + lattice_vectors + ): + raise TypeError("Positions and lattice vectors must be floating point tensors.") + + if lattice_vectors.ndim != 2 or lattice_vectors.shape[0] != lattice_vectors.shape[1]: + raise ValueError("Lattice vectors must be a square matrix.") + + if positions.shape[-1] != lattice_vectors.shape[0]: + raise ValueError("Position dimensionality must match lattice vectors.") + + # Transform to fractional coordinates: f = Br + frac_coords = positions @ torch.linalg.inv(lattice_vectors).T + + # Wrap to reference cell [0,1) using modulo + wrapped_frac = frac_coords % 1.0 + + # Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row + return wrapped_frac @ lattice_vectors.T def pbc_wrap_batched( From 2d8db0d4e848b399a69a3e22aacd5dcc2d0f4284 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 25 Oct 2025 12:53:32 -0700 Subject: [PATCH 6/6] fix lint --- torch_sim/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 90b8f907..53570a66 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -7,10 +7,10 @@ from collections.abc import Callable, Iterable from functools import wraps -from typing_extensions import deprecated import torch from torch.types import _dtype +from typing_extensions import deprecated def get_fractional_coordinates(