Skip to content
Merged
2 changes: 2 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
const DPCTLSyclPlatformRef)


cdef extern from "syclinterface/dpctl_sycl_context_interface.h":
Expand Down
20 changes: 20 additions & 0 deletions dpctl/_sycl_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ from ._backend cimport ( # noqa: E211
DPCTLDevice_GetMaxWriteImageArgs,
DPCTLDevice_GetName,
DPCTLDevice_GetParentDevice,
DPCTLDevice_GetPlatform,
DPCTLDevice_GetPreferredVectorWidthChar,
DPCTLDevice_GetPreferredVectorWidthDouble,
DPCTLDevice_GetPreferredVectorWidthFloat,
Expand Down Expand Up @@ -80,6 +81,7 @@ from ._backend cimport ( # noqa: E211
DPCTLSize_t_Array_Delete,
DPCTLSyclDeviceRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclPlatformRef,
_aspect_type,
_backend_type,
_device_type,
Expand All @@ -91,6 +93,8 @@ from .enum_types import backend_type, device_type
from libc.stdint cimport int64_t, uint32_t
from libc.stdlib cimport free, malloc

from ._sycl_platform cimport SyclPlatform

import collections
import warnings

Expand Down Expand Up @@ -639,6 +643,22 @@ cdef class SyclDevice(_SyclDevice):
self._device_ref
)

@property
def sycl_platform(self):
""" Returns the platform associated with this device.

Returns:
:class:`dpctl.SyclPlatform`: The platform associated with this
device.
"""
cdef DPCTLSyclPlatformRef PRef = (
DPCTLDevice_GetPlatform(self._device_ref)
)
if (PRef == NULL):
raise RuntimeError("Could not get platform for device.")
else:
return SyclPlatform._create(PRef)

@property
def preferred_vector_width_char(self):
""" Returns the preferred native vector width size for built-in scalar
Expand Down
25 changes: 23 additions & 2 deletions dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatform_CreateFromSelector,
DPCTLPlatform_Delete,
DPCTLPlatform_GetBackend,
DPCTLPlatform_GetDefaultContext,
DPCTLPlatform_GetName,
DPCTLPlatform_GetPlatforms,
DPCTLPlatform_GetVendor,
Expand All @@ -40,15 +41,19 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatformVector_GetAt,
DPCTLPlatformVector_Size,
DPCTLPlatformVectorRef,
DPCTLSyclContextRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclPlatformRef,
_backend_type,
)

import warnings

from ._sycl_context import SyclContextCreationError
from .enum_types import backend_type

from ._sycl_context cimport SyclContext

__all__ = [
"get_platforms",
"lsplatform",
Expand Down Expand Up @@ -236,10 +241,10 @@ cdef class SyclPlatform(_SyclPlatform):

@property
def backend(self):
"""Returns the backend_type enum value for this device
"""Returns the backend_type enum value for this platform

Returns:
backend_type: The backend for the device.
backend_type: The backend for the platform.
"""
cdef _backend_type BTy = (
DPCTLPlatform_GetBackend(self._platform_ref)
Expand All @@ -255,6 +260,22 @@ cdef class SyclPlatform(_SyclPlatform):
else:
raise ValueError("Unknown backend type.")

@property
def default_context(self):
"""Returns the default platform context for this platform

Returns:
SyclContext: The default context for the platform.
"""
cdef DPCTLSyclContextRef CRef = (
DPCTLPlatform_GetDefaultContext(self._platform_ref)
)

if (CRef == NULL):
raise
else:
return SyclContext._create(CRef)


def lsplatform(verbosity=0):
"""
Expand Down
7 changes: 7 additions & 0 deletions dpctl/tests/test_sycl_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,11 @@ def check_profiling_timer_resolution(device):
assert isinstance(resol, int) and resol > 0


def check_platform(device):
p = device.sycl_platform
assert isinstance(p, dpctl.SyclPlatform)


list_of_checks = [
check_get_max_compute_units,
check_get_max_work_item_dims,
Expand Down Expand Up @@ -552,6 +557,8 @@ def check_profiling_timer_resolution(device):
check_repr,
check_get_global_mem_size,
check_get_local_mem_size,
check_profiling_timer_resolution,
check_platform,
]


Expand Down
5 changes: 5 additions & 0 deletions dpctl/tests/test_sycl_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def check_repr(platform):
assert r != ""


def check_default_context(platform):
r = platform.default_context
assert type(r) is dpctl.SyclContext


list_of_checks = [
check_name,
check_vendor,
Expand Down
12 changes: 12 additions & 0 deletions libsyclinterface/include/dpctl_sycl_platform_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,16 @@ DPCTLPlatform_GetVersion(__dpctl_keep const DPCTLSyclPlatformRef PRef);
DPCTL_API
__dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms(void);

/*!
* @brief Returns a DPCTLSyclContextRef for default platform context.
*
* @param PRef Opaque pointer to a sycl::platform
* @return A DPCTLSyclContextRef value for the default platform associated
* with this platform.
* @ingroup PlatformInterface
*/
DPCTL_API
__dpctl_give DPCTLSyclContextRef
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);

DPCTL_C_EXTERN_C_END
17 changes: 17 additions & 0 deletions libsyclinterface/source/dpctl_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using namespace cl::sycl;
namespace
{
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(platform, DPCTLSyclPlatformRef);
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef);
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device_selector, DPCTLSyclDeviceSelectorRef);
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclPlatformRef>,
DPCTLPlatformVectorRef);
Expand Down Expand Up @@ -202,3 +203,19 @@ __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
// the wrap function is defined inside dpctl_vector_templ.cpp
return wrap(Platforms);
}

__dpctl_give DPCTLSyclContextRef
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
auto P = unwrap(PRef);
if (P) {
auto default_ctx = P->ext_oneapi_get_default_context();
return wrap(new context(default_ctx));
}
else {
error_handler(
"Default platform cannot be obtained up for a NULL platform.",
__FILE__, __func__, __LINE__);
return nullptr;
}
}
24 changes: 24 additions & 0 deletions libsyclinterface/tests/test_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//===----------------------------------------------------------------------===//

#include "Support/CBindingWrapping.h"
#include "dpctl_sycl_context_interface.h"
#include "dpctl_sycl_device_selector_interface.h"
#include "dpctl_sycl_platform_interface.h"
#include "dpctl_sycl_platform_manager.h"
Expand Down Expand Up @@ -82,6 +83,16 @@ void check_platform_backend(__dpctl_keep const DPCTLSyclPlatformRef PRef)
}());
}

void check_platform_default_context(
__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
DPCTLSyclContextRef CRef = nullptr;
EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(PRef));
EXPECT_TRUE(CRef != nullptr);

EXPECT_NO_FATAL_FAILURE(DPCTLContext_Delete(CRef));
}

} // namespace

struct TestDPCTLSyclPlatformInterface
Expand Down Expand Up @@ -167,6 +178,14 @@ TEST_F(TestDPCTLSyclPlatformNull, ChkGetVersion)
ASSERT_TRUE(version == nullptr);
}

TEST_F(TestDPCTLSyclPlatformNull, ChkGetDefaultConext)
{
DPCTLSyclContextRef CRef = nullptr;

EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(NullPRef));
EXPECT_TRUE(CRef == nullptr);
}

struct TestDPCTLSyclDefaultPlatform : public ::testing::Test
{
DPCTLSyclPlatformRef PRef = nullptr;
Expand Down Expand Up @@ -207,6 +226,11 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkGetBackend)
check_platform_backend(PRef);
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkGetDefaultContext)
{
check_platform_default_context(PRef);
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkCopy)
{
DPCTLSyclPlatformRef Copied_PRef = nullptr;
Expand Down
1 change: 0 additions & 1 deletion libsyclinterface/tests/test_sycl_queue_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ TEST_P(TestDPCTLQueueMemberFunctions, CheckMemset)

ASSERT_NO_FATAL_FAILURE(DPCTLfree_with_queue(p, QRef));

bool equal = true;
for (size_t i = 0; i < nbytes; ++i) {
ASSERT_TRUE(host_arr[i] == val);
}
Expand Down