diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3b76145e..067abad8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -318,15 +318,12 @@ def test_pbc_wrap_batched_triclinic() -> None: ) batch = torch.tensor([0, 1], device=DEVICE) - # Stack the cells for batched processing - cell = torch.stack([cell1, cell2]) - # Apply wrapping wrapped = ft.pbc_wrap_batched(positions, cell=cell, system_idx=batch) - # Calculate expected result for first atom (using original algorithm for verification) - expected1 = ft.pbc_wrap_general(positions[0:1], cell1) - expected2 = ft.pbc_wrap_general(positions[1:2], cell2) + # Calculate expected results by wrapping each system independently + expected1 = ft.wrap_positions(positions[0:1], cell1.T) + expected2 = ft.wrap_positions(positions[1:2], cell2.T) # Verify results match the expected values assert torch.allclose(wrapped[0:1], expected1, atol=1e-6) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 1b2c416b..53570a66 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -10,6 +10,7 @@ import torch from torch.types import _dtype +from typing_extensions import deprecated def get_fractional_coordinates( @@ -110,6 +111,7 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor: raise ValueError(f"Box must be either: a scalar, a vector, or a matrix. Found {box}.") +@deprecated("Use wrap_positions instead") def pbc_wrap_general( positions: torch.Tensor, lattice_vectors: torch.Tensor ) -> torch.Tensor: