diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 1edaff6..fa22531 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -221,6 +221,20 @@ def _reorder_indices( sign = (count & 2).to(dtype=torch.bool) return len(even), len(odd), reorder, sign.flatten() + def _calculate_even_odd(self) -> tuple[int, int]: + return self.calculate_even_odd(self.edges) + + @staticmethod + def calculate_even_odd(edges: tuple[tuple[int, int], ...]) -> tuple[int, int]: + return functools.reduce( + lambda accumulator, even_odd_pair: ( + accumulator[0] * even_odd_pair[0] + accumulator[1] * even_odd_pair[1], + accumulator[0] * even_odd_pair[1] + accumulator[1] * even_odd_pair[0], + ), + edges, + (1, 0), + ) + def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor: """ Reshape the Grassmann tensor, which may split or merge edges. @@ -253,15 +267,38 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens merging_reorder: list[tuple[int, torch.Tensor]] = [] merging_sign: list[tuple[int, torch.Tensor]] = [] + original_self_is_scalar = self.tensor.dim() == 0 + if original_self_is_scalar: + new_shape_list: list[tuple[int, int]] = [] + for item in new_shape: + if item == -1: + raise AssertionError("Cannot use -1 when reshaping from a scalar") + if isinstance(item, int): + if item != 1: + raise AssertionError( + f"Ambiguous integer dim {item} from scalar. " + "Use explicit (even, odd) pairs, or only use 1 for trivial edges." + ) + new_shape_list.append((1, 0)) + else: + new_shape_list.append(item) + new_shape = tuple(new_shape_list) + edges_only = typing.cast(tuple[tuple[int, int], ...], new_shape) + assert self.calculate_even_odd(edges_only) == (1, 0), ( + "Cannot split none edges into illegal edges" + ) + + if len(new_shape) == 0: + assert self._calculate_even_odd() == (1, 0), ( + "Only pure even edges can be merged into none edges" + ) + tensor = self.tensor.reshape(()) + return GrassmannTensor(_arrow=(), _edges=(), _tensor=tensor) + cursor_plan: int = 0 cursor_self: int = 0 while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim(): - if len(new_shape) == 0: - assert all(edge == (0, 1) or edge == (1, 0) for edge in self.edges), ( - f"Edge must be (0, 1) or (1, 0) but got {self.edges}" - ) - cursor_self = self.tensor.dim() - 1 - elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1: + if cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1: # Does not change arrow.append(self.arrow[cursor_self]) edges.append(self.edges[cursor_self]) @@ -280,20 +317,14 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens # A trivial self edge cursor_self += 1 continue - if len(new_shape) == 0: - cursor_new_shape = typing.cast(int | tuple[int, int], tuple()) - total = 1 - else: - cursor_new_shape = new_shape[cursor_plan] - total = ( - cursor_new_shape - if isinstance(cursor_new_shape, int) - else cursor_new_shape[0] + cursor_new_shape[1] - ) + cursor_new_shape = new_shape[cursor_plan] + total = ( + cursor_new_shape + if isinstance(cursor_new_shape, int) + else cursor_new_shape[0] + cursor_new_shape[1] + ) # one of total and shape[cursor_self] is not trivial, otherwise it should be handled before - if self.tensor.dim() == 0: - merging = False - elif total == self.tensor.shape[cursor_self]: + if total == self.tensor.shape[cursor_self]: # We do not know whether it is merging or splitting, check more if isinstance(cursor_new_shape, int) or cursor_new_shape == self.edges[cursor_self]: # If the new shape is exactly the same as the current edge, we treat it as no change @@ -307,9 +338,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens cursor_self_finding = cursor_self cursor_self_found = False while True: - if len(new_shape) == 0: - cursor_self_found = True - break cursor_self_finding += 1 if cursor_self_finding == self.tensor.dim(): break @@ -329,10 +357,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens new_cursor_self = cursor_self self_total = 1 while True: - if len(new_shape) == 0: - new_cursor_self += 1 - even, odd, reorder, sign = self._reorder_indices(self.edges) - break # Try to include more dimension from self self_total *= self.tensor.shape[new_cursor_self] new_cursor_self += 1 @@ -354,26 +378,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens f"New shape exceeds in merging with edges {self.edges} and new shape {new_shape}." ) # The merging block [cursor_self, new_cursor_self) has been determined - if len(new_shape) == 0: - arrow = [] - edges = [] - shape = [] - merging_sign.append((cursor_plan, sign)) - cursor_self = new_cursor_self - else: - arrow.append(self.arrow[cursor_self]) - assert all( - self_arrow == arrow[-1] - for self_arrow in self.arrow[cursor_self:new_cursor_self] - ), ( - f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." - ) - edges.append((even, odd)) - shape.append(total) - merging_sign.append((cursor_plan, sign)) - merging_reorder.append((cursor_plan, reorder)) - cursor_self = new_cursor_self - cursor_plan += 1 + arrow.append(self.arrow[cursor_self]) + assert all( + self_arrow == arrow[-1] + for self_arrow in self.arrow[cursor_self:new_cursor_self] + ), ( + f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." + ) + edges.append((even, odd)) + shape.append(total) + merging_sign.append((cursor_plan, sign)) + merging_reorder.append((cursor_plan, reorder)) + cursor_self = new_cursor_self + cursor_plan += 1 else: # Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total new_cursor_plan = cursor_plan @@ -387,23 +404,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1] new_cursor_plan += 1 # One dimension included, check if we can stop - if self.tensor.dim() == 0: + if plan_total == self.tensor.shape[cursor_self]: + # new_shape block has been verified to be always tuple[int, int] before even, odd, reorder, sign = self._reorder_indices( - typing.cast(tuple[tuple[int, int], ...], new_shape) - ) - new_cursor_plan = len(new_shape) - break - else: - if plan_total == self.tensor.shape[cursor_self]: - # new_shape block has been verified to be always tuple[int, int] before - even, odd, reorder, sign = self._reorder_indices( - typing.cast( - tuple[tuple[int, int], ...], - new_shape[cursor_plan:new_cursor_plan], - ) + typing.cast( + tuple[tuple[int, int], ...], + new_shape[cursor_plan:new_cursor_plan], ) - if (even, odd) == self.edges[cursor_self]: - break + ) + if (even, odd) == self.edges[cursor_self]: + break # For some reason we cannot stop here, continue to include more dimension, check something before continue assert plan_total <= self.tensor.shape[cursor_self], ( f"Dimension mismatch in splitting with edges {self.edges} and new shape {new_shape}." @@ -415,10 +425,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens for i in range(cursor_plan, new_cursor_plan): # new_shape block has been verified to be always tuple[int, int] in the loop new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i]) - if self.tensor.dim() == 0: - arrow.append(False) - else: - arrow.append(self.arrow[cursor_self]) + arrow.append(self.arrow[cursor_self]) edges.append(new_cursor_new_shape) shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1]) splitting_reorder.append((cursor_self, reorder)) @@ -447,9 +454,6 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens tensor = tensor.reshape(shape) - if len(new_shape) == 0: - return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) - merging_parity = functools.reduce( torch.logical_xor, ( diff --git a/tests/reshape_test.py b/tests/reshape_test.py index 2e1b9bc..a1dac6f 100644 --- a/tests/reshape_test.py +++ b/tests/reshape_test.py @@ -180,5 +180,20 @@ def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None: def test_reshape_with_none() -> None: a = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(()) assert len(a.arrow) == 0 and len(a.edges) == 0 and a.tensor.dim() == 0 - b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1), (0, 1))).reshape(()) + b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(()) assert len(b.arrow) == 0 and len(b.edges) == 0 and b.tensor.dim() == 0 + c = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, 1)) + assert len(c.arrow) == 2 and len(c.edges) == 2 and c.tensor.dim() == 2 + + +def test_reshape_with_none_edge_assertion() -> None: + with pytest.raises(AssertionError, match="Only pure even edges can be merged into none edges"): + _ = GrassmannTensor((True, True), ((0, 1), (1, 0)), torch.tensor([[2333]])).reshape(()) + with pytest.raises(AssertionError, match="Cannot split none edges into illegal edges"): + _ = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1),)) + with pytest.raises(AssertionError, match="Cannot split none edges into illegal edges"): + _ = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1), (1, 0))) + with pytest.raises(AssertionError, match="Cannot use -1 when reshaping from a scalar"): + _ = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, -1)) + with pytest.raises(AssertionError, match="Ambiguous integer dim"): + _ = GrassmannTensor((), (), torch.tensor(2333)).reshape((2, 2))