diff --git a/dpctl/_backend.pxd b/dpctl/_backend.pxd index f485fd304f..4af83deb4e 100644 --- a/dpctl/_backend.pxd +++ b/dpctl/_backend.pxd @@ -226,6 +226,7 @@ cdef extern from "dpctl_sycl_event_interface.h": cdef DPCTLSyclEventRef DPCTLEvent_Create() cdef DPCTLSyclEventRef DPCTLEvent_Copy(const DPCTLSyclEventRef ERef) cdef void DPCTLEvent_Wait(DPCTLSyclEventRef ERef) + cdef void DPCTLEvent_WaitAndThrow(DPCTLSyclEventRef ERef) cdef void DPCTLEvent_Delete(DPCTLSyclEventRef ERef) cdef _event_status_type DPCTLEvent_GetCommandExecutionStatus(DPCTLSyclEventRef ERef) cdef _backend_type DPCTLEvent_GetBackend(DPCTLSyclEventRef ERef) diff --git a/dpctl/_sycl_event.pxd b/dpctl/_sycl_event.pxd index c30b64a660..8f385a711c 100644 --- a/dpctl/_sycl_event.pxd +++ b/dpctl/_sycl_event.pxd @@ -46,5 +46,5 @@ cdef public class SyclEventRaw(_SyclEventRaw) [object PySyclEventRawObject, type cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other) cdef int _init_event_from_SyclEvent(self, SyclEvent event) cdef int _init_event_from_capsule(self, object caps) - cdef DPCTLSyclEventRef get_event_ref (self) - cpdef void wait (self) + cdef DPCTLSyclEventRef get_event_ref (self) + cdef void _wait (SyclEventRaw event) diff --git a/dpctl/_sycl_event.pyx b/dpctl/_sycl_event.pyx index 499949eff6..7cf529d4b3 100644 --- a/dpctl/_sycl_event.pyx +++ b/dpctl/_sycl_event.pyx @@ -25,6 +25,7 @@ import logging from cpython cimport pycapsule from libc.stdint cimport uint64_t +import collections.abc from ._backend cimport ( # noqa: E211 DPCTLEvent_Copy, @@ -37,6 +38,7 @@ from ._backend cimport ( # noqa: E211 DPCTLEvent_GetProfilingInfoSubmit, DPCTLEvent_GetWaitList, DPCTLEvent_Wait, + DPCTLEvent_WaitAndThrow, DPCTLEventVector_Delete, DPCTLEventVector_GetAt, DPCTLEventVector_Size, @@ -192,8 +194,25 @@ cdef class SyclEventRaw(_SyclEventRaw): """ return self._event_ref - cpdef void wait(self): - DPCTLEvent_Wait(self._event_ref) + @staticmethod + cdef void _wait(SyclEventRaw event): + DPCTLEvent_WaitAndThrow(event._event_ref) + + @staticmethod + def wait(event): + """ Waits for a given event or a sequence of events. + """ + if (isinstance(event, collections.abc.Sequence) and + all( (isinstance(el, SyclEventRaw) for el in event) )): + for e in event: + SyclEventRaw._wait(e) + elif isinstance(event, SyclEventRaw): + SyclEventRaw._wait(event) + else: + raise TypeError( + "The passed argument is not a SyclEventRaw type or " + "a sequence of such objects" + ) def addressof_ref(self): """ Returns the address of the C API `DPCTLSyclEventRef` pointer as diff --git a/dpctl/tests/test_sycl_event.py b/dpctl/tests/test_sycl_event.py index 4e0fe30f49..d1bae4ac0d 100644 --- a/dpctl/tests/test_sycl_event.py +++ b/dpctl/tests/test_sycl_event.py @@ -81,6 +81,23 @@ def test_create_event_raw_from_capsule(): pytest.fail("Failed to create an event from capsule") +def test_wait_with_event(): + event = dpctl.SyclEventRaw() + try: + dpctl.SyclEventRaw.wait(event) + except ValueError: + pytest.fail("Failed to wait for the event") + + +def test_wait_with_list(): + event_1 = dpctl.SyclEventRaw() + event_2 = dpctl.SyclEventRaw() + try: + dpctl.SyclEventRaw.wait([event_1, event_2]) + except ValueError: + pytest.fail("Failed to wait for events from the list") + + def test_execution_status(): event = dpctl.SyclEventRaw() try: