diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 21d6e5b..b82122a 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -308,7 +308,10 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens if (isinstance(new_shape_check, int) and new_shape_check == 1) or ( new_shape_check == (1, 0) ): - arrow.append(False) + if cursor_plan < len(self.arrow): + arrow.append(self.arrow[cursor_plan]) + else: + arrow.append(False) edges.append((1, 0)) shape.append(1) cursor_plan += 1 @@ -744,17 +747,18 @@ def _check_pairs_coverage(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor: tensor, left_legs, right_legs = self._group_edges(pairs) - arrow_order = (False, True) - edges_to_reverse = tuple( - i for i, arrow in enumerate(arrow_order) if tensor.arrow[i] != arrow + assert tensor.arrow in ((False, True), (True, False)), ( + f"Exponentiation requires arrow (False, True) or (True, False), but got {tensor.arrow}" ) - if edges_to_reverse: - tensor = tensor.reverse(edges_to_reverse) + + tensor_reverse_flag = tensor.arrow != (False, True) + if tensor_reverse_flag: + tensor = tensor.reverse((0, 1)) left_dim, right_dim = tensor.tensor.shape assert left_dim == right_dim, ( - f"Exponential requires a square operator, but got {left_dim} x {right_dim}." + f"Exponentiation requires a square operator, but got {left_dim} x {right_dim}." ) (even_left, odd_left) = tensor.edges[0] @@ -774,8 +778,8 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma tensor_exp = dataclasses.replace(tensor, _tensor=tensor_exp) - if edges_to_reverse: - tensor_exp = tensor_exp.reverse(tuple(edges_to_reverse)) + if tensor_reverse_flag: + tensor_exp = tensor_exp.reverse((0, 1)) order = left_legs + right_legs edges_after_permute = tuple(self.edges[i] for i in order) @@ -787,6 +791,47 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma return tensor_exp + def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor: + tensor, left_legs, right_legs = self._group_edges(pairs) + + assert tensor.arrow in ((False, True), (True, False)), ( + f"Identity requires arrow (False, True) or (True, False), but got {tensor.arrow}" + ) + + tensor_reverse_flag = tensor.arrow != (False, True) + if tensor_reverse_flag: + tensor = tensor.reverse((0, 1)) + + left_dim, right_dim = tensor.tensor.shape + + assert left_dim == right_dim, ( + f"Identity requires a square operator, but got {left_dim} x {right_dim}." + ) + + (even_left, odd_left) = tensor.edges[0] + (even_right, odd_right) = tensor.edges[1] + + assert even_left == even_right and odd_left == odd_right, ( + f"Parity blocks must be square, but got L=({even_left},{odd_left}), R=({even_right},{odd_right})" + ) + + I = torch.eye(left_dim, dtype=tensor.tensor.dtype, device=tensor.tensor.device) # noqa: E741 + + tensor_identity = dataclasses.replace(tensor, _tensor=I) + + if tensor_reverse_flag: + tensor_identity = tensor_identity.reverse((0, 1)) + + order = left_legs + right_legs + edges_after_permute = tuple(self.edges[i] for i in order) + tensor_identity = tensor_identity.reshape(edges_after_permute) + + inv_order = self._get_inv_order(order) + + tensor_identity = tensor_identity.permute(inv_order) + + return tensor_identity + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), ( f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." diff --git a/tests/exponential_test.py b/tests/exponential_test.py index 438cba8..9735115 100644 --- a/tests/exponential_test.py +++ b/tests/exponential_test.py @@ -1,24 +1,17 @@ import torch import pytest +from typing import TypeAlias from grassmann_tensor import GrassmannTensor - -def test_exponential() -> None: - a = GrassmannTensor( - (True, True, True, True), - ((4, 4), (8, 8), (4, 4), (8, 8)), - torch.randn(8, 16, 8, 16, dtype=torch.float64), - ) - b = a.exponential(((0, 3), (1, 2))) - c = a.exponential(((0, 3), (2, 1))) - assert not torch.allclose(b.tensor, c.tensor) +Tensor: TypeAlias = GrassmannTensor +Pairs: TypeAlias = tuple[tuple[int, ...], tuple[int, ...]] def test_exponential_with_empty_parity_block() -> None: - a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1)) + a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1, dtype=torch.float64)) a.exponential(((0,), (1,))) - b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1)) + b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1, dtype=torch.float64)) b.exponential(((0,), (1,))) @@ -28,5 +21,79 @@ def test_exponential_assertation() -> None: ((2, 2), (4, 4), (8, 8), (16, 16)), torch.randn(4, 8, 16, 32, dtype=torch.float64), ) - with pytest.raises(AssertionError, match="Exponential requires a square operator"): + with pytest.raises(AssertionError, match="Exponentiation requires arrow"): a.exponential(((0, 2), (1, 3))) + + b = GrassmannTensor( + (False, True, False, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Exponentiation requires a square operator"): + b.exponential(((0, 2), (1, 3))) + + c = GrassmannTensor( + (False, True, False, True), + ((1, 3), (3, 1), (3, 1), (3, 1)), + torch.randn(4, 4, 4, 4, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Parity blocks must be square"): + c.exponential(((0, 2), (1, 3))) + + +@pytest.mark.parametrize( + "tensor, pairs", + [ + ( + GrassmannTensor( + (False, True), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + ((0,), (1,)), + ), + ( + GrassmannTensor( + (True, False), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + ((0,), (1,)), + ), + ( + GrassmannTensor( + (False, False, True), + ((4, 4), (4, 4), (32, 32)), + torch.randn(8, 8, 64, dtype=torch.float64), + ), + ((0, 1), (2,)), + ), + ( + GrassmannTensor( + (False, False, True, True), + ((4, 4), (8, 8), (4, 4), (8, 8)), + torch.randn(8, 16, 8, 16, dtype=torch.float64), + ), + ((0, 1), (2, 3)), + ), + ], +) +def test_exponential_via_taylor_expansion( + tensor: Tensor, + pairs: Pairs, +) -> None: + tensor_exp = tensor.exponential(pairs) + iter_tensor = tensor.identity(pairs) + iter_tensor, _, _ = iter_tensor._group_edges(pairs) + iter_tensor = iter_tensor.update_mask() + tensor_group_edges, left_legs, right_legs = tensor._group_edges(pairs) + tensor_group_edges = tensor_group_edges.update_mask() + + tensor_taylor_expansion = iter_tensor + for i in range(1, 50): + iter_tensor @= tensor_group_edges / i + tensor_taylor_expansion += iter_tensor + + order = left_legs + right_legs + edges_after_permute = tuple(tensor.edges[i] for i in order) + tensor_taylor_expansion = tensor_taylor_expansion.reshape(edges_after_permute) + inv_order = tensor._get_inv_order(order) + tensor_taylor_expansion = tensor_taylor_expansion.permute(inv_order) + + assert torch.allclose(tensor_taylor_expansion.tensor, tensor_exp.tensor) diff --git a/tests/identity_test.py b/tests/identity_test.py new file mode 100644 index 0000000..b1699d6 --- /dev/null +++ b/tests/identity_test.py @@ -0,0 +1,83 @@ +import pytest +import torch +from typing import TypeAlias + +from grassmann_tensor import GrassmannTensor + +Tensor: TypeAlias = GrassmannTensor +Pairs: TypeAlias = tuple[tuple[int, ...], tuple[int, ...]] + + +def test_identity_assertation() -> None: + a = GrassmannTensor( + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Identity requires arrow"): + a.identity(((0, 2), (1, 3))) + + b = GrassmannTensor( + (False, True, False, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Identity requires a square operator"): + b.identity(((0, 2), (1, 3))) + + c = GrassmannTensor( + (False, True, False, True), + ((1, 3), (3, 1), (3, 1), (3, 1)), + torch.randn(4, 4, 4, 4, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Parity blocks must be square"): + c.identity(((0, 2), (1, 3))) + + +@pytest.mark.parametrize( + "tensor, pairs", + [ + ( + GrassmannTensor( + (False, True), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + ((0,), (1,)), + ), + ( + GrassmannTensor( + (True, False), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + ((0,), (1,)), + ), + ( + GrassmannTensor( + (False, False, True), + ((4, 4), (4, 4), (32, 32)), + torch.randn(8, 8, 64, dtype=torch.float64), + ), + ((0, 1), (2,)), + ), + ( + GrassmannTensor( + (False, False, True, True), + ((4, 4), (8, 8), (4, 4), (8, 8)), + torch.randn(8, 16, 8, 16, dtype=torch.float64), + ), + ((0, 1), (2, 3)), + ), + ], +) +def test_identity_via_self_multiplication( + tensor: Tensor, + pairs: Pairs, +) -> None: + identity = tensor.identity(pairs) + identity, _, _ = identity._group_edges(pairs) + tensor, _, _ = tensor._group_edges(pairs) + tensor_reverse_flag = tensor.arrow != (False, True) + if tensor_reverse_flag: + identity = identity.reverse((0, 1)) + tensor = tensor.reverse((0, 1)) + assert torch.allclose((identity @ identity).tensor, identity.tensor) + assert torch.allclose((identity @ tensor).tensor, tensor.tensor) + assert torch.allclose((tensor @ identity).tensor, tensor.tensor) diff --git a/tests/reshape_test.py b/tests/reshape_test.py index 2b3b06c..77f1823 100644 --- a/tests/reshape_test.py +++ b/tests/reshape_test.py @@ -225,6 +225,8 @@ def test_reshape_with_one_dimension( assert ( len(a.arrow) == len(shape) and len(a.edges) == len(shape) and a.tensor.dim() == len(shape) ) + if len(shape) > len(arrow): + assert all(not a.arrow[i] for i in range(len(arrow), len(shape))) def test_reshape_trailing_nontrivial_dim_raises() -> None: