Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,65 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens

return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor)

def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
"""
Perform matrix multiplication with another Grassmann tensor.
Both of them should be rank 2 tensors, except some pure even edges could exist before the last two edges.
"""
# The creation operator order from arrow is (False True)
# So (x, True) * (False, y) = (x, y)
tensor_a = self
tensor_b = other

vector_a = False
if tensor_a.tensor.dim() == 1:
tensor_a = tensor_a.reshape(((1, 0), -1))
vector_a = True
vector_b = False
if tensor_b.tensor.dim() == 1:
tensor_b = tensor_b.reshape((-1, (1, 0)))
vector_b = True

assert all(odd == 0 for (even, odd) in tensor_a.edges[:-2]), f"All edges except the last two must be pure even. Got {tensor_a.edges[:-2]}."
assert all(odd == 0 for (even, odd) in tensor_b.edges[:-2]), f"All edges except the last two must be pure even. Got {tensor_b.edges[:-2]}."

if tensor_a.arrow[-1] is not True:
tensor_a = tensor_a.reverse((tensor_a.tensor.dim() - 1,))
if tensor_b.arrow[-2] is not False:
tensor_b = tensor_b.reverse((tensor_b.tensor.dim() - 2,))

broadcast_a = tensor_a.tensor.dim() - 2
broadcast_b = tensor_b.tensor.dim() - 2

arrow = []
edges = []
for i in range(-max(broadcast_a, broadcast_b), 0):
arrow.append(False)
candidate_a = candidate_b = 1
if i >= -broadcast_a:
candidate_a = tensor_a.edges[i - 2][0]
if i >= -broadcast_b:
candidate_b = tensor_b.edges[i - 2][0]
assert candidate_a == candidate_b or candidate_a == 1 or candidate_b == 1, f"Cannot broadcast edges {tensor_a.edges[i - 2]} and {tensor_b.edges[i - 2]}."
edges.append((max(candidate_a, candidate_b), 0))
if not vector_a:
arrow.append(tensor_a.arrow[-2])
edges.append(tensor_a.edges[-2])
if not vector_b:
arrow.append(tensor_b.arrow[-1])
edges.append(tensor_b.edges[-1])
tensor = torch.matmul(tensor_a.tensor, tensor_b.tensor)
if vector_a:
tensor = tensor.squeeze(-2)
if vector_b:
tensor = tensor.squeeze(-1)

return GrassmannTensor(
_arrow=tuple(arrow),
_edges=tuple(edges),
_tensor=tensor,
)

def __post_init__(self) -> None:
assert len(self._arrow) == self._tensor.dim(), f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
Expand Down
54 changes: 54 additions & 0 deletions tests/matmul_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch
from grassmann_tensor import GrassmannTensor

MatmulMatrixCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]]


@pytest.mark.parametrize("x", [
(False, False, (1, 1), (1, 1), (1, 1)),
(False, True, (1, 1), (1, 1), (1, 1)),
(True, False, (1, 1), (1, 1), (1, 1)),
(True, True, (1, 1), (1, 1), (1, 1)),
(False, False, (2, 2), (2, 2), (2, 2)),
(False, True, (2, 2), (2, 2), (2, 2)),
(True, False, (2, 2), (2, 2), (2, 2)),
(True, True, (2, 2), (2, 2), (2, 2)),
])
def test_matmul_matrix_tf(x: MatmulMatrixCase) -> None:
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
a = GrassmannTensor((arrow_a, True), (edge_a, edge_common), torch.randn([dim_a, dim_common])).update_mask()
b = GrassmannTensor((False, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b])).update_mask()
c = a.matmul(b)
expected = a.tensor.matmul(b.tensor)
assert c.arrow == (arrow_a, arrow_b)
assert c.edges == (edge_a, edge_b)
assert torch.allclose(c.tensor, expected)


@pytest.mark.parametrize("x", [
(False, False, (1, 1), (1, 1), (1, 1)),
(False, True, (1, 1), (1, 1), (1, 1)),
(True, False, (1, 1), (1, 1), (1, 1)),
(True, True, (1, 1), (1, 1), (1, 1)),
(False, False, (2, 2), (2, 2), (2, 2)),
(False, True, (2, 2), (2, 2), (2, 2)),
(True, False, (2, 2), (2, 2), (2, 2)),
(True, True, (2, 2), (2, 2), (2, 2)),
])
def test_matmul_matrix_ft(x: MatmulMatrixCase) -> None:
arrow_a, arrow_b, edge_a, edge_common, edge_b = x
dim_a = sum(edge_a)
dim_common = sum(edge_common)
dim_b = sum(edge_b)
a = GrassmannTensor((arrow_a, False), (edge_a, edge_common), torch.randn([dim_a, dim_common])).update_mask()
b = GrassmannTensor((True, arrow_b), (edge_common, edge_b), torch.randn([dim_common, dim_b])).update_mask()
c = a.matmul(b)
expected = a.tensor.matmul(b.tensor)
expected[edge_a[0]:, edge_b[0]:] *= -1
assert c.arrow == (arrow_a, arrow_b)
assert c.edges == (edge_a, edge_b)
assert torch.allclose(c.tensor, expected)