Skip to content

[enhancement] unify policy creation in pybind11 #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions onedal/common/_backend.py
Original file line number Diff line number Diff line change
@@ -122,14 +122,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise RuntimeError("Executing functions from SPMD backend requires a queue")

# craft the correct policy including the device queue
if queue is None:
policy = self.backend.host_policy()
elif self.backend.is_spmd:
policy = self.backend.spmd_data_parallel_policy(queue)
elif self.backend.is_dpc:
policy = self.backend.data_parallel_policy(queue)
else:
policy = self.backend.host_policy()
policy = self.backend.get_policy(queue)

logger.debug(
f"Dispatching function '{self.name}' with policy {policy} to {self.backend}"
70 changes: 69 additions & 1 deletion onedal/common/policy.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,11 @@
#include "onedal/common/policy.hpp"
#include "onedal/common/pybind11_helpers.hpp"

#ifdef ONEDAL_DATA_PARALLEL_SPMD
#include "oneapi/dal/detail/spmd_policy.hpp"
#include "oneapi/dal/spmd/mpi/communicator.hpp"
#endif // ONEDAL_DATA_PARALLEL_SPMD

namespace py = pybind11;

namespace oneapi::dal::python {
@@ -78,15 +83,78 @@ void instantiate_data_parallel_policy(py::module& m) {
});
m.def("get_used_memory", &get_used_memory, py::return_value_policy::take_ownership);
}
#endif // ONEDAL_DATA_PARALLEL
#ifdef ONEDAL_DATA_PARALLEL_SPMD
using dp_policy_t = dal::detail::data_parallel_policy;
using spmd_policy_t = dal::detail::spmd_policy<dp_policy_t>;

inline spmd_policy_t make_spmd_policy(dp_policy_t&& local) {
sycl::queue& queue = local.get_queue();
using backend_t = dal::preview::spmd::backend::mpi;
auto comm = dal::preview::spmd::make_communicator<backend_t>(queue);
return spmd_policy_t{ std::forward<dp_policy_t>(local), std::move(comm) };
}

template <typename... Args>
inline spmd_policy_t make_spmd_policy(Args&&... args) {
auto local = make_dp_policy(std::forward<Args>(args)...);
return make_spmd_policy(std::move(local));
}

template <typename Arg, typename Policy = spmd_policy_t>
inline void instantiate_costructor(py::class_<Policy>& policy) {
policy.def(py::init([](const Arg& arg) {
return make_spmd_policy(arg);
}));
}

void instantiate_spmd_policy(py::module& m) {
constexpr const char name[] = "spmd_data_parallel_policy";
py::class_<spmd_policy_t> policy(m, name);
policy.def(py::init<spmd_policy_t>());
policy.def(py::init([](const dp_policy_t& local) {
return make_spmd_policy(local);
}));
policy.def(py::init([](std::uint32_t id) {
return make_spmd_policy(id);
}));
policy.def(py::init([](const std::string& filter) {
return make_spmd_policy(filter);
}));
policy.def(py::init([](const py::object& syclobj) {
return make_spmd_policy(syclobj);
}));
policy.def("get_device_id", [](const spmd_policy_t& policy) {
return get_device_id(policy.get_local());
});
policy.def("get_device_name", [](const spmd_policy_t& policy) {
return get_device_name(policy.get_local());
});
}
#endif // ONEDAL_DATA_PARALLEL_SPMD

py::object get_policy(py::object obj) {
if (!obj.is(py::none())) {
#ifdef ONEDAL_DATA_PARALLEL_SPMD
return py::type::of<spmd_policy_t>()(obj);
#elif ONEDAL_DATA_PARALLEL
return py::type::of<dp_policy_t>()(obj);
#else
throw std::invalid_argument("queues are not supported in the oneDAL backend");
#endif // ONEDAL_DATA_PARALLEL
}
return py::type::of<host_policy_t>()();
};

ONEDAL_PY_INIT_MODULE(policy) {
instantiate_host_policy(m);
instantiate_default_host_policy(m);
#ifdef ONEDAL_DATA_PARALLEL
instantiate_data_parallel_policy(m);
#endif // ONEDAL_DATA_PARALLEL
#ifdef ONEDAL_DATA_PARALLEL_SPMD
instantiate_spmd_policy(m);
#endif // ONEDAL_DATA_PARALLEL_SPMD
m.def("get_policy", &get_policy, py::arg("queue") = py::none());
}

} // namespace oneapi::dal::python
82 changes: 0 additions & 82 deletions onedal/common/spmd_policy.cpp

This file was deleted.

4 changes: 0 additions & 4 deletions onedal/dal.cpp
Original file line number Diff line number Diff line change
@@ -23,8 +23,6 @@ namespace oneapi::dal::python {

/* common */
#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INIT_MODULE(spmd_policy);

/* algorithms */
ONEDAL_PY_INIT_MODULE(covariance);
ONEDAL_PY_INIT_MODULE(dbscan);
@@ -85,8 +83,6 @@ ONEDAL_PY_INIT_MODULE(finiteness_checker);

#ifdef ONEDAL_DATA_PARALLEL_SPMD
PYBIND11_MODULE(_onedal_py_spmd_dpc, m) {
init_spmd_policy(m);

init_covariance(m);
init_dbscan(m);
init_decomposition(m);
Loading
Oops, something went wrong.