diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index e4584f8..d53c5ae 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -256,7 +256,12 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens cursor_plan: int = 0 cursor_self: int = 0 while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim(): - if new_shape[cursor_plan] == -1: + if len(new_shape) == 0: + assert all(edge == (0, 1) or edge == (1, 0) for edge in self.edges), ( + f"Edge must be (0, 1) or (1, 0) but got {self.edges}" + ) + cursor_self = self.tensor.dim() - 1 + elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1: # Does not change arrow.append(self.arrow[cursor_self]) edges.append(self.edges[cursor_self]) @@ -264,25 +269,31 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens cursor_self += 1 cursor_plan += 1 continue - if new_shape[cursor_plan] == (1, 0): - # An trivial plan edge + elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == (1, 0): + # A trivial plan edge arrow.append(False) edges.append((1, 0)) shape.append(1) cursor_plan += 1 continue - if self.edges[cursor_self] == (1, 0): - # An trivial self edge + elif cursor_self != self.tensor.dim() and self.edges[cursor_self] == (1, 0): + # A trivial self edge cursor_self += 1 continue - cursor_new_shape = new_shape[cursor_plan] - total = ( - cursor_new_shape - if isinstance(cursor_new_shape, int) - else cursor_new_shape[0] + cursor_new_shape[1] - ) + if len(new_shape) == 0: + cursor_new_shape = typing.cast(int | tuple[int, int], tuple()) + total = 1 + else: + cursor_new_shape = new_shape[cursor_plan] + total = ( + cursor_new_shape + if isinstance(cursor_new_shape, int) + else cursor_new_shape[0] + cursor_new_shape[1] + ) # one of total and shape[cursor_self] is not trivial, otherwise it should be handled before - if total == self.tensor.shape[cursor_self]: + if self.tensor.dim() == 0: + merging = False + elif total == self.tensor.shape[cursor_self]: # We do not know whether it is merging or splitting, check more if isinstance(cursor_new_shape, int) or cursor_new_shape == self.edges[cursor_self]: # If the new shape is exactly the same as the current edge, we treat it as no change @@ -296,6 +307,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens cursor_self_finding = cursor_self cursor_self_found = False while True: + if len(new_shape) == 0: + cursor_self_found = True + break cursor_self_finding += 1 if cursor_self_finding == self.tensor.dim(): break @@ -306,15 +320,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens break break merging = cursor_self_found - if total > self.tensor.shape[cursor_self]: + elif total > self.tensor.shape[cursor_self]: merging = True - if total < self.tensor.shape[cursor_self]: + elif total < self.tensor.shape[cursor_self]: merging = False if merging: # Merging between [cursor_self, new_cursor_self) and the another side contains dimension as self_total new_cursor_self = cursor_self self_total = 1 while True: + if len(new_shape) == 0: + new_cursor_self += 1 + even, odd, reorder, sign = self._reorder_indices(self.edges) + break # Try to include more dimension from self self_total *= self.tensor.shape[new_cursor_self] new_cursor_self += 1 @@ -336,19 +354,26 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens f"New shape exceeds in merging with edges {self.edges} and new shape {new_shape}." ) # The merging block [cursor_self, new_cursor_self) has been determined - arrow.append(self.arrow[cursor_self]) - assert all( - self_arrow == arrow[-1] - for self_arrow in self.arrow[cursor_self:new_cursor_self] - ), ( - f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." - ) - edges.append((even, odd)) - shape.append(total) - merging_sign.append((cursor_plan, sign)) - merging_reorder.append((cursor_plan, reorder)) - cursor_self = new_cursor_self - cursor_plan += 1 + if len(new_shape) == 0: + arrow = [] + edges = [] + shape = [] + merging_sign.append((cursor_plan, sign)) + cursor_self = new_cursor_self + else: + arrow.append(self.arrow[cursor_self]) + assert all( + self_arrow == arrow[-1] + for self_arrow in self.arrow[cursor_self:new_cursor_self] + ), ( + f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." + ) + edges.append((even, odd)) + shape.append(total) + merging_sign.append((cursor_plan, sign)) + merging_reorder.append((cursor_plan, reorder)) + cursor_self = new_cursor_self + cursor_plan += 1 else: # Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total new_cursor_plan = cursor_plan @@ -362,15 +387,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1] new_cursor_plan += 1 # One dimension included, check if we can stop - if plan_total == self.tensor.shape[cursor_self]: - # new_shape block has been verified to be always tuple[int, int] before + if self.tensor.dim() == 0: even, odd, reorder, sign = self._reorder_indices( - typing.cast( - tuple[tuple[int, int], ...], new_shape[cursor_plan:new_cursor_plan] - ) + typing.cast(tuple[tuple[int, int], ...], new_shape) ) - if (even, odd) == self.edges[cursor_self]: - break + new_cursor_plan = len(new_shape) + break + else: + if plan_total == self.tensor.shape[cursor_self]: + # new_shape block has been verified to be always tuple[int, int] before + even, odd, reorder, sign = self._reorder_indices( + typing.cast( + tuple[tuple[int, int], ...], + new_shape[cursor_plan:new_cursor_plan], + ) + ) + if (even, odd) == self.edges[cursor_self]: + break # For some reason we cannot stop here, continue to include more dimension, check something before continue assert plan_total <= self.tensor.shape[cursor_self], ( f"Dimension mismatch in splitting with edges {self.edges} and new shape {new_shape}." @@ -382,12 +415,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens for i in range(cursor_plan, new_cursor_plan): # new_shape block has been verified to be always tuple[int, int] in the loop new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i]) - arrow.append(self.arrow[cursor_self]) + if self.tensor.dim() == 0: + arrow.append(False) + else: + arrow.append(self.arrow[cursor_self]) edges.append(new_cursor_new_shape) shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1]) splitting_reorder.append((cursor_self, reorder)) splitting_sign.append((cursor_self, sign)) - cursor_self += 1 + if self.tensor.dim() != 0: + cursor_self += 1 cursor_plan = new_cursor_plan tensor = self.tensor @@ -402,7 +439,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens ( self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign - if self.arrow[index] + if self.tensor.dim() != 0 and self.arrow[index] ), torch.zeros([], dtype=torch.bool, device=self.tensor.device), ) @@ -410,6 +447,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens tensor = tensor.reshape(shape) + if len(new_shape) == 0: + return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) + merging_parity = functools.reduce( torch.logical_xor, ( diff --git a/tests/reshape_test.py b/tests/reshape_test.py index 52c49a9..2e1b9bc 100644 --- a/tests/reshape_test.py +++ b/tests/reshape_test.py @@ -63,7 +63,7 @@ def test_reshape_trivial_edges(arrow: tuple[bool, ...], plan_range: tuple[int, i assert a.edges == c.edges -def test_reshape_merging_dimension_mismatch_edges_because_of_nonequal() -> None: +def test_reshape_merging_dimension_mismatch_edges_because_of_unequal() -> None: arrow = (True, True, True) edges = ((2, 2), (8, 8), (2, 2)) a = GrassmannTensor(arrow, edges, torch.randn([4, 16, 4])) @@ -113,7 +113,7 @@ def test_reshape_splitting_shape_type() -> None: _ = a.reshape((2, (2, 2))) -def test_reshape_splitting_dimension_mismatch_edges_because_of_nonequal() -> None: +def test_reshape_splitting_dimension_mismatch_edges_because_of_unequal() -> None: arrow = (True,) edges = ((8, 8),) a = GrassmannTensor(arrow, edges, torch.randn([16])) @@ -175,3 +175,10 @@ def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None: edges = ((1, 3), (1, 0), (0, 1), (2, 2)) a = GrassmannTensor(arrow, edges, torch.randn([4, 1, 1, 4])) _ = a.reshape(((3, 1), (2, 2))) + + +def test_reshape_with_none() -> None: + a = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(()) + assert len(a.arrow) == 0 and len(a.edges) == 0 and a.tensor.dim() == 0 + b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((0, 1), (0, 1))).reshape(()) + assert len(b.arrow) == 0 and len(b.edges) == 0 and b.tensor.dim() == 0