33#include < pybind11/numpy.h>
44#include < pybind11/pybind11.h>
55
6+ // clang-format off
7+ #include " dpctl_sycl_types.h"
68#include " ../_sycl_queue.h"
79#include " ../_sycl_queue_api.h"
8- #include " dpctl_sycl_types.h"
10+ #include " ../_sycl_device.h"
11+ #include " ../_sycl_device_api.h"
12+ // clang-format on
913
1014namespace py = pybind11;
1115
@@ -25,6 +29,34 @@ size_t get_max_compute_units(py::object queue)
2529 }
2630}
2731
32+ uint64_t get_device_global_mem_size (py::object device)
33+ {
34+ PyObject *device_pycapi = device.ptr ();
35+ if (PyObject_TypeCheck (device_pycapi, &PySyclDeviceType)) {
36+ DPCTLSyclDeviceRef DRef = get_device_ref (
37+ reinterpret_cast <PySyclDeviceObject *>(device_pycapi));
38+ sycl::device *d_ptr = reinterpret_cast <sycl::device *>(DRef);
39+ return d_ptr->get_info <sycl::info::device::global_mem_size>();
40+ }
41+ else {
42+ throw std::runtime_error (" expected dpctl.SyclDevice as argument" );
43+ }
44+ }
45+
46+ uint64_t get_device_local_mem_size (py::object device)
47+ {
48+ PyObject *device_pycapi = device.ptr ();
49+ if (PyObject_TypeCheck (device_pycapi, &PySyclDeviceType)) {
50+ DPCTLSyclDeviceRef DRef = get_device_ref (
51+ reinterpret_cast <PySyclDeviceObject *>(device_pycapi));
52+ sycl::device *d_ptr = reinterpret_cast <sycl::device *>(DRef);
53+ return d_ptr->get_info <sycl::info::device::local_mem_size>();
54+ }
55+ else {
56+ throw std::runtime_error (" expected dpctl.SyclDevice as argument" );
57+ }
58+ }
59+
2860py::array_t <int64_t >
2961offloaded_array_mod (py::object queue,
3062 py::array_t <int64_t , py::array::c_style> array,
@@ -82,11 +114,16 @@ offloaded_array_mod(py::object queue,
82114
83115PYBIND11_MODULE (pybind11_example, m)
84116{
85- // Import the dpctl._sycl_queue extension
117+ // Import the dpctl._sycl_queue, dpctl._sycl_device extensions
118+ import_dpctl___sycl_device ();
86119 import_dpctl___sycl_queue ();
87120 m.def (" get_max_compute_units" , &get_max_compute_units,
88121 " Computes max_compute_units property of the device underlying given "
89122 " dpctl.SyclQueue" );
123+ m.def (" get_device_global_mem_size" , &get_device_global_mem_size,
124+ " Computes amount of global memory of the given dpctl.SyclDevice" );
125+ m.def (" get_device_local_mem_size" , &get_device_local_mem_size,
126+ " Computes amount of global memory of the given dpctl.SyclDevice" );
90127 m.def (" offloaded_array_mod" , &offloaded_array_mod,
91128 " Compute offloaded modular reduction of integer-valued NumPy array" );
92129}
0 commit comments