Skip to content

Commit

Permalink
Refactor thrust::complex as a struct derived from `cuda::std::compl…
Browse files Browse the repository at this point in the history
…ex` (#454)

* Replace `thrust::complex` with `std::complex`

There are some notable differences though. thrust::complex has been a
bit more lenient when determining the type of arithmetic operations.

That said, I believe being more strict is actually a feature not a bug

* Refactor thrust::complex as a struct derived from cuda::std::complex

This commit refactors the thrust::complex type to be a struct derived
from cuda::std::complex, enabling reuse of existing implementation
logic. However, to maintain backward compatibility, certain
operators are reintroduced to allow type promotion between 'float'
and 'double' for the underlying type.

* Make `thrust::complex` compile

* Remove obsolete test

* Fix complex build for gcc-12

* Use template evaluation short circuit for `complex` assignment operator.

* [skip-tests] Update the license after complete reimplementation of complex

---------

Co-authored-by: Michael Schellenberger Costa <miscco@nvidia.com>
  • Loading branch information
Blonck and miscco committed Sep 26, 2023
1 parent 69af06d commit 1f6e4bc
Show file tree
Hide file tree
Showing 29 changed files with 552 additions and 5,549 deletions.

This file was deleted.

80 changes: 65 additions & 15 deletions libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,6 @@ template<class T> complex<T> tanh (const complex<T>&);
#ifndef __cuda_std__
#include <__config>
#include <stdexcept>
#if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION)
# include <sstream> // for _CUDA_VSTD::basic_ostringstream
#endif
#endif // __cuda_std__

#include "__assert" // all public C++ headers provide the assertion handler
Expand All @@ -253,6 +250,11 @@ template<class T> complex<T> tanh (const complex<T>&);
#include "type_traits"
#include "version"

#if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION) \
&& !defined(_LIBCUDACXX_COMPILER_NVRTC)
#include <sstream> // for std::basic_ostringstream
#endif // !_LIBCUDACXX_HAS_NO_LOCALIZATION && !_LIBCUDACXX_COMPILER_NVRTC

// Compatability helpers for thrust to convert between `std::complex` and `cuda::std::complex`
#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC) && !defined(_LIBCUDACXX_COMPILER_MSVC)
#include <complex>
Expand Down Expand Up @@ -407,8 +409,10 @@ public:
: __re_(__re), __im_(__im) {}
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<double>& __c);
#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<long double>& __c);
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE

#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC) && !defined(_LIBCUDACXX_COMPILER_MSVC)
template <class _Up>
Expand Down Expand Up @@ -502,8 +506,11 @@ public:
: __re_(__re), __im_(__im) {}
_LIBCUDACXX_INLINE_VISIBILITY
constexpr complex(const complex<float>& __c);

#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
_LIBCUDACXX_INLINE_VISIBILITY
explicit constexpr complex(const complex<long double>& __c);
#endif //_LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE

#if defined(__cuda_std__) && !defined(_LIBCUDACXX_COMPILER_NVRTC) && !defined(_LIBCUDACXX_COMPILER_MSVC)
template <class _Up>
Expand Down Expand Up @@ -585,20 +592,10 @@ public:
}
};

#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
template<>
class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(long double)) complex<long double>
{
#ifndef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
public:
template <typename _Dummy = void>
_LIBCUDACXX_INLINE_VISIBILITY constexpr complex(long double __re = 0.0, long double __im = 0.0)
{static_assert(is_same<_Dummy, void>::value, "complex<long double> is not supported");}

template <typename _Tp, typename _Dummy = void>
_LIBCUDACXX_INLINE_VISIBILITY constexpr complex(const complex<_Tp> &__c)
{static_assert(is_same<_Dummy, void>::value, "complex<long double> is not supported");}

#else
long double __re_;
long double __im_;
public:
Expand Down Expand Up @@ -689,8 +686,8 @@ public:
*this = *this / complex(__c.real(), __c.imag());
return *this;
}
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
};
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE

#if defined(_LIBCUDACXX_USE_PRAGMA_MSVC_WARNING)
// MSVC complains about narrowing conversions on these copy constructors regardless if they are used
Expand Down Expand Up @@ -1191,6 +1188,7 @@ arg(const complex<_Tp>& __c)
return _CUDA_VSTD::atan2(__c.imag(), __c.real());
}

#ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE
template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY
__enable_if_t<
Expand All @@ -1201,6 +1199,7 @@ arg(_Tp __re)
{
return _CUDA_VSTD::atan2l(0.L, __re);
}
#endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE

template<class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY
Expand Down Expand Up @@ -1775,6 +1774,57 @@ operator<<(basic_ostream<_CharT, _Traits>& __os, const complex<_Tp>& __x)
return __os << __s.str();
}
#endif // !_LIBCUDACXX_HAS_NO_LOCALIZATION
#else // ^^^ !__cuda_std__ ^^^ / vvv __cuda_std__
#ifndef _LIBCUDACXX_COMPILER_NVRTC
template<typename ValueType,class charT, class traits>
::std::basic_ostream<charT, traits>& operator<<(::std::basic_ostream<charT, traits>& os, const complex<ValueType>& z)
{
os << '(' << z.real() << ',' << z.imag() << ')';
return os;
}

template<typename ValueType, typename charT, class traits>
::std::basic_istream<charT, traits>&
operator>>(::std::basic_istream<charT, traits>& is, complex<ValueType>& z)
{
ValueType re, im;

charT ch;
is >> ch;

if(ch == '(')
{
is >> re >> ch;
if (ch == ',')
{
is >> im >> ch;
if (ch == ')')
{
z = complex<ValueType>(re, im);
}
else
{
is.setstate(::std::ios_base::failbit);
}
}
else if (ch == ')')
{
z = re;
}
else
{
is.setstate(::std::ios_base::failbit);
}
}
else
{
is.putback(ch);
is >> re;
z = re;
}
return is;
}
#endif // _LIBCUDACXX_COMPILER_NVRTC
#endif // __cuda_std__

#if _LIBCUDACXX_STD_VER > 11 && defined(_LIBCUDACXX_HAS_STL_LITERALS)
Expand Down
12 changes: 12 additions & 0 deletions thrust/testing/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ struct TestComplexSizeAndAlignment
};
SimpleUnitTest<TestComplexSizeAndAlignment, FloatingPointTypes> TestComplexSizeAndAlignmentInstance;

template <typename T>
struct TestComplexTypeCheck
{
void operator()()
{
THRUST_STATIC_ASSERT(thrust::is_complex<thrust::complex<T>>::value);
THRUST_STATIC_ASSERT(thrust::is_complex<std::complex<T>>::value);
THRUST_STATIC_ASSERT(thrust::is_complex<cuda::std::complex<T>>::value);
}
};
SimpleUnitTest<TestComplexTypeCheck, FloatingPointTypes> TestComplexTypeCheckInstance;

template <typename T>
struct TestComplexConstructionAndAssignment
{
Expand Down
29 changes: 6 additions & 23 deletions thrust/testing/unittest/assertions.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,31 +304,14 @@ bool almost_equal(double a, double b, double a_tol, double r_tol)
return true;
}

namespace
{ // anonymous namespace

template <typename>
struct is_complex : public THRUST_NS_QUALIFIER::false_type
{};

template <typename T>
struct is_complex<THRUST_NS_QUALIFIER::complex<T>> : public THRUST_NS_QUALIFIER::true_type
{};

template <typename T>
struct is_complex<std::complex<T>> : public THRUST_NS_QUALIFIER::true_type
{};

} // namespace

template <typename T1, typename T2>
inline
typename THRUST_NS_QUALIFIER::detail::enable_if<is_complex<T1>::value && is_complex<T2>::value,
bool>::type
almost_equal(const T1 &a, const T2 &b, double a_tol, double r_tol)
typename THRUST_NS_QUALIFIER::detail::enable_if<THRUST_NS_QUALIFIER::is_complex<T1>::value &&
THRUST_NS_QUALIFIER::is_complex<T2>::value,
bool>::type
almost_equal(const T1 &a, const T2 &b, double a_tol, double r_tol)
{
return almost_equal(a.real(), b.real(), a_tol, r_tol) &&
almost_equal(a.imag(), b.imag(), a_tol, r_tol);
return almost_equal(a.real(), b.real(), a_tol, r_tol) &&
almost_equal(a.imag(), b.imag(), a_tol, r_tol);
}

template <typename T1, typename T2>
Expand Down

0 comments on commit 1f6e4bc

Please sign in to comment.