Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tests/matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,95 @@ def test_matmul(
assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b)
assert torch.allclose(c.tensor, expected)


@pytest.mark.parametrize("a_is_vector", [False, True])
@pytest.mark.parametrize("b_is_vector", [False, True])
@pytest.mark.parametrize("normal_arrow_order", [False, True])
@pytest.mark.parametrize(
"broadcast",
[
((), (), ()),
((2,), (), (2,)),
((), (3,), (3,)),
((1,), (4,), (4,)),
((5,), (1,), (5,)),
((6,), (6,), (6,)),
((7, 8), (7, 8), (7, 8)),
((1, 8), (7, 8), (7, 8)),
((8,), (7, 8), (7, 8)),
((7, 1), (7, 8), (7, 8)),
((7, 8), (1, 8), (7, 8)),
((7, 8), (8,), (7, 8)),
((7, 8), (7, 1), (7, 8)),
],
)
@pytest.mark.parametrize(
"x",
[
(False, False, (1, 1), (1, 1), (1, 1)),
(False, True, (1, 1), (1, 1), (1, 1)),
(True, False, (1, 1), (1, 1), (1, 1)),
(True, True, (1, 1), (1, 1), (1, 1)),
(False, False, (2, 2), (2, 2), (2, 2)),
(False, True, (2, 2), (2, 2), (2, 2)),
(True, False, (2, 2), (2, 2), (2, 2)),
(True, True, (2, 2), (2, 2), (2, 2)),
],
)
@pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2])
def test_matmul_unpure_even(
a_is_vector: bool,
b_is_vector: bool,
normal_arrow_order: bool,
broadcast: Broadcast,
x: MatmulCase,
impure_even_for_broadcast_indices: int,
) -> None:
broadcast_a, broadcast_b, broadcast_result = broadcast
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
if a_is_vector and broadcast_a != ():
pytest.skip("Vector a cannot be broadcasted")
if b_is_vector and broadcast_b != ():
pytest.skip("Vector b cannot be broadcasted")
if a_is_vector and b_is_vector:
pytest.skip("Both vectors are ignored.")
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
if a_is_vector:
a = GrassmannTensor(
(*(False for _ in broadcast_a), True if normal_arrow_order else False),
(*((i, impure_even_for_broadcast_indices) for i in broadcast_a), edge_common),
torch.randn(
[*[i + impure_even_for_broadcast_indices for i in broadcast_a], dim_common]
),
).update_mask()
else:
a = GrassmannTensor(
(*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False),
(*((i, impure_even_for_broadcast_indices) for i in broadcast_a), edge_a, edge_common),
torch.randn(
[*[i + impure_even_for_broadcast_indices for i in broadcast_a], dim_a, dim_common]
),
).update_mask()
if b_is_vector:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True),
(*((i, impure_even_for_broadcast_indices) for i in broadcast_b), edge_common),
torch.randn(
[*[i + impure_even_for_broadcast_indices for i in broadcast_b], dim_common]
),
).update_mask()
else:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b),
(*((i, impure_even_for_broadcast_indices) for i in broadcast_b), edge_common, edge_b),
torch.randn(
[*[i + impure_even_for_broadcast_indices for i in broadcast_b], dim_common, dim_b]
),
).update_mask()
if a.tensor.dim() <= 2 and b.tensor.dim() <= 2:
pytest.skip("One of the two tensors needs to have a dimension greater than 2")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里真的skip了什么东西么?应该没有吧?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为后面判断了AssertionError:“All edges except the last two must be pure even”,所以张量围度必须大于2。

with pytest.raises(AssertionError, match="All edges except the last two must be pure even"):
_ = a.matmul(b)