diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 16160ee..e4584f8 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -173,7 +173,7 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: f"Indices must be within tensor dimensions. Got {indices}." ) - arrow = tuple(self.arrow[i] ^ i in indices for i in range(self.tensor.dim())) + arrow = tuple(self.arrow[i] ^ (i in indices) for i in range(self.tensor.dim())) tensor = self.tensor total_parity = functools.reduce( @@ -457,20 +457,17 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor: if tensor_b.arrow[-2] is not False: tensor_b = tensor_b.reverse((tensor_b.tensor.dim() - 2,)) - broadcast_a = tensor_a.tensor.dim() - 2 - broadcast_b = tensor_b.tensor.dim() - 2 - arrow = [] edges = [] - for i in range(-max(broadcast_a, broadcast_b), 0): + for i in range(-max(tensor_a.tensor.dim(), tensor_b.tensor.dim()), -2): arrow.append(False) candidate_a = candidate_b = 1 - if i >= -broadcast_a: - candidate_a = tensor_a.edges[i - 2][0] - if i >= -broadcast_b: - candidate_b = tensor_b.edges[i - 2][0] + if i >= -tensor_a.tensor.dim(): + candidate_a, _ = tensor_a.edges[i] + if i >= -tensor_b.tensor.dim(): + candidate_b, _ = tensor_b.edges[i] assert candidate_a == candidate_b or candidate_a == 1 or candidate_b == 1, ( - f"Cannot broadcast edges {tensor_a.edges[i - 2]} and {tensor_b.edges[i - 2]}." + f"Cannot broadcast edges {tensor_a.edges[i]} and {tensor_b.edges[i]}." ) edges.append((max(candidate_a, candidate_b), 0)) if not vector_a: diff --git a/tests/matmul_test.py b/tests/matmul_test.py index f00316e..6b91ea5 100644 --- a/tests/matmul_test.py +++ b/tests/matmul_test.py @@ -2,40 +2,31 @@ import torch from grassmann_tensor import GrassmannTensor -MatmulMatrixCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]] +Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]] +MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]] +@pytest.mark.parametrize("a_is_vector", [False, True]) +@pytest.mark.parametrize("b_is_vector", [False, True]) +@pytest.mark.parametrize("normal_arrow_order", [False, True]) @pytest.mark.parametrize( - "x", + "broadcast", [ - (False, False, (1, 1), (1, 1), (1, 1)), - (False, True, (1, 1), (1, 1), (1, 1)), - (True, False, (1, 1), (1, 1), (1, 1)), - (True, True, (1, 1), (1, 1), (1, 1)), - (False, False, (2, 2), (2, 2), (2, 2)), - (False, True, (2, 2), (2, 2), (2, 2)), - (True, False, (2, 2), (2, 2), (2, 2)), - (True, True, (2, 2), (2, 2), (2, 2)), + ((), (), ()), + ((2,), (), (2,)), + ((), (3,), (3,)), + ((1,), (4,), (4,)), + ((5,), (1,), (5,)), + ((6,), (6,), (6,)), + ((7, 8), (7, 8), (7, 8)), + ((1, 8), (7, 8), (7, 8)), + ((8,), (7, 8), (7, 8)), + ((7, 1), (7, 8), (7, 8)), + ((7, 8), (1, 8), (7, 8)), + ((7, 8), (8,), (7, 8)), + ((7, 8), (7, 1), (7, 8)), ], ) -def test_matmul_matrix_tf(x: MatmulMatrixCase) -> None: - arrow_a, arrow_b, edge_a, edge_common, edge_b = x - dim_a = sum(edge_a) - dim_common = sum(edge_common) - dim_b = sum(edge_b) - a = GrassmannTensor( - (arrow_a, True), (edge_a, edge_common), torch.randn([dim_a, dim_common]) - ).update_mask() - b = GrassmannTensor( - (False, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b]) - ).update_mask() - c = a.matmul(b) - expected = a.tensor.matmul(b.tensor) - assert c.arrow == (arrow_a, arrow_b) - assert c.edges == (edge_a, edge_b) - assert torch.allclose(c.tensor, expected) - - @pytest.mark.parametrize( "x", [ @@ -49,20 +40,62 @@ def test_matmul_matrix_tf(x: MatmulMatrixCase) -> None: (True, True, (2, 2), (2, 2), (2, 2)), ], ) -def test_matmul_matrix_ft(x: MatmulMatrixCase) -> None: +def test_matmul( + a_is_vector: bool, + b_is_vector: bool, + normal_arrow_order: bool, + broadcast: Broadcast, + x: MatmulCase, +) -> None: + broadcast_a, broadcast_b, broadcast_result = broadcast arrow_a, arrow_b, edge_a, edge_common, edge_b = x + if a_is_vector and broadcast_a != (): + pytest.skip("Vector a cannot be broadcasted") + if b_is_vector and broadcast_b != (): + pytest.skip("Vector b cannot be broadcasted") dim_a = sum(edge_a) dim_common = sum(edge_common) dim_b = sum(edge_b) - a = GrassmannTensor( - (arrow_a, False), (edge_a, edge_common), torch.randn([dim_a, dim_common]) - ).update_mask() - b = GrassmannTensor( - (True, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b]) - ).update_mask() + if a_is_vector: + a = GrassmannTensor( + (*(False for _ in broadcast_a), True if normal_arrow_order else False), + (*((i, 0) for i in broadcast_a), edge_common), + torch.randn([*broadcast_a, dim_common]), + ).update_mask() + else: + a = GrassmannTensor( + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, 0) for i in broadcast_a), edge_a, edge_common), + torch.randn([*broadcast_a, dim_a, dim_common]), + ).update_mask() + if b_is_vector: + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True), + (*((i, 0) for i in broadcast_b), edge_common), + torch.randn([*broadcast_b, dim_common]), + ).update_mask() + else: + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, 0) for i in broadcast_b), edge_common, edge_b), + torch.randn([*broadcast_b, dim_common, dim_b]), + ).update_mask() c = a.matmul(b) expected = a.tensor.matmul(b.tensor) - expected[edge_a[0] :, edge_b[0] :] *= -1 - assert c.arrow == (arrow_a, arrow_b) - assert c.edges == (edge_a, edge_b) + if not a_is_vector and not b_is_vector and not normal_arrow_order: + expected[..., edge_a[0] :, edge_b[0] :] *= -1 + if a_is_vector: + if b_is_vector: + assert c.arrow == tuple(False for _ in broadcast_result) + assert c.edges == tuple((i, 0) for i in broadcast_result) + else: + assert c.arrow == (*(False for _ in broadcast_result), arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_b) + else: + if b_is_vector: + assert c.arrow == (*(False for _ in broadcast_result), arrow_a) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a) + else: + assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) assert torch.allclose(c.tensor, expected)