From 536cd64ed170ccf4b53322634dfd8583c6634bd3 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Thu, 31 Jul 2025 16:45:58 +0800 Subject: [PATCH] Fix bug in permute function and fix a type. --- parity_tensor/parity_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 5b7fec1..77fe9e6 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -126,7 +126,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor: total_parity = functools.reduce( torch.logical_xor, ( - torch.logical_and(parity[i], parity[j]) + torch.logical_and(self._unsqueeze(parity[i], i, self.tensor.dim()), self._unsqueeze(parity[j], j, self.tensor.dim())) for j in range(self.tensor.dim()) for i in range(0, j) # all 0 <= i < j < dim if before_by_after[i] > before_by_after[j]), @@ -149,7 +149,7 @@ def __post_init__(self) -> None: for dim, (even, odd) in zip(self._tensor.shape, self._edges): assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." - def _unqueeze(self, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor: + def _unsqueeze(self, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor: return tensor.view([-1 if i == index else 1 for i in range(dim)]) def _edge_mask(self, even: int, odd: int) -> torch.Tensor: @@ -158,7 +158,7 @@ def _edge_mask(self, even: int, odd: int) -> torch.Tensor: def _tensor_mask(self) -> torch.Tensor: return functools.reduce( torch.logical_xor, - (self._unqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)), + (self._unsqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)), torch.ones_like(self._tensor, dtype=torch.bool), )