diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index 0b95554..5b7fec1 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -64,16 +64,45 @@ def mask(self) -> torch.Tensor: self._mask = self._tensor_mask() return self._mask - def to(self, device: torch.device) -> ParityTensor: + def to(self, whatever: torch.device | torch.dtype | str | None = None, *, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ParityTensor: """ - Copy the tensor to a specified device. + Copy the tensor to a specified device or copy it to a specified data type. """ - return dataclasses.replace( - self, - _tensor=self._tensor.to(device), - _parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None, - _mask=self._mask.to(device) if self._mask is not None else None, - ) + if whatever is None: + pass + elif isinstance(whatever, torch.device): + assert device is None, "Duplicate device specification." + device = whatever + elif isinstance(whatever, torch.dtype): + assert dtype is None, "Duplicate dtype specification." + dtype = whatever + elif isinstance(whatever, str): + assert device is None, "Duplicate device specification." + device = torch.device(whatever) + else: + raise TypeError(f"Unsupported type for 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.") + match (device, dtype): + case (None, None): + return self + case (None, _): + return dataclasses.replace( + self, + _tensor=self._tensor.to(dtype=dtype), + ) + case (_, None): + return dataclasses.replace( + self, + _tensor=self._tensor.to(device=device), + _parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None, + _mask=self._mask.to(device) if self._mask is not None else None, + ) + case _: + return dataclasses.replace( + self, + _tensor=self._tensor.to(device=device, dtype=dtype), + _parity=tuple(p.to(device=device) for p in self._parity) if self._parity is not None else None, + _mask=self._mask.to(device=device) if self._mask is not None else None, + ) def update_mask(self) -> ParityTensor: """ diff --git a/tests/conversion_test.py b/tests/conversion_test.py new file mode 100644 index 0000000..ea2e303 --- /dev/null +++ b/tests/conversion_test.py @@ -0,0 +1,56 @@ +import typing +import pytest +import torch +from parity_tensor import ParityTensor + + +@pytest.fixture() +def x() -> ParityTensor: + return ParityTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4])) + + +@pytest.mark.parametrize("dtype_arg", ["position", "keyword", "none"]) +@pytest.mark.parametrize("device_arg", ["position", "keyword", "none"]) +@pytest.mark.parametrize("device_format", ["object", "string"]) +def test_conversion( + x: ParityTensor, + dtype_arg: typing.Literal["position", "keyword", "none"], + device_arg: typing.Literal["position", "keyword", "none"], + device_format: typing.Literal["object", "string"], +) -> None: + args: list[typing.Any] = [] + kwargs: dict[str, typing.Any] = {} + + device = torch.device("cpu") if device_format == "object" else "cpu" + match device_arg: + case "position": + args.append(device) + case "keyword": + kwargs["device"] = device + case _: + pass + + match dtype_arg: + case "position": + args.append(torch.complex128) + case "keyword": + kwargs["dtype"] = torch.complex128 + case _: + pass + + if len(args) <= 1: + y = x.to(*args, **kwargs) + + +def test_conversion_invalid_type(x: ParityTensor) -> None: + with pytest.raises(TypeError): + x.to(2333) # type: ignore[arg-type] + + +def test_conversion_duplicated_value(x: ParityTensor) -> None: + with pytest.raises(AssertionError): + x.to(torch.device("cpu"), device=torch.device("cpu")) + with pytest.raises(AssertionError): + x.to(torch.complex128, dtype=torch.complex128) + with pytest.raises(AssertionError): + x.to("cpu", device=torch.device("cpu"))