diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 208975f970..e21b46e820 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -130,12 +130,17 @@ cdef class usm_ndarray: cdef Py_ssize_t _offset = offset cdef Py_ssize_t ary_min_displacement = 0 cdef Py_ssize_t ary_max_displacement = 0 + cdef Py_ssize_t tmp = 0 cdef char * data_ptr = NULL self._reset() if (not isinstance(shape, (list, tuple)) and not hasattr(shape, 'tolist')): - raise TypeError("Argument shape must be a list of a tuple.") + try: + tmp = shape + shape = [shape, ] + except Exception: + raise TypeError("Argument shape must be a list or a tuple.") nd = len(shape) typenum = dtype_to_typenum(dtype) itemsize = type_bytesize(typenum) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 6d565d8c88..e1cd9e9795 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -38,6 +38,7 @@ (4, 5), (2, 5, 2), (2, 2, 2, 2, 2, 2, 2, 2), + 5, ], ) @pytest.mark.parametrize("usm_type", ["shared", "host", "device"]) @@ -74,6 +75,7 @@ def test_allocate_usm_ndarray(shape, usm_type): "f8", "c8", "c16", + b"float32", np.dtype("d"), np.half, ], @@ -81,6 +83,15 @@ def test_allocate_usm_ndarray(shape, usm_type): def test_dtypes(dtype): Xusm = dpt.usm_ndarray((1,), dtype=dtype) assert Xusm.itemsize == np.dtype(dtype).itemsize + expected_fmt = (np.dtype(dtype).str)[1:] + actual_fmt = Xusm.__sycl_usm_array_interface__["typestr"][1:] + assert expected_fmt == actual_fmt + + +@pytest.mark.parametrize("dtype", ["", ">f4", "invalid", 123]) +def test_dtypes_invalid(dtype): + with pytest.raises((TypeError, ValueError)): + dpt.usm_ndarray((1,), dtype=dtype) def test_properties():