Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda::std::complex specializations for half and bfloat #1140

Merged
merged 37 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e578bfc
Complex for small float types.
griwes Nov 21, 2023
4e81a2a
Add a missing bfloat include to promote.h.
griwes Nov 22, 2023
cc0eeda
Only include under cudacc, constexpr fixes.
griwes Nov 22, 2023
b9b831f
Add tests and special cases for cmath functions for half/bfloat.
griwes Dec 19, 2023
5c2b197
Add an opt-out from including bf16, and respect CUB's opt-out.
griwes Jan 27, 2024
b139717
Detect existence of both fp headers, fix C++11.
griwes Jan 27, 2024
a126862
Silence unused function warnings from cuda_bf16.h in clang.
griwes Jan 27, 2024
205db7e
Rework the detection logic for FP16 and BF16.
griwes Jan 30, 2024
a7e1ec9
Test fixes.
griwes Feb 7, 2024
8121bba
Address review comments.
griwes Feb 7, 2024
d33debb
Use structs instead of extended lambdas in the float tests.
griwes Feb 7, 2024
67cdec9
Use the correct type to silence an msvc warning this time.
griwes Feb 7, 2024
40df11e
Also enable 16-bit complex with clang cuda (with CUDA 12.2+).
griwes Feb 9, 2024
ec0870e
Address the remaining review comment.
griwes Feb 13, 2024
340cce5
Documentation updates.
griwes Feb 13, 2024
f1d6771
Address review comments from miscco
miscco Feb 22, 2024
15ec0fa
Merge branch 'main' into pr/griwes/1140
miscco Feb 22, 2024
b47fa9e
Fix some compilers
miscco Feb 22, 2024
3796172
namespaces...
miscco Feb 23, 2024
1ab3940
Actually make the cmath subheaders work
miscco Feb 23, 2024
0073d2d
Do not mess up namespaces around includes
miscco Feb 23, 2024
a6e2a06
Use proper qualification
miscco Feb 23, 2024
53acac9
Add a reference to the hisinf NVCC bug.
griwes Feb 27, 2024
aae4fa5
Remove no longer needed #ifs.
griwes Feb 27, 2024
3c6f0f9
Merge remote-tracking branch 'origin/main' into feature/small-complex
griwes Feb 27, 2024
ec05b1c
Update the docs to mention the 2.4.0 version
griwes Feb 29, 2024
0a7b850
Update libcudacxx/docs/standard_api/numerics_library/complex.md
miscco Mar 11, 2024
df9a93c
Merge branch 'main' into feature/small-complex
miscco Mar 11, 2024
81172ac
Fix half and bfloat in ptx header
miscco Mar 11, 2024
8d7742f
Merge branch 'main' into pr/griwes/1140
miscco Mar 11, 2024
07c2ab6
Actually define the half / bfloat constructors from float / double
miscco Mar 12, 2024
74c86e8
Add fallbacks for trigonometrix functions for half / float
miscco Mar 12, 2024
b0c5a10
Actually reorg the whole half / bfloat organization
miscco Mar 12, 2024
7283372
Merge branch 'main' into feature/small-complex
miscco Mar 12, 2024
fafff52
Add `inline` to the trigonomentric specilaizations
miscco Mar 12, 2024
1b1e449
Add missing host device
miscco Mar 12, 2024
4b01e5b
Drop long double
miscco Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions libcudacxx/docs/standard_api/numerics_library/complex.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,22 @@ User-defined floating-point literals must be specified in terms of

## Customizations

### Handling of infinities

Our implementation by default recovers infinite values during multiplication and division. This adds a significant runtime overhead, so we allow disabling that canonicalization if it is not desired.

Definition of `LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS` disables canonicalization for both multiplication *and* division.

Definition of `LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_DIVISION` or `LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_DIVISION` disables canonicalization for multiplication or division individually.
miscco marked this conversation as resolved.
Show resolved Hide resolved

### Support for half and bfloat16 (since libcu++ 2.x.x)
griwes marked this conversation as resolved.
Show resolved Hide resolved
griwes marked this conversation as resolved.
Show resolved Hide resolved

Our implementation includes support for the `__half` type from `<cuda_fp16.h>`, when the CUDA toolkit version is at
least 12.2. This is detected automatically when compiling through NVCC. If you are compiling a host-only translation
unit directly with the host compiler, you must define the macro `LIBCUDACXX_ENABLE_HOST_NVFP16` prior to including any
libcu++ headers, and you must ensure that the `<cuda_fp16.h>` header that's found by the compiler comes from a CUDA
toolkit version 12.2 or higher.

Our implementation includes support for the `__nv_bfloat16` type from `<cuda_bf16.h>`, when the conditions for the
support of `__half` are fulfilled, and when `CUB_DISABLE_BF16_SUPPORT` is **not** defined.

Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ set(files
__cuda/barrier.h
__cuda/chrono.h
__cuda/climits_prelude.h
__cuda/cmath_nvbf16.h
__cuda/cmath_nvfp16.h
__cuda/complex_nvbf16.h
__cuda/complex_nvfp16.h
__cuda/cstddef_prelude.h
__cuda/cstdint_prelude.h
__cuda/latch.h
Expand Down
26 changes: 24 additions & 2 deletions libcudacxx/include/cuda/std/detail/libcxx/include/__config
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@
#if defined(_LIBCUDACXX_CUDACC) && _LIBCUDACXX_CUDACC_VER < 1108000
#define _LIBCUDACXX_CUDACC_BELOW_11_8
#endif // defined(_LIBCUDACXX_CUDACC) && _LIBCUDACXX_CUDACC_VER < 1108000
#if defined(_LIBCUDACXX_CUDACC) && _LIBCUDACXX_CUDACC_VER < 1202000
#define _LIBCUDACXX_CUDACC_BELOW_12_2
#endif // defined(_LIBCUDACXX_CUDACC) && _LIBCUDACXX_CUDACC_VER < 1203000
#if defined(_LIBCUDACXX_CUDACC) && _LIBCUDACXX_CUDACC_VER < 1203000
#define _LIBCUDACXX_CUDACC_BELOW_12_3
#endif // defined(_LIBCUDACXX_CUDACC) && _LIBCUDACXX_CUDACC_VER < 1203000
Expand Down Expand Up @@ -1151,8 +1154,27 @@ typedef __char32_t char32_t;
#endif
#endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE

#ifndef _LIBCUDACXX_HAS_NO_ATTRIBUTE_NO_UNIQUE_ADDRESS
#if __has_cpp_attribute(msvc::no_unique_address)
# ifndef _LIBCUDACXX_HAS_NVFP16
# if __has_include(<cuda_fp16.h>) \
&& defined(__cuda_std__) \
&& (defined(_LIBCUDACXX_COMPILER_CLANG_CUDA) || !defined(_LIBCUDACXX_CUDACC_BELOW_12_2)) \
&& (!defined(_LIBCUDACXX_COMPILER_CLANG_CUDA) || CUDA_VERSION >= 12020) \
&& (defined(_LIBCUDACXX_CUDACC) || defined(LIBCUDACXX_ENABLE_HOST_NVFP16))
# define _LIBCUDACXX_HAS_NVFP16
# endif
# endif // !_LIBCUDACXX_HAS_NVFP16

# ifndef _LIBCUDACXX_HAS_NVBF16
# if __has_include(<cuda_bf16.h>) \
&& defined(__cuda_std__) \
&& defined(_LIBCUDACXX_HAS_NVFP16) \
&& !defined(CUB_DISABLE_BF16_SUPPORT)
# define _LIBCUDACXX_HAS_NVBF16
# endif
# endif // !_LIBCUDACXX_HAS_NVBF16

# ifndef _LIBCUDACXX_HAS_NO_ATTRIBUTE_NO_UNIQUE_ADDRESS
# if __has_cpp_attribute(msvc::no_unique_address)
// MSVC implements [[no_unique_address]] as a silent no-op currently.
// (If/when MSVC breaks its C++ ABI, it will be changed to work as intended.)
// However, MSVC implements [[msvc::no_unique_address]] which does what
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___CUDA_CMATH_NVBF16_H
#define _LIBCUDACXX___CUDA_CMATH_NVBF16_H

#ifndef __cuda_std__
# include <config>
#endif // __cuda_std__

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#if defined(_LIBCUDACXX_HAS_NVBF16)

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP

# include <nv/target>

# include "../__type_traits/integral_constant.h"
# include "../cmath"

_LIBCUDACXX_BEGIN_NAMESPACE_STD

// trigonometric functions
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __nv_bfloat16(::sin(float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)
{
return __nv_bfloat16(::sinh(float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return hcos(__v);), (return __nv_bfloat16(::cos(float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
{
return __nv_bfloat16(::cosh(float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __nv_bfloat16(::exp(float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(::hypot(float(__x), float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(::atan2(float(__x), float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __nv_bfloat16(::log(float(__x)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __nv_bfloat16(::sqrt(float(__x)));))
}

// floating point helper
inline _LIBCUDACXX_INLINE_VISIBILITY bool signbit(__nv_bfloat16 __v)
{
return ::signbit(::__bfloat162float(__v));
}

inline _LIBCUDACXX_INLINE_VISIBILITY bool __constexpr_isnan(__nv_bfloat16 __x) noexcept
{
return ::__hisnan(__x);
}

inline _LIBCUDACXX_INLINE_VISIBILITY bool isnan(__nv_bfloat16 __v)
{
return __constexpr_isnan(__v);
}

inline _LIBCUDACXX_INLINE_VISIBILITY bool __constexpr_isinf(__nv_bfloat16 __x) noexcept
{
# if _CCCL_STD_VER >= 2020
// There's some sort of a bug with C++20 here.
// XXX nvbug number pending
return !::__hisnan(__x) && ::__hisnan(__x - __x);
# else // ^^^ C++20 ^^^ / vvv C++17 vvv
return ::__hisinf(__x) != 0;
# endif // _CCCL_STD_VER <= 2017
}

inline _LIBCUDACXX_INLINE_VISIBILITY bool isinf(__nv_bfloat16 __v)
{
return __constexpr_isinf(__v);
}

inline _LIBCUDACXX_INLINE_VISIBILITY bool __constexpr_isfinite(__nv_bfloat16 __x) noexcept
{
return !__constexpr_isnan(__x) && !__constexpr_isinf(__x);
}

inline _LIBCUDACXX_INLINE_VISIBILITY bool isfinite(__nv_bfloat16 __v)
{
return __constexpr_isfinite(__v);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
return __nv_bfloat16(::copysignf(float(__x), float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 copysign(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __constexpr_copysign(__x, __y);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_fabs(__nv_bfloat16 __x) noexcept
{
return ::__habs(__x);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 fabs(__nv_bfloat16 __x)
{
return __constexpr_fabs(__x);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 abs(__nv_bfloat16 __x)
{
return __constexpr_fabs(__x);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_fmax(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
return ::__hmax(__x, __y);
}

_LIBCUDACXX_END_NAMESPACE_STD

#endif /// _LIBCUDACXX_HAS_NVBF16

#endif // _LIBCUDACXX___CUDA_CMATH_NVBF16_H
Loading
Loading