-
Notifications
You must be signed in to change notification settings - Fork 0
Add field parity, and copy parity and mask in arithmetic operators. #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,16 +19,54 @@ class ParityTensor: | |||||
| Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers. | ||||||
| """ | ||||||
|
|
||||||
| edges: tuple[tuple[int, int], ...] | ||||||
| tensor: torch.Tensor | ||||||
| mask: torch.Tensor | None = None | ||||||
| _edges: tuple[tuple[int, int], ...] | ||||||
| _tensor: torch.Tensor | ||||||
| _parity: tuple[torch.Tensor, ...] | None = None | ||||||
| _mask: torch.Tensor | None = None | ||||||
|
|
||||||
| @property | ||||||
| def edges(self) -> tuple[tuple[int, int], ...]: | ||||||
| """ | ||||||
| The edges of the tensor, represented as a tuple of pairs (even, odd). | ||||||
| """ | ||||||
| return self._edges | ||||||
|
|
||||||
| @property | ||||||
| def tensor(self) -> torch.Tensor: | ||||||
| """ | ||||||
| The underlying tensor data. | ||||||
| """ | ||||||
| return self._tensor | ||||||
|
|
||||||
| @property | ||||||
| def parity(self) -> tuple[torch.Tensor, ...]: | ||||||
| """ | ||||||
| The parity of each edge, represented as a tuple of tensors. | ||||||
| """ | ||||||
| if self._parity is None: | ||||||
| self._parity = tuple(self._edge_mask(even, odd) for (even, odd) in self._edges) | ||||||
| return self._parity | ||||||
|
|
||||||
| @property | ||||||
| def mask(self) -> torch.Tensor: | ||||||
| """ | ||||||
| The mask of the tensor, which has the same shape as the tensor and indicates which elements could be non-zero based on the parity. | ||||||
| """ | ||||||
| if self._mask is None: | ||||||
| self._mask = self._tensor_mask() | ||||||
| return self._mask | ||||||
|
|
||||||
| def update_mask(self) -> ParityTensor: | ||||||
| """ | ||||||
| Update the mask of the tensor based on its parity. | ||||||
| """ | ||||||
| self._tensor = torch.where(self.mask, self._tensor, 0) | ||||||
| return self | ||||||
|
|
||||||
| def __post_init__(self) -> None: | ||||||
| assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})." | ||||||
| for dim, (even, odd) in zip(self.tensor.shape, self.edges): | ||||||
| assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})." | ||||||
| for dim, (even, odd) in zip(self._tensor.shape, self._edges): | ||||||
| assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." | ||||||
| if self.mask is None: | ||||||
| self.mask = self._tensor_mask() | ||||||
|
|
||||||
| @classmethod | ||||||
| def _unqueeze(cls, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor: | ||||||
|
|
@@ -41,176 +79,204 @@ def _edge_mask(cls, even: int, odd: int) -> torch.Tensor: | |||||
| def _tensor_mask(self) -> torch.Tensor: | ||||||
| return functools.reduce( | ||||||
| torch.logical_xor, | ||||||
| (self._unqueeze(self._edge_mask(even, odd), index, self.tensor.dim()) for index, (even, odd) in enumerate(self.edges)), | ||||||
| torch.ones_like(self.tensor, dtype=torch.bool), | ||||||
| (self._unqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)), | ||||||
| torch.ones_like(self._tensor, dtype=torch.bool), | ||||||
| ) | ||||||
|
|
||||||
| def _validate_edge_compatibility(self, other: ParityTensor) -> None: | ||||||
| """ | ||||||
| Validate that the edges of two ParityTensor instances are compatible for arithmetic operations. | ||||||
| """ | ||||||
| assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}." | ||||||
| assert self._edges == other.edges, f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}." | ||||||
|
|
||||||
| def __pos__(self) -> ParityTensor: | ||||||
| return ParityTensor( | ||||||
| edges=self.edges, | ||||||
| tensor=+self.tensor, | ||||||
| _edges=self._edges, | ||||||
| _tensor=+self._tensor, | ||||||
| _parity=self._parity, | ||||||
| _mask=self._mask, | ||||||
| ) | ||||||
|
|
||||||
| def __neg__(self) -> ParityTensor: | ||||||
| return ParityTensor( | ||||||
| edges=self.edges, | ||||||
| tensor=-self.tensor, | ||||||
| _edges=self._edges, | ||||||
| _tensor=-self._tensor, | ||||||
| _parity=self._parity, | ||||||
| _mask=self._mask, | ||||||
| ) | ||||||
|
|
||||||
| def __add__(self, other: typing.Any) -> ParityTensor: | ||||||
| if isinstance(other, ParityTensor): | ||||||
| self._validate_edge_compatibility(other) | ||||||
| return ParityTensor( | ||||||
| edges=self.edges, | ||||||
| tensor=self.tensor + other.tensor, | ||||||
| _edges=self._edges, | ||||||
| _tensor=self._tensor + other._tensor, | ||||||
| _parity=self._parity, | ||||||
| _mask=self._mask, | ||||||
|
Comment on lines
111
to
+115
|
||||||
| ) | ||||||
| try: | ||||||
| result = self.tensor + other | ||||||
| result = self._tensor + other | ||||||
| except TypeError: | ||||||
| return NotImplemented | ||||||
| if isinstance(result, torch.Tensor): | ||||||
| return ParityTensor( | ||||||
| edges=self.edges, | ||||||
| tensor=result, | ||||||
| _edges=self._edges, | ||||||
| _tensor=result, | ||||||
| _parity=self._parity, | ||||||
| _mask=self._mask, | ||||||
| ) | ||||||
| return NotImplemented | ||||||
|
|
||||||
| def __radd__(self, other: typing.Any) -> ParityTensor: | ||||||
| try: | ||||||
| result = other + self.tensor | ||||||
| result = other + self._tensor | ||||||
| except TypeError: | ||||||
| return NotImplemented | ||||||
| if isinstance(result, torch.Tensor): | ||||||
| return ParityTensor( | ||||||
| edges=self.edges, | ||||||
| tensor=result, | ||||||
| _edges=self._edges, | ||||||
| _tensor=result, | ||||||
| _parity=self._parity, | ||||||
| _mask=self._mask, | ||||||
| ) | ||||||
| return NotImplemented | ||||||
|
|
||||||
| def __iadd__(self, other: typing.Any) -> ParityTensor: | ||||||
| if isinstance(other, ParityTensor): | ||||||
| self._validate_edge_compatibility(other) | ||||||
| self.tensor += other.tensor | ||||||
| self._tensor += other._tensor | ||||||
| else: | ||||||
| self.tensor += other | ||||||
| self._tensor += other | ||||||
| return self | ||||||
|
|
||||||
| def __sub__(self, other: typing.Any) -> ParityTensor: | ||||||
| if isinstance(other, ParityTensor): | ||||||
| self._validate_edge_compatibility(other) | ||||||
| return ParityTensor( | ||||||
| edges=self.edges, | ||||||
| tensor=self.tensor - other.tensor, | ||||||
| _edges=self._edges, | ||||||
| _tensor=self._tensor - other._tensor, | ||||||
| _parity=self._parity, | ||||||
| _mask=self._mask, | ||||||
|
||||||
| _mask=self._mask, | |
| _mask=None, |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask from self is being copied to the result of multiplication with another ParityTensor, but element-wise multiplication can change which elements should be zero according to the parity rules. The mask should be recomputed or set to None to ensure correctness.
| _mask=self._mask, | |
| _mask=None, |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask from self is being copied to the result of division with another ParityTensor, but division can change which elements should be zero according to the parity rules. The mask should be recomputed or set to None to ensure correctness.
| _mask=self._mask, | |
| _mask=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When adding two ParityTensor instances, copying only the left operand's parity and mask may be incorrect. The result should have its own computed parity and mask, or both operands' cached values should be considered if they're compatible.