Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda::std::complex specializations for half and bfloat #1140

Merged
merged 37 commits into from
Mar 12, 2024

Conversation

griwes
Copy link
Collaborator

@griwes griwes commented Nov 21, 2023

Description

Resolves #1139

Introduce specializations of complex<T> for half and bfloat.

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Additional checklist

  • The documentation contains the actual release this will be made available in.

@griwes griwes requested review from a team as code owners November 21, 2023 23:42
@griwes griwes requested review from ericniebler and alliepiper and removed request for a team November 21, 2023 23:42
@griwes griwes marked this pull request as draft November 21, 2023 23:45
@griwes griwes force-pushed the feature/small-complex branch 2 times, most recently from 65e6f36 to 744f2d1 Compare November 22, 2023 19:57
Copy link
Collaborator

@miscco miscco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great job working around the quirks of those types 👏

I would love to move some of the traits around (e.g. into is_floating_point.h) and importantly add a proper named define that one can grep for.

libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated Show resolved Hide resolved
libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated Show resolved Hide resolved
Copy link
Collaborator

@gonzalobg gonzalobg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, thanks for working on this @griwes !
I think I missed some static_asserts for the size and alignment of complex half and bfloat, do we have these somewhere? Thanks!

@griwes griwes force-pushed the feature/small-complex branch 3 times, most recently from c2d87c2 to add3d52 Compare January 27, 2024 05:52
Copy link
Collaborator

@miscco miscco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether we should just keep all the _LIBCUDACXX_HAS_NO_NVFP16 in place and define it conditionally for host

Specifically:
* disable BF16 when FP16 is disabled, since the former includes the
  latter;
* disable both when the toolkit version is lower than 12.2, since 12.2
  is when both types got the host versions of a lot of functions we need
  to make useful heterogeneous things with them;
* disable both in host-only TU, as there's no easy way I could find to
  detect the condition above. I've included an opt-in macro for
  asserting that the headers (if available) are from a sufficiently new
  CTK, will add that to docs in a later commit.
NVCC is spewing code that makes various versions of clang unhappy about
a deprecated implicit copy constructor of a lambda wrapper, so just work
around that by not using one.
@griwes griwes added the libcu++ For all items related to libcu++ label Feb 27, 2024
@miscco miscco mentioned this pull request Feb 28, 2024
2 tasks
miscco and others added 2 commits March 11, 2024 18:46
@miscco miscco enabled auto-merge (squash) March 11, 2024 17:47
@leofang
Copy link
Member

leofang commented Mar 11, 2024

Note: As discussed offline, local tests show that at least on sm86/89 we need this patch for performance reasons. I haven't had a chance to test on sm70/80/90, though.

diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/complex b/libcudacxx/include/cuda/std/detail/libcxx/include/complex
index 3ba249779..416c0e71d 100644
--- a/libcudacxx/include/cuda/std/detail/libcxx/include/complex
+++ b/libcudacxx/include/cuda/std/detail/libcxx/include/complex
@@ -1702,6 +1702,16 @@ atanh(const complex<_Tp>& __x)
     return complex<_Tp>(__constexpr_copysign(__z.real(), __x.real()), __constexpr_copysign(__z.imag(), __x.imag()));
 }
 
+// we add a specialization for fp16 atanh because of performance issues
+template<>
+_LIBCUDACXX_INLINE_VISIBILITY complex<__half>
+atanh(const complex<__half>& __x)
+{
+    complex<float> __temp(__x);
+    __temp = _CUDA_VSTD::atanh(__temp);
+    return complex<__half>(__temp.real(), __temp.imag());
+}
+
 // sinh
 
 template<class _Tp>
@@ -1815,6 +1825,16 @@ atan(const complex<_Tp>& __x)
     return complex<_Tp>(__z.imag(), -__z.real());
 }
 
+// we add a specialization for fp16 atanh because of performance issues
+template<>
+_LIBCUDACXX_INLINE_VISIBILITY complex<__half>
+atan(const complex<__half>& __x)
+{
+    complex<float> __temp(__x);
+    __temp = _CUDA_VSTD::atan(__temp);
+    return complex<__half>(__temp.real(), __temp.imag());
+}
+
 // sin
 
 template<class _Tp>

@miscco
Copy link
Collaborator

miscco commented Mar 12, 2024

@leofang I added some workarounds for asinh acosh atanh and cosh

@miscco miscco merged commit ae0ee04 into NVIDIA:main Mar 12, 2024
584 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
libcu++ For all items related to libcu++
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Specializations of complex<T> for half and bfloat
6 participants