diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index e4584f8..8fd93e4 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -488,6 +488,44 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor: _tensor=tensor, ) + def trace(self, trace_pair: tuple[int, int]) -> torch.Tensor: + assert len(trace_pair) == 2, ( + f"The length of trace pair must be 2, but got {len(trace_pair)}." + ) + + assert trace_pair[0] != trace_pair[1], ( + f"Trace requires two distinct axes but got {trace_pair[0]} and {trace_pair[1]}." + ) + + assert self.arrow[trace_pair[0]] == self.arrow[trace_pair[1]], ( + f"Trace requires two different arrows but got {self.arrow[trace_pair[0]]} and {self.arrow[trace_pair[1]]}." + ) + + tensor = self + + if trace_pair[0] > trace_pair[1]: + trace_pair = trace_pair[::-1] + + edge_first, edge_end = self.edges[trace_pair[0]], self.edges[trace_pair[1]] + if edge_first != edge_end: + raise ValueError(f"Incompatible edges: {edge_first} and {edge_end}.") + + order = list(range(tensor.tensor.dim())) + order_first = order.pop(trace_pair[0]) + order_end = order.pop(trace_pair[1]) + order[trace_pair[0] : trace_pair[0]] = [order_first, order_end] + tensor = tensor.permute(tuple(order)) + + shape = tensor.tensor.shape[0] + tensor.tensor.shape[1] + tensor.reshape( + ( + shape, + *(-1 for _ in range(tensor.tensor.dim() - 2)), + ) + ) + + return tensor.tensor[0].trace() + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), ( f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." diff --git a/tests/trace_test.py b/tests/trace_test.py new file mode 100644 index 0000000..df1fb9a --- /dev/null +++ b/tests/trace_test.py @@ -0,0 +1,11 @@ +import torch +from grassmann_tensor.tensor import GrassmannTensor + + +def test_trace() -> None: + gt = GrassmannTensor( + (False, False, False), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ) + gt.trace((0, 1))