From 7ded7f10c2bfcb34520f495a5cd087ebc34376eb Mon Sep 17 00:00:00 2001 From: Gausshj Date: Wed, 10 Sep 2025 17:18:24 +0800 Subject: [PATCH] test(matmul): add assertion tests for matmul of impure tensors - Add assertion tests for matmul of impure tensors Signed-off-by: Gausshj Signed-off-by: Hao Zhang --- tests/matmul_test.py | 92 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/tests/matmul_test.py b/tests/matmul_test.py index 6b91ea5..744ea47 100644 --- a/tests/matmul_test.py +++ b/tests/matmul_test.py @@ -99,3 +99,95 @@ def test_matmul( assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) assert torch.allclose(c.tensor, expected) + + +@pytest.mark.parametrize("a_is_vector", [False, True]) +@pytest.mark.parametrize("b_is_vector", [False, True]) +@pytest.mark.parametrize("normal_arrow_order", [False, True]) +@pytest.mark.parametrize( + "broadcast", + [ + ((), (), ()), + ((2,), (), (2,)), + ((), (3,), (3,)), + ((1,), (4,), (4,)), + ((5,), (1,), (5,)), + ((6,), (6,), (6,)), + ((7, 8), (7, 8), (7, 8)), + ((1, 8), (7, 8), (7, 8)), + ((8,), (7, 8), (7, 8)), + ((7, 1), (7, 8), (7, 8)), + ((7, 8), (1, 8), (7, 8)), + ((7, 8), (8,), (7, 8)), + ((7, 8), (7, 1), (7, 8)), + ], +) +@pytest.mark.parametrize( + "x", + [ + (False, False, (1, 1), (1, 1), (1, 1)), + (False, True, (1, 1), (1, 1), (1, 1)), + (True, False, (1, 1), (1, 1), (1, 1)), + (True, True, (1, 1), (1, 1), (1, 1)), + (False, False, (2, 2), (2, 2), (2, 2)), + (False, True, (2, 2), (2, 2), (2, 2)), + (True, False, (2, 2), (2, 2), (2, 2)), + (True, True, (2, 2), (2, 2), (2, 2)), + ], +) +@pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2]) +def test_matmul_unpure_even( + a_is_vector: bool, + b_is_vector: bool, + normal_arrow_order: bool, + broadcast: Broadcast, + x: MatmulCase, + impure_even_for_broadcast_indices: int, +) -> None: + broadcast_a, broadcast_b, broadcast_result = broadcast + arrow_a, arrow_b, edge_a, edge_common, edge_b = x + if a_is_vector and broadcast_a != (): + pytest.skip("Vector a cannot be broadcasted") + if b_is_vector and broadcast_b != (): + pytest.skip("Vector b cannot be broadcasted") + if a_is_vector and b_is_vector: + pytest.skip("Both vectors are ignored.") + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + if a_is_vector: + a = GrassmannTensor( + (*(False for _ in broadcast_a), True if normal_arrow_order else False), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_a), edge_common), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_a], dim_common] + ), + ).update_mask() + else: + a = GrassmannTensor( + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_a), edge_a, edge_common), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_a], dim_a, dim_common] + ), + ).update_mask() + if b_is_vector: + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_b), edge_common), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_b], dim_common] + ), + ).update_mask() + else: + b = GrassmannTensor( + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_b), edge_common, edge_b), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_b], dim_common, dim_b] + ), + ).update_mask() + if a.tensor.dim() <= 2 and b.tensor.dim() <= 2: + pytest.skip("One of the two tensors needs to have a dimension greater than 2") + with pytest.raises(AssertionError, match="All edges except the last two must be pure even"): + _ = a.matmul(b)