Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions parity_tensor/parity_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
total_parity = functools.reduce(
torch.logical_xor,
(
torch.logical_and(parity[i], parity[j])
torch.logical_and(self._unsqueeze(parity[i], i, self.tensor.dim()), self._unsqueeze(parity[j], j, self.tensor.dim()))
Copy link

Copilot AI Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable parity is not defined in this scope. Based on the context, it appears this should be self.parity to access the parity property of the class instance.

Suggested change
torch.logical_and(self._unsqueeze(parity[i], i, self.tensor.dim()), self._unsqueeze(parity[j], j, self.tensor.dim()))
torch.logical_and(self._unsqueeze(self.parity[i], i, self.tensor.dim()), self._unsqueeze(self.parity[j], j, self.tensor.dim()))

Copilot uses AI. Check for mistakes.
for j in range(self.tensor.dim())
for i in range(0, j) # all 0 <= i < j < dim
if before_by_after[i] > before_by_after[j]),
Expand All @@ -149,7 +149,7 @@ def __post_init__(self) -> None:
for dim, (even, odd) in zip(self._tensor.shape, self._edges):
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."

def _unqueeze(self, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
def _unsqueeze(self, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
return tensor.view([-1 if i == index else 1 for i in range(dim)])

def _edge_mask(self, even: int, odd: int) -> torch.Tensor:
Expand All @@ -158,7 +158,7 @@ def _edge_mask(self, even: int, odd: int) -> torch.Tensor:
def _tensor_mask(self) -> torch.Tensor:
return functools.reduce(
torch.logical_xor,
(self._unqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)),
(self._unsqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)),
torch.ones_like(self._tensor, dtype=torch.bool),
)

Expand Down