From 34a704e9f16b0effd2e3055ee6c2ec6af6f50cc0 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Thu, 31 Jul 2025 09:51:58 +0800 Subject: [PATCH] Use dataclasses.replace to simplify coding. --- parity_tensor/parity_tensor.py | 91 ++++++++++++---------------------- 1 file changed, 32 insertions(+), 59 deletions(-) diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 7d2db33..aad7d57 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -60,8 +60,8 @@ def to(self, device: torch.device) -> ParityTensor: """ Copy the tensor to a specified device. """ - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=self._tensor.to(device), _parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None, _mask=self._mask.to(device) if self._mask is not None else None, @@ -96,7 +96,8 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor: ) tensor = torch.where(total_parity, -tensor, +tensor) - return ParityTensor( + return dataclasses.replace( + self, _edges=edges, _tensor=tensor, _parity=parity, @@ -128,40 +129,32 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None: assert self._edges == other.edges, f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}." def __pos__(self) -> ParityTensor: - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=+self._tensor, - _parity=self._parity, - _mask=self._mask, ) def __neg__(self) -> ParityTensor: - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=-self._tensor, - _parity=self._parity, - _mask=self._mask, ) def __add__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=self._tensor + other._tensor, - _parity=self._parity, - _mask=self._mask, ) try: result = self._tensor + other except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -171,11 +164,9 @@ def __radd__(self, other: typing.Any) -> ParityTensor: except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -190,22 +181,18 @@ def __iadd__(self, other: typing.Any) -> ParityTensor: def __sub__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=self._tensor - other._tensor, - _parity=self._parity, - _mask=self._mask, ) try: result = self._tensor - other except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -215,11 +202,9 @@ def __rsub__(self, other: typing.Any) -> ParityTensor: except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -234,22 +219,18 @@ def __isub__(self, other: typing.Any) -> ParityTensor: def __mul__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=self._tensor * other._tensor, - _parity=self._parity, - _mask=self._mask, ) try: result = self._tensor * other except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -259,11 +240,9 @@ def __rmul__(self, other: typing.Any) -> ParityTensor: except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -278,22 +257,18 @@ def __imul__(self, other: typing.Any) -> ParityTensor: def __truediv__(self, other: typing.Any) -> ParityTensor: if isinstance(other, ParityTensor): self._validate_edge_compatibility(other) - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=self._tensor / other._tensor, - _parity=self._parity, - _mask=self._mask, ) try: result = self._tensor / other except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented @@ -303,11 +278,9 @@ def __rtruediv__(self, other: typing.Any) -> ParityTensor: except TypeError: return NotImplemented if isinstance(result, torch.Tensor): - return ParityTensor( - _edges=self._edges, + return dataclasses.replace( + self, _tensor=result, - _parity=self._parity, - _mask=self._mask, ) return NotImplemented