Skip to content

Commit 3f8f4f3

Browse files
[Lapack][MKL-backend] Fix LAPACK specific exceptions (#663)
1 parent 08da10e commit 3f8f4f3

File tree

3 files changed

+633
-562
lines changed

3 files changed

+633
-562
lines changed

include/oneapi/math/lapack/exceptions.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class exception {
4848
class computation_error : public oneapi::math::computation_error,
4949
public oneapi::math::lapack::exception {
5050
public:
51+
computation_error(const std::string& message, std::int64_t code)
52+
: oneapi::math::computation_error(message),
53+
oneapi::math::lapack::exception(this, code) {}
5154
computation_error(const std::string& function, const std::string& info, std::int64_t code)
5255
: oneapi::math::computation_error("LAPACK", function, info),
5356
oneapi::math::lapack::exception(this, code) {}
@@ -56,6 +59,12 @@ class computation_error : public oneapi::math::computation_error,
5659

5760
class batch_error : public oneapi::math::batch_error, public oneapi::math::lapack::exception {
5861
public:
62+
batch_error(const std::string& message, std::int64_t num_errors,
63+
std::vector<std::int64_t> ids = {}, std::vector<std::exception_ptr> exceptions = {})
64+
: oneapi::math::batch_error(message),
65+
oneapi::math::lapack::exception(this, num_errors),
66+
_ids(ids),
67+
_exceptions(exceptions) {}
5968
batch_error(const std::string& function, const std::string& info, std::int64_t num_errors,
6069
std::vector<std::int64_t> ids = {}, std::vector<std::exception_ptr> exceptions = {})
6170
: oneapi::math::batch_error("LAPACK", function, info),
@@ -78,6 +87,10 @@ class batch_error : public oneapi::math::batch_error, public oneapi::math::lapac
7887
class invalid_argument : public oneapi::math::invalid_argument,
7988
public oneapi::math::lapack::exception {
8089
public:
90+
invalid_argument(const std::string& message, std::int64_t arg_position = 0,
91+
std::int64_t detail = 0)
92+
: oneapi::math::invalid_argument(message),
93+
oneapi::math::lapack::exception(this, arg_position, detail) {}
8194
invalid_argument(const std::string& function, const std::string& info,
8295
std::int64_t arg_position = 0, std::int64_t detail = 0)
8396
: oneapi::math::invalid_argument("LAPACK", function, info),

src/include/common_onemkl_conversion.hpp

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "oneapi/math/types.hpp"
3030
#include "oneapi/math/exceptions.hpp"
31+
#include "oneapi/math/lapack/exceptions.hpp"
3132

3233
namespace oneapi {
3334
namespace math {
@@ -105,48 +106,88 @@ inline auto get_onemkl_order(oneapi::math::order param) {
105106
return *reinterpret_cast<oneapi::mkl::order*>(&param);
106107
}
107108

109+
} // namespace detail
110+
} // namespace math
111+
} // namespace oneapi
112+
108113
// Rethrow Intel(R) oneMKL exceptions as oneMath exceptions
109-
#define RETHROW_ONEMKL_EXCEPTIONS(EXPRESSION) \
110-
do { \
111-
try { \
112-
EXPRESSION; \
113-
} \
114-
catch (const oneapi::mkl::unsupported_device& e) { \
115-
throw unsupported_device(e.what()); \
116-
} \
117-
catch (const oneapi::mkl::host_bad_alloc& e) { \
118-
throw host_bad_alloc(e.what()); \
119-
} \
120-
catch (const oneapi::mkl::device_bad_alloc& e) { \
121-
throw device_bad_alloc(e.what()); \
122-
} \
123-
catch (const oneapi::mkl::unimplemented& e) { \
124-
throw unimplemented(e.what()); \
125-
} \
126-
catch (const oneapi::mkl::invalid_argument& e) { \
127-
throw invalid_argument(e.what()); \
128-
} \
129-
catch (const oneapi::mkl::uninitialized& e) { \
130-
throw uninitialized(e.what()); \
131-
} \
132-
catch (const oneapi::mkl::computation_error& e) { \
133-
throw computation_error(e.what()); \
134-
} \
135-
catch (const oneapi::mkl::batch_error& e) { \
136-
throw batch_error(e.what()); \
137-
} \
138-
catch (const oneapi::mkl::exception& e) { \
139-
throw exception(e.what()); \
140-
} \
114+
#define RETHROW_ONEMKL_EXCEPTIONS(EXPRESSION) \
115+
do { \
116+
try { \
117+
EXPRESSION; \
118+
} \
119+
catch (const oneapi::mkl::unsupported_device& e) { \
120+
throw oneapi::math::unsupported_device(e.what()); \
121+
} \
122+
catch (const oneapi::mkl::host_bad_alloc& e) { \
123+
throw oneapi::math::host_bad_alloc(e.what()); \
124+
} \
125+
catch (const oneapi::mkl::device_bad_alloc& e) { \
126+
throw oneapi::math::device_bad_alloc(e.what()); \
127+
} \
128+
catch (const oneapi::mkl::unimplemented& e) { \
129+
throw oneapi::math::unimplemented(e.what()); \
130+
} \
131+
catch (const oneapi::mkl::invalid_argument& e) { \
132+
throw oneapi::math::invalid_argument(e.what()); \
133+
} \
134+
catch (const oneapi::mkl::uninitialized& e) { \
135+
throw oneapi::math::uninitialized(e.what()); \
136+
} \
137+
catch (const oneapi::mkl::computation_error& e) { \
138+
throw oneapi::math::computation_error(e.what()); \
139+
} \
140+
catch (const oneapi::mkl::batch_error& e) { \
141+
throw oneapi::math::batch_error(e.what()); \
142+
} \
143+
catch (const oneapi::mkl::exception& e) { \
144+
throw oneapi::math::exception(e.what()); \
145+
} \
141146
} while (0)
142147

143148
#define RETHROW_ONEMKL_EXCEPTIONS_RET(EXPRESSION) \
144149
do { \
145150
RETHROW_ONEMKL_EXCEPTIONS(return EXPRESSION); \
146151
} while (0)
147152

148-
} // namespace detail
149-
} // namespace math
150-
} // namespace oneapi
153+
// Rethrow Intel(R) oneMKL LAPCK exceptions as oneMath LAPACK exceptions
154+
#define RETHROW_ONEMKL_LAPACK_EXCEPTIONS(EXPRESSION) \
155+
do { \
156+
try { \
157+
EXPRESSION; \
158+
} \
159+
catch (const oneapi::mkl::unsupported_device& e) { \
160+
throw oneapi::math::unsupported_device(e.what()); \
161+
} \
162+
catch (const oneapi::mkl::host_bad_alloc& e) { \
163+
throw oneapi::math::host_bad_alloc(e.what()); \
164+
} \
165+
catch (const oneapi::mkl::device_bad_alloc& e) { \
166+
throw oneapi::math::device_bad_alloc(e.what()); \
167+
} \
168+
catch (const oneapi::mkl::unimplemented& e) { \
169+
throw oneapi::math::unimplemented(e.what()); \
170+
} \
171+
catch (const oneapi::mkl::uninitialized& e) { \
172+
throw oneapi::math::uninitialized(e.what()); \
173+
} \
174+
catch (const oneapi::mkl::lapack::invalid_argument& e) { \
175+
throw oneapi::math::lapack::invalid_argument(e.what(), e.info(), e.detail()); \
176+
} \
177+
catch (const oneapi::mkl::lapack::computation_error& e) { \
178+
throw oneapi::math::lapack::computation_error(e.what(), e.info()); \
179+
} \
180+
catch (const oneapi::mkl::lapack::batch_error& e) { \
181+
throw oneapi::math::lapack::batch_error(e.what(), e.info(), e.ids(), e.exceptions()); \
182+
} \
183+
catch (const oneapi::mkl::exception& e) { \
184+
throw oneapi::math::exception(e.what()); \
185+
} \
186+
} while (0)
187+
188+
#define RETHROW_ONEMKL_LAPACK_EXCEPTIONS_RET(EXPRESSION) \
189+
do { \
190+
RETHROW_ONEMKL_LAPACK_EXCEPTIONS(return EXPRESSION); \
191+
} while (0)
151192

152193
#endif // _ONEMATH_SRC_INCLUDE_COMMON_ONEMKL_TYPES_CONVERSION_HPP_

0 commit comments

Comments
 (0)