From 156b5d0669bfcf8cb235cb5b2ec76d96030ed270 Mon Sep 17 00:00:00 2001 From: Gausshj Date: Wed, 10 Sep 2025 17:18:24 +0800 Subject: [PATCH 1/4] test(matmul): add assertion tests for matmul of impure tensors - Add assertion tests for matmul of impure tensors Signed-off-by: Gausshj --- tests/matmul_test.py | 84 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/matmul_test.py b/tests/matmul_test.py index 6b91ea5..b9e90e4 100644 --- a/tests/matmul_test.py +++ b/tests/matmul_test.py @@ -99,3 +99,87 @@ def test_matmul( assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) assert torch.allclose(c.tensor, expected) + + +@pytest.mark.parametrize("a_is_vector", [False, True]) +@pytest.mark.parametrize("b_is_vector", [False, True]) +@pytest.mark.parametrize("normal_arrow_order", [False, True]) +@pytest.mark.parametrize( + "broadcast", + [ + ((), (), ()), + ((2,), (), (2,)), + ((), (3,), (3,)), + ((1,), (4,), (4,)), + ((5,), (1,), (5,)), + ((6,), (6,), (6,)), + ((7, 8), (7, 8), (7, 8)), + ((1, 8), (7, 8), (7, 8)), + ((8,), (7, 8), (7, 8)), + ((7, 1), (7, 8), (7, 8)), + ((7, 8), (1, 8), (7, 8)), + ((7, 8), (8,), (7, 8)), + ((7, 8), (7, 1), (7, 8)), + ], +) +@pytest.mark.parametrize( + "x", + [ + (False, False, (1, 1), (1, 1), (1, 1)), + (False, True, (1, 1), (1, 1), (1, 1)), + (True, False, (1, 1), (1, 1), (1, 1)), + (True, True, (1, 1), (1, 1), (1, 1)), + (False, False, (2, 2), (2, 2), (2, 2)), + (False, True, (2, 2), (2, 2), (2, 2)), + (True, False, (2, 2), (2, 2), (2, 2)), + (True, True, (2, 2), (2, 2), (2, 2)), + ], +) +@pytest.mark.parametrize("impure_even", [1, 2, 3, 4, 5, 6, 7, 8]) +def test_matmul_unpure_even( + a_is_vector: bool, + b_is_vector: bool, + normal_arrow_order: bool, + broadcast: Broadcast, + x: MatmulCase, + impure_even: int, +) -> None: + broadcast_a, broadcast_b, broadcast_result = broadcast + arrow_a, arrow_b, edge_a, edge_common, edge_b = x + if a_is_vector and broadcast_a != (): + pytest.skip("Vector a cannot be broadcasted") + if b_is_vector and broadcast_b != (): + pytest.skip("Vector b cannot be broadcasted") + if a_is_vector and b_is_vector: + pytest.skip("Both vectors are ignored.") + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + if a_is_vector: + a = GrassmannTensor( + (*(False for _ in broadcast_a), True if normal_arrow_order else False), + (*((i, impure_even) for i in broadcast_a), edge_common), + torch.randn([*[x + impure_even for x in broadcast_a], dim_common]), + ).update_mask() + else: + a = GrassmannTensor( + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, impure_even) for i in broadcast_a), edge_a, edge_common), + torch.randn([*[x + impure_even for x in broadcast_a], dim_a, dim_common]), + ).update_mask() + if b_is_vector: + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True), + (*((i, impure_even) for i in broadcast_b), edge_common), + torch.randn([*[x + impure_even for x in broadcast_b], dim_common]), + ).update_mask() + else: + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, impure_even) for i in broadcast_b), edge_common, edge_b), + torch.randn([*[x + impure_even for x in broadcast_b], dim_common, dim_b]), + ).update_mask() + if a.tensor.dim() <= 2 and b.tensor.dim() <= 2: + pytest.skip("One of the two tensors needs to have a dimension greater than 2") + with pytest.raises(AssertionError, match="All edges except the last two must be pure even"): + _ = a.matmul(b) From 520984c7c77411f89aa7782edc616d76e47cc736 Mon Sep 17 00:00:00 2001 From: Gausshj Date: Fri, 12 Sep 2025 15:36:32 +0800 Subject: [PATCH 2/4] feat(trace): add trace support for tensor - Add trace support for tensor - Add simply test for trace Signed-off-by: Gausshj --- grassmann_tensor/tensor.py | 41 ++++++++++++++++++++++++++++++++++++++ tests/trace_test.py | 11 ++++++++++ 2 files changed, 52 insertions(+) create mode 100644 tests/trace_test.py diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index e4584f8..a41af2d 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -488,6 +488,47 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor: _tensor=tensor, ) + def trace(self, trace_pair: tuple[int, int]) -> torch.Tensor: + assert len(trace_pair) == 2, ( + f"The length of trace pair must be 2, but got {len(trace_pair)}." + ) + + assert trace_pair[0] != trace_pair[1], ( + f"Trace requires two distinct axes but got {trace_pair[0]} and {trace_pair[1]}." + ) + + assert self.arrow[trace_pair[0]] == self.arrow[trace_pair[1]], ( + f"Trace requires two different arrows but got {self.arrow[trace_pair[0]]} and {self.arrow[trace_pair[1]]}." + ) + + tensor = self + + if trace_pair[0] > trace_pair[1]: + trace_pair = trace_pair[::-1] + + edge_first, edge_end = self.edges[trace_pair[0]], self.edges[trace_pair[1]] + if edge_first != edge_end: + raise ValueError(f"Incompatible edges: {edge_first} and {edge_end}.") + + if tensor.arrow[trace_pair[0]] != tensor.arrow[trace_pair[1]]: + tensor = tensor.reverse((trace_pair[1],)) + + order = list(range(tensor.tensor.dim())) + order_first = order.pop(trace_pair[0]) + order_end = order.pop(trace_pair[1]) + order[trace_pair[0] : trace_pair[0]] = [order_first, order_end] + tensor = tensor.permute(tuple(order)) + + shape = tensor.tensor.shape[0] + tensor.tensor.shape[1] + tensor.reshape( + ( + shape, + *(-1 for _ in range(tensor.tensor.dim() - 2)), + ) + ) + + return tensor.tensor[0].trace() + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), ( f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." diff --git a/tests/trace_test.py b/tests/trace_test.py new file mode 100644 index 0000000..df1fb9a --- /dev/null +++ b/tests/trace_test.py @@ -0,0 +1,11 @@ +import torch +from grassmann_tensor.tensor import GrassmannTensor + + +def test_trace() -> None: + gt = GrassmannTensor( + (False, False, False), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ) + gt.trace((0, 1)) From babaef34b9d5cf1bc0d268f3a73da89f25746f21 Mon Sep 17 00:00:00 2001 From: Gausshj Date: Fri, 12 Sep 2025 15:45:57 +0800 Subject: [PATCH 3/4] Revert "test(matmul): add assertion tests for matmul of impure tensors" This reverts commit 156b5d0669bfcf8cb235cb5b2ec76d96030ed270. --- tests/matmul_test.py | 84 -------------------------------------------- 1 file changed, 84 deletions(-) diff --git a/tests/matmul_test.py b/tests/matmul_test.py index b9e90e4..6b91ea5 100644 --- a/tests/matmul_test.py +++ b/tests/matmul_test.py @@ -99,87 +99,3 @@ def test_matmul( assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) assert torch.allclose(c.tensor, expected) - - -@pytest.mark.parametrize("a_is_vector", [False, True]) -@pytest.mark.parametrize("b_is_vector", [False, True]) -@pytest.mark.parametrize("normal_arrow_order", [False, True]) -@pytest.mark.parametrize( - "broadcast", - [ - ((), (), ()), - ((2,), (), (2,)), - ((), (3,), (3,)), - ((1,), (4,), (4,)), - ((5,), (1,), (5,)), - ((6,), (6,), (6,)), - ((7, 8), (7, 8), (7, 8)), - ((1, 8), (7, 8), (7, 8)), - ((8,), (7, 8), (7, 8)), - ((7, 1), (7, 8), (7, 8)), - ((7, 8), (1, 8), (7, 8)), - ((7, 8), (8,), (7, 8)), - ((7, 8), (7, 1), (7, 8)), - ], -) -@pytest.mark.parametrize( - "x", - [ - (False, False, (1, 1), (1, 1), (1, 1)), - (False, True, (1, 1), (1, 1), (1, 1)), - (True, False, (1, 1), (1, 1), (1, 1)), - (True, True, (1, 1), (1, 1), (1, 1)), - (False, False, (2, 2), (2, 2), (2, 2)), - (False, True, (2, 2), (2, 2), (2, 2)), - (True, False, (2, 2), (2, 2), (2, 2)), - (True, True, (2, 2), (2, 2), (2, 2)), - ], -) -@pytest.mark.parametrize("impure_even", [1, 2, 3, 4, 5, 6, 7, 8]) -def test_matmul_unpure_even( - a_is_vector: bool, - b_is_vector: bool, - normal_arrow_order: bool, - broadcast: Broadcast, - x: MatmulCase, - impure_even: int, -) -> None: - broadcast_a, broadcast_b, broadcast_result = broadcast - arrow_a, arrow_b, edge_a, edge_common, edge_b = x - if a_is_vector and broadcast_a != (): - pytest.skip("Vector a cannot be broadcasted") - if b_is_vector and broadcast_b != (): - pytest.skip("Vector b cannot be broadcasted") - if a_is_vector and b_is_vector: - pytest.skip("Both vectors are ignored.") - dim_a = sum(edge_a) - dim_common = sum(edge_common) - dim_b = sum(edge_b) - if a_is_vector: - a = GrassmannTensor( - (*(False for _ in broadcast_a), True if normal_arrow_order else False), - (*((i, impure_even) for i in broadcast_a), edge_common), - torch.randn([*[x + impure_even for x in broadcast_a], dim_common]), - ).update_mask() - else: - a = GrassmannTensor( - (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), - (*((i, impure_even) for i in broadcast_a), edge_a, edge_common), - torch.randn([*[x + impure_even for x in broadcast_a], dim_a, dim_common]), - ).update_mask() - if b_is_vector: - b = GrassmannTensor( - (*(False for _ in broadcast_b), False if normal_arrow_order else True), - (*((i, impure_even) for i in broadcast_b), edge_common), - torch.randn([*[x + impure_even for x in broadcast_b], dim_common]), - ).update_mask() - else: - b = GrassmannTensor( - (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), - (*((i, impure_even) for i in broadcast_b), edge_common, edge_b), - torch.randn([*[x + impure_even for x in broadcast_b], dim_common, dim_b]), - ).update_mask() - if a.tensor.dim() <= 2 and b.tensor.dim() <= 2: - pytest.skip("One of the two tensors needs to have a dimension greater than 2") - with pytest.raises(AssertionError, match="All edges except the last two must be pure even"): - _ = a.matmul(b) From 8d23116eb24573ec0049825d374b63a68b80164b Mon Sep 17 00:00:00 2001 From: Gausshj Date: Thu, 9 Oct 2025 16:53:14 +0800 Subject: [PATCH 4/4] dev(trace): remove unnecessary judge - Remove unnecessary judge for arrow Signed-off-by: Gausshj --- grassmann_tensor/tensor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index a41af2d..8fd93e4 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -510,9 +510,6 @@ def trace(self, trace_pair: tuple[int, int]) -> torch.Tensor: if edge_first != edge_end: raise ValueError(f"Incompatible edges: {edge_first} and {edge_end}.") - if tensor.arrow[trace_pair[0]] != tensor.arrow[trace_pair[1]]: - tensor = tensor.reverse((trace_pair[1],)) - order = list(range(tensor.tensor.dim())) order_first = order.pop(trace_pair[0]) order_end = order.pop(trace_pair[1])