From ff013d1345ab389a732f0b18086fabd1a1b0af2d Mon Sep 17 00:00:00 2001 From: Gausshj Date: Tue, 4 Nov 2025 17:31:28 +0800 Subject: [PATCH] fix(exponential): fix exponential processing logic - Modify the `_group_edges` function to support input of type `tuple[tuple[int, ...], tuple[int, ...]]` - Fix exponentiation logic; input parameters pairs must be of type `tuple[tuple[int, ...], tuple[int, ...]]` - Add new function `_check_pairs_coverage` to check the coverage of parameters pairs; ensuring that input parameter pairs cover all dimension and non-overlapping --- grassmann_tensor/tensor.py | 33 +++++++++++++++++++++++++++------ tests/exponential_test.py | 10 ++++++---- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 32fced9..21d6e5b 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -571,12 +571,19 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor: def _group_edges( self, - left_legs: typing.Iterable[int], + pairs: tuple[int, ...] | tuple[tuple[int, ...], tuple[int, ...]], ) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]: - left_legs = tuple(int(i) for i in left_legs) - right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs) - assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), ( - "Left/right must cover all tensor legs." + if (isinstance(pairs, tuple) and len(pairs)) and all( + isinstance(x, tuple) and all(isinstance(i, int) for i in x) for x in pairs + ): + left_legs = typing.cast(tuple[int, ...], pairs[0]) + right_legs = typing.cast(tuple[int, ...], pairs[1]) + else: + left_legs = typing.cast(tuple[int, ...], pairs) + right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs) + + assert self._check_pairs_coverage((left_legs, right_legs)), ( + f"Input pairs must cover all dimension and disjoint, but got {(left_legs, right_legs)}" ) order = left_legs + right_legs @@ -724,7 +731,17 @@ def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]: inv[origin_idx] = new_position return tuple(inv) - def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor: + def _check_pairs_coverage(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> bool: + set0 = set(pairs[0]) + set1 = set(pairs[1]) + + are_disjoint = set0.isdisjoint(set1) + + is_complete_union = (set0 | set1) == set(range(self.tensor.dim())) + + return are_disjoint and is_complete_union + + def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor: tensor, left_legs, right_legs = self._group_edges(pairs) arrow_order = (False, True) @@ -743,6 +760,10 @@ def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor: (even_left, odd_left) = tensor.edges[0] (even_right, odd_right) = tensor.edges[1] + assert even_left == even_right and odd_left == odd_right, ( + f"Parity blocks must be square, but got L=({even_left},{odd_left}), R=({even_right},{odd_right})" + ) + even_tensor = tensor.tensor[:even_left, :even_right] odd_tensor = tensor.tensor[even_left:, even_right:] diff --git a/tests/exponential_test.py b/tests/exponential_test.py index d926170..438cba8 100644 --- a/tests/exponential_test.py +++ b/tests/exponential_test.py @@ -10,14 +10,16 @@ def test_exponential() -> None: ((4, 4), (8, 8), (4, 4), (8, 8)), torch.randn(8, 16, 8, 16, dtype=torch.float64), ) - a.exponential((0, 3)) + b = a.exponential(((0, 3), (1, 2))) + c = a.exponential(((0, 3), (2, 1))) + assert not torch.allclose(b.tensor, c.tensor) def test_exponential_with_empty_parity_block() -> None: a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1)) - a.exponential((0,)) + a.exponential(((0,), (1,))) b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1)) - b.exponential((0,)) + b.exponential(((0,), (1,))) def test_exponential_assertation() -> None: @@ -27,4 +29,4 @@ def test_exponential_assertation() -> None: torch.randn(4, 8, 16, 32, dtype=torch.float64), ) with pytest.raises(AssertionError, match="Exponential requires a square operator"): - a.exponential((0, 2)) + a.exponential(((0, 2), (1, 3)))