Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 120 additions & 54 deletions parity_tensor/parity_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,54 @@ class ParityTensor:
Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers.
"""

edges: tuple[tuple[int, int], ...]
tensor: torch.Tensor
mask: torch.Tensor | None = None
_edges: tuple[tuple[int, int], ...]
_tensor: torch.Tensor
_parity: tuple[torch.Tensor, ...] | None = None
_mask: torch.Tensor | None = None

@property
def edges(self) -> tuple[tuple[int, int], ...]:
"""
The edges of the tensor, represented as a tuple of pairs (even, odd).
"""
return self._edges

@property
def tensor(self) -> torch.Tensor:
"""
The underlying tensor data.
"""
return self._tensor

@property
def parity(self) -> tuple[torch.Tensor, ...]:
"""
The parity of each edge, represented as a tuple of tensors.
"""
if self._parity is None:
self._parity = tuple(self._edge_mask(even, odd) for (even, odd) in self._edges)
return self._parity

@property
def mask(self) -> torch.Tensor:
"""
The mask of the tensor, which has the same shape as the tensor and indicates which elements could be non-zero based on the parity.
"""
if self._mask is None:
self._mask = self._tensor_mask()
return self._mask

def update_mask(self) -> ParityTensor:
"""
Update the mask of the tensor based on its parity.
"""
self._tensor = torch.where(self.mask, self._tensor, 0)
return self

def __post_init__(self) -> None:
assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})."
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
assert len(self._edges) == self._tensor.dim(), f"Edges length ({len(self._edges)}) must match tensor dimensions ({self._tensor.dim()})."
for dim, (even, odd) in zip(self._tensor.shape, self._edges):
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
if self.mask is None:
self.mask = self._tensor_mask()

@classmethod
def _unqueeze(cls, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
Expand All @@ -41,176 +79,204 @@ def _edge_mask(cls, even: int, odd: int) -> torch.Tensor:
def _tensor_mask(self) -> torch.Tensor:
return functools.reduce(
torch.logical_xor,
(self._unqueeze(self._edge_mask(even, odd), index, self.tensor.dim()) for index, (even, odd) in enumerate(self.edges)),
torch.ones_like(self.tensor, dtype=torch.bool),
(self._unqueeze(parity, index, self._tensor.dim()) for index, parity in enumerate(self.parity)),
torch.ones_like(self._tensor, dtype=torch.bool),
)

def _validate_edge_compatibility(self, other: ParityTensor) -> None:
"""
Validate that the edges of two ParityTensor instances are compatible for arithmetic operations.
"""
assert self.edges == other.edges, f"Edges must match for arithmetic operations. Got {self.edges} and {other.edges}."
assert self._edges == other.edges, f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}."

def __pos__(self) -> ParityTensor:
return ParityTensor(
edges=self.edges,
tensor=+self.tensor,
_edges=self._edges,
_tensor=+self._tensor,
_parity=self._parity,
_mask=self._mask,
)

def __neg__(self) -> ParityTensor:
return ParityTensor(
edges=self.edges,
tensor=-self.tensor,
_edges=self._edges,
_tensor=-self._tensor,
_parity=self._parity,
_mask=self._mask,
)

def __add__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor + other.tensor,
_edges=self._edges,
_tensor=self._tensor + other._tensor,
_parity=self._parity,
_mask=self._mask,
Comment on lines 111 to +115
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When adding two ParityTensor instances, copying only the left operand's parity and mask may be incorrect. The result should have its own computed parity and mask, or both operands' cached values should be considered if they're compatible.

Copilot uses AI. Check for mistakes.
Comment on lines 111 to +115
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to parity, copying only the left operand's mask in addition operations may be incorrect. The mask should be recomputed for the result or properly validated against the right operand's mask.

Copilot uses AI. Check for mistakes.
Comment on lines 111 to +115
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask from self is being copied to the result of addition with another ParityTensor, but the result tensor may have different values that could invalidate this mask. The mask should be recomputed or set to None to ensure correctness.

Copilot uses AI. Check for mistakes.
)
try:
result = self.tensor + other
result = self._tensor + other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __radd__(self, other: typing.Any) -> ParityTensor:
try:
result = other + self.tensor
result = other + self._tensor
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __iadd__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor += other.tensor
self._tensor += other._tensor
else:
self.tensor += other
self._tensor += other
return self

def __sub__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor - other.tensor,
_edges=self._edges,
_tensor=self._tensor - other._tensor,
_parity=self._parity,
_mask=self._mask,
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask from self is being copied to the result of subtraction with another ParityTensor, but the result tensor may have different values that could invalidate this mask. The mask should be recomputed or set to None to ensure correctness.

Suggested change
_mask=self._mask,
_mask=None,

Copilot uses AI. Check for mistakes.
)
try:
result = self.tensor - other
result = self._tensor - other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __rsub__(self, other: typing.Any) -> ParityTensor:
try:
result = other - self.tensor
result = other - self._tensor
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __isub__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor -= other.tensor
self._tensor -= other._tensor
else:
self.tensor -= other
self._tensor -= other
return self

def __mul__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor * other.tensor,
_edges=self._edges,
_tensor=self._tensor * other._tensor,
_parity=self._parity,
_mask=self._mask,
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask from self is being copied to the result of multiplication with another ParityTensor, but element-wise multiplication can change which elements should be zero according to the parity rules. The mask should be recomputed or set to None to ensure correctness.

Suggested change
_mask=self._mask,
_mask=None,

Copilot uses AI. Check for mistakes.
)
try:
result = self.tensor * other
result = self._tensor * other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __rmul__(self, other: typing.Any) -> ParityTensor:
try:
result = other * self.tensor
result = other * self._tensor
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __imul__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor *= other.tensor
self._tensor *= other._tensor
else:
self.tensor *= other
self._tensor *= other
return self

def __truediv__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
return ParityTensor(
edges=self.edges,
tensor=self.tensor / other.tensor,
_edges=self._edges,
_tensor=self._tensor / other._tensor,
_parity=self._parity,
_mask=self._mask,
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask from self is being copied to the result of division with another ParityTensor, but division can change which elements should be zero according to the parity rules. The mask should be recomputed or set to None to ensure correctness.

Suggested change
_mask=self._mask,
_mask=None,

Copilot uses AI. Check for mistakes.
)
try:
result = self.tensor / other
result = self._tensor / other
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __rtruediv__(self, other: typing.Any) -> ParityTensor:
try:
result = other / self.tensor
result = other / self._tensor
except TypeError:
return NotImplemented
if isinstance(result, torch.Tensor):
return ParityTensor(
edges=self.edges,
tensor=result,
_edges=self._edges,
_tensor=result,
_parity=self._parity,
_mask=self._mask,
)
return NotImplemented

def __itruediv__(self, other: typing.Any) -> ParityTensor:
if isinstance(other, ParityTensor):
self._validate_edge_compatibility(other)
self.tensor /= other.tensor
self._tensor /= other._tensor
else:
self.tensor /= other
self._tensor /= other
return self