-
Notifications
You must be signed in to change notification settings - Fork 0
Fix issue when reshape with none edge. #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9a34fea
834f661
9a91366
ab9ddd5
771a40b
ce00566
2a2ab4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,33 +256,44 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
cursor_plan: int = 0 | ||
cursor_self: int = 0 | ||
while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim(): | ||
if new_shape[cursor_plan] == -1: | ||
if len(new_shape) == 0: | ||
assert all(edge == (0, 1) or edge == (1, 0) for edge in self.edges), ( | ||
f"Edge must be (0, 1) or (1, 0) but got {self.edges}" | ||
) | ||
cursor_self = self.tensor.dim() - 1 | ||
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1: | ||
# Does not change | ||
arrow.append(self.arrow[cursor_self]) | ||
edges.append(self.edges[cursor_self]) | ||
shape.append(self.tensor.shape[cursor_self]) | ||
cursor_self += 1 | ||
cursor_plan += 1 | ||
continue | ||
if new_shape[cursor_plan] == (1, 0): | ||
# An trivial plan edge | ||
elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == (1, 0): | ||
# A trivial plan edge | ||
arrow.append(False) | ||
edges.append((1, 0)) | ||
shape.append(1) | ||
cursor_plan += 1 | ||
continue | ||
if self.edges[cursor_self] == (1, 0): | ||
# An trivial self edge | ||
elif cursor_self != self.tensor.dim() and self.edges[cursor_self] == (1, 0): | ||
# A trivial self edge | ||
cursor_self += 1 | ||
continue | ||
cursor_new_shape = new_shape[cursor_plan] | ||
total = ( | ||
cursor_new_shape | ||
if isinstance(cursor_new_shape, int) | ||
else cursor_new_shape[0] + cursor_new_shape[1] | ||
) | ||
if len(new_shape) == 0: | ||
cursor_new_shape = typing.cast(int | tuple[int, int], tuple()) | ||
total = 1 | ||
else: | ||
cursor_new_shape = new_shape[cursor_plan] | ||
total = ( | ||
cursor_new_shape | ||
if isinstance(cursor_new_shape, int) | ||
else cursor_new_shape[0] + cursor_new_shape[1] | ||
) | ||
# one of total and shape[cursor_self] is not trivial, otherwise it should be handled before | ||
if total == self.tensor.shape[cursor_self]: | ||
if self.tensor.dim() == 0: | ||
merging = False | ||
elif total == self.tensor.shape[cursor_self]: | ||
# We do not know whether it is merging or splitting, check more | ||
if isinstance(cursor_new_shape, int) or cursor_new_shape == self.edges[cursor_self]: | ||
# If the new shape is exactly the same as the current edge, we treat it as no change | ||
|
@@ -296,6 +307,9 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
cursor_self_finding = cursor_self | ||
cursor_self_found = False | ||
while True: | ||
if len(new_shape) == 0: | ||
cursor_self_found = True | ||
break | ||
cursor_self_finding += 1 | ||
if cursor_self_finding == self.tensor.dim(): | ||
break | ||
|
@@ -306,15 +320,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
break | ||
break | ||
merging = cursor_self_found | ||
if total > self.tensor.shape[cursor_self]: | ||
elif total > self.tensor.shape[cursor_self]: | ||
merging = True | ||
if total < self.tensor.shape[cursor_self]: | ||
elif total < self.tensor.shape[cursor_self]: | ||
merging = False | ||
if merging: | ||
# Merging between [cursor_self, new_cursor_self) and the another side contains dimension as self_total | ||
new_cursor_self = cursor_self | ||
self_total = 1 | ||
while True: | ||
if len(new_shape) == 0: | ||
new_cursor_self += 1 | ||
even, odd, reorder, sign = self._reorder_indices(self.edges) | ||
break | ||
# Try to include more dimension from self | ||
self_total *= self.tensor.shape[new_cursor_self] | ||
new_cursor_self += 1 | ||
|
@@ -336,19 +354,26 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
f"New shape exceeds in merging with edges {self.edges} and new shape {new_shape}." | ||
) | ||
# The merging block [cursor_self, new_cursor_self) has been determined | ||
arrow.append(self.arrow[cursor_self]) | ||
assert all( | ||
self_arrow == arrow[-1] | ||
for self_arrow in self.arrow[cursor_self:new_cursor_self] | ||
), ( | ||
f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." | ||
) | ||
edges.append((even, odd)) | ||
shape.append(total) | ||
merging_sign.append((cursor_plan, sign)) | ||
merging_reorder.append((cursor_plan, reorder)) | ||
cursor_self = new_cursor_self | ||
cursor_plan += 1 | ||
if len(new_shape) == 0: | ||
arrow = [] | ||
edges = [] | ||
shape = [] | ||
merging_sign.append((cursor_plan, sign)) | ||
cursor_self = new_cursor_self | ||
else: | ||
arrow.append(self.arrow[cursor_self]) | ||
assert all( | ||
self_arrow == arrow[-1] | ||
for self_arrow in self.arrow[cursor_self:new_cursor_self] | ||
), ( | ||
f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}." | ||
) | ||
edges.append((even, odd)) | ||
shape.append(total) | ||
merging_sign.append((cursor_plan, sign)) | ||
merging_reorder.append((cursor_plan, reorder)) | ||
cursor_self = new_cursor_self | ||
cursor_plan += 1 | ||
else: | ||
# Splitting between [cursor_plan, new_cursor_plan) and the another side contains dimension as plan_total | ||
new_cursor_plan = cursor_plan | ||
|
@@ -362,15 +387,23 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1] | ||
new_cursor_plan += 1 | ||
# One dimension included, check if we can stop | ||
if plan_total == self.tensor.shape[cursor_self]: | ||
# new_shape block has been verified to be always tuple[int, int] before | ||
if self.tensor.dim() == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你在整个流程多次判断self.tensor.dim不如直接放在外面,这样时不时判断一下有点考验可读性。。。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果你经过思考,确认确实只有前后dim=0的情况会触发问题,直接在最外面做判断吧。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 但是对于0维度张量的merge和split的逻辑基本是相同的,只是对这些特殊情况进行了处理,如果在最外层判断,可能得把相同的代码逻辑移植到外面,可能会使这个函数更长。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好,目前这个版本 2a2ab4e 可读性还行。 |
||
even, odd, reorder, sign = self._reorder_indices( | ||
typing.cast( | ||
tuple[tuple[int, int], ...], new_shape[cursor_plan:new_cursor_plan] | ||
) | ||
typing.cast(tuple[tuple[int, int], ...], new_shape) | ||
) | ||
if (even, odd) == self.edges[cursor_self]: | ||
break | ||
new_cursor_plan = len(new_shape) | ||
break | ||
else: | ||
if plan_total == self.tensor.shape[cursor_self]: | ||
# new_shape block has been verified to be always tuple[int, int] before | ||
even, odd, reorder, sign = self._reorder_indices( | ||
typing.cast( | ||
tuple[tuple[int, int], ...], | ||
new_shape[cursor_plan:new_cursor_plan], | ||
) | ||
) | ||
if (even, odd) == self.edges[cursor_self]: | ||
break | ||
# For some reason we cannot stop here, continue to include more dimension, check something before continue | ||
assert plan_total <= self.tensor.shape[cursor_self], ( | ||
f"Dimension mismatch in splitting with edges {self.edges} and new shape {new_shape}." | ||
|
@@ -382,12 +415,16 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
for i in range(cursor_plan, new_cursor_plan): | ||
# new_shape block has been verified to be always tuple[int, int] in the loop | ||
new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i]) | ||
arrow.append(self.arrow[cursor_self]) | ||
if self.tensor.dim() == 0: | ||
arrow.append(False) | ||
else: | ||
arrow.append(self.arrow[cursor_self]) | ||
edges.append(new_cursor_new_shape) | ||
shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1]) | ||
splitting_reorder.append((cursor_self, reorder)) | ||
splitting_sign.append((cursor_self, sign)) | ||
cursor_self += 1 | ||
if self.tensor.dim() != 0: | ||
cursor_self += 1 | ||
cursor_plan = new_cursor_plan | ||
|
||
tensor = self.tensor | ||
|
@@ -402,14 +439,17 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens | |
( | ||
self._unsqueeze(sign, index, self.tensor.dim()) | ||
for index, sign in splitting_sign | ||
if self.arrow[index] | ||
if self.tensor.dim() != 0 and self.arrow[index] | ||
), | ||
torch.zeros([], dtype=torch.bool, device=self.tensor.device), | ||
) | ||
tensor = torch.where(splitting_parity, -tensor, +tensor) | ||
|
||
tensor = tensor.reshape(shape) | ||
|
||
if len(new_shape) == 0: | ||
return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor) | ||
|
||
merging_parity = functools.reduce( | ||
torch.logical_xor, | ||
( | ||
|
Uh oh!
There was an error while loading. Please reload this page.