diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 15612ff..38c7332 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -63,6 +63,35 @@ def update_mask(self) -> ParityTensor: self._tensor = torch.where(self.mask, self._tensor, 0) return self + def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor: + """ + Permute the indices of the parity tensor. + """ + assert set(before_by_after) == set(range(self.tensor.dim())), "Permutation indices must cover all dimensions." + + edges = tuple(self.edges[i] for i in before_by_after) + tensor = self.tensor.permute(before_by_after) + parity = tuple(self.parity[i] for i in before_by_after) + mask = self.mask.permute(before_by_after) + + total_parity = functools.reduce( + torch.logical_xor, + ( + torch.logical_and(parity[i], parity[j]) + for j in range(self.tensor.dim()) + for i in range(0, j) # all 0 <= i < j < dim + if before_by_after[i] > before_by_after[j]), + torch.zeros([], dtype=torch.bool), + ) + tensor = torch.where(total_parity, -tensor, +tensor) + + return ParityTensor( + _edges=edges, + _tensor=tensor, + _parity=parity, + _mask=mask, + ) + def __post_init__(self) -> None: assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})." for dim, (even, odd) in zip(self._tensor.shape, self._edges):