Skip to content

Commit

Permalink
Closes gh-1279 for dpt.sqrt
Browse files Browse the repository at this point in the history
This change provides private method csqrt to evaluate square-root
for complex types. It handles special values as mandated by array API.

The finite input, it provides its own implementation based on std::hypot
and std::sqrt for real types instead of calling std::sqrt on finite
input of complex type.

Compile with -DUSE_STD_SQRT_FOR_COMPLEX_TYPES to use std::sqrt instead
of custom implementation.

Cursory performance study suggests that custom implementation is at least
not worse than std::sqrt one.
  • Loading branch information
oleksandr-pavlyk committed Aug 15, 2023
1 parent 5f298e6 commit 4a2578f
Showing 1 changed file with 93 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
#pragma once
#include <CL/sycl.hpp>
#include <cmath>
#include <complex>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"
Expand Down Expand Up @@ -66,7 +68,97 @@ template <typename argT, typename resT> struct SqrtFunctor

resT operator()(const argT &in)
{
return std::sqrt(in);
if constexpr (is_complex<argT>::value) {
// #ifdef _WINDOWS
// return csqrt(in);
// #else
// return std::sqrt(in);
// #endif
return csqrt(in);
}
else {
return std::sqrt(in);
}
}

private:
template <typename T> std::complex<T> csqrt(std::complex<T> const &z) const
{
// csqrt(x + y*1j)
// * csqrt(x - y * 1j) = conj(csqrt(x + y * 1j))
// * If x is either +0 or -0 and y is +0, the result is +0 + 0j.
// * If x is any value (including NaN) and y is +infinity, the result
// is +infinity + infinity j.
// * If x is a finite number and y is NaN, the result is NaN + NaN j.

// * If x -infinity and y is a positive (i.e., greater than 0) finite
// number, the result is NaN + NaN j.
// * If x is +infinity and y is a positive (i.e., greater than 0)
// finite number, the result is +0 + infinity j.
// * If x is -infinity and y is NaN, the result is NaN + infinity j
// (sign of the imaginary component is unspecified).
// * If x is +infinity and y is NaN, the result is +infinity + NaN j.
// * If x is NaN and y is any value, the result is NaN + NaN j.

using realT = T;
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
constexpr realT p_inf = std::numeric_limits<realT>::infinity();
constexpr realT zero = realT(0);

realT x = std::real(z);
realT y = std::imag(z);

if (std::isinf(y)) {
return {p_inf, y};
}
else if (std::isnan(x)) {
return {x, q_nan};
}
else if (std::isinf(x)) { // x is an infinity
// y is either finite, or nan
if (std::signbit(x)) { // x == -inf
return {(std::isfinite(y) ? zero : y), std::copysign(p_inf, y)};
}
else {
return {p_inf, (std::isfinite(y) ? std::copysign(zero, y) : y)};
}
}
else { // x is finite
if (std::isfinite(y)) {
#ifdef USE_STD_SQRT_FOR_COMPLEX_TYPES
return std::sqrt(z);
#else
return csqrt_finite(x, y);
#endif
}
else {
return {q_nan, y};
}
}
}

template <typename T>
std::complex<T> csqrt_finite(T const &x, T const &y) const
{
// csqrt(x + y*1j) =
// sqrt((cabs(x, y) + x) / 2) +
// 1j * copysign(sqrt((cabs(x, y) - x) / 2), y)

using realT = T;
constexpr realT half = realT(0x1.0p-1f); // 1/2
constexpr realT zero = realT(0);

if (std::signbit(x)) {
realT m = std::hypot(x, y);
realT d = std::sqrt((m - x) * half);
return {(d == zero ? zero : std::abs(y) / d * half),
std::copysign(d, y)};
}
else {
realT m = std::hypot(x, y);
realT d = std::sqrt((m + x) * half);
return {d, (d == zero) ? std::copysign(zero, y) : y * half / d};
}
}
};

Expand Down

0 comments on commit 4a2578f

Please sign in to comment.