Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,156 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor:
_tensor=tensor,
)

def _reorder_indices(self, edges: tuple[tuple[int, int], ...]) -> tuple[int, int, torch.Tensor, torch.Tensor]:
parity = functools.reduce(
torch.logical_xor,
(self._unsqueeze(self._edge_mask(even, odd), index, len(edges)) for index, (even, odd) in enumerate(edges)),
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
)
flatten_parity = parity.flatten()
even = (~flatten_parity).nonzero().squeeze()
odd = flatten_parity.nonzero().squeeze()
reorder = torch.cat([even, odd], dim=0)

total = functools.reduce(
torch.add,
(self._unsqueeze(self._edge_mask(even, odd), index, len(edges)).to(dtype=torch.int16) for index, (even, odd) in enumerate(edges)),
torch.zeros([], dtype=torch.int16, device=self.tensor.device),
)
count = total * (total - 1)
sign = (count & 2).to(dtype=torch.bool)
return len(even), len(odd), reorder, sign.flatten()

def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTensor:
"""
Reshape the Grassmann tensor, which may split or merge edges.

The new shape must be compatible with the original shape.
This operation does not change the arrow and it cannot merge two edges with different arrows.

The new shape should be a tuple of each new dimension, which is represented as either a single integer or a pair of two integers.
When a dimension is not changed, user could pass -1 to indicate that the dimension remains the same.
When a dimension is merged, user only needs to pass a single integer to indicate the new dimension size.
When a dimension is split, user must pass several pairs of two integers (even, odd) to indicate the new even and odd parts.

A single sign is generated during merging or splitting two edges, which should be applied to one of the connected two tensors.
This package always applies it to the tensor with arrow as True.
"""
# This function reshapes the Grassmann tensor according to the new shape, including the following steps:
# 1. Generate new arrow, edges, and shape for tensor
# 2. Reorder the indices for splitting
# 3. Apply the sign for splitting
# 4. reshape the core tensor according to the new shape
# 5. Apply the sign for merging
# 6. Reorder the indices for merging

# pylint: disable=too-many-branches, too-many-locals, too-many-statements

arrow: list[bool] = []
edges: list[tuple[int, int]] = []
shape: list[int] = []

splitting_sign: list[tuple[int, torch.Tensor]] = []
splitting_reorder: list[tuple[int, torch.Tensor]] = []
merging_reorder: list[tuple[int, torch.Tensor]] = []
merging_sign: list[tuple[int, torch.Tensor]] = []

cursor_plan: int = 0
cursor_self: int = 0
while True:
if new_shape[cursor_plan] == -1:
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing new_shape[cursor_plan] without bounds checking could cause an IndexError if cursor_plan exceeds the length of new_shape. This check should be performed before accessing the array element.

Copilot uses AI. Check for mistakes.
# Does not change
arrow.append(self.arrow[cursor_self])
edges.append(self.edges[cursor_self])
shape.append(self.tensor.shape[cursor_self])
cursor_self += 1
cursor_plan += 1
else:
cursor_new_shape = new_shape[cursor_plan]
total = cursor_new_shape if isinstance(cursor_new_shape, int) else cursor_new_shape[0] + cursor_new_shape[1]
if total >= self.tensor.shape[cursor_self]:
# Merging
new_cursor_self = cursor_self
self_total = 1
while True:
self_total *= self.tensor.shape[new_cursor_self]
new_cursor_self += 1
if self_total == total:
break
assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}."
assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}."
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another unbounded while loop that could potentially run indefinitely without proper safeguards.

Suggested change
assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}."
self_total = 1
max_merge_iters = self.tensor.dim() - cursor_self
merge_iter_count = 0
while True:
self_total *= self.tensor.shape[new_cursor_self]
new_cursor_self += 1
merge_iter_count += 1
if self_total == total:
break
assert self_total < total, f"Dimension mismatch with edges {self.edges} and new shape {new_shape}."
assert new_cursor_self < self.tensor.dim(), f"New shape {new_shape} exceeds tensor dimensions {self.tensor.dim()}."
if merge_iter_count > max_merge_iters:
raise RuntimeError(f"Exceeded maximum merge iterations ({max_merge_iters}) in merging loop. Possible infinite loop with edges {self.edges} and new shape {new_shape}.")

Copilot uses AI. Check for mistakes.
even, odd, reorder, sign = self._reorder_indices(self.edges[cursor_self:new_cursor_self])
if isinstance(cursor_new_shape, tuple):
assert (even, odd) == cursor_new_shape, f"New even and odd number mismatch during merging {self.edges} to {new_shape}."
arrow.append(self.arrow[cursor_self])
assert all(
self_arrow == arrow[-1] for self_arrow in self.arrow[cursor_self:new_cursor_self]), f"Cannot merge edges with different arrows {self.arrow[cursor_self:new_cursor_self]}."
edges.append((even, odd))
shape.append(total)
if cursor_self + 1 != new_cursor_self:
# Really something merged
merging_sign.append((cursor_plan, sign))
merging_reorder.append((cursor_plan, reorder))
cursor_self = new_cursor_self
cursor_plan += 1
else:
# Splitting
new_cursor_plan = cursor_plan
plan_total = 1
while True:
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nested while loop also lacks proper bounds checking and could run indefinitely if the termination conditions aren't met properly.

Suggested change
while True:
while plan_total < self.tensor.shape[cursor_self] and new_cursor_plan < len(new_shape):

Copilot uses AI. Check for mistakes.
new_cursor_new_shape = new_shape[new_cursor_plan]
assert isinstance(new_cursor_new_shape, tuple), f"New shape must be a pair when splitting, got {new_cursor_new_shape}."
plan_total *= new_cursor_new_shape[0] + new_cursor_new_shape[1]
new_cursor_plan += 1
if plan_total == self.tensor.shape[cursor_self]:
break
assert plan_total < self.tensor.shape[cursor_self], f"Dimension mismatch with edges {self.edges} and new shape {new_shape}."
assert new_cursor_plan < len(new_shape), f"New shape {new_shape} exceeds specified dimensions {len(new_shape)}."
# new_shape has been verified to be tuple[int, int] in the loop
even, odd, reorder, sign = self._reorder_indices(typing.cast(tuple[tuple[int, int], ...], new_shape[cursor_plan:new_cursor_plan]))
assert (even, odd) == self.edges[cursor_self], f"New even and odd number mismatch during splitting {self.edges[cursor_self]} to {new_shape[cursor_plan:new_cursor_plan]}."
for i in range(cursor_plan, new_cursor_plan):
# new_shape has been verified to be tuple[int, int] in the loop
new_cursor_new_shape = typing.cast(tuple[int, int], new_shape[i])
arrow.append(self.arrow[cursor_self])
edges.append(new_cursor_new_shape)
shape.append(new_cursor_new_shape[0] + new_cursor_new_shape[1])
splitting_reorder.append((cursor_self, reorder))
splitting_sign.append((cursor_self, sign))
cursor_self += 1
cursor_plan = new_cursor_plan

if cursor_plan == len(new_shape) and cursor_self == self.tensor.dim():
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop termination condition is placed after potential IndexError scenarios. If either cursor exceeds bounds during the loop iterations, the function will crash before reaching this check.

Suggested change
if cursor_plan == len(new_shape) and cursor_self == self.tensor.dim():

Copilot uses AI. Check for mistakes.
break

tensor = self.tensor

for index, reorder in splitting_reorder:
inverse_reorder = torch.empty_like(reorder)
inverse_reorder[reorder] = torch.arange(reorder.size(0), device=reorder.device)
tensor = tensor.index_select(index, inverse_reorder)

splitting_parity = functools.reduce(
torch.logical_xor,
(self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]),
Copy link

Copilot AI Aug 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.arrow[index] uses the original tensor's arrow, but index refers to positions in the original tensor during splitting. After tensor reshaping, the dimensions may not correspond correctly.

Suggested change
(self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if self.arrow[index]),
(self._unsqueeze(sign, index, self.tensor.dim()) for index, sign in splitting_sign if arrow[index]),

Copilot uses AI. Check for mistakes.
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
)
tensor = torch.where(splitting_parity, -tensor, +tensor)

tensor = tensor.reshape(shape)

merging_parity = functools.reduce(
torch.logical_xor,
(self._unsqueeze(sign, index, tensor.dim()) for index, sign in merging_sign if arrow[index]),
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
)
tensor = torch.where(merging_parity, -tensor, +tensor)

for index, reorder in merging_reorder:
tensor = tensor.index_select(index, reorder)

return GrassmannTensor(_arrow=tuple(arrow), _edges=tuple(edges), _tensor=tensor)

def __post_init__(self) -> None:
assert len(self._arrow) == self._tensor.dim(), f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
Expand Down