Skip to content

Commit 6533536

Browse files
Merge pull request #529 from vlad-perevezentsev/cython_wait
Add a static method `wait` for the SyclEventRaw class
2 parents 1320c34 + 21276ff commit 6533536

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

dpctl/_backend.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ cdef extern from "dpctl_sycl_event_interface.h":
226226
cdef DPCTLSyclEventRef DPCTLEvent_Create()
227227
cdef DPCTLSyclEventRef DPCTLEvent_Copy(const DPCTLSyclEventRef ERef)
228228
cdef void DPCTLEvent_Wait(DPCTLSyclEventRef ERef)
229+
cdef void DPCTLEvent_WaitAndThrow(DPCTLSyclEventRef ERef)
229230
cdef void DPCTLEvent_Delete(DPCTLSyclEventRef ERef)
230231
cdef _event_status_type DPCTLEvent_GetCommandExecutionStatus(DPCTLSyclEventRef ERef)
231232
cdef _backend_type DPCTLEvent_GetBackend(DPCTLSyclEventRef ERef)

dpctl/_sycl_event.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ cdef public class SyclEventRaw(_SyclEventRaw) [object PySyclEventRawObject, type
4646
cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other)
4747
cdef int _init_event_from_SyclEvent(self, SyclEvent event)
4848
cdef int _init_event_from_capsule(self, object caps)
49-
cdef DPCTLSyclEventRef get_event_ref (self)
50-
cpdef void wait (self)
49+
cdef DPCTLSyclEventRef get_event_ref (self)
50+
cdef void _wait (SyclEventRaw event)

dpctl/_sycl_event.pyx

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import logging
2525

2626
from cpython cimport pycapsule
2727
from libc.stdint cimport uint64_t
28+
import collections.abc
2829

2930
from ._backend cimport ( # noqa: E211
3031
DPCTLEvent_Copy,
@@ -37,6 +38,7 @@ from ._backend cimport ( # noqa: E211
3738
DPCTLEvent_GetProfilingInfoSubmit,
3839
DPCTLEvent_GetWaitList,
3940
DPCTLEvent_Wait,
41+
DPCTLEvent_WaitAndThrow,
4042
DPCTLEventVector_Delete,
4143
DPCTLEventVector_GetAt,
4244
DPCTLEventVector_Size,
@@ -192,8 +194,25 @@ cdef class SyclEventRaw(_SyclEventRaw):
192194
"""
193195
return self._event_ref
194196

195-
cpdef void wait(self):
196-
DPCTLEvent_Wait(self._event_ref)
197+
@staticmethod
198+
cdef void _wait(SyclEventRaw event):
199+
DPCTLEvent_WaitAndThrow(event._event_ref)
200+
201+
@staticmethod
202+
def wait(event):
203+
""" Waits for a given event or a sequence of events.
204+
"""
205+
if (isinstance(event, collections.abc.Sequence) and
206+
all( (isinstance(el, SyclEventRaw) for el in event) )):
207+
for e in event:
208+
SyclEventRaw._wait(e)
209+
elif isinstance(event, SyclEventRaw):
210+
SyclEventRaw._wait(event)
211+
else:
212+
raise TypeError(
213+
"The passed argument is not a SyclEventRaw type or "
214+
"a sequence of such objects"
215+
)
197216

198217
def addressof_ref(self):
199218
""" Returns the address of the C API `DPCTLSyclEventRef` pointer as

dpctl/tests/test_sycl_event.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,23 @@ def test_create_event_raw_from_capsule():
8181
pytest.fail("Failed to create an event from capsule")
8282

8383

84+
def test_wait_with_event():
85+
event = dpctl.SyclEventRaw()
86+
try:
87+
dpctl.SyclEventRaw.wait(event)
88+
except ValueError:
89+
pytest.fail("Failed to wait for the event")
90+
91+
92+
def test_wait_with_list():
93+
event_1 = dpctl.SyclEventRaw()
94+
event_2 = dpctl.SyclEventRaw()
95+
try:
96+
dpctl.SyclEventRaw.wait([event_1, event_2])
97+
except ValueError:
98+
pytest.fail("Failed to wait for events from the list")
99+
100+
84101
def test_execution_status():
85102
event = dpctl.SyclEventRaw()
86103
try:

0 commit comments

Comments
 (0)