Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ cdef extern from "dpctl_sycl_event_interface.h":
cdef DPCTLSyclEventRef DPCTLEvent_Create()
cdef DPCTLSyclEventRef DPCTLEvent_Copy(const DPCTLSyclEventRef ERef)
cdef void DPCTLEvent_Wait(DPCTLSyclEventRef ERef)
cdef void DPCTLEvent_WaitAndThrow(DPCTLSyclEventRef ERef)
cdef void DPCTLEvent_Delete(DPCTLSyclEventRef ERef)
cdef _event_status_type DPCTLEvent_GetCommandExecutionStatus(DPCTLSyclEventRef ERef)
cdef _backend_type DPCTLEvent_GetBackend(DPCTLSyclEventRef ERef)
Expand Down
4 changes: 2 additions & 2 deletions dpctl/_sycl_event.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ cdef public class SyclEventRaw(_SyclEventRaw) [object PySyclEventRawObject, type
cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other)
cdef int _init_event_from_SyclEvent(self, SyclEvent event)
cdef int _init_event_from_capsule(self, object caps)
cdef DPCTLSyclEventRef get_event_ref (self)
cpdef void wait (self)
cdef DPCTLSyclEventRef get_event_ref (self)
cdef void _wait (SyclEventRaw event)
23 changes: 21 additions & 2 deletions dpctl/_sycl_event.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import logging

from cpython cimport pycapsule
from libc.stdint cimport uint64_t
import collections.abc

from ._backend cimport ( # noqa: E211
DPCTLEvent_Copy,
Expand All @@ -37,6 +38,7 @@ from ._backend cimport ( # noqa: E211
DPCTLEvent_GetProfilingInfoSubmit,
DPCTLEvent_GetWaitList,
DPCTLEvent_Wait,
DPCTLEvent_WaitAndThrow,
DPCTLEventVector_Delete,
DPCTLEventVector_GetAt,
DPCTLEventVector_Size,
Expand Down Expand Up @@ -192,8 +194,25 @@ cdef class SyclEventRaw(_SyclEventRaw):
"""
return self._event_ref

cpdef void wait(self):
DPCTLEvent_Wait(self._event_ref)
@staticmethod
cdef void _wait(SyclEventRaw event):
DPCTLEvent_WaitAndThrow(event._event_ref)

@staticmethod
def wait(event):
""" Waits for a given event or a sequence of events.
"""
if (isinstance(event, collections.abc.Sequence) and
all( (isinstance(el, SyclEventRaw) for el in event) )):
for e in event:
SyclEventRaw._wait(e)
elif isinstance(event, SyclEventRaw):
SyclEventRaw._wait(event)
else:
raise TypeError(
"The passed argument is not a SyclEventRaw type or "
"a sequence of such objects"
)

def addressof_ref(self):
""" Returns the address of the C API `DPCTLSyclEventRef` pointer as
Expand Down
17 changes: 17 additions & 0 deletions dpctl/tests/test_sycl_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ def test_create_event_raw_from_capsule():
pytest.fail("Failed to create an event from capsule")


def test_wait_with_event():
event = dpctl.SyclEventRaw()
try:
dpctl.SyclEventRaw.wait(event)
except ValueError:
pytest.fail("Failed to wait for the event")


def test_wait_with_list():
event_1 = dpctl.SyclEventRaw()
event_2 = dpctl.SyclEventRaw()
try:
dpctl.SyclEventRaw.wait([event_1, event_2])
except ValueError:
pytest.fail("Failed to wait for events from the list")


def test_execution_status():
event = dpctl.SyclEventRaw()
try:
Expand Down