-
Notifications
You must be signed in to change notification settings - Fork 0
Add reshape function. #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
85dff0f
bf9fb27
1673994
c3e6fb3
d15b150
ac793e7
e14d550
663679f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -169,6 +169,156 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: | |||||||||||||||||||||||||||||
| _tensor=tensor, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int, torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||
| parity = functools.reduce( | ||||||||||||||||||||||||||||||
| torch.logical_xor, | ||||||||||||||||||||||||||||||
| (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)) for index, (even, odd) in enumerate(edges)), | ||||||||||||||||||||||||||||||
| torch.zeros([], dtype=torch.bool, device=self.tensor.device), | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| flatten_parity = parity.flatten() | ||||||||||||||||||||||||||||||
| even = (~flatten_parity).nonzero().squeeze() | ||||||||||||||||||||||||||||||
| odd = flatten_parity.nonzero().squeeze() | ||||||||||||||||||||||||||||||
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
| reorder = torch.cat([even, odd], dim=0) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| total = functools.reduce( | ||||||||||||||||||||||||||||||
| torch.add, | ||||||||||||||||||||||||||||||
| (self._unsqueeze(self._edge_mask(even, odd), index, len(edges)).to(dtype=torch.int16) for index, (even, odd) in enumerate(edges)), | ||||||||||||||||||||||||||||||
| torch.zeros([], dtype=torch.int16, device=self.tensor.device), | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| count = total * (total - 1) | ||||||||||||||||||||||||||||||
| sign = (count & 2).to(dtype=torch.bool) | ||||||||||||||||||||||||||||||
| return len(even), len(odd), reorder, sign.flatten() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor: | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Reshape the Grassmann tensor, which may split or merge edges. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| The new shape must be compatible with the original shape. | ||||||||||||||||||||||||||||||
| This operation does not change the arrow and it cannot merge two edges with different arrows. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| The new shape should be a tuple of each new dimension, which is represented as either a single integer or a pair of two integers. | ||||||||||||||||||||||||||||||
| When a dimension is not changed, user could pass -1 to indicate that the dimension remains the same. | ||||||||||||||||||||||||||||||
| When a dimension is merged, user only needs to pass a single integer to indicate the new dimension size. | ||||||||||||||||||||||||||||||
| When a dimension is split, user must pass several pairs of two integers (even, odd) to indicate the new even and odd parts. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| A single sign is generated during merging or splitting two edges, which should be applied to one of the connected two tensors. | ||||||||||||||||||||||||||||||
| This package always applies it to the tensor with arrow as True. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| # This function reshapes the Grassmann tensor according to the new shape, including the following steps: | ||||||||||||||||||||||||||||||
| # 1. Generate new arrow, edges, and shape for tensor | ||||||||||||||||||||||||||||||
| # 2. Reorder the indices for splitting | ||||||||||||||||||||||||||||||
| # 3. Apply the sign for splitting | ||||||||||||||||||||||||||||||
| # 4. reshape the core tensor according to the new shape | ||||||||||||||||||||||||||||||
| # 5. Apply the sign for merging | ||||||||||||||||||||||||||||||
| # 6. Reorder the indices for merging | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # pylint: disable=too-many-branches, too-many-locals, too-many-statements | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| arrow: list[bool] = [] | ||||||||||||||||||||||||||||||
| edges: list[tuple[int, int]] = [] | ||||||||||||||||||||||||||||||
| shape: list[int] = [] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| splitting_sign: list[tuple[int, torch.Tensor]] = [] | ||||||||||||||||||||||||||||||
| splitting_reorder: list[tuple[int, torch.Tensor]] = [] | ||||||||||||||||||||||||||||||
| merging_reorder: list[tuple[int, torch.Tensor]] = [] | ||||||||||||||||||||||||||||||
| merging_sign: list[tuple[int, torch.Tensor]] = [] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| cursor_plan: int = 0 | ||||||||||||||||||||||||||||||
| cursor_self: int = 0 | ||||||||||||||||||||||||||||||
| while True: | ||||||||||||||||||||||||||||||
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
| if new_shape[cursor_plan] == -1: | ||||||||||||||||||||||||||||||
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
| # Does not change | ||||||||||||||||||||||||||||||
| arrow.append(self.arrow[cursor_self]) | ||||||||||||||||||||||||||||||
| edges.append(self.edges[cursor_self]) | ||||||||||||||||||||||||||||||
| shape.append(self.tensor.shape[cursor_self]) | ||||||||||||||||||||||||||||||
| cursor_self += 1 | ||||||||||||||||||||||||||||||
| cursor_plan += 1 | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| cursor_new_shape = new_shape[cursor_plan] | ||||||||||||||||||||||||||||||
| total = cursor_new_shape if isinstance(cursor_new_shape, int) else cursor_new_shape[0] + cursor_new_shape[1] | ||||||||||||||||||||||||||||||
| if total >= self.tensor.shape[cursor_self]: | ||||||||||||||||||||||||||||||
| # Merging | ||||||||||||||||||||||||||||||
| new_cursor_self = cursor_self | ||||||||||||||||||||||||||||||
| self_total = 1 | ||||||||||||||||||||||||||||||
| while True: | ||||||||||||||||||||||||||||||
| self_total *= self.tensor.shape[new_cursor_self] | ||||||||||||||||||||||||||||||
| new_cursor_self += 1 | ||||||||||||||||||||||||||||||
| if self_total == total: | ||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||
| assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." | ||||||||||||||||||||||||||||||
hzhangxyz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
| assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}." | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}." | |
| self_total = 1 | |
| max_merge_iters = self.tensor.dim() - cursor_self | |
| merge_iter_count = 0 | |
| while True: | |
| self_total *= self.tensor.shape[new_cursor_self] | |
| new_cursor_self += 1 | |
| merge_iter_count += 1 | |
| if self_total == total: | |
| break | |
| assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}." | |
| assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}." | |
| if merge_iter_count > max_merge_iters: | |
| raise RuntimeError(f"Exceeded maximum merge iterations ({max_merge_iters}) in merging loop. Possible infinite loop with edges {self.edges} and new shape {new_shape}.") |
Copilot
AI
Aug 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This nested while loop also lacks proper bounds checking and could run indefinitely if the termination conditions aren't met properly.
| while True: | |
| while plan_total < self.tensor.shape[cursor_self] and new_cursor_plan < len(new_shape): |
Copilot
AI
Aug 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop termination condition is placed after potential IndexError scenarios. If either cursor exceeds bounds during the loop iterations, the function will crash before reaching this check.
| if cursor_plan == len(new_shape) and cursor_self == self.tensor.dim(): |
Copilot
AI
Aug 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition self.arrow[index] uses the original tensor's arrow, but index refers to positions in the original tensor during splitting. After tensor reshaping, the dimensions may not correspond correctly.
| (self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]), | |
| (self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if arrow[index]), |
Uh oh!
There was an error while loading. Please reload this page.