From 1ed961d8a67388b04aca0f2a40c2a2bac12f0d04 Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Tue, 17 Aug 2021 14:22:56 -0500 Subject: [PATCH 1/2] Add static method wait for SyclEventRaw class --- dpctl/_backend.pxd | 1 + dpctl/_sycl_event.pxd | 4 ++-- dpctl/_sycl_event.pyx | 17 +++++++++++++++-- dpctl/tests/test_sycl_event.py | 17 +++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index f485fd304f..4af83deb4e 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -226,6 +226,7 @@ cdef extern from "dpctl_sycl_event_interface.h": cdef DPCTLSyclEventRef DPCTLEvent_Create() cdef DPCTLSyclEventRef DPCTLEvent_Copy(const DPCTLSyclEventRef ERef) cdef void DPCTLEvent_Wait(DPCTLSyclEventRef ERef) + cdef void DPCTLEvent_WaitAndThrow(DPCTLSyclEventRef ERef) cdef void DPCTLEvent_Delete(DPCTLSyclEventRef ERef) cdef _event_status_type DPCTLEvent_GetCommandExecutionStatus(DPCTLSyclEventRef ERef) cdef _backend_type DPCTLEvent_GetBackend(DPCTLSyclEventRef ERef) diff --git a/dpctl/_sycl_event.pxd b/dpctl/_sycl_event.pxd index c30b64a660..8f385a711c 100644 --- a/dpctl/_sycl_event.pxd +++ b/dpctl/_sycl_event.pxd @@ -46,5 +46,5 @@ cdef public class SyclEventRaw(_SyclEventRaw) [object PySyclEventRawObject, type cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other) cdef int _init_event_from_SyclEvent(self, SyclEvent event) cdef int _init_event_from_capsule(self, object caps) - cdef DPCTLSyclEventRef get_event_ref (self) - cpdef void wait (self) + cdef DPCTLSyclEventRef get_event_ref (self) + cdef void _wait (SyclEventRaw event) diff --git a/dpctl/_sycl_event.pyx b/dpctl/_sycl_event.pyx index 499949eff6..2bca5619b4 100644 --- a/dpctl/_sycl_event.pyx +++ b/dpctl/_sycl_event.pyx @@ -37,6 +37,7 @@ from ._backend cimport ( # noqa: E211 DPCTLEvent_GetProfilingInfoSubmit, DPCTLEvent_GetWaitList, DPCTLEvent_Wait, + DPCTLEvent_WaitAndThrow, DPCTLEventVector_Delete, DPCTLEventVector_GetAt, DPCTLEventVector_Size, @@ -192,8 +193,20 @@ cdef class SyclEventRaw(_SyclEventRaw): """ return self._event_ref - cpdef void wait(self): - DPCTLEvent_Wait(self._event_ref) + @staticmethod + cdef void _wait(SyclEventRaw event): + DPCTLEvent_WaitAndThrow(event._event_ref) + + @staticmethod + def wait(event): + if isinstance(event, list): + for e in event: + SyclEventRaw._wait(e) + elif isinstance(event, SyclEventRaw): + SyclEventRaw._wait(event) + else: + raise ValueError("The passed argument is not a list \ + or a SyclEventRaw type.") def addressof_ref(self): """ Returns the address of the C API `DPCTLSyclEventRef` pointer as diff --git a/dpctl/tests/test_sycl_event.py b/dpctl/tests/test_sycl_event.py index 4e0fe30f49..d1bae4ac0d 100644 --- a/dpctl/tests/test_sycl_event.py +++ b/dpctl/tests/test_sycl_event.py @@ -81,6 +81,23 @@ def test_create_event_raw_from_capsule(): pytest.fail("Failed to create an event from capsule") +def test_wait_with_event(): + event = dpctl.SyclEventRaw() + try: + dpctl.SyclEventRaw.wait(event) + except ValueError: + pytest.fail("Failed to wait for the event") + + +def test_wait_with_list(): + event_1 = dpctl.SyclEventRaw() + event_2 = dpctl.SyclEventRaw() + try: + dpctl.SyclEventRaw.wait([event_1, event_2]) + except ValueError: + pytest.fail("Failed to wait for events from the list") + + def test_execution_status(): event = dpctl.SyclEventRaw() try: From 21276ff2d5a17cd5106836a15248b6537f17c103 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 24 Aug 2021 09:48:55 -0500 Subject: [PATCH 2/2] Allow dpctl.SyclEventRaw.wait to take a sequence of events --- dpctl/_sycl_event.pyx | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dpctl/_sycl_event.pyx b/dpctl/_sycl_event.pyx index 2bca5619b4..7cf529d4b3 100644 --- a/dpctl/_sycl_event.pyx +++ b/dpctl/_sycl_event.pyx @@ -25,6 +25,7 @@ import logging from cpython cimport pycapsule from libc.stdint cimport uint64_t +import collections.abc from ._backend cimport ( # noqa: E211 DPCTLEvent_Copy, @@ -199,14 +200,19 @@ cdef class SyclEventRaw(_SyclEventRaw): @staticmethod def wait(event): - if isinstance(event, list): + """ Waits for a given event or a sequence of events. + """ + if (isinstance(event, collections.abc.Sequence) and + all( (isinstance(el, SyclEventRaw) for el in event) )): for e in event: SyclEventRaw._wait(e) elif isinstance(event, SyclEventRaw): SyclEventRaw._wait(event) else: - raise ValueError("The passed argument is not a list \ - or a SyclEventRaw type.") + raise TypeError( + "The passed argument is not a SyclEventRaw type or " + "a sequence of such objects" + ) def addressof_ref(self): """ Returns the address of the C API `DPCTLSyclEventRef` pointer as