Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions parity_tensor/parity_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,35 @@ def update_mask(self) -> ParityTensor:
self._tensor = torch.where(self.mask, self._tensor, 0)
return self

def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
"""
Permute the indices of the parity tensor.
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is incomplete. It should include parameter descriptions, return value description, and example usage for a public API method.

Suggested change
Permute the indices of the parity tensor.
Permute the indices of the parity tensor.
This method rearranges the dimensions of the tensor and its associated metadata
(edges, parity, and mask) according to the specified permutation.
Parameters:
----------
before_by_after : tuple[int, ...]
A tuple specifying the new order of the dimensions. Each element in the tuple
represents the index of the original dimension that should appear in the
corresponding position in the new order.
Returns:
-------
ParityTensor
A new `ParityTensor` instance with permuted indices, updated edges, parity,
and mask.
Example:
--------
>>> tensor = torch.tensor([[1, 2], [3, 4]])
>>> edges = ((0, 1), (1, 0))
>>> parity_tensor = ParityTensor(_edges=edges, _tensor=tensor)
>>> permuted_tensor = parity_tensor.permute((1, 0))
>>> print(permuted_tensor.tensor)
tensor([[1, 3],
[2, 4]])

Copilot uses AI. Check for mistakes.
"""
assert set(before_by_after) == set(range(self.tensor.dim())), "Permutation indices must cover all dimensions."

edges = tuple(self.edges[i] for i in before_by_after)
tensor = self.tensor.permute(before_by_after)
parity = tuple(self.parity[i] for i in before_by_after)
mask = self.mask.permute(before_by_after)

total_parity = functools.reduce(
torch.logical_xor,
(
torch.logical_and(parity[i], parity[j])
for j in range(self.tensor.dim())
for i in range(0, j) # all 0 <= i < j < dim
if before_by_after[i] > before_by_after[j]),
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method lacks input validation. It should verify that before_by_after contains valid indices (0 to tensor.dim()-1) and that all indices are unique to prevent runtime errors.

Copilot uses AI. Check for mistakes.
torch.zeros([], dtype=torch.bool),
)
tensor = torch.where(total_parity, -tensor, +tensor)

return ParityTensor(
_edges=edges,
_tensor=tensor,
_parity=parity,
_mask=mask,
)

def __post_init__(self) -> None:
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
for dim, (even, odd) in zip(self._tensor.shape, self._edges):
Expand Down