From d9398929606dc49ba83c400214793ec0370cde86 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 23 Jun 2023 18:05:04 +0200 Subject: [PATCH 1/2] Return default floating point type of the device when new_dtype=None --- dpctl/tensor/_copy_utils.py | 6 ++++-- dpctl/tests/test_usm_ndarray_ctor.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 86d87ff46c..07c0bf3c56 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -23,6 +23,7 @@ import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti import dpctl.utils +from dpctl.tensor._ctors import _get_dtype from dpctl.tensor._device import normalize_queue_device __doc__ = ( @@ -364,7 +365,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): array (usm_ndarray): An input array. new_dtype (dtype): - The data type of the resulting array. + The data type of the resulting array. If None, gives default + floating point type supported by device where `array` is allocated. order ({"C", "F", "A", "K"}, optional): Controls memory layout of the resulting array if a copy is returned. @@ -392,7 +394,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): "Recognized values are 'A', 'C', 'F', or 'K'" ) ary_dtype = usm_ary.dtype - target_dtype = dpt.dtype(newdtype) + target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) if not dpt.can_cast(ary_dtype, target_dtype, casting=casting): raise TypeError( f"Can not cast from {ary_dtype} to {newdtype} " diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index da2201c75f..67bc162b81 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1196,6 +1196,11 @@ def test_astype(): assert np.allclose(dpt.to_numpy(Y), np.full(Y.shape, 7, dtype="f4")) Y = dpt.astype(X[::2, ::-1], "i4", order="K", copy=False) assert Y.usm_data is X.usm_data + Y = dpt.astype(X, None, order="K") + if X.sycl_queue.sycl_device.has_aspect_fp64: + assert Y.dtype is dpt.float64 + else: + assert Y.dtype is dpt.float32 def test_astype_invalid_order(): From 1518258147e539e2c1cc805bf011bf13affd2b32 Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Mon, 26 Jun 2023 23:24:06 +0200 Subject: [PATCH 2/2] Fix Co-authored-by: Oleksandr Pavlyk --- dpctl/tensor/_copy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 07c0bf3c56..c220b61b26 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -365,7 +365,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): array (usm_ndarray): An input array. new_dtype (dtype): - The data type of the resulting array. If None, gives default + The data type of the resulting array. If `None`, gives default floating point type supported by device where `array` is allocated. order ({"C", "F", "A", "K"}, optional): Controls memory layout of the resulting array if a copy