From 6bf1100d34476701a4194434111c01bc5547dccb Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 30 Jul 2025 22:10:29 +0800 Subject: [PATCH 1/6] Add the basic parity tensor class and its arithmetic operators. --- parity_tensor/__init__.py | 3 +- parity_tensor/parity_tensor.py | 152 +++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 parity_tensor/parity_tensor.py diff --git a/parity_tensor/__init__.py b/parity_tensor/__init__.py index 30fea6e..0f46506 100644 --- a/parity_tensor/__init__.py +++ b/parity_tensor/__init__.py @@ -2,6 +2,7 @@ A parity tensor package. """ -__all__ = ["__version__"] +__all__ = ["__version__", "ParityTensor"] from .version import __version__ +from .parity_tensor import ParityTensor diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py new file mode 100644 index 0000000..95fb91b --- /dev/null +++ b/parity_tensor/parity_tensor.py @@ -0,0 +1,152 @@ +""" +A parity tensor class. +""" + +from __future__ import annotations + +__all__ = ["ParityTensor"] + +import dataclasses +import torch + + +@dataclasses.dataclass +class ParityTensor: + """ + A parity tensor class, which stores a tensor along with information about its edges. + Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers. + """ + + edges: tuple[tuple[int, int], ...] + tensor: torch.Tensor + + def __post_init__(self): + assert len(self.edges) == self.tensor.dim(), "Edges length must match tensor dimensions." + for dim, (even, odd) in zip(self.tensor.shape, self.edges): + assert even >= 0 and odd >= 0 and dim == even + odd, "Each dimension must match the sum of even and odd parts." + + def _validate_edge_compatibility(self, other: ParityTensor) -> None: + """ + Validate that the edges of two ParityTensor instances are compatible for arithmetic operations. + """ + assert self.edges == other.edges, "Edges must match for arithmetic operations." + + def __add__(self, other): + if isinstance(other, ParityTensor): # pylint: disable=no-else-return + self._validate_edge_compatibility(other) + return ParityTensor( + edges=self.edges, + tensor=self.tensor + other.tensor, + ) + else: + return ParityTensor( + edges=self.edges, + tensor=self.tensor + other, + ) + + def __radd__(self, other): + return ParityTensor( + edges=self.edges, + tensor=other + self.tensor, + ) + + def __iadd__(self, other): + if isinstance(other, ParityTensor): + self._validate_edge_compatibility(other) + self.tensor += other.tensor + else: + self.tensor += other + return self + + def __sub__(self, other): + if isinstance(other, ParityTensor): # pylint: disable=no-else-return + self._validate_edge_compatibility(other) + return ParityTensor( + edges=self.edges, + tensor=self.tensor - other.tensor, + ) + else: + return ParityTensor( + edges=self.edges, + tensor=self.tensor - other, + ) + + def __rsub__(self, other): + return ParityTensor( + edges=self.edges, + tensor=other - self.tensor, + ) + + def __isub__(self, other): + if isinstance(other, ParityTensor): + self._validate_edge_compatibility(other) + self.tensor -= other.tensor + else: + self.tensor -= other + return self + + def __mul__(self, other): + if isinstance(other, ParityTensor): # pylint: disable=no-else-return + self._validate_edge_compatibility(other) + return ParityTensor( + edges=self.edges, + tensor=self.tensor * other.tensor, + ) + else: + return ParityTensor( + edges=self.edges, + tensor=self.tensor * other, + ) + + def __rmul__(self, other): + return ParityTensor( + edges=self.edges, + tensor=other * self.tensor, + ) + + def __imul__(self, other): + if isinstance(other, ParityTensor): + self._validate_edge_compatibility(other) + self.tensor *= other.tensor + else: + self.tensor *= other + return self + + def __truediv__(self, other): + if isinstance(other, ParityTensor): # pylint: disable=no-else-return + self._validate_edge_compatibility(other) + return ParityTensor( + edges=self.edges, + tensor=self.tensor / other.tensor, + ) + else: + return ParityTensor( + edges=self.edges, + tensor=self.tensor / other, + ) + + def __rtruediv__(self, other): + return ParityTensor( + edges=self.edges, + tensor=other / self.tensor, + ) + + def __itruediv__(self, other): + if isinstance(other, ParityTensor): + self._validate_edge_compatibility(other) + self.tensor /= other.tensor + else: + self.tensor /= other + return self + + def __pos__(self): + return ParityTensor( + edges=self.edges, + tensor=+self.tensor, + ) + + def __neg__(self): + return ParityTensor( + edges=self.edges, + tensor=-self.tensor, + ) From 40ba62a807f3d7767a432762e442ed83fd619277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hao=20Zhang=28=E5=BC=A0=E6=B5=A9=29?= Date: Wed, 30 Jul 2025 23:00:38 +0800 Subject: [PATCH 2/6] Update parity_tensor/parity_tensor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- parity_tensor/parity_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 95fb91b..ed92fbc 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -29,7 +29,7 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None: """ Validate that the edges of two ParityTensor instances are compatible for arithmetic operations. """ - assert self.edges == other.edges, "Edges must match for arithmetic operations." + assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}." def __add__(self, other): if isinstance(other, ParityTensor): # pylint: disable=no-else-return From 4ee3380785d8ae90334e18e9921a929c148f6ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hao=20Zhang=28=E5=BC=A0=E6=B5=A9=29?= Date: Wed, 30 Jul 2025 23:01:04 +0800 Subject: [PATCH 3/6] Update parity_tensor/parity_tensor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- parity_tensor/parity_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index ed92fbc..1814ae8 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -23,7 +23,7 @@ class ParityTensor: def __post_init__(self): assert len(self.edges) == self.tensor.dim(), "Edges length must match tensor dimensions." for dim, (even, odd) in zip(self.tensor.shape, self.edges): - assert even >= 0 and odd >= 0 and dim == even + odd, "Each dimension must match the sum of even and odd parts." + assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." def _validate_edge_compatibility(self, other: ParityTensor) -> None: """ From 93e2bb82e587f0478314d70d0c88af0a11e565fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hao=20Zhang=28=E5=BC=A0=E6=B5=A9=29?= Date: Wed, 30 Jul 2025 23:01:17 +0800 Subject: [PATCH 4/6] Update parity_tensor/parity_tensor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- parity_tensor/parity_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 1814ae8..4bd8650 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -21,7 +21,7 @@ class ParityTensor: tensor: torch.Tensor def __post_init__(self): - assert len(self.edges) == self.tensor.dim(), "Edges length must match tensor dimensions." + assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})." for dim, (even, odd) in zip(self.tensor.shape, self.edges): assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." From bee4047d751c6945706af518f4953f772d7a73fb Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 30 Jul 2025 22:53:46 +0800 Subject: [PATCH 5/6] Add annotation for return type of functions. --- parity_tensor/parity_tensor.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 4bd8650..cb2d645 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -20,7 +20,7 @@ class ParityTensor: edges: tuple[tuple[int, int], ...] tensor: torch.Tensor - def __post_init__(self): + def __post_init__(self) -> None: assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})." for dim, (even, odd) in zip(self.tensor.shape, self.edges): assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." @@ -31,7 +31,7 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None: """ assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}." - def __add__(self, other): + def __add__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): # pylint: disable=no-else-return self._validate_edge_compatibility(other) return ParityTensor( @@ -44,13 +44,13 @@ def __add__(self, other): tensor=self.tensor + other, ) - def __radd__(self, other): + def __radd__(self, other) -> ParityTensor: return ParityTensor( edges=self.edges, tensor=other + self.tensor, ) - def __iadd__(self, other): + def __iadd__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor += other.tensor @@ -58,7 +58,7 @@ def __iadd__(self, other): self.tensor += other return self - def __sub__(self, other): + def __sub__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): # pylint: disable=no-else-return self._validate_edge_compatibility(other) return ParityTensor( @@ -71,13 +71,13 @@ def __sub__(self, other): tensor=self.tensor - other, ) - def __rsub__(self, other): + def __rsub__(self, other) -> ParityTensor: return ParityTensor( edges=self.edges, tensor=other - self.tensor, ) - def __isub__(self, other): + def __isub__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor -= other.tensor @@ -85,7 +85,7 @@ def __isub__(self, other): self.tensor -= other return self - def __mul__(self, other): + def __mul__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): # pylint: disable=no-else-return self._validate_edge_compatibility(other) return ParityTensor( @@ -98,13 +98,13 @@ def __mul__(self, other): tensor=self.tensor * other, ) - def __rmul__(self, other): + def __rmul__(self, other) -> ParityTensor: return ParityTensor( edges=self.edges, tensor=other * self.tensor, ) - def __imul__(self, other): + def __imul__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor *= other.tensor @@ -112,7 +112,7 @@ def __imul__(self, other): self.tensor *= other return self - def __truediv__(self, other): + def __truediv__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): # pylint: disable=no-else-return self._validate_edge_compatibility(other) return ParityTensor( @@ -125,13 +125,13 @@ def __truediv__(self, other): tensor=self.tensor / other, ) - def __rtruediv__(self, other): + def __rtruediv__(self, other) -> ParityTensor: return ParityTensor( edges=self.edges, tensor=other / self.tensor, ) - def __itruediv__(self, other): + def __itruediv__(self, other) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor /= other.tensor @@ -139,13 +139,13 @@ def __itruediv__(self, other): self.tensor /= other return self - def __pos__(self): + def __pos__(self) -> ParityTensor: return ParityTensor( edges=self.edges, tensor=+self.tensor, ) - def __neg__(self): + def __neg__(self) -> ParityTensor: return ParityTensor( edges=self.edges, tensor=-self.tensor, From ae6aaf97459b1d1e9e2d5492bf98eb0a076dc17d Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 30 Jul 2025 23:25:47 +0800 Subject: [PATCH 6/6] Add typing for arithmetic operator rhs. --- parity_tensor/parity_tensor.py | 125 +++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 52 deletions(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index cb2d645..132f591 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -7,6 +7,7 @@ __all__ = ["ParityTensor"] import dataclasses +import typing import torch @@ -31,26 +32,43 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None: """ assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}." - def __add__(self, other) -> ParityTensor: - if isinstance(other, ParityTensor): # pylint: disable=no-else-return + def __pos__(self) -> ParityTensor: + return ParityTensor( + edges=self.edges, + tensor=+self.tensor, + ) + + def __neg__(self) -> ParityTensor: + return ParityTensor( + edges=self.edges, + tensor=-self.tensor, + ) + + def __add__(self, other: typing.Any) -> ParityTensor: + if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) return ParityTensor( edges=self.edges, tensor=self.tensor + other.tensor, ) - else: + result = self.tensor + other + if isinstance(result, torch.Tensor): return ParityTensor( edges=self.edges, - tensor=self.tensor + other, + tensor=result, ) + return NotImplemented - def __radd__(self, other) -> ParityTensor: - return ParityTensor( - edges=self.edges, - tensor=other + self.tensor, - ) + def __radd__(self, other: typing.Any) -> ParityTensor: + result = other + self.tensor + if isinstance(result, torch.Tensor): + return ParityTensor( + edges=self.edges, + tensor=result, + ) + return NotImplemented - def __iadd__(self, other) -> ParityTensor: + def __iadd__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor += other.tensor @@ -58,26 +76,31 @@ def __iadd__(self, other) -> ParityTensor: self.tensor += other return self - def __sub__(self, other) -> ParityTensor: - if isinstance(other, ParityTensor): # pylint: disable=no-else-return + def __sub__(self, other: typing.Any) -> ParityTensor: + if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) return ParityTensor( edges=self.edges, tensor=self.tensor - other.tensor, ) - else: + result = self.tensor - other + if isinstance(result, torch.Tensor): return ParityTensor( edges=self.edges, - tensor=self.tensor - other, + tensor=result, ) + return NotImplemented - def __rsub__(self, other) -> ParityTensor: - return ParityTensor( - edges=self.edges, - tensor=other - self.tensor, - ) + def __rsub__(self, other: typing.Any) -> ParityTensor: + result = other - self.tensor + if isinstance(result, torch.Tensor): + return ParityTensor( + edges=self.edges, + tensor=result, + ) + return NotImplemented - def __isub__(self, other) -> ParityTensor: + def __isub__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor -= other.tensor @@ -85,26 +108,31 @@ def __isub__(self, other) -> ParityTensor: self.tensor -= other return self - def __mul__(self, other) -> ParityTensor: - if isinstance(other, ParityTensor): # pylint: disable=no-else-return + def __mul__(self, other: typing.Any) -> ParityTensor: + if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) return ParityTensor( edges=self.edges, tensor=self.tensor * other.tensor, ) - else: + result = self.tensor * other + if isinstance(result, torch.Tensor): return ParityTensor( edges=self.edges, - tensor=self.tensor * other, + tensor=result, ) + return NotImplemented - def __rmul__(self, other) -> ParityTensor: - return ParityTensor( - edges=self.edges, - tensor=other * self.tensor, - ) + def __rmul__(self, other: typing.Any) -> ParityTensor: + result = other * self.tensor + if isinstance(result, torch.Tensor): + return ParityTensor( + edges=self.edges, + tensor=result, + ) + return NotImplemented - def __imul__(self, other) -> ParityTensor: + def __imul__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor *= other.tensor @@ -112,41 +140,34 @@ def __imul__(self, other) -> ParityTensor: self.tensor *= other return self - def __truediv__(self, other) -> ParityTensor: - if isinstance(other, ParityTensor): # pylint: disable=no-else-return + def __truediv__(self, other: typing.Any) -> ParityTensor: + if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) return ParityTensor( edges=self.edges, tensor=self.tensor / other.tensor, ) - else: + result = self.tensor / other + if isinstance(result, torch.Tensor): return ParityTensor( edges=self.edges, - tensor=self.tensor / other, + tensor=result, ) + return NotImplemented - def __rtruediv__(self, other) -> ParityTensor: - return ParityTensor( - edges=self.edges, - tensor=other / self.tensor, - ) + def __rtruediv__(self, other: typing.Any) -> ParityTensor: + result = other / self.tensor + if isinstance(result, torch.Tensor): + return ParityTensor( + edges=self.edges, + tensor=result, + ) + return NotImplemented - def __itruediv__(self, other) -> ParityTensor: + def __itruediv__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) self.tensor /= other.tensor else: self.tensor /= other return self - - def __pos__(self) -> ParityTensor: - return ParityTensor( - edges=self.edges, - tensor=+self.tensor, - ) - - def __neg__(self) -> ParityTensor: - return ParityTensor( - edges=self.edges, - tensor=-self.tensor, - )