Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 32 additions & 59 deletions parity_tensor/parity_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def to(self, device: torch.device) -> ParityTensor:
"""
Copy the tensor to a specified device.
"""
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
Copy link

Copilot AI Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataclasses module is not imported. You need to add import dataclasses at the top of the file to use dataclasses.replace().

Copilot uses AI. Check for mistakes.
self,
_tensor=self._tensor.to(device),
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
_mask=self._mask.to(device) if self._mask is not None else None,
Expand Down Expand Up @@ -96,7 +96,8 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
)
tensor = torch.where(total_parity, -tensor, +tensor)

return ParityTensor(
return dataclasses.replace(
self,
_edges=edges,
_tensor=tensor,
_parity=parity,
Expand Down Expand Up @@ -128,40 +129,32 @@ def _validate_edge_compatibility(self, other: ParityTensor) -> None:
assert self._edges == other.edges, f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}."

def __pos__(self) -> ParityTensor:
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=+self._tensor,
_parity=self._parity,
_mask=self._mask,
)

def __neg__(self) -> ParityTensor:
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=-self._tensor,
_parity=self._parity,
_mask=self._mask,
)

def __add__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=self._tensor + other._tensor,
_parity=self._parity,
_mask=self._mask,
)
try:
result = self._tensor + other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -171,11 +164,9 @@ def __radd__(self, other: typing.Any) -> ParityTensor:
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -190,22 +181,18 @@ def __iadd__(self, other: typing.Any) -> ParityTensor:
def __sub__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=self._tensor - other._tensor,
_parity=self._parity,
_mask=self._mask,
)
try:
result = self._tensor - other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -215,11 +202,9 @@ def __rsub__(self, other: typing.Any) -> ParityTensor:
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -234,22 +219,18 @@ def __isub__(self, other: typing.Any) -> ParityTensor:
def __mul__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=self._tensor * other._tensor,
_parity=self._parity,
_mask=self._mask,
)
try:
result = self._tensor * other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -259,11 +240,9 @@ def __rmul__(self, other: typing.Any) -> ParityTensor:
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -278,22 +257,18 @@ def __imul__(self, other: typing.Any) -> ParityTensor:
def __truediv__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=self._tensor / other._tensor,
_parity=self._parity,
_mask=self._mask,
)
try:
result = self._tensor / other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand All @@ -303,11 +278,9 @@ def __rtruediv__(self, other: typing.Any) -> ParityTensor:
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
_edges=self._edges,
return dataclasses.replace(
self,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

Expand Down