Skip to content

Commit cf2dbcb

Browse files
authored
[enhancement] unify policy creation in pybind11 (#2390)
* create get_policy * formatting * Update policy.cpp
1 parent 79af322 commit cf2dbcb

File tree

4 files changed

+70
-95
lines changed

4 files changed

+70
-95
lines changed

onedal/common/_backend.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
122122
raise RuntimeError("Executing functions from SPMD backend requires a queue")
123123

124124
# craft the correct policy including the device queue
125-
if queue is None:
126-
policy = self.backend.host_policy()
127-
elif self.backend.is_spmd:
128-
policy = self.backend.spmd_data_parallel_policy(queue)
129-
elif self.backend.is_dpc:
130-
policy = self.backend.data_parallel_policy(queue)
131-
else:
132-
policy = self.backend.host_policy()
125+
policy = self.backend.get_policy(queue)
133126

134127
logger.debug(
135128
f"Dispatching function '{self.name}' with policy {policy} to {self.backend}"

onedal/common/policy.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#include "onedal/common/policy.hpp"
1919
#include "onedal/common/pybind11_helpers.hpp"
2020

21+
#ifdef ONEDAL_DATA_PARALLEL_SPMD
22+
#include "oneapi/dal/detail/spmd_policy.hpp"
23+
#include "oneapi/dal/spmd/mpi/communicator.hpp"
24+
#endif // ONEDAL_DATA_PARALLEL_SPMD
25+
2126
namespace py = pybind11;
2227

2328
namespace oneapi::dal::python {
@@ -78,15 +83,78 @@ void instantiate_data_parallel_policy(py::module& m) {
7883
});
7984
m.def("get_used_memory", &get_used_memory, py::return_value_policy::take_ownership);
8085
}
86+
#endif // ONEDAL_DATA_PARALLEL
87+
#ifdef ONEDAL_DATA_PARALLEL_SPMD
88+
using dp_policy_t = dal::detail::data_parallel_policy;
89+
using spmd_policy_t = dal::detail::spmd_policy<dp_policy_t>;
90+
91+
inline spmd_policy_t make_spmd_policy(dp_policy_t&& local) {
92+
sycl::queue& queue = local.get_queue();
93+
using backend_t = dal::preview::spmd::backend::mpi;
94+
auto comm = dal::preview::spmd::make_communicator<backend_t>(queue);
95+
return spmd_policy_t{ std::forward<dp_policy_t>(local), std::move(comm) };
96+
}
97+
98+
template <typename... Args>
99+
inline spmd_policy_t make_spmd_policy(Args&&... args) {
100+
auto local = make_dp_policy(std::forward<Args>(args)...);
101+
return make_spmd_policy(std::move(local));
102+
}
81103

104+
template <typename Arg, typename Policy = spmd_policy_t>
105+
inline void instantiate_costructor(py::class_<Policy>& policy) {
106+
policy.def(py::init([](const Arg& arg) {
107+
return make_spmd_policy(arg);
108+
}));
109+
}
110+
111+
void instantiate_spmd_policy(py::module& m) {
112+
constexpr const char name[] = "spmd_data_parallel_policy";
113+
py::class_<spmd_policy_t> policy(m, name);
114+
policy.def(py::init<spmd_policy_t>());
115+
policy.def(py::init([](const dp_policy_t& local) {
116+
return make_spmd_policy(local);
117+
}));
118+
policy.def(py::init([](std::uint32_t id) {
119+
return make_spmd_policy(id);
120+
}));
121+
policy.def(py::init([](const std::string& filter) {
122+
return make_spmd_policy(filter);
123+
}));
124+
policy.def(py::init([](const py::object& syclobj) {
125+
return make_spmd_policy(syclobj);
126+
}));
127+
policy.def("get_device_id", [](const spmd_policy_t& policy) {
128+
return get_device_id(policy.get_local());
129+
});
130+
policy.def("get_device_name", [](const spmd_policy_t& policy) {
131+
return get_device_name(policy.get_local());
132+
});
133+
}
134+
#endif // ONEDAL_DATA_PARALLEL_SPMD
135+
136+
py::object get_policy(py::object obj) {
137+
if (!obj.is(py::none())) {
138+
#ifdef ONEDAL_DATA_PARALLEL_SPMD
139+
return py::type::of<spmd_policy_t>()(obj);
140+
#elif ONEDAL_DATA_PARALLEL
141+
return py::type::of<dp_policy_t>()(obj);
142+
#else
143+
throw std::invalid_argument("queues are not supported in the oneDAL backend");
82144
#endif // ONEDAL_DATA_PARALLEL
145+
}
146+
return py::type::of<host_policy_t>()();
147+
};
83148

84149
ONEDAL_PY_INIT_MODULE(policy) {
85150
instantiate_host_policy(m);
86151
instantiate_default_host_policy(m);
87152
#ifdef ONEDAL_DATA_PARALLEL
88153
instantiate_data_parallel_policy(m);
89154
#endif // ONEDAL_DATA_PARALLEL
155+
#ifdef ONEDAL_DATA_PARALLEL_SPMD
156+
instantiate_spmd_policy(m);
157+
#endif // ONEDAL_DATA_PARALLEL_SPMD
158+
m.def("get_policy", &get_policy, py::arg("queue") = py::none());
90159
}
91-
92160
} // namespace oneapi::dal::python

onedal/common/spmd_policy.cpp

Lines changed: 0 additions & 82 deletions
This file was deleted.

onedal/dal.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ namespace oneapi::dal::python {
2323

2424
/* common */
2525
#ifdef ONEDAL_DATA_PARALLEL_SPMD
26-
ONEDAL_PY_INIT_MODULE(spmd_policy);
27-
2826
/* algorithms */
2927
ONEDAL_PY_INIT_MODULE(covariance);
3028
ONEDAL_PY_INIT_MODULE(dbscan);
@@ -85,8 +83,6 @@ ONEDAL_PY_INIT_MODULE(finiteness_checker);
8583

8684
#ifdef ONEDAL_DATA_PARALLEL_SPMD
8785
PYBIND11_MODULE(_onedal_py_spmd_dpc, m) {
88-
init_spmd_policy(m);
89-
9086
init_covariance(m);
9187
init_dbscan(m);
9288
init_decomposition(m);

0 commit comments

Comments
 (0)