diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 430e465..697bb7e 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -319,6 +319,65 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) + def matmul(self, other: GrassmannTensor) -> GrassmannTensor: + """ + Perform matrix multiplication with another Grassmann tensor. + Both of them should be rank 2 tensors, except some pure even edges could exist before the last two edges. + """ + # The creation operator order from arrow is (False True) + # So (x, True) * (False, y) = (x, y) + tensor_a = self + tensor_b = other + + vector_a = False + if tensor_a.tensor.dim() == 1: + tensor_a = tensor_a.reshape(((1, 0), -1)) + vector_a = True + vector_b = False + if tensor_b.tensor.dim() == 1: + tensor_b = tensor_b.reshape((-1, (1, 0))) + vector_b = True + + assert all(odd == 0 for (even, odd) in tensor_a.edges[:-2]), f"All edges except the last two must be pure even. Got {tensor_a.edges[:-2]}." + assert all(odd == 0 for (even, odd) in tensor_b.edges[:-2]), f"All edges except the last two must be pure even. Got {tensor_b.edges[:-2]}." + + if tensor_a.arrow[-1] is not True: + tensor_a = tensor_a.reverse((tensor_a.tensor.dim() - 1,)) + if tensor_b.arrow[-2] is not False: + tensor_b = tensor_b.reverse((tensor_b.tensor.dim() - 2,)) + + broadcast_a = tensor_a.tensor.dim() - 2 + broadcast_b = tensor_b.tensor.dim() - 2 + + arrow = [] + edges = [] + for i in range(-max(broadcast_a, broadcast_b), 0): + arrow.append(False) + candidate_a = candidate_b = 1 + if i >= -broadcast_a: + candidate_a = tensor_a.edges[i - 2][0] + if i >= -broadcast_b: + candidate_b = tensor_b.edges[i - 2][0] + assert candidate_a == candidate_b or candidate_a == 1 or candidate_b == 1, f"Cannot broadcast edges {tensor_a.edges[i - 2]} and {tensor_b.edges[i - 2]}." + edges.append((max(candidate_a, candidate_b), 0)) + if not vector_a: + arrow.append(tensor_a.arrow[-2]) + edges.append(tensor_a.edges[-2]) + if not vector_b: + arrow.append(tensor_b.arrow[-1]) + edges.append(tensor_b.edges[-1]) + tensor = torch.matmul(tensor_a.tensor, tensor_b.tensor) + if vector_a: + tensor = tensor.squeeze(-2) + if vector_b: + tensor = tensor.squeeze(-1) + + return GrassmannTensor( + _arrow=tuple(arrow), + _edges=tuple(edges), + _tensor=tensor, + ) + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})." diff --git a/tests/matmul_test.py b/tests/matmul_test.py new file mode 100644 index 0000000..d7eb048 --- /dev/null +++ b/tests/matmul_test.py @@ -0,0 +1,54 @@ +import pytest +import torch +from grassmann_tensor import GrassmannTensor + +MatmulMatrixCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]] + + +@pytest.mark.parametrize("x", [ + (False, False, (1, 1), (1, 1), (1, 1)), + (False, True, (1, 1), (1, 1), (1, 1)), + (True, False, (1, 1), (1, 1), (1, 1)), + (True, True, (1, 1), (1, 1), (1, 1)), + (False, False, (2, 2), (2, 2), (2, 2)), + (False, True, (2, 2), (2, 2), (2, 2)), + (True, False, (2, 2), (2, 2), (2, 2)), + (True, True, (2, 2), (2, 2), (2, 2)), +]) +def test_matmul_matrix_tf(x: MatmulMatrixCase) -> None: + arrow_a, arrow_b, edge_a, edge_common, edge_b = x + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + a = GrassmannTensor((arrow_a, True), (edge_a, edge_common), torch.randn([dim_a, dim_common])).update_mask() + b = GrassmannTensor((False, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b])).update_mask() + c = a.matmul(b) + expected = a.tensor.matmul(b.tensor) + assert c.arrow == (arrow_a, arrow_b) + assert c.edges == (edge_a, edge_b) + assert torch.allclose(c.tensor, expected) + + +@pytest.mark.parametrize("x", [ + (False, False, (1, 1), (1, 1), (1, 1)), + (False, True, (1, 1), (1, 1), (1, 1)), + (True, False, (1, 1), (1, 1), (1, 1)), + (True, True, (1, 1), (1, 1), (1, 1)), + (False, False, (2, 2), (2, 2), (2, 2)), + (False, True, (2, 2), (2, 2), (2, 2)), + (True, False, (2, 2), (2, 2), (2, 2)), + (True, True, (2, 2), (2, 2), (2, 2)), +]) +def test_matmul_matrix_ft(x: MatmulMatrixCase) -> None: + arrow_a, arrow_b, edge_a, edge_common, edge_b = x + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + a = GrassmannTensor((arrow_a, False), (edge_a, edge_common), torch.randn([dim_a, dim_common])).update_mask() + b = GrassmannTensor((True, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b])).update_mask() + c = a.matmul(b) + expected = a.tensor.matmul(b.tensor) + expected[edge_a[0]:, edge_b[0]:] *= -1 + assert c.arrow == (arrow_a, arrow_b) + assert c.edges == (edge_a, edge_b) + assert torch.allclose(c.tensor, expected)