From 24cf554a242f0054f2882bc5b88a0ef55347eb6e Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Sat, 9 Aug 2025 12:15:04 +0800 Subject: [PATCH] Add tests for permute. --- grassmann_tensor/tensor.py | 1 + tests/permute_test.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 tests/permute_test.py diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 9c5d109..82539c3 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -115,6 +115,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> GrassmannTensor: """ Permute the indices of the Grassmann tensor. """ + assert len(before_by_after) == len(set(before_by_after)), "Permutation indices must be unique." assert set(before_by_after) == set(range(self.tensor.dim())), "Permutation indices must cover all dimensions." arrow = tuple(self.arrow[i] for i in before_by_after) diff --git a/tests/permute_test.py b/tests/permute_test.py new file mode 100644 index 0000000..095008b --- /dev/null +++ b/tests/permute_test.py @@ -0,0 +1,35 @@ +import pytest +import torch +from grassmann_tensor import GrassmannTensor + +PermuteCase = tuple[tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor, tuple[int, ...], torch.Tensor] + + +@pytest.mark.parametrize("x", [ + ((False, False), ((1, 1), (1, 1)), torch.tensor([[1, 0], [0, 4]]), (0, 1), torch.tensor([[1, 0], [0, 4]])), + ((True, False), ((1, 1), (1, 1)), torch.tensor([[1, 0], [0, 4]]), (1, 0), torch.tensor([[1, 0], [0, -4]])), + ((False, True, True), ((1, 1), (1, 1), (1, 1)), torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), (0, 2, 1), torch.tensor([[[1, 0], [0, -2]], [[0, 4], [3, 0]]])), + ((True, True, True), ((1, 1), (1, 1), (1, 1)), torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), (1, 0, 2), torch.tensor([[[1, 0], [0, 3]], [[0, 2], [-4, 0]]])), + ((True, False, False), ((1, 1), (1, 1), (1, 1)), torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), (2, 1, 0), torch.tensor([[[1, 0], [0, -4]], [[0, -3], [-2, 0]]])), + ((False, False, False), ((1, 1), (1, 1), (1, 1)), torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), (2, 0, 1), torch.tensor([[[1, 0], [0, 4]], [[0, -2], [-3, 0]]])), +]) +def test_permute(x: PermuteCase) -> None: + arrow, edges, tensor, before_by_after, expected = x + grassmann_tensor = GrassmannTensor(arrow, edges, tensor) + result = grassmann_tensor.permute(before_by_after) + assert torch.allclose(result.tensor, expected) + + +PermuteFailCase = tuple[tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor, tuple[int, ...]] + + +@pytest.mark.parametrize("x", [ + ((False, False), ((1, 1), (1, 1)), torch.tensor([[1, 0], [0, 4]]), (0, 0)), + ((False, False), ((1, 1), (1, 1)), torch.tensor([[1, 0], [0, 4]]), (2, 0)), + ((False, False), ((1, 1), (1, 1)), torch.tensor([[1, 0], [0, 4]]), (0, 0, 1)), +]) +def test_permute_fail(x: PermuteFailCase) -> None: + arrow, edges, tensor, before_by_after = x + grassmann_tensor = GrassmannTensor(arrow, edges, tensor) + with pytest.raises(AssertionError): + grassmann_tensor.permute(before_by_after)