diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index d53c5ae..1edaff6 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -756,6 +756,19 @@ def __itruediv__(self, other: typing.Any) -> GrassmannTensor: return self return NotImplemented + def __matmul__(self, other: typing.Any) -> GrassmannTensor: + if isinstance(other, GrassmannTensor): + return self.matmul(other) + return NotImplemented + + def __rmatmul__(self, other: typing.Any) -> GrassmannTensor: + return NotImplemented + + def __imatmul__(self, other: typing.Any) -> GrassmannTensor: + if isinstance(other, GrassmannTensor): + return self.matmul(other) + return NotImplemented + def clone(self) -> GrassmannTensor: """ Create a deep copy of the Grassmann tensor. diff --git a/tests/matmul_test.py b/tests/matmul_test.py index 744ea47..589817d 100644 --- a/tests/matmul_test.py +++ b/tests/matmul_test.py @@ -1,17 +1,29 @@ import pytest import torch +import typing from grassmann_tensor import GrassmannTensor Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]] MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]] -@pytest.mark.parametrize("a_is_vector", [False, True]) -@pytest.mark.parametrize("b_is_vector", [False, True]) -@pytest.mark.parametrize("normal_arrow_order", [False, True]) -@pytest.mark.parametrize( - "broadcast", - [ +@pytest.fixture(params=[False, True]) +def a_is_vector(request: pytest.FixtureRequest) -> bool: + return request.param + + +@pytest.fixture(params=[False, True]) +def b_is_vector(request: pytest.FixtureRequest) -> bool: + return request.param + + +@pytest.fixture(params=[False, True]) +def normal_arrow_order(request: pytest.FixtureRequest) -> bool: + return request.param + + +@pytest.fixture( + params=[ ((), (), ()), ((2,), (), (2,)), ((), (3,), (3,)), @@ -27,9 +39,12 @@ ((7, 8), (7, 1), (7, 8)), ], ) -@pytest.mark.parametrize( - "x", - [ +def broadcast(request: pytest.FixtureRequest) -> Broadcast: + return request.param + + +@pytest.fixture( + params=[ (False, False, (1, 1), (1, 1), (1, 1)), (False, True, (1, 1), (1, 1), (1, 1)), (True, False, (1, 1), (1, 1), (1, 1)), @@ -40,6 +55,10 @@ (True, True, (2, 2), (2, 2), (2, 2)), ], ) +def x(request: pytest.FixtureRequest) -> MatmulCase: + return request.param + + def test_matmul( a_is_vector: bool, b_is_vector: bool, @@ -101,40 +120,6 @@ def test_matmul( assert torch.allclose(c.tensor, expected) -@pytest.mark.parametrize("a_is_vector", [False, True]) -@pytest.mark.parametrize("b_is_vector", [False, True]) -@pytest.mark.parametrize("normal_arrow_order", [False, True]) -@pytest.mark.parametrize( - "broadcast", - [ - ((), (), ()), - ((2,), (), (2,)), - ((), (3,), (3,)), - ((1,), (4,), (4,)), - ((5,), (1,), (5,)), - ((6,), (6,), (6,)), - ((7, 8), (7, 8), (7, 8)), - ((1, 8), (7, 8), (7, 8)), - ((8,), (7, 8), (7, 8)), - ((7, 1), (7, 8), (7, 8)), - ((7, 8), (1, 8), (7, 8)), - ((7, 8), (8,), (7, 8)), - ((7, 8), (7, 1), (7, 8)), - ], -) -@pytest.mark.parametrize( - "x", - [ - (False, False, (1, 1), (1, 1), (1, 1)), - (False, True, (1, 1), (1, 1), (1, 1)), - (True, False, (1, 1), (1, 1), (1, 1)), - (True, True, (1, 1), (1, 1), (1, 1)), - (False, False, (2, 2), (2, 2), (2, 2)), - (False, True, (2, 2), (2, 2), (2, 2)), - (True, False, (2, 2), (2, 2), (2, 2)), - (True, True, (2, 2), (2, 2), (2, 2)), - ], -) @pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2]) def test_matmul_unpure_even( a_is_vector: bool, @@ -191,3 +176,99 @@ def test_matmul_unpure_even( pytest.skip("One of the two tensors needs to have a dimension greater than 2") with pytest.raises(AssertionError, match="All edges except the last two must be pure even"): _ = a.matmul(b) + + +def test_matmul_operator_matmul( + a_is_vector: bool, + b_is_vector: bool, + normal_arrow_order: bool, + broadcast: Broadcast, +) -> None: + normal_arrow_order = True + broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8) + arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + a = GrassmannTensor( + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, 0) for i in broadcast_a), edge_a, edge_common), + torch.randn([*broadcast_a, dim_a, dim_common]), + ).update_mask() + + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, 0) for i in broadcast_b), edge_common, edge_b), + torch.randn([*broadcast_b, dim_common, dim_b]), + ).update_mask() + + c = a @ b + expected = a.tensor.matmul(b.tensor) + assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) + assert torch.allclose(c.tensor, expected) + + +@pytest.fixture( + params=[ + GrassmannTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + GrassmannTensor((True, False, True), ((1, 1), (2, 2), (3, 1)), torch.randn([2, 4, 4])), + GrassmannTensor( + (True, True, False, False), ((1, 2), (2, 2), (1, 1), (3, 1)), torch.randn([3, 4, 2, 4]) + ), + ] +) +def tensors(request: pytest.FixtureRequest) -> GrassmannTensor: + return request.param + + +@pytest.mark.parametrize( + "unsupported_type", + [ + "string", # string + None, # NoneType + {"key", "value"}, # dict + [1, 2, 3], # list + {1, 2}, # set + object(), # arbitrary object + ], +) +def test_matmul_unsupported_type_raises_typeerror( + unsupported_type: typing.Any, + tensors: GrassmannTensor, +) -> None: + with pytest.raises(TypeError): + _ = tensors @ unsupported_type + + with pytest.raises(TypeError): + _ = unsupported_type @ tensors + + with pytest.raises(TypeError): + tensors @= unsupported_type + + +def test_matmul_operator_rmatmul() -> None: + normal_arrow_order = True + broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8) + arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + a = GrassmannTensor( + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, 0) for i in broadcast_a), edge_a, edge_common), + torch.randn([*broadcast_a, dim_a, dim_common]), + ).update_mask() + + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, 0) for i in broadcast_b), edge_common, edge_b), + torch.randn([*broadcast_b, dim_common, dim_b]), + ).update_mask() + + c = a + c @= b + expected = a.tensor.matmul(b.tensor) + assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) + assert torch.allclose(c.tensor, expected)