Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions parity_tensor/parity_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,45 @@ def mask(self) -> torch.Tensor:
self._mask = self._tensor_mask()
return self._mask

def to(self, device: torch.device) -> ParityTensor:
def to(self, whatever: torch.device | torch.dtype | str | None = None, *, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ParityTensor:
"""
Copy the tensor to a specified device.
Copy the tensor to a specified device or copy it to a specified data type.
"""
return dataclasses.replace(
self,
_tensor=self._tensor.to(device),
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
_mask=self._mask.to(device) if self._mask is not None else None,
)
if whatever is None:
pass
elif isinstance(whatever, torch.device):
assert device is None, "Duplicate device specification."
device = whatever
elif isinstance(whatever, torch.dtype):
assert dtype is None, "Duplicate dtype specification."
dtype = whatever
elif isinstance(whatever, str):
assert device is None, "Duplicate device specification."
device = torch.device(whatever)
else:
raise TypeError(f"Unsupported type for 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.")
Copy link

Copilot AI Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message references the parameter name 'whatever' which is exposed to users. This should use a more appropriate parameter name or generic description like 'first argument' to maintain professionalism in user-facing error messages.

Suggested change
raise TypeError(f"Unsupported type for 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.")
raise TypeError(f"Unsupported type for the first argument of 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.")

Copilot uses AI. Check for mistakes.
match (device, dtype):
case (None, None):
return self
case (None, _):
return dataclasses.replace(
self,
_tensor=self._tensor.to(dtype=dtype),
)
case (_, None):
return dataclasses.replace(
self,
_tensor=self._tensor.to(device=device),
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
_mask=self._mask.to(device) if self._mask is not None else None,
)
case _:
return dataclasses.replace(
self,
_tensor=self._tensor.to(device=device, dtype=dtype),
_parity=tuple(p.to(device=device) for p in self._parity) if self._parity is not None else None,
_mask=self._mask.to(device=device) if self._mask is not None else None,
)

def update_mask(self) -> ParityTensor:
"""
Expand Down
56 changes: 56 additions & 0 deletions tests/conversion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import typing
import pytest
import torch
from parity_tensor import ParityTensor


@pytest.fixture()
def x() -> ParityTensor:
return ParityTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4]))


@pytest.mark.parametrize("dtype_arg", ["position", "keyword", "none"])
@pytest.mark.parametrize("device_arg", ["position", "keyword", "none"])
@pytest.mark.parametrize("device_format", ["object", "string"])
def test_conversion(
x: ParityTensor,
dtype_arg: typing.Literal["position", "keyword", "none"],
device_arg: typing.Literal["position", "keyword", "none"],
device_format: typing.Literal["object", "string"],
) -> None:
args: list[typing.Any] = []
kwargs: dict[str, typing.Any] = {}

device = torch.device("cpu") if device_format == "object" else "cpu"
match device_arg:
case "position":
args.append(device)
case "keyword":
kwargs["device"] = device
case _:
pass

match dtype_arg:
case "position":
args.append(torch.complex128)
case "keyword":
kwargs["dtype"] = torch.complex128
case _:
pass

if len(args) <= 1:
y = x.to(*args, **kwargs)


def test_conversion_invalid_type(x: ParityTensor) -> None:
with pytest.raises(TypeError):
x.to(2333) # type: ignore[arg-type]


def test_conversion_duplicated_value(x: ParityTensor) -> None:
with pytest.raises(AssertionError):
x.to(torch.device("cpu"), device=torch.device("cpu"))
with pytest.raises(AssertionError):
x.to(torch.complex128, dtype=torch.complex128)
with pytest.raises(AssertionError):
x.to("cpu", device=torch.device("cpu"))