diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 30ec73e..94d6f8c 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -169,6 +169,156 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: _tensor=tensor, ) + def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int, torch.Tensor, torch.Tensor]: + parity = functools.reduce( + torch.logical_xor, + (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)) for index, (even, odd) in enumerate(edges)), + torch.zeros([], dtype=torch.bool, device=self.tensor.device), + ) + flatten_parity = parity.flatten() + even = (~flatten_parity).nonzero().squeeze() + odd = flatten_parity.nonzero().squeeze() + reorder = torch.cat([even, odd], dim=0) + + total = functools.reduce( + torch.add, + (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)).to(dtype=torch.int16) for index, (even, odd) in enumerate(edges)), + torch.zeros([], dtype=torch.int16, device=self.tensor.device), + ) + count = total * (total - 1) + sign = (count & 2).to(dtype=torch.bool) + return len(even), len(odd), reorder, sign.flatten() + + def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor: + """ + Reshape the Grassmann tensor, which may split or merge edges. + + The new shape must be compatible with the original shape. + This operation does not change the arrow and it cannot merge two edges with different arrows. + + The new shape should be a tuple of each new dimension, which is represented as either a single integer or a pair of two integers. + When a dimension is not changed, user could pass -1 to indicate that the dimension remains the same. + When a dimension is merged, user only needs to pass a single integer to indicate the new dimension size. + When a dimension is split, user must pass several pairs of two integers (even, odd) to indicate the new even and odd parts. + + A single sign is generated during merging or splitting two edges, which should be applied to one of the connected two tensors. + This package always applies it to the tensor with arrow as True. + """ + # This function reshapes the Grassmann tensor according to the new shape, including the following steps: + # 1. Generate new arrow, edges, and shape for tensor + # 2. Reorder the indices for splitting + # 3. Apply the sign for splitting + # 4. reshape the core tensor according to the new shape + # 5. Apply the sign for merging + # 6. Reorder the indices for merging + + # pylint: disable=too-many-branches, too-many-locals, too-many-statements + + arrow: list[bool] = [] + edges: list[tuple[int, int]] = [] + shape: list[int] = [] + + splitting_sign: list[tuple[int, torch.Tensor]] = [] + splitting_reorder: list[tuple[int, torch.Tensor]] = [] + merging_reorder: list[tuple[int, torch.Tensor]] = [] + merging_sign: list[tuple[int, torch.Tensor]] = [] + + cursor_plan: int = 0 + cursor_self: int = 0 + while True: + if new_shape[cursor_plan] == -1: + # Does not change + arrow.append(self.arrow[cursor_self]) + edges.append(self.edges[cursor_self]) + shape.append(self.tensor.shape[cursor_self]) + cursor_self += 1 + cursor_plan += 1 + else: + cursor_new_shape = new_shape[cursor_plan] + total = cursor_new_shape if isinstance(cursor_new_shape, int) else cursor_new_shape[0] + cursor_new_shape[1] + if total >= self.tensor.shape[cursor_self]: + # Merging + new_cursor_self = cursor_self + self_total = 1 + while True: + self_total *= self.tensor.shape[new_cursor_self] + new_cursor_self += 1 + if self_total == total: + break + assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." + assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}." + even, odd, reorder, sign = self._reorder_indices(self.edges[cursor_self:new_cursor_self]) + if isinstance(cursor_new_shape, tuple): + assert (even, odd) == cursor_new_shape, f"New even and odd number mismatch during merging {self.edges} to {new_shape}." + arrow.append(self.arrow[cursor_self]) + assert all( + self_arrow == arrow[-1] for self_arrow in self.arrow[cursor_self:new_cursor_self]), f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." + edges.append((even, odd)) + shape.append(total) + if cursor_self + 1 != new_cursor_self: + # Really something merged + merging_sign.append((cursor_plan, sign)) + merging_reorder.append((cursor_plan, reorder)) + cursor_self = new_cursor_self + cursor_plan += 1 + else: + # Splitting + new_cursor_plan = cursor_plan + plan_total = 1 + while True: + new_cursor_new_shape = new_shape[new_cursor_plan] + assert isinstance(new_cursor_new_shape, tuple), f"New shape must be a pair when splitting, got {new_cursor_new_shape}." + plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1] + new_cursor_plan += 1 + if plan_total == self.tensor.shape[cursor_self]: + break + assert plan_total < self.tensor.shape[cursor_self], f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." + assert new_cursor_plan < len(new_shape), f"New shape {new_shape} exceeds specified dimensions {len(new_shape)}." + # new_shape has been verified to be tuple[int, int] in the loop + even, odd, reorder, sign = self._reorder_indices(typing.cast(tuple[tuple[int, int], ...], new_shape[cursor_plan:new_cursor_plan])) + assert (even, odd) == self.edges[cursor_self], f"New even and odd number mismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}." + for i in range(cursor_plan, new_cursor_plan): + # new_shape has been verified to be tuple[int, int] in the loop + new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i]) + arrow.append(self.arrow[cursor_self]) + edges.append(new_cursor_new_shape) + shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1]) + splitting_reorder.append((cursor_self, reorder)) + splitting_sign.append((cursor_self, sign)) + cursor_self += 1 + cursor_plan = new_cursor_plan + + if cursor_plan == len(new_shape) and cursor_self == self.tensor.dim(): + break + + tensor = self.tensor + + for index, reorder in splitting_reorder: + inverse_reorder = torch.empty_like(reorder) + inverse_reorder[reorder] = torch.arange(reorder.size(0), device=reorder.device) + tensor = tensor.index_select(index, inverse_reorder) + + splitting_parity = functools.reduce( + torch.logical_xor, + (self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]), + torch.zeros([], dtype=torch.bool, device=self.tensor.device), + ) + tensor = torch.where(splitting_parity, -tensor, +tensor) + + tensor = tensor.reshape(shape) + + merging_parity = functools.reduce( + torch.logical_xor, + (self._unsqueeze(sign, index, tensor.dim()) for index, sign in merging_sign if arrow[index]), + torch.zeros([], dtype=torch.bool, device=self.tensor.device), + ) + tensor = torch.where(merging_parity, -tensor, +tensor) + + for index, reorder in merging_reorder: + tensor = tensor.index_select(index, reorder) + + return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."