Skip to content

Commit 6da3fd5

Browse files
authored
Merge branch 'master' into reuse_dpctl_moveaxis
2 parents 7c1f5cf + 8e2c5f7 commit 6da3fd5

File tree

11 files changed

+264
-81
lines changed

11 files changed

+264
-81
lines changed

.github/workflows/build-sphinx.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717

1818
env:
1919
python-ver: '3.9'
20+
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'
2021

2122
steps:
2223
- name: Cancel Previous Runs
@@ -74,10 +75,10 @@ jobs:
7475
- name: Install dpnp dependencies
7576
run: |
7677
conda install dpctl mkl-devel-dpcpp onedpl-devel tbb-devel dpcpp_linux-64 \
77-
cmake cython pytest ninja scikit-build -c dppy/label/dev -c intel -c conda-forge
78+
cmake cython pytest ninja scikit-build sysroot_linux-64">=2.28" ${{ env.CHANNELS }}
7879
7980
- name: Install cuPy dependencies
80-
run: conda install -c conda-forge cupy cudatoolkit=10.0
81+
run: conda install cupy cudatoolkit=10.0
8182

8283
- name: Conda info
8384
run: conda info

.github/workflows/generate_coverage.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ jobs:
1515

1616
env:
1717
python-ver: '3.10'
18+
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'
1819

1920
steps:
2021
- name: Cancel Previous Runs
@@ -34,15 +35,14 @@ jobs:
3435
python-version: ${{ env.python-ver }}
3536
miniconda-version: 'latest'
3637
activate-environment: 'coverage'
37-
channels: intel, conda-forge
3838

3939
- name: Install Lcov
4040
run: |
4141
sudo apt-get install lcov
4242
- name: Install dpnp dependencies
4343
run: |
4444
conda install cython llvm cmake scikit-build ninja pytest pytest-cov coverage[toml] \
45-
dppy/label/dev::dpctl dpcpp_linux-64 mkl-devel-dpcpp tbb-devel onedpl-devel
45+
dpctl dpcpp_linux-64 sysroot_linux-64">=2.28" mkl-devel-dpcpp tbb-devel onedpl-devel ${{ env.CHANNELS }}
4646
- name: Conda info
4747
run: |
4848
conda info
@@ -54,7 +54,7 @@ jobs:
5454
- name: Install coverall dependencies
5555
run: |
5656
sudo gem install coveralls-lcov
57-
conda install coveralls
57+
pip install coveralls==3.2.0
5858
- name: Upload coverage data to coveralls.io
5959
run: |
6060
echo "Processing pytest-coverage"

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ set(CYTHON_FLAGS "-t -w \"${CMAKE_SOURCE_DIR}\"")
5757
find_package(Cython REQUIRED)
5858
find_package(Dpctl REQUIRED)
5959

60+
message(STATUS "Dpctl_INCLUDE_DIRS=" ${Dpctl_INCLUDE_DIRS})
61+
message(STATUS "Dpctl_TENSOR_INCLUDE_DIR=" ${Dpctl_TENSOR_INCLUDE_DIR})
62+
6063
if(WIN32)
6164
string(CONCAT WARNING_FLAGS
6265
"-Wall "

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_
4545
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
4646

4747
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
48+
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})
4849

4950
if (WIN32)
5051
target_compile_options(${python_module_name} PRIVATE

dpnp/backend/extensions/lapack/heevd.cpp

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626

2727
#include <pybind11/pybind11.h>
2828

29+
// dpctl tensor headers
30+
#include "utils/memory_overlap.hpp"
31+
#include "utils/type_utils.hpp"
32+
2933
#include "heevd.hpp"
34+
#include "types_matrix.hpp"
3035

3136
#include "dpnp_utils.hpp"
3237

@@ -42,19 +47,34 @@ namespace lapack
4247

4348
namespace mkl_lapack = oneapi::mkl::lapack;
4449
namespace py = pybind11;
50+
namespace type_utils = dpctl::tensor::type_utils;
51+
52+
typedef sycl::event (*heevd_impl_fn_ptr_t)(sycl::queue,
53+
const oneapi::mkl::job,
54+
const oneapi::mkl::uplo,
55+
const std::int64_t,
56+
char*,
57+
char*,
58+
std::vector<sycl::event>&,
59+
const std::vector<sycl::event>&);
60+
61+
static heevd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
4562

4663
template <typename T, typename RealT>
47-
static sycl::event call_heevd(sycl::queue exec_q,
64+
static sycl::event heevd_impl(sycl::queue exec_q,
4865
const oneapi::mkl::job jobz,
4966
const oneapi::mkl::uplo upper_lower,
5067
const std::int64_t n,
51-
T* a,
52-
RealT* w,
68+
char* in_a,
69+
char* out_w,
5370
std::vector<sycl::event>& host_task_events,
5471
const std::vector<sycl::event>& depends)
5572
{
56-
validate_type_for_device<T>(exec_q);
57-
validate_type_for_device<RealT>(exec_q);
73+
type_utils::validate_type_for_device<T>(exec_q);
74+
type_utils::validate_type_for_device<RealT>(exec_q);
75+
76+
T* a = reinterpret_cast<T*>(in_a);
77+
RealT* w = reinterpret_cast<RealT*>(out_w);
5878

5979
const std::int64_t lda = std::max<size_t>(1UL, n);
6080
const std::int64_t scratchpad_size = mkl_lapack::heevd_scratchpad_size<T>(exec_q, jobz, upper_lower, n, lda);
@@ -163,13 +183,11 @@ std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
163183
throw py::value_error("Execution queue is not compatible with allocation queues");
164184
}
165185

166-
// check that arrays do not overlap, and concurrent access is safe.
167-
// TODO: need to be exposed by DPCTL headers
168-
// auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
169-
// if (overlap(eig_vecs, eig_vals))
170-
// {
171-
// throw py::value_error("Arrays index overlapping segments of memory");
172-
// }
186+
auto const& overlap = dpctl::tensor::overlap::MemoryOverlap();
187+
if (overlap(eig_vecs, eig_vals))
188+
{
189+
throw py::value_error("Arrays with eigenvectors and eigenvalues are overlapping segments of memory");
190+
}
173191

174192
bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
175193
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
@@ -182,38 +200,51 @@ std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
182200
throw py::value_error("An array with output eigenvalues must be C-contiguous");
183201
}
184202

185-
int eig_vecs_typenum = eig_vecs.get_typenum();
186-
int eig_vals_typenum = eig_vals.get_typenum();
187-
auto const& dpctl_capi = dpctl::detail::dpctl_capi::get();
203+
auto array_types = dpctl_td_ns::usm_ndarray_types();
204+
int eig_vecs_type_id = array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
205+
int eig_vals_type_id = array_types.typenum_to_lookup_id(eig_vals.get_typenum());
188206

189-
sycl::event heevd_ev;
190-
std::vector<sycl::event> host_task_events;
207+
heevd_impl_fn_ptr_t heevd_fn = heevd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
208+
if (heevd_fn == nullptr)
209+
{
210+
throw py::value_error("No heevd implementation defined for a pair of type for eigenvectors and eigenvalues");
211+
}
212+
213+
char* eig_vecs_data = eig_vecs.get_data();
214+
char* eig_vals_data = eig_vals.get_data();
191215

192216
const std::int64_t n = eig_vecs_shape[0];
193217
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
194218
const oneapi::mkl::uplo uplo_val = static_cast<oneapi::mkl::uplo>(upper_lower);
195219

196-
if ((eig_vecs_typenum == dpctl_capi.UAR_CDOUBLE_) && (eig_vals_typenum == dpctl_capi.UAR_DOUBLE_))
197-
{
198-
std::complex<double>* a = reinterpret_cast<std::complex<double>*>(eig_vecs.get_data());
199-
double* w = reinterpret_cast<double*>(eig_vals.get_data());
220+
std::vector<sycl::event> host_task_events;
221+
sycl::event heevd_ev =
222+
heevd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data, eig_vals_data, host_task_events, depends);
200223

201-
heevd_ev = call_heevd(exec_q, jobz_val, uplo_val, n, a, w, host_task_events, depends);
202-
}
203-
else if ((eig_vecs_typenum == dpctl_capi.UAR_CFLOAT_) && (eig_vals_typenum == dpctl_capi.UAR_FLOAT_))
204-
{
205-
std::complex<float>* a = reinterpret_cast<std::complex<float>*>(eig_vecs.get_data());
206-
float* w = reinterpret_cast<float*>(eig_vals.get_data());
224+
sycl::event args_ev = dpctl::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, host_task_events);
225+
return std::make_pair(args_ev, heevd_ev);
226+
}
207227

208-
heevd_ev = call_heevd(exec_q, jobz_val, uplo_val, n, a, w, host_task_events, depends);
209-
}
210-
else
228+
template <typename fnT, typename T, typename RealT>
229+
struct HeevdContigFactory
230+
{
231+
fnT get()
211232
{
212-
throw py::value_error("Unexpected types of either eigenvectors or eigenvalues");
233+
if constexpr (types::HeevdTypePairSupportFactory<T, RealT>::is_defined)
234+
{
235+
return heevd_impl<T, RealT>;
236+
}
237+
else
238+
{
239+
return nullptr;
240+
}
213241
}
242+
};
214243

215-
sycl::event args_ev = dpctl::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, host_task_events);
216-
return std::make_pair(args_ev, heevd_ev);
244+
void init_heevd_dispatch_table(void)
245+
{
246+
dpctl_td_ns::DispatchTableBuilder<heevd_impl_fn_ptr_t, HeevdContigFactory, dpctl_td_ns::num_types> contig;
247+
contig.populate_dispatch_table(heevd_dispatch_table);
217248
}
218249
}
219250
}

dpnp/backend/extensions/lapack/heevd.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ namespace lapack
4545
dpctl::tensor::usm_ndarray eig_vecs,
4646
dpctl::tensor::usm_ndarray eig_vals,
4747
const std::vector<sycl::event>& depends);
48+
49+
extern void init_heevd_dispatch_table(void);
4850
}
4951
}
5052
}

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,45 @@
3333
#include "heevd.hpp"
3434
#include "syevd.hpp"
3535

36+
namespace lapack_ext = dpnp::backend::ext::lapack;
3637
namespace py = pybind11;
3738

39+
// populate dispatch vectors
40+
void init_dispatch_vectors(void)
41+
{
42+
lapack_ext::init_syevd_dispatch_vector();
43+
}
44+
45+
// populate dispatch tables
46+
void init_dispatch_tables(void)
47+
{
48+
lapack_ext::init_heevd_dispatch_table();
49+
}
50+
3851
PYBIND11_MODULE(_lapack_impl, m)
3952
{
53+
init_dispatch_vectors();
54+
init_dispatch_tables();
55+
4056
m.def("_heevd",
41-
&dpnp::backend::ext::lapack::heevd,
57+
&lapack_ext::heevd,
4258
"Call `heevd` from OneMKL LAPACK library to return "
4359
"the eigenvalues and eigenvectors of a complex Hermitian matrix",
4460
py::arg("sycl_queue"),
45-
py::arg("jobz"), py::arg("upper_lower"),
46-
py::arg("eig_vecs"), py::arg("eig_vals"),
61+
py::arg("jobz"),
62+
py::arg("upper_lower"),
63+
py::arg("eig_vecs"),
64+
py::arg("eig_vals"),
4765
py::arg("depends") = py::list());
4866

4967
m.def("_syevd",
50-
&dpnp::backend::ext::lapack::syevd,
68+
&lapack_ext::syevd,
5169
"Call `syevd` from OneMKL LAPACK library to return "
5270
"the eigenvalues and eigenvectors of a real symmetric matrix",
5371
py::arg("sycl_queue"),
54-
py::arg("jobz"), py::arg("upper_lower"),
55-
py::arg("eig_vecs"), py::arg("eig_vals"),
72+
py::arg("jobz"),
73+
py::arg("upper_lower"),
74+
py::arg("eig_vecs"),
75+
py::arg("eig_vals"),
5676
py::arg("depends") = py::list());
5777
}

0 commit comments

Comments
 (0)