Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor:
f"Indices must be within tensor dimensions. Got {indices}."
)

arrow = tuple(self.arrow[i] ^ i in indices for i in range(self.tensor.dim()))
arrow = tuple(self.arrow[i] ^ (i in indices) for i in range(self.tensor.dim()))
tensor = self.tensor

total_parity = functools.reduce(
Expand Down Expand Up @@ -457,20 +457,17 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
if tensor_b.arrow[-2] is not False:
tensor_b = tensor_b.reverse((tensor_b.tensor.dim() - 2,))

broadcast_a = tensor_a.tensor.dim() - 2
broadcast_b = tensor_b.tensor.dim() - 2

arrow = []
edges = []
for i in range(-max(broadcast_a, broadcast_b), 0):
for i in range(-max(tensor_a.tensor.dim(), tensor_b.tensor.dim()), -2):
arrow.append(False)
candidate_a = candidate_b = 1
if i >= -broadcast_a:
candidate_a = tensor_a.edges[i - 2][0]
if i >= -broadcast_b:
candidate_b = tensor_b.edges[i - 2][0]
if i >= -tensor_a.tensor.dim():
candidate_a, _ = tensor_a.edges[i]
if i >= -tensor_b.tensor.dim():
candidate_b, _ = tensor_b.edges[i]
assert candidate_a == candidate_b or candidate_a == 1 or candidate_b == 1, (
f"Cannot broadcast edges {tensor_a.edges[i - 2]} and {tensor_b.edges[i - 2]}."
f"Cannot broadcast edges {tensor_a.edges[i]} and {tensor_b.edges[i]}."
)
edges.append((max(candidate_a, candidate_b), 0))
if not vector_a:
Expand Down
109 changes: 71 additions & 38 deletions tests/matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,31 @@
import torch
from grassmann_tensor import GrassmannTensor

MatmulMatrixCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]]
Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]
MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]]


@pytest.mark.parametrize("a_is_vector", [False, True])
@pytest.mark.parametrize("b_is_vector", [False, True])
@pytest.mark.parametrize("normal_arrow_order", [False, True])
@pytest.mark.parametrize(
"x",
"broadcast",
[
(False, False, (1, 1), (1, 1), (1, 1)),
(False, True, (1, 1), (1, 1), (1, 1)),
(True, False, (1, 1), (1, 1), (1, 1)),
(True, True, (1, 1), (1, 1), (1, 1)),
(False, False, (2, 2), (2, 2), (2, 2)),
(False, True, (2, 2), (2, 2), (2, 2)),
(True, False, (2, 2), (2, 2), (2, 2)),
(True, True, (2, 2), (2, 2), (2, 2)),
((), (), ()),
((2,), (), (2,)),
((), (3,), (3,)),
((1,), (4,), (4,)),
((5,), (1,), (5,)),
((6,), (6,), (6,)),
((7, 8), (7, 8), (7, 8)),
((1, 8), (7, 8), (7, 8)),
((8,), (7, 8), (7, 8)),
((7, 1), (7, 8), (7, 8)),
((7, 8), (1, 8), (7, 8)),
((7, 8), (8,), (7, 8)),
((7, 8), (7, 1), (7, 8)),
],
)
def test_matmul_matrix_tf(x: MatmulMatrixCase) -> None:
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
a = GrassmannTensor(
(arrow_a, True), (edge_a, edge_common), torch.randn([dim_a, dim_common])
).update_mask()
b = GrassmannTensor(
(False, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b])
).update_mask()
c = a.matmul(b)
expected = a.tensor.matmul(b.tensor)
assert c.arrow == (arrow_a, arrow_b)
assert c.edges == (edge_a, edge_b)
assert torch.allclose(c.tensor, expected)


@pytest.mark.parametrize(
"x",
[
Expand All @@ -49,20 +40,62 @@ def test_matmul_matrix_tf(x: MatmulMatrixCase) -> None:
(True, True, (2, 2), (2, 2), (2, 2)),
],
)
def test_matmul_matrix_ft(x: MatmulMatrixCase) -> None:
def test_matmul(
a_is_vector: bool,
b_is_vector: bool,
normal_arrow_order: bool,
broadcast: Broadcast,
x: MatmulCase,
) -> None:
broadcast_a, broadcast_b, broadcast_result = broadcast
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
if a_is_vector and broadcast_a != ():
pytest.skip("Vector a cannot be broadcasted")
if b_is_vector and broadcast_b != ():
pytest.skip("Vector b cannot be broadcasted")
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
a = GrassmannTensor(
(arrow_a, False), (edge_a, edge_common), torch.randn([dim_a, dim_common])
).update_mask()
b = GrassmannTensor(
(True, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b])
).update_mask()
if a_is_vector:
a = GrassmannTensor(
(*(False for _ in broadcast_a), True if normal_arrow_order else False),
(*((i, 0) for i in broadcast_a), edge_common),
torch.randn([*broadcast_a, dim_common]),
).update_mask()
else:
a = GrassmannTensor(
(*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False),
(*((i, 0) for i in broadcast_a), edge_a, edge_common),
torch.randn([*broadcast_a, dim_a, dim_common]),
).update_mask()
if b_is_vector:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True),
(*((i, 0) for i in broadcast_b), edge_common),
torch.randn([*broadcast_b, dim_common]),
).update_mask()
else:
b = GrassmannTensor(
(*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b),
(*((i, 0) for i in broadcast_b), edge_common, edge_b),
torch.randn([*broadcast_b, dim_common, dim_b]),
).update_mask()
c = a.matmul(b)
expected = a.tensor.matmul(b.tensor)
expected[edge_a[0] :, edge_b[0] :] *= -1
assert c.arrow == (arrow_a, arrow_b)
assert c.edges == (edge_a, edge_b)
if not a_is_vector and not b_is_vector and not normal_arrow_order:
expected[..., edge_a[0] :, edge_b[0] :] *= -1
if a_is_vector:
if b_is_vector:
assert c.arrow == tuple(False for _ in broadcast_result)
assert c.edges == tuple((i, 0) for i in broadcast_result)
else:
assert c.arrow == (*(False for _ in broadcast_result), arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_b)
else:
if b_is_vector:
assert c.arrow == (*(False for _ in broadcast_result), arrow_a)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a)
else:
assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b)
assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b)
assert torch.allclose(c.tensor, expected)
Loading