|
18 | 18 | #include "onedal/common/policy.hpp"
|
19 | 19 | #include "onedal/common/pybind11_helpers.hpp"
|
20 | 20 |
|
| 21 | +#ifdef ONEDAL_DATA_PARALLEL_SPMD |
| 22 | +#include "oneapi/dal/detail/spmd_policy.hpp" |
| 23 | +#include "oneapi/dal/spmd/mpi/communicator.hpp" |
| 24 | +#endif // ONEDAL_DATA_PARALLEL_SPMD |
| 25 | + |
21 | 26 | namespace py = pybind11;
|
22 | 27 |
|
23 | 28 | namespace oneapi::dal::python {
|
@@ -78,15 +83,78 @@ void instantiate_data_parallel_policy(py::module& m) {
|
78 | 83 | });
|
79 | 84 | m.def("get_used_memory", &get_used_memory, py::return_value_policy::take_ownership);
|
80 | 85 | }
|
| 86 | +#endif // ONEDAL_DATA_PARALLEL |
| 87 | +#ifdef ONEDAL_DATA_PARALLEL_SPMD |
| 88 | +using dp_policy_t = dal::detail::data_parallel_policy; |
| 89 | +using spmd_policy_t = dal::detail::spmd_policy<dp_policy_t>; |
| 90 | + |
| 91 | +inline spmd_policy_t make_spmd_policy(dp_policy_t&& local) { |
| 92 | + sycl::queue& queue = local.get_queue(); |
| 93 | + using backend_t = dal::preview::spmd::backend::mpi; |
| 94 | + auto comm = dal::preview::spmd::make_communicator<backend_t>(queue); |
| 95 | + return spmd_policy_t{ std::forward<dp_policy_t>(local), std::move(comm) }; |
| 96 | +} |
| 97 | + |
| 98 | +template <typename... Args> |
| 99 | +inline spmd_policy_t make_spmd_policy(Args&&... args) { |
| 100 | + auto local = make_dp_policy(std::forward<Args>(args)...); |
| 101 | + return make_spmd_policy(std::move(local)); |
| 102 | +} |
81 | 103 |
|
| 104 | +template <typename Arg, typename Policy = spmd_policy_t> |
| 105 | +inline void instantiate_costructor(py::class_<Policy>& policy) { |
| 106 | + policy.def(py::init([](const Arg& arg) { |
| 107 | + return make_spmd_policy(arg); |
| 108 | + })); |
| 109 | +} |
| 110 | + |
| 111 | +void instantiate_spmd_policy(py::module& m) { |
| 112 | + constexpr const char name[] = "spmd_data_parallel_policy"; |
| 113 | + py::class_<spmd_policy_t> policy(m, name); |
| 114 | + policy.def(py::init<spmd_policy_t>()); |
| 115 | + policy.def(py::init([](const dp_policy_t& local) { |
| 116 | + return make_spmd_policy(local); |
| 117 | + })); |
| 118 | + policy.def(py::init([](std::uint32_t id) { |
| 119 | + return make_spmd_policy(id); |
| 120 | + })); |
| 121 | + policy.def(py::init([](const std::string& filter) { |
| 122 | + return make_spmd_policy(filter); |
| 123 | + })); |
| 124 | + policy.def(py::init([](const py::object& syclobj) { |
| 125 | + return make_spmd_policy(syclobj); |
| 126 | + })); |
| 127 | + policy.def("get_device_id", [](const spmd_policy_t& policy) { |
| 128 | + return get_device_id(policy.get_local()); |
| 129 | + }); |
| 130 | + policy.def("get_device_name", [](const spmd_policy_t& policy) { |
| 131 | + return get_device_name(policy.get_local()); |
| 132 | + }); |
| 133 | +} |
| 134 | +#endif // ONEDAL_DATA_PARALLEL_SPMD |
| 135 | + |
| 136 | +py::object get_policy(py::object obj) { |
| 137 | + if (!obj.is(py::none())) { |
| 138 | +#ifdef ONEDAL_DATA_PARALLEL_SPMD |
| 139 | + return py::type::of<spmd_policy_t>()(obj); |
| 140 | +#elif ONEDAL_DATA_PARALLEL |
| 141 | + return py::type::of<dp_policy_t>()(obj); |
| 142 | +#else |
| 143 | + throw std::invalid_argument("queues are not supported in the oneDAL backend"); |
82 | 144 | #endif // ONEDAL_DATA_PARALLEL
|
| 145 | + } |
| 146 | + return py::type::of<host_policy_t>()(); |
| 147 | +}; |
83 | 148 |
|
84 | 149 | ONEDAL_PY_INIT_MODULE(policy) {
|
85 | 150 | instantiate_host_policy(m);
|
86 | 151 | instantiate_default_host_policy(m);
|
87 | 152 | #ifdef ONEDAL_DATA_PARALLEL
|
88 | 153 | instantiate_data_parallel_policy(m);
|
89 | 154 | #endif // ONEDAL_DATA_PARALLEL
|
| 155 | +#ifdef ONEDAL_DATA_PARALLEL_SPMD |
| 156 | + instantiate_spmd_policy(m); |
| 157 | +#endif // ONEDAL_DATA_PARALLEL_SPMD |
| 158 | + m.def("get_policy", &get_policy, py::arg("queue") = py::none()); |
90 | 159 | }
|
91 |
| - |
92 | 160 | } // namespace oneapi::dal::python
|
0 commit comments