From 85dff0f871f2b96bbaa8116ad372807c3296f75a Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Fri, 8 Aug 2025 04:03:25 +0800 Subject: [PATCH 1/8] Add reshape function. --- grassmann_tensor/tensor.py | 143 +++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 30ec73e..31aebc9 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -169,6 +169,149 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: _tensor=tensor, ) + def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int, torch.Tensor, torch.Tensor]: + parity = functools.reduce( + torch.logical_xor, + (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)) for index, [even, odd] in enumerate(edges)), + torch.zeros([], dtype=torch.bool, device=self.tensor.device), + ) + flatten_parity = parity.flatten() + even = (~flatten_parity).nonzero().squeeze() + odd = flatten_parity.nonzero().squeeze() + reorder = torch.cat([even, odd], dim=0) + + total = functools.reduce( + torch.add, + (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)).to(dtype=torch.int16) for index, [even, odd] in enumerate(edges)), + torch.zeros([], dtype=torch.int16, device=self.tensor.device), + ) + count = total * (total - 1) + sign = (count & 2).to(dtype=torch.bool) + return len(even), len(odd), reorder, sign.flatten() + + def shape(self, index: int) -> int: + """ + Get the shape of the Grassmann tensor at a specific index. + """ + assert 0 <= index < self.tensor.dim(), f"Index {index} out of bounds for tensor with {self.tensor.dim()} dimensions." + return self.edges[index][0] + self.edges[index][1] + + def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor: + """ + Reshape the Grassmann tensor, which may split or merge edges. + + The new shape must be compatible with the original shape. + This operation does not change the arrow and it cannot merge two edges with different arrows. + + The new shape should be a tuple of each new dimension, which is represented as either a single integer or a pair of two integers. + When a dimension is not changed, user could pass -1 to indicate that the dimension remains the same. + When a dimension is merged, user only needs to pass a single integer to indicate the new dimension size. + When a dimension is split, user must pass several pairs of two integers (even, odd) to indicate the new even and odd parts. + + A single sign is generated during merging or splitting two edges, which should be applied to one of the connected two tensors. + This package always applies it to the tensor with arrow as True. + """ + # This function reshapes the Grassmann tensor according to the new shape, including the following steps: + # 1. Generate new arrow, edges, and shape for tensor + # 2. Reorder the indices for splitting + # 3. Apply the sign for splitting + # 4. reshape the core tensor according to the new shape + # 5. Reorder the indices for merging + # 6. Apply the sign for merging + + tensor = self.tensor + arrow = [] + edges = [] + shape = [] + + splitting_sign = [] + splitting_reorder = [] + merging_reorder = [] + merging_sign = [] + + cursor_plan = 0 + cursor_self = 0 + while True: + if new_shape[cursor_plan] == -1: + # Does not change + arrow.append(self.arrow[cursor_self]) + edges.append(self.edges[cursor_self]) + shape.append(self.shape(cursor_self)) + cursor_self += 1 + cursor_plan += 1 + else: + total = new_shape[cursor_plan] if isinstance(new_shape[cursor_plan], int) else new_shape[cursor_plan][0] + new_shape[cursor_plan][1] + if total >= self.shape(cursor_self): + # Merging + new_cursor_self = cursor_self + self_total = 1 + while True: + self_total *= self.shape(new_cursor_self) + new_cursor_self += 1 + if self_total == total: + break + assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." + even, odd, reorder, sign = self._reorder_indices(self.edges[cursor_self:new_cursor_self]) + if isinstance(new_shape[cursor_plan], tuple): + assert new_shape[cursor_plan][0] == even and new_shape[cursor_plan][1] == odd, f"New even and odd number dismatch during merging {self.edegs} to {new_shape}." + arrow.append(self.arrow[cursor_self]) + assert all( + self_arrow == arrow[-1] for self_arrow in self.arrow[cursor_self:new_cursor_self]), f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." + edges.append((even, odd)) + shape.append(total) + if cursor_self + 1 != new_cursor_self: + # Really something merged + merging_sign.append((cursor_plan, sign)) + merging_reorder.append((cursor_plan, reorder)) + cursor_self = new_cursor_self + cursor_plan += 1 + else: + # Splitting + new_cursor_plan = cursor_plan + plan_total = 1 + while True: + assert isinstance(new_shape[new_cursor_plan], tuple), f"New shape must be a pair when splitting, got {new_shape[new_cursor_plan]}." + plan_total *= new_shape[new_cursor_plan][0] + new_shape[new_cursor_plan][1] + new_cursor_plan += 1 + if plan_total == self.shape(cursor_self): + break + assert plan_total < self.shape(cursor_self), f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." + even, odd, reorder, sign = self._reorder_indices(new_shape[cursor_plan:new_cursor_plan]) + assert (even, odd) == self.edges[cursor_self], f"New even and odd number dismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}." + for i in range(cursor_plan, new_cursor_plan): + arrow.append(self.arrow[cursor_self]) + edges.append(new_shape[i]) + shape.append(new_shape[i][0] + new_shape[i][1]) + splitting_reorder.append((cursor_self, reorder)) + splitting_sign.append((cursor_self, sign)) + cursor_self += 1 + cursor_plan = new_cursor_plan + + if cursor_plan == len(new_shape): + break + + splitting_parity = functools.reduce( + torch.logical_xor, + (self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]), + torch.zeros([], dtype=torch.bool, device=self.tensor.device), + ) + tensor = torch.where(splitting_parity, -tensor, +tensor) + for index, reorder in splitting_reorder: + tensor = tensor.index_select(index, reorder) + + tensor = tensor.reshape(shape) + + for index, reorder in merging_reorder: + tensor = tensor.index_select(index, reorder) + merging_parity = functools.reduce( + torch.logical_xor, + (self._unsqueeze(sign, index, tensor.dim()) for index, sign in merging_sign if arrow[index]), + torch.zeros([], dtype=torch.bool, device=self.tensor.device), + ) + tensor = torch.where(merging_parity, -tensor, +tensor) + + return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})." From bf9fb274f86c0d3e640db3a009aa9ab42e2312ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hao=20Zhang=28=E5=BC=A0=E6=B5=A9=29?= Date: Fri, 8 Aug 2025 15:19:13 +0800 Subject: [PATCH 2/8] Update grassmann_tensor/tensor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- grassmann_tensor/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 31aebc9..3f6e0df 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -253,7 +253,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." even, odd, reorder, sign = self._reorder_indices(self.edges[cursor_self:new_cursor_self]) if isinstance(new_shape[cursor_plan], tuple): - assert new_shape[cursor_plan][0] == even and new_shape[cursor_plan][1] == odd, f"New even and odd number dismatch during merging {self.edegs} to {new_shape}." + assert new_shape[cursor_plan][0] == even and new_shape[cursor_plan][1] == odd, f"New even and odd number mismatch during merging {self.edges} to {new_shape}." arrow.append(self.arrow[cursor_self]) assert all( self_arrow == arrow[-1] for self_arrow in self.arrow[cursor_self:new_cursor_self]), f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." From 16739945bd088aa1be85a35f5f5f7f90c7e6ff14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hao=20Zhang=28=E5=BC=A0=E6=B5=A9=29?= Date: Fri, 8 Aug 2025 15:19:26 +0800 Subject: [PATCH 3/8] Update grassmann_tensor/tensor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- grassmann_tensor/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 3f6e0df..ac46241 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -277,7 +277,7 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens break assert plan_total < self.shape(cursor_self), f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." even, odd, reorder, sign = self._reorder_indices(new_shape[cursor_plan:new_cursor_plan]) - assert (even, odd) == self.edges[cursor_self], f"New even and odd number dismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}." + assert (even, odd) == self.edges[cursor_self], f"New even and odd number mismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}." for i in range(cursor_plan, new_cursor_plan): arrow.append(self.arrow[cursor_self]) edges.append(new_shape[i]) From c3e6fb3133eba8aa085cac11c4ac8c6c74a999bc Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Sat, 9 Aug 2025 12:23:48 +0800 Subject: [PATCH 4/8] Use self.tensor.shape isntead of individual function shape. --- grassmann_tensor/tensor.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index ac46241..8831528 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -189,13 +189,6 @@ def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int sign = (count & 2).to(dtype=torch.bool) return len(even), len(odd), reorder, sign.flatten() - def shape(self, index: int) -> int: - """ - Get the shape of the Grassmann tensor at a specific index. - """ - assert 0 <= index < self.tensor.dim(), f"Index {index} out of bounds for tensor with {self.tensor.dim()} dimensions." - return self.edges[index][0] + self.edges[index][1] - def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor: """ Reshape the Grassmann tensor, which may split or merge edges. @@ -236,17 +229,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens # Does not change arrow.append(self.arrow[cursor_self]) edges.append(self.edges[cursor_self]) - shape.append(self.shape(cursor_self)) + shape.append(self.tensor.shape[cursor_self]) cursor_self += 1 cursor_plan += 1 else: total = new_shape[cursor_plan] if isinstance(new_shape[cursor_plan], int) else new_shape[cursor_plan][0] + new_shape[cursor_plan][1] - if total >= self.shape(cursor_self): + if total >= self.tensor.shape[cursor_self]: # Merging new_cursor_self = cursor_self self_total = 1 while True: - self_total *= self.shape(new_cursor_self) + self_total *= self.tensor.shape[new_cursor_self] new_cursor_self += 1 if self_total == total: break @@ -273,9 +266,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens assert isinstance(new_shape[new_cursor_plan], tuple), f"New shape must be a pair when splitting, got {new_shape[new_cursor_plan]}." plan_total *= new_shape[new_cursor_plan][0] + new_shape[new_cursor_plan][1] new_cursor_plan += 1 - if plan_total == self.shape(cursor_self): + if plan_total == self.tensor.shape[cursor_self]: break - assert plan_total < self.shape(cursor_self), f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." + assert plan_total < self.tensor.shape[cursor_self], f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." even, odd, reorder, sign = self._reorder_indices(new_shape[cursor_plan:new_cursor_plan]) assert (even, odd) == self.edges[cursor_self], f"New even and odd number mismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}." for i in range(cursor_plan, new_cursor_plan): From d15b150f4a1699f3b64b2ea0322e0744f7ef3ed7 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Sat, 9 Aug 2025 12:49:36 +0800 Subject: [PATCH 5/8] Add some assertion and type annotation in reshape. --- grassmann_tensor/tensor.py | 47 +++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 8831528..555d117 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -212,18 +212,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens # 5. Reorder the indices for merging # 6. Apply the sign for merging - tensor = self.tensor - arrow = [] - edges = [] - shape = [] + # pylint: disable=too-many-branches, too-many-locals, too-many-statements + + arrow: list[bool] = [] + edges: list[tuple[int, int]] = [] + shape: list[int] = [] - splitting_sign = [] - splitting_reorder = [] - merging_reorder = [] - merging_sign = [] + splitting_sign: list[tuple[int, torch.Tensor]] = [] + splitting_reorder: list[tuple[int, torch.Tensor]] = [] + merging_reorder: list[tuple[int, torch.Tensor]] = [] + merging_sign: list[tuple[int, torch.Tensor]] = [] - cursor_plan = 0 - cursor_self = 0 + cursor_plan: int = 0 + cursor_self: int = 0 while True: if new_shape[cursor_plan] == -1: # Does not change @@ -233,7 +234,8 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens cursor_self += 1 cursor_plan += 1 else: - total = new_shape[cursor_plan] if isinstance(new_shape[cursor_plan], int) else new_shape[cursor_plan][0] + new_shape[cursor_plan][1] + cursor_new_shape = new_shape[cursor_plan] + total = cursor_new_shape if isinstance(cursor_new_shape, int) else cursor_new_shape[0] + cursor_new_shape[1] if total >= self.tensor.shape[cursor_self]: # Merging new_cursor_self = cursor_self @@ -244,9 +246,10 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens if self_total == total: break assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." + assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}." even, odd, reorder, sign = self._reorder_indices(self.edges[cursor_self:new_cursor_self]) - if isinstance(new_shape[cursor_plan], tuple): - assert new_shape[cursor_plan][0] == even and new_shape[cursor_plan][1] == odd, f"New even and odd number mismatch during merging {self.edges} to {new_shape}." + if isinstance(cursor_new_shape, tuple): + assert (even, odd) == cursor_new_shape, f"New even and odd number mismatch during merging {self.edges} to {new_shape}." arrow.append(self.arrow[cursor_self]) assert all( self_arrow == arrow[-1] for self_arrow in self.arrow[cursor_self:new_cursor_self]), f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." @@ -263,26 +266,32 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens new_cursor_plan = cursor_plan plan_total = 1 while True: - assert isinstance(new_shape[new_cursor_plan], tuple), f"New shape must be a pair when splitting, got {new_shape[new_cursor_plan]}." - plan_total *= new_shape[new_cursor_plan][0] + new_shape[new_cursor_plan][1] + new_cursor_new_shape = new_shape[new_cursor_plan] + assert isinstance(new_cursor_new_shape, tuple), f"New shape must be a pair when splitting, got {new_cursor_new_shape}." + plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1] new_cursor_plan += 1 if plan_total == self.tensor.shape[cursor_self]: break assert plan_total < self.tensor.shape[cursor_self], f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." - even, odd, reorder, sign = self._reorder_indices(new_shape[cursor_plan:new_cursor_plan]) + assert new_cursor_plan < len(new_shape), f"New shape {new_shape} exceeds specified dimensions {len(new_shape)}." + # new_shape has been verified to be tuple[int, int] in the loop + even, odd, reorder, sign = self._reorder_indices(typing.cast(tuple[tuple[int, int], ...], new_shape[cursor_plan:new_cursor_plan])) assert (even, odd) == self.edges[cursor_self], f"New even and odd number mismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}." for i in range(cursor_plan, new_cursor_plan): + # new_shape has been verified to be tuple[int, int] in the loop + new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i]) arrow.append(self.arrow[cursor_self]) - edges.append(new_shape[i]) - shape.append(new_shape[i][0] + new_shape[i][1]) + edges.append(new_cursor_new_shape) + shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1]) splitting_reorder.append((cursor_self, reorder)) splitting_sign.append((cursor_self, sign)) cursor_self += 1 cursor_plan = new_cursor_plan - if cursor_plan == len(new_shape): + if cursor_plan == len(new_shape) and cursor_self == self.tensor.dim(): break + tensor = self.tensor splitting_parity = functools.reduce( torch.logical_xor, (self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]), From ac793e7fe47db3688020fe382dca67f27a11eba2 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Sat, 9 Aug 2025 13:31:38 +0800 Subject: [PATCH 6/8] Fix wrong reorder and sign applying order. --- grassmann_tensor/tensor.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 555d117..ed42929 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -209,8 +209,8 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens # 2. Reorder the indices for splitting # 3. Apply the sign for splitting # 4. reshape the core tensor according to the new shape - # 5. Reorder the indices for merging - # 6. Apply the sign for merging + # 5. Apply the sign for merging + # 6. Reorder the indices for merging # pylint: disable=too-many-branches, too-many-locals, too-many-statements @@ -292,19 +292,20 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens break tensor = self.tensor + + for index, reorder in splitting_reorder: + inverse_reorder = torch.zeros_like(reorder).scatter_(0, reorder, torch.arange(reorder.size(0), device=reorder.device)) + tensor = tensor.index_select(index, inverse_reorder) + splitting_parity = functools.reduce( torch.logical_xor, (self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]), torch.zeros([], dtype=torch.bool, device=self.tensor.device), ) tensor = torch.where(splitting_parity, -tensor, +tensor) - for index, reorder in splitting_reorder: - tensor = tensor.index_select(index, reorder) tensor = tensor.reshape(shape) - for index, reorder in merging_reorder: - tensor = tensor.index_select(index, reorder) merging_parity = functools.reduce( torch.logical_xor, (self._unsqueeze(sign, index, tensor.dim()) for index, sign in merging_sign if arrow[index]), @@ -312,6 +313,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens ) tensor = torch.where(merging_parity, -tensor, +tensor) + for index, reorder in merging_reorder: + tensor = tensor.index_select(index, reorder) + return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) def __post_init__(self) -> None: From e14d55019021687a7c544951fa0d2f76f7d8a592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hao=20Zhang=28=E5=BC=A0=E6=B5=A9=29?= Date: Sat, 9 Aug 2025 13:53:34 +0800 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- grassmann_tensor/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index ed42929..aaa3c40 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -172,7 +172,7 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int, torch.Tensor, torch.Tensor]: parity = functools.reduce( torch.logical_xor, - (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)) for index, [even, odd] in enumerate(edges)), + (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)) for index, (even, odd) in enumerate(edges)), torch.zeros([], dtype=torch.bool, device=self.tensor.device), ) flatten_parity = parity.flatten() @@ -182,7 +182,7 @@ def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int total = functools.reduce( torch.add, - (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)).to(dtype=torch.int16) for index, [even, odd] in enumerate(edges)), + (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)).to(dtype=torch.int16) for index, (even, odd) in enumerate(edges)), torch.zeros([], dtype=torch.int16, device=self.tensor.device), ) count = total * (total - 1) From 663679faafff6fdec78f302e9dab83bb60d6a7a7 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Sat, 9 Aug 2025 14:11:04 +0800 Subject: [PATCH 8/8] Use assign to be more pythonnic. --- grassmann_tensor/tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index aaa3c40..94d6f8c 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -294,7 +294,8 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens tensor = self.tensor for index, reorder in splitting_reorder: - inverse_reorder = torch.zeros_like(reorder).scatter_(0, reorder, torch.arange(reorder.size(0), device=reorder.device)) + inverse_reorder = torch.empty_like(reorder) + inverse_reorder[reorder] = torch.arange(reorder.size(0), device=reorder.device) tensor = tensor.index_select(index, inverse_reorder) splitting_parity = functools.reduce(