Skip to content

Commit

Permalink
Only include under cudacc, constexpr fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
griwes committed Nov 22, 2023
1 parent 8b13b12 commit 65e6f36
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "../__utility/declval.h"
#include "../cstddef"

#ifdef __cuda_std__
#if defined(__cuda_std__) && defined(_LIBCUDACXX_CUDACC)
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif
Expand All @@ -38,7 +38,7 @@ template <class _Tp>
struct __numeric_type
{
_LIBCUDACXX_INLINE_VISIBILITY static void __test(...);
#ifdef __cuda_std__
#if defined(__cuda_std__) && defined(_LIBCUDACXX_CUDACC)
_LIBCUDACXX_INLINE_VISIBILITY static __half __test(__half);
_LIBCUDACXX_INLINE_VISIBILITY static __nv_bfloat16 __test(__nv_bfloat16);
#endif
Expand Down
63 changes: 30 additions & 33 deletions libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -424,19 +424,19 @@ class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__half2)) com
public:
typedef __half value_type;

_LIBCUDACXX_INLINE_VISIBILITY constexpr complex(__half __re = 0.0f, __half __im = 0.0f)
_LIBCUDACXX_INLINE_VISIBILITY complex(__half __re = 0.0f, __half __im = 0.0f)
: __repr(__re, __im) {}
template <class _Int, typename = __enable_if_t<is_arithmetic<_Int>::value>>
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(_Int __re = _Int(), _Int __im = _Int())
explicit complex(_Int __re = _Int(), _Int __im = _Int())
: __repr(__re, __im) {}
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<float>& __c);
explicit complex(const complex<float>& __c);
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<double>& __c);
explicit complex(const complex<double>& __c);
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<long double>& __c);
explicit complex(const complex<long double>& __c);
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE

#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC)
Expand All @@ -455,11 +455,11 @@ public:
}
#endif // defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC)

_LIBCUDACXX_INLINE_VISIBILITY constexpr __half real() const {return __repr.x;}
_LIBCUDACXX_INLINE_VISIBILITY constexpr __half imag() const {return __repr.y;}
_LIBCUDACXX_INLINE_VISIBILITY __half real() const {return __repr.x;}
_LIBCUDACXX_INLINE_VISIBILITY __half imag() const {return __repr.y;}

_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 void real(value_type __re) {__repr.x = __re;}
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 void imag(value_type __im) {__repr.y = __im;}
_LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re) {__repr.x = __re;}
_LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im) {__repr.y = __im;}

_LIBCUDACXX_INLINE_VISIBILITY
complex& operator= (__half __re) { __repr.x = __re; __repr.y = value_type(); return *this;}
Expand Down Expand Up @@ -524,19 +524,19 @@ class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__nv_bfloat16
public:
typedef __nv_bfloat16 value_type;

_LIBCUDACXX_INLINE_VISIBILITY constexpr complex(__nv_bfloat16 __re = 0.0f, __nv_bfloat16 __im = 0.0f)
_LIBCUDACXX_INLINE_VISIBILITY complex(__nv_bfloat16 __re = 0.0f, __nv_bfloat16 __im = 0.0f)
: __repr(__re, __im) {}
template <class _Int, typename = __enable_if_t<is_arithmetic<_Int>::value>>
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(_Int __re = _Int(), _Int __im = _Int())
explicit complex(_Int __re = _Int(), _Int __im = _Int())
: __repr(__re, __im) {}
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<float>& __c);
explicit complex(const complex<float>& __c);
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<double>& __c);
explicit complex(const complex<double>& __c);
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<long double>& __c);
explicit complex(const complex<long double>& __c);
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE

#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC)
Expand All @@ -555,11 +555,11 @@ public:
}
#endif // defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC)

_LIBCUDACXX_INLINE_VISIBILITY constexpr __nv_bfloat16 real() const {return __repr.x;}
_LIBCUDACXX_INLINE_VISIBILITY constexpr __nv_bfloat16 imag() const {return __repr.y;}
_LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 real() const {return __repr.x;}
_LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 imag() const {return __repr.y;}

_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 void real(value_type __re) {__repr.x = __re;}
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 void imag(value_type __im) {__repr.y = __im;}
_LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re) {__repr.x = __re;}
_LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im) {__repr.y = __im;}

_LIBCUDACXX_INLINE_VISIBILITY
complex& operator= (__nv_bfloat16 __re) { __repr.x = __re; __repr.y = value_type(); return *this;}
Expand Down Expand Up @@ -1041,27 +1041,25 @@ _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11_COMPLEX
__abcd_results<_Tp>
__complex_calculate_partials(_Tp __a, _Tp __b, _Tp __c, _Tp __d)
{
__abcd_results<_Tp> __ret;

__ret.__ac = __a * __c;
__ret.__bd = __b * __d;
__ret.__ad = __a * __d;
__ret.__bc = __b * __c;

return __ret;
return {
__a * __c,
__b * __d,
__a * __d,
__b * __c
};
}

static_assert(__complex_calculate_partials(0, 1, 2, 3).__bd == 3, "");

template<class _Tp>
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11_COMPLEX
__ab_results<_Tp>
__complex_piecewise_mul(_Tp __x1, _Tp __y1, _Tp __x2, _Tp __y2)
{
__ab_results<_Tp> __ret;

__ret.__a = __x1 * __x2;
__ret.__b = __y1 * __y2;

return __ret;
return {
__x1 * __x2,
__y1 * __y2
};
}

#ifdef __cuda_std__
Expand Down Expand Up @@ -1313,7 +1311,6 @@ operator/(const complex<_Tp>& __z, const complex<_Tp>& __w)
_CUDA_VSTD::__constexpr_isfinite(__d)) {
__a = _CUDA_VSTD::__constexpr_copysign(_CUDA_VSTD::__constexpr_isinf(__a) ? _Tp(1) : _Tp(0), __a);
__b = _CUDA_VSTD::__constexpr_copysign(_CUDA_VSTD::__constexpr_isinf(__b) ? _Tp(1) : _Tp(0), __b);
__partials = __complex_calculate_partials(__a, __b, __c, __d);
__x = _Tp(INFINITY) * (__a * __c + __b * __d);
__y = _Tp(INFINITY) * (__b * __c - __a * __d);
} else if (_CUDA_VSTD::__constexpr_isinf(__logbw) && __logbw > _Tp(0) && _CUDA_VSTD::__constexpr_isfinite(__a) &&
Expand Down

0 comments on commit 65e6f36

Please sign in to comment.