From ff090b71a5272ae2c11d32bb7110c92559993b79 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 10 Nov 2022 07:59:55 -0600 Subject: [PATCH] dpnp_array must expose usm_type --- dpnp/dpnp_array.py | 4 ++++ tests/test_random_state.py | 26 ++++++++++++++++---------- tests/test_sycl_queue.py | 6 ++---- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 1ac50e12c381..c3d35ab0e729 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -115,6 +115,10 @@ def sycl_context(self): def device(self): return self._array_obj.device + @property + def usm_type(self): + return self._array_obj.usm_type + def __abs__(self): return dpnp.abs(self) diff --git a/tests/test_random_state.py b/tests/test_random_state.py index c09b5c17a880..9d2f14643c84 100644 --- a/tests/test_random_state.py +++ b/tests/test_random_state.py @@ -15,6 +15,12 @@ ) +def assert_cfd(data, exp_sycl_queue, exp_usm_type=None): + assert exp_sycl_queue == data.sycl_queue + if exp_usm_type: + assert exp_usm_type == data.usm_type + + class TestNormal: @pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64, None], @@ -47,7 +53,7 @@ def test_distr(self, dtype, usm_type): assert_array_almost_equal(dpnp.asnumpy(data), desired, decimal=precision) # check if compute follows data isn't broken - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) @pytest.mark.parametrize("dtype", @@ -138,7 +144,7 @@ def test_fallback(self, loc, scale): assert_array_almost_equal(actual, desired, decimal=precision) # check if compute follows data isn't broken - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue) @pytest.mark.parametrize("dtype", @@ -174,17 +180,17 @@ def test_distr(self, usm_type): precision = numpy.finfo(dtype=numpy.float64).precision assert_array_almost_equal(dpnp.asnumpy(data), desired, decimal=precision) - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) # call with the same seed has to draw the same values data = RandomState(seed, sycl_queue=sycl_queue).rand(3, 2, usm_type=usm_type) assert_array_almost_equal(dpnp.asnumpy(data), desired, decimal=precision) - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) # call with omitted dimensions has to draw the first element from desired data = RandomState(seed, sycl_queue=sycl_queue).rand(usm_type=usm_type) assert_array_almost_equal(dpnp.asnumpy(data), desired[0, 0], decimal=precision) - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) # rand() is an alias on random_sample(), map arguments with mock.patch('dpnp.random.RandomState.random_sample') as m: @@ -245,7 +251,7 @@ def test_distr(self, dtype, usm_type): [5, 3], [5, 7]], dtype=numpy.int32) assert_array_equal(dpnp.asnumpy(data), desired) - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) # call with the same seed has to draw the same values data = RandomState(seed, sycl_queue=sycl_queue).randint(low=low, @@ -254,7 +260,7 @@ def test_distr(self, dtype, usm_type): dtype=dtype, usm_type=usm_type) assert_array_equal(dpnp.asnumpy(data), desired) - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) # call with omitted dimensions has to draw the first element from desired data = RandomState(seed, sycl_queue=sycl_queue).randint(low=low, @@ -262,7 +268,7 @@ def test_distr(self, dtype, usm_type): dtype=dtype, usm_type=usm_type) assert_array_equal(dpnp.asnumpy(data), desired[0, 0]) - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) # rand() is an alias on random_sample(), map arguments with mock.patch('dpnp.random.RandomState.uniform') as m: @@ -701,7 +707,7 @@ def test_distr(self, bounds, dtype, usm_type): assert_array_equal(dpnp.asnumpy(data), desired) # check if compute follows data isn't broken - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue, usm_type) @pytest.mark.parametrize("dtype", @@ -766,7 +772,7 @@ def test_fallback(self, low, high): assert_array_almost_equal(actual, desired, decimal=precision) # check if compute follows data isn't broken - assert sycl_queue == data.sycl_queue + assert_cfd(data, sycl_queue) @pytest.mark.parametrize("dtype", diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index e1d902588afe..e3e8680e6aca 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -278,8 +278,7 @@ def test_uniform(usm_type, size): high = 2.0 res = dpnp.random.uniform(low, high, size=size, usm_type=usm_type) - res_usm_type = res.get_array().usm_type - assert usm_type == res_usm_type + assert usm_type == res.usm_type @pytest.mark.parametrize("usm_type", @@ -295,8 +294,7 @@ def test_rs_uniform(usm_type, seed): rs = dpnp.random.RandomState(seed, sycl_queue=sycl_queue) res = rs.uniform(low, high, usm_type=usm_type) - res_usm_type = res.get_array().usm_type - assert usm_type == res_usm_type + assert usm_type == res.usm_type res_sycl_queue = res.get_array().sycl_queue assert_sycl_queue_equal(res_sycl_queue, sycl_queue)