diff --git a/.github/workflows/cmake-action.yml b/.github/workflows/cmake-action.yml new file mode 100644 index 0000000..fd621db --- /dev/null +++ b/.github/workflows/cmake-action.yml @@ -0,0 +1,47 @@ +name: CMake + +on: + workflow_call: + inputs: + cuda-version: + required: true + type: string + +env: + # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) + BUILD_TYPE: Debug + +jobs: + build: + # The CMake configure and build commands are platform agnostic and should work equally well on Windows or Mac. + # You can convert this to a matrix build if you need cross-platform coverage. + # See: https://docs.github.com/en/free-pro-team@latest/actions/learn-github-actions/managing-complex-workflows#using-a-build-matrix + runs-on: ubuntu-latest + + steps: + - uses: Jimver/cuda-toolkit@v0.2.11 + id: cuda-toolkit + with: + method: network + sub-packages: '["nvcc"]' + cuda: ${{ inputs.cuda-version }} + + - uses: actions/checkout@v3 + with: + submodules: 'true' + + - name: Configure CMake + # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. + # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type + run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DKERNEL_FLOAT_BUILD_TEST=1 -DKERNEL_FLOAT_BUILD_EXAMPLE=1 + + - name: Build + # Build your program with the given configuration + run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}} + + - name: Test + working-directory: ${{github.workspace}}/build + # Execute tests defined by the CMake configuration. + # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail + run: ./tests/kernel_float_tests --durations=yes --success --verbosity=high ~[GPU] + diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml new file mode 100644 index 0000000..136fcd3 --- /dev/null +++ b/.github/workflows/cmake.yml @@ -0,0 +1,28 @@ +name: CMake + +on: + push: + pull_request: + branches: [ "main" ] + +env: + # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) + BUILD_TYPE: Debug + +jobs: + build-cuda: + uses: ./.github/workflows/cmake-action.yml + with: + cuda-version: "12.2.0" + + build-cuda-11-7: + needs: build-cuda + uses: ./.github/workflows/cmake-action.yml + with: + cuda-version: "11.7.0" + + build-cuda-12-0: + needs: build-cuda + uses: ./.github/workflows/cmake-action.yml + with: + cuda-version: "12.0.0" diff --git a/combine.py b/combine.py index c22e2ba..5a7c857 100644 --- a/combine.py +++ b/combine.py @@ -2,6 +2,24 @@ import subprocess from datetime import datetime +license_boilerplate = """/* + * Kernel Float: Header-only library for vector types and reduced precision floating-point math. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +""" + directory = "include/kernel_float" contents = dict() @@ -28,7 +46,8 @@ except Exception as e: print(f"warning: {e}") -output = "\n".join([ +output = license_boilerplate +output += "\n".join([ "//" + "=" * 80, "// this file has been auto-generated, do not modify its contents!", f"// date: {date}", diff --git a/docs/api.rst b/docs/api.rst index e525b1c..85b407a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3,8 +3,10 @@ API Reference .. toctree:: api/types.rst api/primitives.rst + api/generation.rst api/unary_operators.rst api/binary_operators.rst api/reductions.rst - api/shuffling.rst api/mathematical.rst + api/conditional.rst + diff --git a/docs/build_api.py b/docs/build_api.py index 178769f..ee15b1e 100644 --- a/docs/build_api.py +++ b/docs/build_api.py @@ -65,51 +65,51 @@ def build_index_page(groups): return filename -aliases = [] -for ty in ["vec", "float", "double", "half", "bfloat16x", ""]: - if ty != "vec": - aliases.append(f"{ty}X") - +aliases = ["scalar", "vec"] +for ty in ["vec"]: for i in range(2, 8 + 1): aliases.append(f"{ty}{i}") groups = { "Types": [ ("vector", "vector", "struct"), - ("Aliases", [ - "unaligned_vec", - "vec", - ] + aliases, - "typedef"), + ("Aliases", aliases, "typedef"), ], "Primitives": [ - ("range", "range()"), - ("range", "range(F)"), "map", "reduce", "zip", "zip_common", "cast", "broadcast", - "resize", - "for_each", - ], - "Shuffling": [ + "convert", + "make_vec", + "into_vector", "concat", - "swizzle", - "first", - "last", - "reversed", - "rotate_left", - "rotate_right", + "select", + "for_each", ], - "Unary Operators": [ + "Generation": [ + "range", + "range_like", + "each_index", "fill", "fill_like", "zeros", "zeros_like", "ones", "ones_like", + ], + "Shuffling": [ + # "concat", + # "swizzle", + # "first", + # "last", + # "reversed", + # "rotate_left", + # "rotate_right", + ], + "Unary Operators": [ "negate", "bit_not", "logical_not", @@ -135,21 +135,21 @@ def build_index_page(groups): ("min", "min(L&&, R&&)"), "nextafter", "modf", - "pow", + ("pow", "pow(L&&, R&&)"), "remainder", #"rhypot", ], "Reductions": [ "sum", - ("max", "max(V&&)"), - ("min", "min(V&&)"), + ("max", "max(const V&)"), + ("min", "min(const V&)"), "product", "all", "any", "count", ], "Mathematical": [ - "abs", + ("abs", "abs(const V&)"), "acos", "acosh", "asin", @@ -166,14 +166,14 @@ def build_index_page(groups): "erfcinv", "erfcx", "erfinv", - "exp", + ("exp", "exp(const V&)"), "exp10", "exp2", "fabs", "floor", "ilogb", "lgamma", - "log", + ("log", "log(const V&)"), "log10", "logb", "nearbyint", @@ -181,7 +181,7 @@ def build_index_page(groups): "rcbrt", "sin", "sinh", - "sqrt", + ("sqrt", "sqrt(const V&)"), "tan", "tanh", "tgamma", @@ -193,6 +193,11 @@ def build_index_page(groups): "isinf", "isnan", ], + "Conditional": [ + ("where", "where(const C&, const L&, const R&)"), + ("where", "where(const C&, const L&)"), + ("where", "where(const C&)"), + ] } build_index_page(groups) diff --git a/examples/vector_add/main.cu b/examples/vector_add/main.cu index fe69857..ea78d1a 100644 --- a/examples/vector_add/main.cu +++ b/examples/vector_add/main.cu @@ -4,9 +4,7 @@ #include #include "kernel_float.h" -namespace kf = kernel_float; - -using x = kf::half; +using namespace kernel_float::prelude; void cuda_check(cudaError_t code) { if (code != cudaSuccess) { @@ -15,11 +13,7 @@ void cuda_check(cudaError_t code) { } template -__global__ void my_kernel( - int length, - const kf::unaligned_vec<__half, N>* input, - double constant, - kf::unaligned_vec* output) { +__global__ void my_kernel(int length, const khalf* input, double constant, kfloat* output) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * N < length) { @@ -30,24 +24,24 @@ __global__ void my_kernel( template void run_kernel(int n) { double constant = 1.0; - std::vector<__half> input(n); + std::vector input(n); std::vector output_expected; std::vector output_result; // Generate input data for (int i = 0; i < n; i++) { - input[i] = __half(i); + input[i] = half(i); output_expected[i] = float(i + constant); } // Allocate device memory - kf::unaligned_vec<__half, items_per_thread>* input_dev; - kf::unaligned_vec* output_dev; - cuda_check(cudaMalloc(&input_dev, sizeof(__half) * n)); + khalf* input_dev; + kfloat* output_dev; + cuda_check(cudaMalloc(&input_dev, sizeof(half) * n)); cuda_check(cudaMalloc(&output_dev, sizeof(float) * n)); // Copy device memory - cuda_check(cudaMemcpy(input_dev, input.data(), sizeof(__half) * n, cudaMemcpyDefault)); + cuda_check(cudaMemcpy(input_dev, input.data(), sizeof(half) * n, cudaMemcpyDefault)); // Launch kernel! int block_size = 256; diff --git a/include/kernel_float.h b/include/kernel_float.h index 93ada9c..f2b796d 100644 --- a/include/kernel_float.h +++ b/include/kernel_float.h @@ -1,18 +1,18 @@ #ifndef KERNEL_FLOAT_H #define KERNEL_FLOAT_H +#include "kernel_float/base.h" #include "kernel_float/bf16.h" #include "kernel_float/binops.h" -#include "kernel_float/cast.h" +#include "kernel_float/conversion.h" #include "kernel_float/fp16.h" -#include "kernel_float/fp8.h" -#include "kernel_float/interface.h" #include "kernel_float/iterate.h" #include "kernel_float/macros.h" #include "kernel_float/meta.h" +#include "kernel_float/prelude.h" #include "kernel_float/reduce.h" -#include "kernel_float/storage.h" -#include "kernel_float/swizzle.h" +#include "kernel_float/triops.h" #include "kernel_float/unops.h" +#include "kernel_float/vector.h" -#endif \ No newline at end of file +#endif diff --git a/include/kernel_float/base.h b/include/kernel_float/base.h new file mode 100644 index 0000000..b3edb20 --- /dev/null +++ b/include/kernel_float/base.h @@ -0,0 +1,337 @@ +#ifndef KERNEL_FLOAT_BASE_H +#define KERNEL_FLOAT_BASE_H + +#include "macros.h" +#include "meta.h" + +namespace kernel_float { + +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + T* data() { + return items_; + } + + KERNEL_FLOAT_INLINE + const T* data() const { + return items_; + } + + T items_[N] = {}; +}; + +template +struct aligned_array { + KERNEL_FLOAT_INLINE + T* data() { + while (true) + ; + } + + KERNEL_FLOAT_INLINE + const T* data() const { + while (true) + ; + } +}; + +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T value = {}) : x(value) {} + + KERNEL_FLOAT_INLINE + operator T() const { + return x; + } + + KERNEL_FLOAT_INLINE + T* data() { + return &x; + } + + KERNEL_FLOAT_INLINE + const T* data() const { + return &x; + } + + T x; +}; + +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T x, T y) : x(x), y(y) {} + + KERNEL_FLOAT_INLINE + aligned_array() : aligned_array(T {}, T {}) {} + + KERNEL_FLOAT_INLINE + T* data() { + return items; + } + + KERNEL_FLOAT_INLINE + const T* data() const { + return items; + } + + union { + T items[2]; + struct { + T x; + T y; + }; + }; +}; + +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T x, T y, T z) : x(x), y(y), z(z) {} + + KERNEL_FLOAT_INLINE + aligned_array() : aligned_array(T {}, T {}, T {}) {} + + KERNEL_FLOAT_INLINE + T* data() { + return items; + } + + KERNEL_FLOAT_INLINE + const T* data() const { + return items; + } + + union { + T items[3]; + struct { + T x; + T y; + T z; + }; + }; +}; + +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {} + + KERNEL_FLOAT_INLINE + aligned_array() : aligned_array(T {}, T {}, T {}, T {}) {} + + KERNEL_FLOAT_INLINE + T* data() { + return items; + } + + KERNEL_FLOAT_INLINE + const T* data() const { + return items; + } + + union { + T items[4]; + struct { + T x; + T y; + T z; + T w; + }; + }; +}; + +KERNEL_FLOAT_INLINE +static constexpr size_t compute_max_alignment(size_t total_size, size_t min_align) { + if (total_size % 32 == 0 || min_align >= 32) { + return 32; + } else if (total_size % 16 == 0 || min_align == 16) { + return 16; + } else if (total_size % 8 == 0 || min_align == 8) { + return 8; + } else if (total_size % 4 == 0 || min_align == 4) { + return 4; + } else if (total_size % 2 == 0 || min_align == 2) { + return 2; + } else { + return 1; + } +} + +template +using vector_storage = aligned_array; + +template +struct extent; + +template +struct extent { + static constexpr size_t value = N; + static constexpr size_t size = N; +}; + +template +struct into_vector_impl { + using value_type = T; + using extent_type = extent<1>; + + KERNEL_FLOAT_INLINE + static vector_storage call(const T& input) { + return vector_storage {input}; + } +}; + +template +struct into_vector_impl { + using value_type = T; + using extent_type = extent; + + KERNEL_FLOAT_INLINE + static vector_storage call(const T (&input)[N]) { + return call(input, make_index_sequence()); + } + + private: + template + KERNEL_FLOAT_INLINE static vector_storage + call(const T (&input)[N], index_sequence) { + return {input[Is]...}; + } +}; + +template +struct into_vector_impl: into_vector_impl {}; + +template +struct into_vector_impl: into_vector_impl {}; + +template +struct into_vector_impl: into_vector_impl {}; + +template +struct into_vector_impl: into_vector_impl {}; + +template +struct into_vector_impl> { + using value_type = T; + using extent_type = extent; + + KERNEL_FLOAT_INLINE + static vector_storage call(const aligned_array& input) { + return input; + } +}; + +#define KERNEL_FLOAT_DEFINE_VECTOR_TYPE(T, T1, T2, T3, T4) \ + template<> \ + struct into_vector_impl<::T1> { \ + using value_type = T; \ + using extent_type = extent<1>; \ + \ + KERNEL_FLOAT_INLINE \ + static vector_storage call(::T1 v) { \ + return {v.x}; \ + } \ + }; \ + \ + template<> \ + struct into_vector_impl<::T2> { \ + using value_type = T; \ + using extent_type = extent<2>; \ + \ + KERNEL_FLOAT_INLINE \ + static vector_storage call(::T2 v) { \ + return {v.x, v.y}; \ + } \ + }; \ + \ + template<> \ + struct into_vector_impl<::T3> { \ + using value_type = T; \ + using extent_type = extent<3>; \ + \ + KERNEL_FLOAT_INLINE \ + static vector_storage call(::T3 v) { \ + return {v.x, v.y, v.z}; \ + } \ + }; \ + \ + template<> \ + struct into_vector_impl<::T4> { \ + using value_type = T; \ + using extent_type = extent<4>; \ + \ + KERNEL_FLOAT_INLINE \ + static vector_storage call(::T4 v) { \ + return {v.x, v.y, v.z, v.w}; \ + } \ + }; + +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(char, char1, char2, char3, char4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(short, short1, short2, short3, short4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(int, int1, int2, int3, int4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(long, long1, long2, long3, long4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(long long, longlong1, longlong2, longlong3, longlong4) + +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned char, uchar1, uchar2, uchar3, uchar4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned short, ushort1, ushort2, ushort3, ushort4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned int, uint1, uint2, uint3, uint4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned long, ulong1, ulong2, ulong3, ulong4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned long long, ulonglong1, ulonglong2, ulonglong3, ulonglong4) + +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(float, float1, float2, float3, float4) +KERNEL_FLOAT_DEFINE_VECTOR_TYPE(double, double1, double2, double3, double4) + +template> +struct vector; + +template +struct into_vector_impl> { + using value_type = T; + using extent_type = E; + + KERNEL_FLOAT_INLINE + static vector_storage call(const vector& input) { + return input.storage(); + } +}; + +template +struct vector_traits; + +template +struct vector_traits> { + using value_type = T; + using extent_type = E; + using storage_type = S; + using vector_type = vector; +}; + +template +using vector_value_type = typename into_vector_impl::value_type; + +template +using vector_extent_type = typename into_vector_impl::extent_type; + +template +static constexpr size_t vector_extent = vector_extent_type::value; + +template +using into_vector_type = vector, vector_extent_type>; + +template +using vector_storage_type = vector_storage, vector_extent>; + +template +using promoted_vector_value_type = promote_t...>; + +template +KERNEL_FLOAT_INLINE vector_storage_type into_vector_storage(V&& input) { + return into_vector_impl::call(std::forward(input)); +} + +} // namespace kernel_float + +#endif diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 8406615..9580a69 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -7,130 +7,196 @@ #include #include "binops.h" -#include "cast.h" -#include "interface.h" -#include "storage.h" -#include "unops.h" +#include "reduce.h" +#include "vector.h" namespace kernel_float { -KERNEL_FLOAT_DEFINE_COMMON_TYPE(__nv_bfloat16, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, __nv_bfloat16) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16) template<> -struct vector_traits<__nv_bfloat162> { +struct into_vector_impl<__nv_bfloat162> { using value_type = __nv_bfloat16; - static constexpr size_t size = 2; + using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static __nv_bfloat162 fill(__nv_bfloat16 value) { -#if KERNEL_FLOAT_ON_DEVICE - return __bfloat162bfloat162(value); -#else - return {value, value}; -#endif + static vector_storage<__nv_bfloat16, 2> call(__nv_bfloat162 input) { + return {input.x, input.y}; } +}; +namespace detail { +template +struct map_bfloat16x2 { KERNEL_FLOAT_INLINE - static __nv_bfloat162 create(__nv_bfloat16 low, __nv_bfloat16 high) { -#if KERNEL_FLOAT_ON_DEVICE - return __halves2bfloat162(low, high); -#else - return {low, high}; -#endif + static __nv_bfloat162 call(F fun, __nv_bfloat162 input) { + __nv_bfloat16 a = fun(input.x); + __nv_bfloat16 b = fun(input.y); + return {a, b}; } +}; +template +struct zip_bfloat16x2 { KERNEL_FLOAT_INLINE - static __nv_bfloat16 get(__nv_bfloat162 self, size_t index) { -#if KERNEL_FLOAT_ON_DEVICE - if (index == 0) { - return __low2bfloat16(self); - } else { - return __high2bfloat16(self); + static __nv_bfloat162 call(F fun, __nv_bfloat162 left, __nv_bfloat162 right) { + __nv_bfloat16 a = fun(left.x, left.y); + __nv_bfloat16 b = fun(right.y, right.y); + return {a, b}; + } +}; + +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, __nv_bfloat16* result, const __nv_bfloat16* input) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __nv_bfloat162 a = {input[2 * i], input[2 * i + 1]}; + __nv_bfloat162 b = map_bfloat16x2::call(fun, a); + result[2 * i + 0] = b.x; + result[2 * i + 1] = b.y; } -#else - if (index == 0) { - return self.x; - } else { - return self.y; + + if (N % 2 != 0) { + result[N - 1] = fun(input[N - 1]); } -#endif } +}; - KERNEL_FLOAT_INLINE - static void set(__nv_bfloat162& self, size_t index, __nv_bfloat16 value) { - if (index == 0) { - self.x = value; - } else { - self.y = value; +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void + call(F fun, __nv_bfloat16* result, const __nv_bfloat16* left, const __nv_bfloat16* right) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __nv_bfloat162 a = {left[2 * i], left[2 * i + 1]}; + __nv_bfloat162 b = {right[2 * i], right[2 * i + 1]}; + __nv_bfloat162 c = zip_bfloat16x2::call(fun, a, b); + result[2 * i + 0] = c.x; + result[2 * i + 1] = c.y; + } + + if (N % 2 != 0) { + result[N - 1] = fun(left[N - 1], right[N - 1]); } } }; -template -struct default_storage<__nv_bfloat16, N, Alignment::Maximum, enabled_t<(N >= 2)>> { - using type = nested_array<__nv_bfloat162, N>; -}; +template +struct reduce_impl= 2)>> { + KERNEL_FLOAT_INLINE static __nv_bfloat16 call(F fun, const __nv_bfloat16* input) { + __nv_bfloat162 accum = {input[0], input[1]}; -template -struct default_storage<__nv_bfloat16, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 == 0)>> { - using type = nested_array<__nv_bfloat162, N>; +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __nv_bfloat162 a = {input[2 * i], input[2 * i + 1]}; + accum = zip_bfloat16x2::call(fun, accum, a); + } + + __nv_bfloat16 result = fun(accum.x, accum.y); + + if (N % 2 != 0) { + result = fun(result, input[N - 1]); + } + + return result; + } }; +} // namespace detail + +#define KERNEL_FLOAT_BF16_UNARY_FORWARD(NAME) \ + namespace ops { \ + template<> \ + struct NAME<__nv_bfloat16> { \ + KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ + return __nv_bfloat16(ops::NAME {}(float(input))); \ + } \ + }; \ + } -#if KERNEL_FLOAT_ON_DEVICE -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ +// There operations are not implemented in half precision, so they are forward to single precision +KERNEL_FLOAT_BF16_UNARY_FORWARD(tan) +KERNEL_FLOAT_BF16_UNARY_FORWARD(asin) +KERNEL_FLOAT_BF16_UNARY_FORWARD(acos) +KERNEL_FLOAT_BF16_UNARY_FORWARD(atan) +KERNEL_FLOAT_BF16_UNARY_FORWARD(expm1) + +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__nv_bfloat16> { \ + KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct map_bfloat16x2> { \ + KERNEL_FLOAT_INLINE static __nv_bfloat162 \ + call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \ + return FUN2(input); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) KERNEL_FLOAT_BF16_UNARY_FORWARD(NAME) +#endif + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) +KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) +KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) +KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) +KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) +KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) +KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) +KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) + +KERNEL_FLOAT_BF16_UNARY_FUN(fast_exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(fast_log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(fast_cos, ::hcos, ::h2cos) +KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin) + +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__nv_bfloat16> { \ + KERNEL_FLOAT_INLINE __nv_bfloat16 \ + operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ + return FUN1(left, right); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct zip_bfloat16x2> { \ + KERNEL_FLOAT_INLINE static __nv_bfloat162 \ + call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 left, __nv_bfloat162 right) { \ + return FUN2(left, right); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct map_helper, __nv_bfloat162, __nv_bfloat162> { \ - KERNEL_FLOAT_INLINE static __nv_bfloat162 \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \ - return FUN2(input); \ + KERNEL_FLOAT_INLINE __nv_bfloat16 \ + operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ + return __nv_bfloat16(ops::NAME {}(float(left), float(right))); \ } \ }; \ } - -KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2); -KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2); -KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil); -KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos); -KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp); -KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10); -KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor); -KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log); -KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2); -KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint); -KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt); -KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin); -KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt); -KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc); - -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 \ - operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct zip_helper, __nv_bfloat162, __nv_bfloat162, __nv_bfloat162> { \ - KERNEL_FLOAT_INLINE static __nv_bfloat162 \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 left, __nv_bfloat162 right) { \ - return FUN2(left, right); \ - } \ - }; \ - } +#endif KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) @@ -139,6 +205,8 @@ KERNEL_FLOAT_BF16_BINARY_FUN(divide, __hdiv, __h2div) KERNEL_FLOAT_BF16_BINARY_FUN(min, __hmin, __hmin2) KERNEL_FLOAT_BF16_BINARY_FUN(max, __hmax, __hmax2) +KERNEL_FLOAT_BF16_BINARY_FUN(fast_div, __hdiv, __h2div) + KERNEL_FLOAT_BF16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_BF16_BINARY_FUN(not_equal_to, __heq, __heq2) KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2) @@ -146,8 +214,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) -#endif - #define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ @@ -167,38 +233,76 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), double(__bfloat162float(input))); KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input)); +// clang-format off // there are no official char casts. Instead, cast to int and then to char KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST( - signed char, - __int2bfloat16_rn(input), - (signed char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST( - unsigned char, - __int2bfloat16_rn(input), - (unsigned char)__bfloat162int_rz(input)); +KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input)); +KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input)); KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST( - signed long, - __ll2bfloat16_rn(input), - (signed long)(__bfloat162ll_rz(input))); +KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input))); KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input)); KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST( - unsigned long, - __ull2bfloat16_rn(input), - (unsigned long)(__bfloat162ull_rz(input))); +KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); +// clang-format on using bfloat16 = __nv_bfloat16; -//KERNEL_FLOAT_TYPE_ALIAS(half, __nv_bfloat16) //KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16) //KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16) +#if KERNEL_FLOAT_IS_DEVICE +namespace detail { +template<> +struct dot_impl<__nv_bfloat16, 0> { + KERNEL_FLOAT_INLINE + static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) { + return __nv_bfloat16(0); + } +}; + +template<> +struct dot_impl<__nv_bfloat16, 1> { + KERNEL_FLOAT_INLINE + static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) { + return __hmul(left[0], right[0]); + } +}; + +template +struct dot_impl<__nv_bfloat16, N> { + static_assert(N >= 2, "internal error"); + + KERNEL_FLOAT_INLINE + static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) { + __nv_bfloat162 first_a = {left[0], left[1]}; + __nv_bfloat162 first_b = {right[0], right[1]}; + __nv_bfloat162 accum = __hmul2(first_a, first_b); + +#pragma unroll + for (size_t i = 2; i + 1 < N; i += 2) { + __nv_bfloat162 a = {left[i], left[i + 1]}; + __nv_bfloat162 b = {right[i], right[i + 1]}; + accum = __hfma2(a, b, accum); + } + + __nv_bfloat16 result = __hadd(accum.x, accum.y); + + if (N % 2 != 0) { + __nv_bfloat16 a = left[N - 1]; + __nv_bfloat16 b = right[N - 1]; + result = __hfma(a, b, result); + } + + return result; + } +}; +} // namespace detail +#endif + } // namespace kernel_float #if KERNEL_FLOAT_FP16_AVAILABLE @@ -206,7 +310,18 @@ using bfloat16 = __nv_bfloat16; namespace kernel_float { KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input)); -} + +template<> +struct promote_type<__nv_bfloat16, __half> { + using type = float; +}; + +template<> +struct promote_type<__half, __nv_bfloat16> { + using type = float; +}; + +} // namespace kernel_float #endif // KERNEL_FLOAT_FP16_AVAILABLE #endif diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 6b9bf69..705562b 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -1,148 +1,118 @@ #ifndef KERNEL_FLOAT_BINOPS_H #define KERNEL_FLOAT_BINOPS_H +#include "conversion.h" #include "unops.h" namespace kernel_float { -namespace detail { -template -struct zip_helper { - KERNEL_FLOAT_INLINE static Output call(F fun, const Left& left, const Right& right) { - return call_with_indices(fun, left, right, make_index_sequence> {}); - } - - private: - template - KERNEL_FLOAT_INLINE static Output - call_with_indices(F fun, const Left& left, const Right& right, index_sequence = {}) { - return vector_traits::create(fun(vector_get(left), vector_get(right))...); - } -}; - -template -struct zip_helper, nested_array, nested_array> { - KERNEL_FLOAT_INLINE static nested_array - call(F fun, const nested_array& left, const nested_array& right) { - return call(fun, left, right, make_index_sequence::num_packets> {}); - } - - private: - template - KERNEL_FLOAT_INLINE static nested_array call( - F fun, - const nested_array& left, - const nested_array& right, - index_sequence) { - return {zip_helper::call(fun, left[Is], right[Is])...}; - } -}; -}; // namespace detail - -template -using common_vector_value_type = common_t...>; - -template -static constexpr size_t common_vector_size = common_size...>; template -using zip_type = default_storage_type< +using zip_type = vector< result_t, vector_value_type>, - common_vector_size>; + broadcast_vector_extent_type>; /** - * Applies ``fun`` to each pair of two elements from ``left`` and ``right`` and returns a new - * vector with the results. - * - * If ``left`` and ``right`` are not the same size, they will first be broadcast into a - * common size using ``resize``. + * Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary + * function (`fun`) to each pair of corresponding elements. * - * Note that this function does **not** cast the input vectors to a common element type. See - * ``zip_common`` for that functionality. + * Example + * ======= + * ``` + * vec make_negative = {true, false, true}; + * vec input = {1, 2, 3}; + * vec output = zip([](bool b, int n){ return b ? -n : +n; }, make_negative, input); // returns [-1, 2, -3] + * ``` */ -template> -KERNEL_FLOAT_INLINE vector zip(F fun, Left&& left, Right&& right) { - static constexpr size_t N = vector_size; - using LeftInput = default_storage_type, N>; - using RightInput = default_storage_type, N>; - - return detail::zip_helper::call( +template +KERNEL_FLOAT_INLINE zip_type zip(F fun, const L& left, const R& right) { + using A = vector_value_type; + using B = vector_value_type; + using O = result_t; + using E = broadcast_vector_extent_type; + vector_storage result; + + detail::apply_impl::call( fun, - broadcast(std::forward(left)), - broadcast(std::forward(right))); + result.data(), + detail::broadcast_impl, E>::call(into_vector_storage(left)).data(), + detail::broadcast_impl, E>::call(into_vector_storage(right)) + .data()); + + return result; } template -using zip_common_type = default_storage_type< - result_t, common_vector_value_type>, - common_vector_size>; +using zip_common_type = vector< + result_t, promoted_vector_value_type>, + broadcast_vector_extent_type>; /** - * Applies ``fun`` to each pair of two elements from ``left`` and ``right`` and returns a new - * vector with the results. - * - * If ``left`` and ``right`` are not the same size, they will first be broadcast into a - * common size using ``resize``. - * - * If ``left`` and ``right`` are not of the same type, they will first be case into a common - * data type. For example, zipping ``float`` and ``double`` first cast vectors to ``double``. + * Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary + * function (`fun`) to each pair of corresponding elements. The elements are promoted to a common type before applying + * the binary function. * * Example * ======= * ``` - * vec x = {1, 2, 3, 4}; - * vec = {8}; - * vec = zip_common([](auto a, auto b){ return a + b; }, x, y); // [9, 10, 11, 12] + * vec a = {1.0f, 2.0f, 3.0f}; + * vec b = {4, 5, 6}; + * vec c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f] * ``` */ -template< - typename F, - typename Left, - typename Right, - typename Output = zip_common_type> -KERNEL_FLOAT_INLINE vector zip_common(F fun, Left&& left, Right&& right) { - static constexpr size_t N = vector_size; - using C = common_t, vector_value_type>; - using Input = default_storage_type; - - return detail::zip_helper::call( +template +KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, const R& right) { + using T = promoted_vector_value_type; + using O = result_t; + using E = broadcast_vector_extent_type; + + vector_storage result; + + detail::apply_impl::call( fun, - broadcast(std::forward(left)), - broadcast(std::forward(right))); + result.data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(left)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(right)) + .data()); + + return result; } -#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \ - namespace ops { \ - template \ - struct NAME { \ - KERNEL_FLOAT_INLINE T operator()(T left, T right) { \ - return T(EXPR); \ - } \ - }; \ - } \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \ + namespace ops { \ + template \ + struct NAME { \ + KERNEL_FLOAT_INLINE T operator()(T left, T right) { \ + return T(EXPR); \ + } \ + }; \ + } \ + template> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ } -#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \ - KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> operator OP( \ - const vector& left, \ - const vector& right) { \ - return zip_common(ops::NAME {}, left, right); \ - } \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> operator OP( \ - const vector& left, \ - const R& right) { \ - return zip_common(ops::NAME {}, left, right); \ - } \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> operator OP( \ - const L& left, \ - const vector& right) { \ - return zip_common(ops::NAME {}, left, right); \ +#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \ + KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \ + template, typename E1, typename E2> \ + KERNEL_FLOAT_INLINE zip_common_type, vector, vector> operator OP( \ + const vector& left, \ + const vector& right) { \ + return zip_common(ops::NAME {}, left, right); \ + } \ + template>, typename E> \ + KERNEL_FLOAT_INLINE zip_common_type, vector, R> operator OP( \ + const vector& left, \ + const R& right) { \ + return zip_common(ops::NAME {}, left, right); \ + } \ + template, R>, typename E> \ + KERNEL_FLOAT_INLINE zip_common_type, L, vector> operator OP( \ + const L& left, \ + const vector& right) { \ + return zip_common(ops::NAME {}, left, right); \ } KERNEL_FLOAT_DEFINE_BINARY_OP(add, +) @@ -163,28 +133,29 @@ KERNEL_FLOAT_DEFINE_BINARY_OP(bit_or, |) KERNEL_FLOAT_DEFINE_BINARY_OP(bit_xor, ^) // clang-format off -template typename F, typename L, typename R> -static constexpr bool vector_assign_allowed = - common_vector_size == vector_size && - is_implicit_convertible< - result_t< - F, vector_value_type>>, - vector_value_type, - vector_value_type - >, - vector_value_type - >; +template typename F, typename T, typename E, typename R> +static constexpr bool is_vector_assign_allowed = + is_vector_broadcastable && + is_implicit_convertible< + result_t< + F>>, + T, + vector_value_type + >, + T + >; // clang-format on -#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \ - template< \ - typename L, \ - typename R, \ - typename T = enabled_t, vector_value_type>> \ - KERNEL_FLOAT_INLINE vector& operator OP(vector& lhs, const R& rhs) { \ - using F = ops::NAME; \ - lhs = zip_common(F {}, lhs.storage(), rhs); \ - return lhs; \ +#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \ + template< \ + typename T, \ + typename E, \ + typename R, \ + typename = enable_if_t>> \ + KERNEL_FLOAT_INLINE vector& operator OP(vector& lhs, const R& rhs) { \ + using F = ops::NAME; \ + lhs = zip_common(F {}, lhs, rhs); \ + return lhs; \ } KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(add, +=) @@ -201,16 +172,66 @@ KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_xor, ^=) KERNEL_FLOAT_DEFINE_BINARY_FUN(min) KERNEL_FLOAT_DEFINE_BINARY_FUN(max) KERNEL_FLOAT_DEFINE_BINARY_FUN(copysign) -KERNEL_FLOAT_DEFINE_BINARY_FUN(hypot) KERNEL_FLOAT_DEFINE_BINARY_FUN(modf) KERNEL_FLOAT_DEFINE_BINARY_FUN(nextafter) KERNEL_FLOAT_DEFINE_BINARY_FUN(pow) KERNEL_FLOAT_DEFINE_BINARY_FUN(remainder) -#if KERNEL_FLOAT_CUDA_DEVICE -KERNEL_FLOAT_DEFINE_BINARY_FUN(rhypot) +KERNEL_FLOAT_DEFINE_BINARY(hypot, (ops::sqrt()(left * left + right * right))) +KERNEL_FLOAT_DEFINE_BINARY(rhypot, (T(1) / ops::hypot()(left, right))) + +namespace ops { +template<> +struct hypot { + KERNEL_FLOAT_INLINE double operator()(double left, double right) { + return ::hypot(left, right); + }; +}; + +template<> +struct hypot { + KERNEL_FLOAT_INLINE float operator()(float left, float right) { + return ::hypotf(left, right); + }; +}; + +// rhypot is only support on the GPU +#if KERNEL_FLOAT_IS_DEVICE +template<> +struct rhypot { + KERNEL_FLOAT_INLINE double operator()(double left, double right) { + return ::rhypot(left, right); + }; +}; + +template<> +struct rhypot { + KERNEL_FLOAT_INLINE float operator()(float left, float right) { + return ::rhypotf(left, right); + }; +}; +#endif +}; // namespace ops + +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_BINARY(FUN_NAME, ops::OP_NAME {}(left, right)) \ + namespace ops { \ + template<> \ + struct OP_NAME { \ + KERNEL_FLOAT_INLINE float operator()(float left, float right) { \ + return FLOAT_FUN(left, right); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_BINARY(FUN_NAME, ops::OP_NAME {}(left, right)) #endif +KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_div, divide, __fdividef) +KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_pow, pow, __powf) + namespace ops { template<> struct add { @@ -269,6 +290,39 @@ struct bit_xor { }; }; // namespace ops +namespace detail { +template +struct cross_impl { + KERNEL_FLOAT_INLINE + static vector> + call(const vector_storage& av, const vector_storage& bv) { + auto a = av.data(); + auto b = bv.data(); + vector> v0 = {a[1], a[2], a[0], a[2], a[0], a[1]}; + vector> v1 = {b[2], b[0], b[1], b[1], b[2], b[0]}; + vector> rv = v0 * v1; + + auto r = rv.data(); + vector> r0 = {r[0], r[1], r[2]}; + vector> r1 = {r[3], r[4], r[5]}; + return r0 - r1; + } +}; +}; // namespace detail + +/** + * Calculates the cross-product between two vectors of length 3. + */ +template< + typename L, + typename R, + typename T = promoted_vector_value_type, + typename = + enable_if_t> && is_vector_broadcastable>>> +KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { + return detail::cross_impl::call(convert_storage(left), convert_storage(right)); +} + } // namespace kernel_float -#endif //KERNEL_FLOAT_BINOPS_H +#endif diff --git a/include/kernel_float/cast.h b/include/kernel_float/cast.h deleted file mode 100644 index f88ebc8..0000000 --- a/include/kernel_float/cast.h +++ /dev/null @@ -1,203 +0,0 @@ -#ifndef KERNEL_FLOAT_CAST_H -#define KERNEL_FLOAT_CAST_H - -#include "storage.h" - -namespace kernel_float { -namespace ops { -template -struct cast { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return R(input); - } -}; - -template -struct cast { - KERNEL_FLOAT_INLINE T operator()(T input) noexcept { - return input; - } -}; -} // namespace ops - -namespace detail { - -// Cast a vector of type `Input` to type `Output`. Vectors must have the same size. -// The input vector has value type `T` -// The output vector has value type `R` -template< - typename Input, - typename Output, - typename T = vector_value_type, - typename R = vector_value_type> -struct cast_helper { - static_assert(vector_size == vector_size, "sizes must match"); - static constexpr size_t N = vector_size; - - KERNEL_FLOAT_INLINE static Output call(const Input& input) { - return call(input, make_index_sequence {}); - } - - private: - template - KERNEL_FLOAT_INLINE static Output call(const Input& input, index_sequence) { - ops::cast fun; - return vector_traits::create(fun(vector_get(input))...); - } -}; - -// Cast a vector of type `Input` to type `Output`. -// The input vector has value type `T` and size `N`. -// The output vector has value type `R` and size `M`. -template< - typename Input, - typename Output, - typename T = vector_value_type, - size_t N = vector_size, - typename R = vector_value_type, - size_t M = vector_size> -struct broadcast_helper; - -// T[1] => T[1] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Vector call(Vector input) { - return input; - } -}; - -// T[N] => T[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Vector call(Vector input) { - return input; - } -}; - -// T[1] => T[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::fill(vector_get<0>(input)); - } -}; - -// T[1] => T[1], but different vector types -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::create(vector_get<0>(input)); - } -}; - -// T[N] => T[N], but different vector types -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return cast_helper::call(input); - } -}; - -// T[1] => R[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::fill(ops::cast {}(vector_get<0>(input))); - } -}; - -// T[1] => R[1] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::create(ops::cast {}(vector_get<0>(input))); - } -}; - -// T[N] => R[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return cast_helper::call(input); - } -}; -} // namespace detail - -/** - * Cast the elements of the given vector ``input`` to the given type ``R`` and then widen the - * vector to length ``N``. The cast may lead to a loss in precision if ``R`` is a smaller data - * type. Widening is only possible if the input vector has size ``1`` or ``N``, other sizes - * will lead to a compilation error. - * - * Example - * ======= - * ``` - * vec x = {6}; - * vec y = broadcast(x); - * vec z = broadcast(y); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector broadcast(Input&& input) { - return detail::broadcast_helper, Output>::call( - into_storage(std::forward(input))); -} - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template< - size_t N, - typename Input, - typename Output = default_storage_type, N>> -KERNEL_FLOAT_INLINE vector broadcast(Input&& input) { - return detail::broadcast_helper, Output>::call( - into_storage(std::forward(input))); -} - -template -KERNEL_FLOAT_INLINE vector broadcast(Input&& input) { - return detail::broadcast_helper, Output>::call( - into_storage(std::forward(input))); -} -#endif - -/** - * Widen the given vector ``input`` to length ``N``. Widening is only possible if the input vector - * has size ``1`` or ``N``, other sizes will lead to a compilation error. - * - * Example - * ======= - * ``` - * vec x = {6}; - * vec y = resize<3>(x); - * ``` - */ -template< - size_t N, - typename Input, - typename Output = default_storage_type, N>> -KERNEL_FLOAT_INLINE vector resize(Input&& input) noexcept { - return detail::broadcast_helper::call(std::forward(input)); -} - -template -using cast_type = default_storage_type>; - -/** - * Cast the elements of given vector ``input`` to the given type ``R``. Note that this cast may - * lead to a loss in precision if ``R`` is a smaller data type. - * - * Example - * ======= - * ``` - * vec x = {1.0f, 2.0f, 3.0f}; - * vec y = cast(x); - * vec z = cast(x); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector cast(Input&& input) noexcept { - return detail::broadcast_helper::call(std::forward(input)); -} -} // namespace kernel_float - -#endif //KERNEL_FLOAT_CAST_H diff --git a/include/kernel_float/complex.h b/include/kernel_float/complex.h new file mode 100644 index 0000000..aa133e3 --- /dev/null +++ b/include/kernel_float/complex.h @@ -0,0 +1,249 @@ +#ifndef KERNEL_FLOAT_COMPLEX_TYPE_H +#define KERNEL_FLOAT_COMPLEX_TYPE_H + +#include "macros.h" +#include "meta.h" + +namespace kernel_float { + +template +struct alignas(2 * alignof(T)) complex_type_storage { + T re; + T im; +}; + +template +struct complex_type: complex_type_storage { + using base_type = complex_type_storage; + + template + KERNEL_FLOAT_INLINE complex_type(complex_type that) : base_type(that.real(), that.imag()) {} + + KERNEL_FLOAT_INLINE + complex_type(T real = {}, T imag = {}) : base_type(real, imag) {} + + KERNEL_FLOAT_INLINE + T real() const { + return this->re; + } + + KERNEL_FLOAT_INLINE + T imag() const { + return this->im; + } + + KERNEL_FLOAT_INLINE + T norm() const { + return real() * real() + imag() * imag(); + } + + KERNEL_FLOAT_INLINE + complex_type conj() const { + return {real(), -imag()}; + } +}; + +template +KERNEL_FLOAT_INLINE complex_type operator+(complex_type v) { + return v; +} + +template +KERNEL_FLOAT_INLINE complex_type operator+(complex_type a, complex_type b) { + return {a.real() + b.real(), a.imag() + b.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator+(T a, complex_type b) { + return {a + b.real(), b.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator+(complex_type a, T b) { + return {a.real() + b, a.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type& operator+=(complex_type& a, complex_type b) { + return (a = a + b); +} + +template +KERNEL_FLOAT_INLINE complex_type& operator+=(complex_type& a, T b) { + return (a = a + b); +} + +template +KERNEL_FLOAT_INLINE complex_type operator-(complex_type v) { + return {-v.real(), -v.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator-(complex_type a, complex_type b) { + return {a.real() - b.real(), a.imag() - b.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator-(T a, complex_type b) { + return {a - b.real(), -b.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator-(complex_type a, T b) { + return {a.real() - b, a.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type& operator-=(complex_type& a, complex_type b) { + return (a = a - b); +} + +template +KERNEL_FLOAT_INLINE complex_type& operator-=(complex_type& a, T b) { + return (a = a - b); +} + +template +KERNEL_FLOAT_INLINE complex_type operator*(complex_type a, complex_type b) { + return {a.real() * b.real() - a.imag() * b.imag(), a.real() * b.imag() + a.imag() * b.real()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator*(complex_type a, T b) { + return {a.real() * b, a.imag() * b}; +} + +template +KERNEL_FLOAT_INLINE complex_type* operator*=(complex_type& a, complex_type b) { + return (a = a * b); +} + +template +KERNEL_FLOAT_INLINE complex_type& operator*=(complex_type& a, T b) { + return (a = a * b); +} + +template +KERNEL_FLOAT_INLINE complex_type operator*(T a, complex_type b) { + return {a * b.real(), a * b.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator/(complex_type a, complex_type b) { + T normi = T(1) / b.norm(); + + return { + (a.real() * b.real() + a.imag() * b.imag()) * normi, + (a.imag() * b.real() - a.real() * b.imag()) * normi}; +} + +template +KERNEL_FLOAT_INLINE complex_type operator/(complex_type a, T b) { + return a * (T(1) / b); +} + +template +KERNEL_FLOAT_INLINE complex_type operator/(T a, complex_type b) { + T normi = T(1) / b.norm(); + + return {a * b.real() * normi, -a * b.imag() * normi}; +} + +template +KERNEL_FLOAT_INLINE complex_type* operator/=(complex_type& a, complex_type b) { + return (a = a / b); +} + +template +KERNEL_FLOAT_INLINE complex_type& operator/=(complex_type& a, T b) { + return (a = a / b); +} + +template +KERNEL_FLOAT_INLINE T real(complex_type v) { + return v.real(); +} + +template +KERNEL_FLOAT_INLINE T imag(complex_type v) { + return v.imag(); +} + +template +KERNEL_FLOAT_INLINE T abs(complex_type v) { + return hypot(v.real(), v.imag()); +} + +template +KERNEL_FLOAT_INLINE T arg(complex_type v) { + return atan2(v.imag(), v.real()); +} + +template +KERNEL_FLOAT_INLINE complex_type sqrt(complex_type v) { + T radius = abs(v); + T cosA = v.real() / radius; + + complex_type out = { + sqrt(radius * (cosA + T(1)) * T(.5)), + sqrt(radius * (T(1) - cosA) * T(.5))}; + + // signbit should be false if x.y is negative + if (v.imag() < 0) { + out = complex_type {out.real, -out.im}; + } + + return out; +} + +template +KERNEL_FLOAT_INLINE complex_type norm(complex_type v) { + return v.real() * v.real() + v.imag() * v.imag(); +} + +template +KERNEL_FLOAT_INLINE complex_type conj(complex_type v) { + return {v.real(), -v.imag()}; +} + +template +KERNEL_FLOAT_INLINE complex_type exp(complex_type v) { + // TODO: Handle nan and inf correctly + T e = exp(v.real()); + T a = v.imag(); + return complex_type(e * cos(a), e * sin(a)); +} + +template +KERNEL_FLOAT_INLINE complex_type log(complex_type v) { + return {log(abs(v)), arg(v)}; +} + +template +KERNEL_FLOAT_INLINE complex_type pow(complex_type a, T b) { + return exp(a * log(b)); +} + +template +KERNEL_FLOAT_INLINE complex_type pow(complex_type a, complex_type b) { + return exp(a * log(b)); +} + +template +struct promote_type, complex_type> { + using type = complex_type>; +}; + +template +struct promote_type, R> { + using type = complex_type>; +}; + +template +struct promote_type> { + using type = complex_type>; +}; + +} // namespace kernel_float + +#endif diff --git a/include/kernel_float/constant.h b/include/kernel_float/constant.h new file mode 100644 index 0000000..1b98925 --- /dev/null +++ b/include/kernel_float/constant.h @@ -0,0 +1,116 @@ +#ifndef KERNEL_FLOAT_CONSTANT +#define KERNEL_FLOAT_CONSTANT + +#include "base.h" +#include "conversion.h" + +namespace kernel_float { + +template +struct constant { + template + KERNEL_FLOAT_INLINE explicit constexpr constant(const constant& that) { + auto f = ops::cast(); + value_ = f(that.get()); + } + + KERNEL_FLOAT_INLINE + constexpr constant(T value = {}) : value_(value) {} + + KERNEL_FLOAT_INLINE + constexpr T get() const { + return value_; + } + + KERNEL_FLOAT_INLINE + constexpr operator T() const { + return value_; + } + + private: + T value_; +}; + +// Deduction guide for `constant` +#if defined(__cpp_deduction_guides) +template +constant(T&&) -> constant>; +#endif + +template +KERNEL_FLOAT_INLINE constexpr constant make_constant(T value) { + return value; +} + +template +struct promote_type, constant> { + using type = constant::type>; +}; + +template +struct promote_type, R> { + using type = typename promote_type::type; +}; + +template +struct promote_type> { + using type = typename promote_type::type; +}; + +namespace ops { +template +struct cast, R> { + KERNEL_FLOAT_INLINE R operator()(const T& input) noexcept { + return cast {}(input); + } +}; + +template +struct cast, R, m> { + KERNEL_FLOAT_INLINE R operator()(const T& input) noexcept { + return cast {}(input); + } +}; +} // namespace ops + +#define KERNEL_FLOAT_CONSTANT_DEFINE_OP(OP) \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const constant& left, const R& right) { \ + auto f = ops::cast>(); \ + return f(left.get()) OP right; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const L& left, const constant& right) { \ + auto f = ops::cast>(); \ + return left OP f(right.get()); \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const constant& left, const vector& right) { \ + auto f = ops::cast(); \ + return f(left.get()) OP right; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const vector& left, const constant& right) { \ + auto f = ops::cast(); \ + return left OP f(right.get()); \ + } \ + \ + template> \ + KERNEL_FLOAT_INLINE constant operator OP( \ + const constant& left, \ + const constant& right) { \ + return constant(left.get()) OP constant(right.get()); \ + } + +KERNEL_FLOAT_CONSTANT_DEFINE_OP(+) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(-) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(*) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(/) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(%) + +} // namespace kernel_float + +#endif diff --git a/include/kernel_float/conversion.h b/include/kernel_float/conversion.h new file mode 100644 index 0000000..2e6a454 --- /dev/null +++ b/include/kernel_float/conversion.h @@ -0,0 +1,338 @@ +#ifndef KERNEL_FLOAT_CAST_H +#define KERNEL_FLOAT_CAST_H + +#include "base.h" +#include "unops.h" + +namespace kernel_float { + +enum struct RoundingMode { ANY, DOWN, UP, NEAREST, TOWARD_ZERO }; + +namespace ops { +template +struct cast; + +template +struct cast { + KERNEL_FLOAT_INLINE R operator()(T input) noexcept { + return R(input); + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(T input) noexcept { + return input; + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(T input) noexcept { + return input; + } +}; +} // namespace ops + +/** + * Cast the elements of the given vector `input` to a different type `R`. + * + * This function casts each element of the input vector to a different data type specified by + * template parameter `R`. + * + * Optionally, the rounding mode can be set using the `Mode` template parameter. The default mode is `ANY`, which + * uses the fastest rounding mode available. + * + * Example + * ======= + * ``` + * vec input {1.2f, 2.7f, 3.5f, 4.9f}; + * auto casted = cast(input); // [1, 2, 3, 4] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> cast(const V& input) { + using F = ops::cast, R, Mode>; + return map(F {}, input); +} + +namespace detail { + +template +struct broadcast_extent_helper; + +template +struct broadcast_extent_helper { + using type = E; +}; + +template +struct broadcast_extent_helper, extent> { + using type = extent; +}; + +template +struct broadcast_extent_helper, extent> { + using type = extent; +}; + +template +struct broadcast_extent_helper, extent<1>> { + using type = extent; +}; + +template<> +struct broadcast_extent_helper, extent<1>> { + using type = extent<1>; +}; + +template +struct broadcast_extent_helper: + broadcast_extent_helper::type, C, Rest...> {}; + +} // namespace detail + +template +using broadcast_extent = typename detail::broadcast_extent_helper::type; + +template +using broadcast_vector_extent_type = broadcast_extent...>; + +template +static constexpr bool is_broadcastable = is_same_type, To>; + +template +static constexpr bool is_vector_broadcastable = is_broadcastable, To>; + +namespace detail { + +template +struct broadcast_impl; + +template +struct broadcast_impl, extent> { + KERNEL_FLOAT_INLINE static vector_storage call(const vector_storage& input) { + vector_storage output; + for (size_t i = 0; i < N; i++) { + output.data()[i] = input.data()[0]; + } + return output; + } +}; + +template +struct broadcast_impl, extent> { + KERNEL_FLOAT_INLINE static vector_storage call(vector_storage input) { + return input; + } +}; + +template +struct broadcast_impl, extent<1>> { + KERNEL_FLOAT_INLINE static vector_storage call(vector_storage input) { + return input; + } +}; + +} // namespace detail + +/** + * Takes the given vector `input` and extends its size to a length of `N`. This is only valid if the size of `input` + * is 1 or `N`. + * + * Example + * ======= + * ``` + * vec a = {1.0f}; + * vec x = broadcast<5>(a); // Returns [1.0f, 1.0f, 1.0f, 1.0f, 1.0f] + * + * vec b = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + * vec y = broadcast<5>(b); // Returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector, extent> +broadcast(const V& input, extent new_size = {}) { + using T = vector_value_type; + return detail::broadcast_impl, extent>::call( + into_vector_storage(input)); +} + +/** + * Takes the given vector `input` and extends its size to the same length as vector `other`. This is only valid if the + * size of `input` is 1 or the same as `other`. + */ +template +KERNEL_FLOAT_INLINE vector, vector_extent_type> +broadcast_like(const V& input, const R& other) { + return broadcast(input, vector_extent_type {}); +} + +namespace detail { +/** + * Convert vector of element type `T` and extent type `E` to vector of element type `T2` and extent type `E2`. + * Specialization exist for the cases where `T==T2` and/or `E==E2`. + */ +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + using F = ops::cast; + vector_storage intermediate; + detail::apply_impl::call(F {}, intermediate.data(), input.data()); + return detail::broadcast_impl::call(intermediate); + } +}; + +// T == T2, E == E2 +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + return input; + } +}; + +// T == T2, E != E2 +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + return detail::broadcast_impl::call(input); + } +}; + +// T != T2, E == E2 +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + using F = ops::cast; + + vector_storage result; + detail::apply_impl::call(F {}, result.data(), input.data()); + return result; + } +}; +} // namespace detail + +template +KERNEL_FLOAT_INLINE vector_storage convert_storage(const V& input, extent new_size = {}) { + return detail::convert_impl, vector_extent_type, R, extent, M>::call( + into_vector_storage(input)); +} + +/** + * Cast the values of the given input vector to type `R` and then broadcast the result to the given size `N`. + * + * Example + * ======= + * ``` + * int a = 5; + * vec x = convert(a); // returns [5.0f, 5.0f, 5.0f] + * + * float b = 5.0f; + * vec x = convert(b); // returns [5.0f, 5.0f, 5.0f] + * + * vec c = {1, 2, 3}; + * vec x = convert(c); // returns [1.0f, 2.0f, 3.0f] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> convert(const V& input, extent new_size = {}) { + return convert_storage(input); +} + +/** + * Returns a vector containing `N` copies of `value`. + * + * Example + * ======= + * ``` + * vec a = fill<3>(42); // return [42, 42, 42] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> fill(T value = {}, extent = {}) { + vector_storage input = {value}; + return detail::broadcast_impl, extent>::call(input); +} + +/** + * Returns a vector containing `N` copies of `T(0)`. + * + * Example + * ======= + * ``` + * vec a = zeros(); // return [0, 0, 0] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> zeros(extent = {}) { + vector_storage input = {T {}}; + return detail::broadcast_impl, extent>::call(input); +} + +/** + * Returns a vector containing `N` copies of `T(1)`. + * + * Example + * ======= + * ``` + * vec a = ones(); // return [1, 1, 1] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> ones(extent = {}) { + vector_storage input = {T {1}}; + return detail::broadcast_impl, extent>::call(input); +} + +/** + * Returns a vector filled with `value` having the same type and size as input vector `V`. + * + * Example + * ======= + * ``` + * vec a = {1, 2, 3}; + * vec b = fill_like(a, 42); // return [42, 42, 42] + * ``` + */ +template, typename E = vector_extent_type> +KERNEL_FLOAT_INLINE vector fill_like(const V&, T value) { + return fill(value, E {}); +} + +/** + * Returns a vector filled with zeros having the same type and size as input vector `V`. + * + * Example + * ======= + * ``` + * vec a = {1, 2, 3}; + * vec b = zeros_like(a); // return [0, 0, 0] + * ``` + */ +template, typename E = vector_extent_type> +KERNEL_FLOAT_INLINE vector zeros_like(const V& = {}) { + return zeros(E {}); +} + +/** + * Returns a vector filled with ones having the same type and size as input vector `V`. + * + * Example + * ======= + * ``` + * vec a = {1, 2, 3}; + * vec b = ones_like(a); // return [1, 1, 1] + * ``` + */ +template, typename E = vector_extent_type> +KERNEL_FLOAT_INLINE vector ones_like(const V& = {}) { + return ones(E {}); +} + +} // namespace kernel_float + +#endif diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index d95edce..41330bb 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -6,74 +6,122 @@ #if KERNEL_FLOAT_FP16_AVAILABLE #include -#include "interface.h" +#include "vector.h" namespace kernel_float { -KERNEL_FLOAT_DEFINE_COMMON_TYPE(__half, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, __half) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half) template<> -struct vector_traits<__half2> { +struct into_vector_impl<__half2> { using value_type = __half; - static constexpr size_t size = 2; + using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static __half2 fill(__half value) { -#if KERNEL_FLOAT_ON_DEVICE - return __half2half2(value); -#else - return {value, value}; -#endif + static vector_storage<__half, 2> call(__half2 input) { + return {input.x, input.y}; } +}; +namespace detail { +template +struct map_halfx2 { KERNEL_FLOAT_INLINE - static __half2 create(__half low, __half high) { -#if KERNEL_FLOAT_ON_DEVICE - return __halves2half2(low, high); -#else - return {low, high}; -#endif + static __half2 call(F fun, __half2 input) { + __half a = fun(input.x); + __half b = fun(input.y); + return {a, b}; } +}; +template +struct zip_halfx2 { KERNEL_FLOAT_INLINE - static __half get(__half2 self, size_t index) { -#if KERNEL_FLOAT_ON_DEVICE - if (index == 0) { - return __low2half(self); - } else { - return __high2half(self); + static __half2 call(F fun, __half2 left, __half2 right) { + __half a = fun(left.x, left.y); + __half b = fun(right.y, right.y); + return {a, b}; + } +}; + +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, __half* result, const __half* input) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __half2 a = {input[2 * i], input[2 * i + 1]}; + __half2 b = map_halfx2::call(fun, a); + result[2 * i + 0] = b.x; + result[2 * i + 1] = b.y; } -#else - if (index == 0) { - return self.x; - } else { - return self.y; + + if (N % 2 != 0) { + result[N - 1] = fun(input[N - 1]); } -#endif } +}; - KERNEL_FLOAT_INLINE - static void set(__half2& self, size_t index, __half value) { - if (index == 0) { - self.x = value; - } else { - self.y = value; +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void + call(F fun, __half* result, const __half* left, const __half* right) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __half2 a = {left[2 * i], left[2 * i + 1]}; + __half2 b = {right[2 * i], right[2 * i + 1]}; + __half2 c = zip_halfx2::call(fun, a, b); + result[2 * i + 0] = c.x; + result[2 * i + 1] = c.y; + } + + if (N % 2 != 0) { + result[N - 1] = fun(left[N - 1], right[N - 1]); } } }; -template -struct default_storage<__half, N, Alignment::Maximum, enabled_t<(N >= 2)>> { - using type = nested_array<__half2, N>; -}; +template +struct reduce_impl= 2)>> { + KERNEL_FLOAT_INLINE static __half call(F fun, const __half* input) { + __half2 accum = {input[0], input[1]}; -template -struct default_storage<__half, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 == 0)>> { - using type = nested_array<__half2, N>; +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __half2 a = {input[2 * i], input[2 * i + 1]}; + accum = zip_halfx2::call(fun, accum, a); + } + + __half result = fun(accum.x, accum.y); + + if (N % 2 != 0) { + result = fun(result, input[N - 1]); + } + + return result; + } }; -#if KERNEL_FLOAT_ON_DEVICE +}; // namespace detail + +#define KERNEL_FLOAT_FP16_UNARY_FORWARD(NAME) \ + namespace ops { \ + template<> \ + struct NAME<__half> { \ + KERNEL_FLOAT_INLINE __half operator()(__half input) { \ + return __half(ops::NAME {}(float(input))); \ + } \ + }; \ + } + +// There operations are not implemented in half precision, so they are forward to single precision +KERNEL_FLOAT_FP16_UNARY_FORWARD(tan) +KERNEL_FLOAT_FP16_UNARY_FORWARD(asin) +KERNEL_FLOAT_FP16_UNARY_FORWARD(acos) +KERNEL_FLOAT_FP16_UNARY_FORWARD(atan) +KERNEL_FLOAT_FP16_UNARY_FORWARD(expm1) + +#if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -85,28 +133,37 @@ struct default_storage<__half, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 } \ namespace detail { \ template<> \ - struct map_helper, __half2, __half2> { \ + struct map_halfx2> { \ KERNEL_FLOAT_INLINE static __half2 call(ops::NAME<__half>, __half2 input) { \ return FUN2(input); \ } \ }; \ } +#else +#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) KERNEL_FLOAT_FP16_UNARY_FORWARD(NAME) +#endif + +KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos) +KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin) +KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) -KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2); -KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2); -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil); -KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos); -KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp); -KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10); -KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor); -KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log); -KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2); -KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint); -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt); -KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin); -KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt); -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc); +KERNEL_FLOAT_FP16_UNARY_FUN(fast_exp, ::hexp, ::h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(fast_log, ::hlog, ::h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(fast_cos, ::hcos, ::h2cos) +KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin) +#if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -118,12 +175,23 @@ KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc); } \ namespace detail { \ template<> \ - struct zip_helper, __half2, __half2, __half2> { \ + struct zip_halfx2> { \ KERNEL_FLOAT_INLINE static __half2 call(ops::NAME<__half>, __half2 left, __half2 right) { \ return FUN2(left, right); \ } \ }; \ } +#else +#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__half> { \ + KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ + return __half(ops::NAME {}(float(left), float(right))); \ + } \ + }; \ + } +#endif KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2) @@ -131,6 +199,7 @@ KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2) KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div) KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) +KERNEL_FLOAT_FP16_BINARY_FUN(fast_div, __hdiv, __h2div) KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __heq, __heq2) @@ -139,8 +208,6 @@ KERNEL_FLOAT_FP16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_FP16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2) -#endif - #define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ @@ -165,22 +232,69 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input)); +KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input)); KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input))); KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input)); +KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); using half = __half; -using float16 = __half; -//KERNEL_FLOAT_TYPE_ALIAS(half, __half) //KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) //KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +#if KERNEL_FLOAT_IS_DEVICE +namespace detail { +template<> +struct dot_impl<__half, 0> { + KERNEL_FLOAT_INLINE + static __half call(const __half* left, const __half* right) { + return __half(0); + } +}; + +template<> +struct dot_impl<__half, 1> { + KERNEL_FLOAT_INLINE + static __half call(const __half* left, const __half* right) { + return __hmul(left[0], right[0]); + } +}; + +template +struct dot_impl<__half, N> { + static_assert(N >= 2, "internal error"); + + KERNEL_FLOAT_INLINE + static __half call(const __half* left, const __half* right) { + __half2 first_a = {left[0], left[1]}; + __half2 first_b = {right[0], right[1]}; + __half2 accum = __hmul2(first_a, first_b); + +#pragma unroll + for (size_t i = 2; i + 2 <= N; i += 2) { + __half2 a = {left[i], left[i + 1]}; + __half2 b = {right[i], right[i + 1]}; + accum = __hfma2(a, b, accum); + } + + __half result = __hadd(accum.x, accum.y); + + if (N % 2 != 0) { + __half a = left[N - 1]; + __half b = right[N - 1]; + result = __hfma(a, b, result); + } + + return result; + } +}; +} // namespace detail +#endif + } // namespace kernel_float #endif diff --git a/include/kernel_float/fp8.h b/include/kernel_float/fp8.h deleted file mode 100644 index e69de29..0000000 diff --git a/include/kernel_float/interface.h b/include/kernel_float/interface.h deleted file mode 100644 index d191b65..0000000 --- a/include/kernel_float/interface.h +++ /dev/null @@ -1,294 +0,0 @@ -#ifndef KERNEL_FLOAT_INTERFACE_H -#define KERNEL_FLOAT_INTERFACE_H - -#include "binops.h" -#include "iterate.h" -#include "reduce.h" -#include "storage.h" -#include "swizzle.h" -#include "unops.h" - -namespace kernel_float { - -template -KERNEL_FLOAT_INLINE vector broadcast(Input&& input); - -template -struct index_proxy { - using value_type = typename vector_traits::value_type; - - KERNEL_FLOAT_INLINE - index_proxy(V& storage, I index) : storage_(storage), index_(index) {} - - KERNEL_FLOAT_INLINE - index_proxy& operator=(value_type value) { - vector_traits::set(storage_, index_, value); - return *this; - } - - KERNEL_FLOAT_INLINE - operator value_type() const { - return vector_traits::get(storage_, index_); - } - - private: - V& storage_; - I index_; -}; - -template -struct index_proxy> { - using value_type = typename vector_traits::value_type; - - KERNEL_FLOAT_INLINE - index_proxy(V& storage, const_index) : storage_(storage) {} - - KERNEL_FLOAT_INLINE - index_proxy& operator=(value_type value) { - vector_index::set(storage_, value); - return *this; - } - - KERNEL_FLOAT_INLINE - operator value_type() const { - return vector_index::get(storage_); - } - - private: - V& storage_; -}; - -template -struct vector { - using storage_type = V; - using traits_type = vector_traits; - using value_type = typename traits_type::value_type; - static constexpr size_t const_size = traits_type::size; - - vector(const vector&) = default; - vector(vector&) = default; - vector(vector&&) = default; - - vector& operator=(const vector&) = default; - vector& operator=(vector&) = default; - vector& operator=(vector&&) = default; - - KERNEL_FLOAT_INLINE - vector() : storage_(traits_type::fill(value_type {})) {} - - KERNEL_FLOAT_INLINE - vector(storage_type storage) : storage_(storage) {} - - template< - typename U, - enabled_t, value_type>, int> = 0> - KERNEL_FLOAT_INLINE vector(U&& init) : vector(broadcast(std::forward(init))) {} - - template = 0> - KERNEL_FLOAT_INLINE vector(Args&&... args) : storage_(traits_type::create(args...)) {} - - KERNEL_FLOAT_INLINE - operator storage_type() const { - return storage_; - } - - KERNEL_FLOAT_INLINE - storage_type& storage() { - return storage_; - } - - KERNEL_FLOAT_INLINE - const storage_type& storage() const { - return storage_; - } - - KERNEL_FLOAT_INLINE - value_type get(size_t index) const { - return traits_type::get(storage_, index); - } - - KERNEL_FLOAT_INLINE - void set(size_t index, value_type value) { - traits_type::set(storage_, index, value); - } - - template - KERNEL_FLOAT_INLINE value_type get(const_index) const { - return vector_index::get(storage_); - } - - template - KERNEL_FLOAT_INLINE void set(const_index, value_type value) { - return vector_index::set(storage_, value); - } - - KERNEL_FLOAT_INLINE - value_type operator[](size_t index) const { - return get(index); - } - - template - KERNEL_FLOAT_INLINE value_type operator[](const_index) const { - return get(const_index {}); - } - - KERNEL_FLOAT_INLINE - index_proxy operator[](size_t index) { - return {storage_, index}; - } - - template - KERNEL_FLOAT_INLINE index_proxy> operator[](const_index) { - return {storage_, const_index {}}; - } - - KERNEL_FLOAT_INLINE - static constexpr size_t size() { - return const_size; - } - - private: - storage_type storage_; -}; - -template -struct vector_traits> { - using value_type = vector_value_type; - static constexpr size_t size = vector_size; - - KERNEL_FLOAT_INLINE - static vector fill(value_type value) { - return vector_traits::fill(value); - } - - template - KERNEL_FLOAT_INLINE static vector create(Args... args) { - return vector_traits::create(args...); - } - - KERNEL_FLOAT_INLINE - static value_type get(const vector& self, size_t index) { - return vector_traits::get(self.storage(), index); - } - - KERNEL_FLOAT_INLINE - static void set(vector& self, size_t index, value_type value) { - vector_traits::set(self.storage(), index, value); - } -}; - -template -struct vector_index, I> { - using value_type = vector_value_type; - - KERNEL_FLOAT_INLINE - static value_type get(const vector& self) { - return vector_index::get(self.storage()); - } - - KERNEL_FLOAT_INLINE - static void set(vector& self, value_type value) { - vector_index::set(self.storage(), value); - } -}; - -template -struct into_storage_traits> { - using type = V; - - KERNEL_FLOAT_INLINE - static constexpr type call(const vector& self) { - return self.storage(); - } -}; - -template -struct vector_swizzle, index_sequence> { - KERNEL_FLOAT_INLINE static Output call(const vector& self) { - return vector_swizzle>::call(self.storage()); - } -}; - -template -using vec = vector>; - -template -using unaligned_vec = vector>; - -template -KERNEL_FLOAT_INLINE vec, sizeof...(Args)> make_vec(Args&&... args) { - using value_type = common_t; - using vector_type = default_storage_type; - return vector_traits::create(value_type(args)...); -} - -template -KERNEL_FLOAT_INLINE vector> into_vec(V&& input) { - return into_storage(input); -} - -using float32 = float; -using float64 = double; - -template -using vec1 = vec; -template -using vec2 = vec; -template -using vec3 = vec; -template -using vec4 = vec; -template -using vec5 = vec; -template -using vec6 = vec; -template -using vec7 = vec; -template -using vec8 = vec; - -#define KERNEL_FLOAT_TYPE_ALIAS(NAME, T) \ - template \ - using NAME##N = vec; \ - using NAME##1 = vec; \ - using NAME##2 = vec; \ - using NAME##3 = vec; \ - using NAME##4 = vec; \ - using NAME##5 = vec; \ - using NAME##6 = vec; \ - using NAME##7 = vec; \ - using NAME##8 = vec; \ - template \ - using unaligned_##NAME##X = unaligned_vec; \ - using unaligned_##NAME##1 = unaligned_vec; \ - using unaligned_##NAME##2 = unaligned_vec; \ - using unaligned_##NAME##3 = unaligned_vec; \ - using unaligned_##NAME##4 = unaligned_vec; \ - using unaligned_##NAME##5 = unaligned_vec; \ - using unaligned_##NAME##6 = unaligned_vec; \ - using unaligned_##NAME##7 = unaligned_vec; \ - using unaligned_##NAME##8 = unaligned_vec; - -KERNEL_FLOAT_TYPE_ALIAS(char, char) -KERNEL_FLOAT_TYPE_ALIAS(short, short) -KERNEL_FLOAT_TYPE_ALIAS(int, int) -KERNEL_FLOAT_TYPE_ALIAS(long, long) -KERNEL_FLOAT_TYPE_ALIAS(longlong, long long) - -KERNEL_FLOAT_TYPE_ALIAS(uchar, unsigned char) -KERNEL_FLOAT_TYPE_ALIAS(ushort, unsigned short) -KERNEL_FLOAT_TYPE_ALIAS(uint, unsigned int) -KERNEL_FLOAT_TYPE_ALIAS(ulong, unsigned long) -KERNEL_FLOAT_TYPE_ALIAS(ulonglong, unsigned long long) - -KERNEL_FLOAT_TYPE_ALIAS(float, float) -KERNEL_FLOAT_TYPE_ALIAS(f32x, float) -KERNEL_FLOAT_TYPE_ALIAS(float32x, float) - -KERNEL_FLOAT_TYPE_ALIAS(double, double) -KERNEL_FLOAT_TYPE_ALIAS(f64x, double) -KERNEL_FLOAT_TYPE_ALIAS(float64x, double) - -} // namespace kernel_float - -#endif //KERNEL_FLOAT_INTERFACE_H diff --git a/include/kernel_float/iterate.h b/include/kernel_float/iterate.h index 2b98194..68c1645 100644 --- a/include/kernel_float/iterate.h +++ b/include/kernel_float/iterate.h @@ -1,175 +1,303 @@ #ifndef KERNEL_FLOAT_ITERATE_H #define KERNEL_FLOAT_ITERATE_H -#include "storage.h" -#include "unops.h" +#include "base.h" +#include "conversion.h" namespace kernel_float { -namespace detail { -template>> -struct range_helper; - -template -struct range_helper> { - KERNEL_FLOAT_INLINE static V call(F fun) { - return vector_traits::create(fun(const_index {})...); - } -}; -} // namespace detail - /** - * Generate vector of length ``N`` by applying the given function ``fun`` to - * each index ``0...N-1``. + * Apply the function fun for each element from input. * * Example * ======= * ``` - * // returns [0, 2, 4] - * vector vec = range<3>([](auto i) { return float(i * 2); }); + * for_each(range(), [&](auto i) { + * printf("element: %d\n", i); + * }); * ``` */ -template< - size_t N, - typename F, - typename T = result_t, - typename Output = default_storage_type> -KERNEL_FLOAT_INLINE vector range(F fun) { - return detail::range_helper::call(fun); +template +void for_each(V&& input, F fun) { + auto storage = into_vector_storage(input); + +#pragma unroll + for (size_t i = 0; i < vector_extent; i++) { + fun(storage.data()[i]); + } } +namespace detail { +template +struct range_impl { + KERNEL_FLOAT_INLINE + static vector_storage call() { + vector_storage result; + +#pragma unroll + for (size_t i = 0; i < N; i++) { + result.data()[i] = T(i); + } + + return result; + } +}; +} // namespace detail + /** - * Generate vector consisting of the numbers ``0...N-1`` of type ``T``. + * Generate vector consisting of the numbers `0...N-1` of type `T` * * Example * ======= * ``` * // Returns [0, 1, 2] - * vector vec = range(); + * vec vec = range(); * ``` */ -template> -KERNEL_FLOAT_INLINE vector range() { - using F = ops::cast; - return detail::range_helper::call(F {}); -} - -/** - * Generate vector having same size and type as ``V``, but filled with the numbers ``0..N-1``. - */ -template> -KERNEL_FLOAT_INLINE vector range_like(const Input&) { - using F = ops::cast>; - return detail::range_helper::call(F {}); +template +KERNEL_FLOAT_INLINE vector> range() { + return detail::range_impl::call(); } /** - * Generate vector of `N` elements of type `T` + * Takes a vector `vec` and returns a new vector consisting of the numbers ``0...N-1`` of type ``T`` * * Example * ======= * ``` - * // Returns [1.0, 1.0, 1.0] - * vector = fill(1.0f); + * auto input = vec(5.0f, 10.0f, -1.0f); + * auto indices = range_like(input); // returns [0.0f, 1.0f, 2.0f] * ``` */ -template> -KERNEL_FLOAT_INLINE vector fill(T value) { - return vector_traits::fill(value); -} - -/** - * Generate vector having same size and type as ``V``, but filled with the given ``value``. - */ -template -KERNEL_FLOAT_INLINE vector fill_like(const Output&, vector_value_type value) { - return vector_traits::fill(value); +template +KERNEL_FLOAT_INLINE into_vector_type range_like(const V& = {}) { + return detail::range_impl, vector_extent>::call(); } /** - * Generate vector of ``N`` zeros of type ``T`` + * Takes a vector of size ``N`` and returns a new vector consisting of the numbers ``0...N-1``. The data type used + * for the indices is given by the first template argument, which is `size_t` by default. This function is useful when + * needing to iterate over the indices of a vector. * * Example * ======= * ``` - * // Returns [0.0, 0.0, 0.0] - * vector = zeros(); + * // Returns [0, 1, 2] of type size_t + * vec a = each_index(float3(6, 4, 2)); + * + * // Returns [0, 1, 2] of type int. + * vec b = each_index(float3(6, 4, 2)); + * + * vec input = {1.0f, 2.0f, 3.0f, 4.0f}; + * for (auto index: each_index(input)) { + * printf("%d] %f\n", index, input[index]); + * } * ``` */ -template> -KERNEL_FLOAT_INLINE vector zeros() { - return vector_traits::fill(T(0)); +template +KERNEL_FLOAT_INLINE vector> each_index(const V& = {}) { + return detail::range_impl>::call(); } -/** - * Generate vector having same size and type as ``V``, but filled with zeros. - * - */ -template -KERNEL_FLOAT_INLINE vector zeros_like(const Output& output = {}) { - return vector_traits::fill(0); -} +namespace detail { +template, size_t N = vector_extent> +struct flatten_impl { + using value_type = typename flatten_impl::value_type; + static constexpr size_t size = N * flatten_impl::size; + + template + KERNEL_FLOAT_INLINE static void call(U* output, const V& input) { + vector_storage storage = into_vector_storage(input); + +#pragma unroll + for (size_t i = 0; i < N; i++) { + flatten_impl::call(output + flatten_impl::size * i, storage.data()[i]); + } + } +}; + +template +struct flatten_impl { + using value_type = T; + static constexpr size_t size = 1; + + KERNEL_FLOAT_INLINE + static void call(T* output, const T& input) { + *output = input; + } + + template + KERNEL_FLOAT_INLINE static void call(U* output, const T& input) { + *output = ops::cast {}(input); + } +}; +} // namespace detail + +template +using flatten_value_type = typename detail::flatten_impl::value_type; + +template +static constexpr size_t flatten_size = detail::flatten_impl::size; + +template +using flatten_type = vector, extent>>; /** - * Generate vector of ``N`` ones of type ``T`` + * Flattens the elements of this vector. For example, this turns a `vec, 3>` into a `vec`. * * Example * ======= * ``` - * // Returns [1.0, 1.0, 1.0] - * vector = ones(); + * vec input = {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}; + * vec result = flatten(input); // returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f] * ``` */ -template> -KERNEL_FLOAT_INLINE vector ones() { - return vector_traits::fill(T(1)); -} - -/** - * Generate vector having same size and type as ``V``, but filled with ones. - * - */ -template -KERNEL_FLOAT_INLINE vector ones_like(const Output& output = {}) { - return vector_traits::fill(1); +template +KERNEL_FLOAT_INLINE flatten_type flatten(const V& input) { + vector_storage, flatten_size> output; + detail::flatten_impl::call(output.data(), input); + return output; } namespace detail { -template>> -struct iterate_helper; +template> +struct concat_base_impl { + static constexpr size_t size = vector_extent; -template -struct iterate_helper> { - KERNEL_FLOAT_INLINE - static void call(F fun, const V& input) {} + KERNEL_FLOAT_INLINE static void call(U* output, const V& input) { + vector_storage storage = into_vector_storage(input); + + for (size_t i = 0; i < size; i++) { + output[i] = ops::cast {}(storage.data()[i]); + } + } }; -template -struct iterate_helper> { - KERNEL_FLOAT_INLINE - static void call(F fun, const V& input) { - fun(vector_get(input)); - iterate_helper>::call(fun, input); +template +struct concat_base_impl { + static constexpr size_t size = 1; + + KERNEL_FLOAT_INLINE static void call(U* output, const T& input) { + *output = ops::cast {}(input); + } +}; + +template +struct concat_base_impl { + static constexpr size_t size = 1; + + KERNEL_FLOAT_INLINE static void call(T* output, const T& input) { + *output = input; + } +}; + +template +struct concat_impl {}; + +template +struct concat_impl { + using value_type = + typename promote_type, typename concat_impl::value_type>::type; + static constexpr size_t size = concat_base_impl::size + concat_impl::size; + + template + KERNEL_FLOAT_INLINE static void call(U* output, const V& input, const Vs&... rest) { + concat_base_impl::call(output, input); + concat_impl::call(output + concat_base_impl::size, rest...); } }; + +template<> +struct concat_impl<> { + using value_type = void; + static constexpr size_t size = 1; + + template + KERNEL_FLOAT_INLINE static void call(U* output) {} +}; } // namespace detail +template +using concat_value_type = promote_t::value_type>; + +template +static constexpr size_t concat_size = detail::concat_impl::size; + +template +using concat_type = vector, extent>>; + /** - * Apply the function ``fun`` for each element from ``input``. + * Concatenates the provided input values into a single one-dimensional vector. + * + * This function works in three steps: + * - All input values are converted into vectors using the `into_vector` operation. + * - The resulting vectors' elements are then promoted into a shared value type. + * - The resultant vectors are finally concatenated together. + * + * For instance, when invoking this function with arguments of types `float, double2, double`: + * - After the first step: `vec, vec, vec` + * - After the second step: `vec, vec, vec` + * - After the third step: `vec` * * Example * ======= * ``` - * for_each(range<3>(), [&](auto i) { - * printf("element: %d\n", i); - * }); + * double vec1 = 1.0; + * double3 vec2 = {3.0, 4.0, 5.0); + * double4 vec3 = {6.0, 7.0, 8.0, 9.0}; + * vec concatenated = concat(vec1, vec2, vec3); // contains [1, 2, 3, 4, 5, 6, 7, 8, 9] + * + * int num1 = 42; + * float num2 = 3.14159; + * int2 num3 = {-10, 10}; + * vec concatenated = concat(num1, num2, num3); // contains [42, 3.14159, -10, 10] * ``` */ -template -KERNEL_FLOAT_INLINE void for_each(const V& input, F fun) { - detail::iterate_helper>::call(fun, into_storage(input)); +template +KERNEL_FLOAT_INLINE concat_type concat(const Vs&... inputs) { + vector_storage, concat_size> output; + detail::concat_impl::call(output.data(), inputs...); + return output; +} + +template +using select_type = vector, extent>>; + +/** + * Selects elements from the this vector based on the specified indices. + * + * Example + * ======= + * ``` + * vec input = {0, 10, 20, 30, 40, 50}; + * vec vec1 = select(input, 0, 4, 4, 2); // [0, 40, 40, 20] + * + * vec indices = {0, 4, 4, 2}; + * vec vec2 = select(input, indices); // [0, 40, 40, 20] + * ``` + */ +template +KERNEL_FLOAT_INLINE select_type select(const V& input, const Is&... indices) { + using T = vector_value_type; + static constexpr size_t N = vector_extent; + static constexpr size_t M = concat_size; + + vector_storage index_set; + detail::concat_impl::call(index_set.data(), indices...); + + vector_storage inputs = into_vector_storage(input); + vector_storage outputs; + for (size_t i = 0; i < M; i++) { + size_t j = index_set.data()[i]; + + if (j < N) { + outputs.data()[i] = inputs.data()[j]; + } + } + + return outputs; } } // namespace kernel_float -#endif //KERNEL_FLOAT_ITERATE_H +#endif \ No newline at end of file diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 761360e..ab70d2a 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -6,20 +6,20 @@ #ifdef __CUDA_ARCH__ #define KERNEL_FLOAT_INLINE __forceinline__ __device__ -#define KERNEL_FLOAT_ON_DEVICE (1) -#define KERNEL_FLOAT_ON_HOST (0) +#define KERNEL_FLOAT_IS_DEVICE (1) +#define KERNEL_FLOAT_IS_HOST (0) #define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__) #else #define KERNEL_FLOAT_INLINE __forceinline__ __host__ -#define KERNEL_FLOAT_ON_DEVICE (0) -#define KERNEL_FLOAT_ON_HOST (1) +#define KERNEL_FLOAT_IS_DEVICE (0) +#define KERNEL_FLOAT_IS_HOST (1) #define KERNEL_FLOAT_CUDA_ARCH (0) #endif #else #define KERNEL_FLOAT_INLINE inline #define KERNEL_FLOAT_CUDA (0) -#define KERNEL_FLOAT_ON_HOST (1) -#define KERNEL_FLOAT_ON_DEVICE (0) +#define KERNEL_FLOAT_IS_HOST (1) +#define KERNEL_FLOAT_IS_DEVICE (0) #define KERNEL_FLOAT_CUDA_ARCH (0) #endif diff --git a/include/kernel_float/memory.h b/include/kernel_float/memory.h new file mode 100644 index 0000000..1c136a6 --- /dev/null +++ b/include/kernel_float/memory.h @@ -0,0 +1,268 @@ +#ifndef KERNEL_FLOAT_MEMORY_H +#define KERNEL_FLOAT_MEMORY_H + +/* +#include "binops.h" +#include "conversion.h" +#include "iterate.h" + +namespace kernel_float { + + namespace detail { + template > + struct load_helper; + + template + struct load_helper> { + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets + ) { + return {base[offsets.data()[Is]]...}; + } + + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets, + vector_storage mask + ) { + if (all(mask)) { + return call(base, offsets); + } else { + return { + (mask.data()[Is] ? base[offsets.data()[Is]] : T())... + }; + } + } + }; + } + + template < + typename T, + typename I, + typename M, + typename E = broadcast_vector_extent_type + > + KERNEL_FLOAT_INLINE + vector load(const T* ptr, const I& indices, const M& mask) { + static constexpr E new_size = {}; + + return detail::load_helper::call( + ptr, + convert_storage(indices, new_size), + convert_storage(mask, new_size) + ); + } + + template + KERNEL_FLOAT_INLINE + vector> load(const T* ptr, const I& indices) { + return detail::load_helper::value>::call( + ptr, + cast(indices) + ); + } + + template + KERNEL_FLOAT_INLINE + vector> load(const T* ptr, ptrdiff_t length) { + using index_type = vector_value_type; + return load_masked(ptr, range(), range() < length); + } + + template + KERNEL_FLOAT_INLINE + vector> load(const T* ptr) { + return load(ptr, range()); + } + + namespace detail { + template + struct store_helper { + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets, + vector_storage mask, + vector_storage values + ) { + for (size_t i = 0; i < N; i++) { + if (mask.data()[i]) { + base[offset.data()[i]] = values.data()[i]; + } + } + } + + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets, + vector_storage values + ) { + for (size_t i = 0; i < N; i++) { + base[offset.data()[i]] = values.data()[i]; + } + } + }; + } + + template < + typename T, + typename I, + typename M, + typename V, + typename E = broadcast_extent, broadcast_vector_extent_type>> + > + KERNEL_FLOAT_INLINE + void store(const T* ptr, const I& indices, const M& mask, const V& values) { + static constexpr E new_size = {}; + + return detail::store_helper::call( + ptr, + convert_storage(indices, new_size), + convert_storage(mask, new_size), + convert_storage(values, new_size) + ); + } + + template < + typename T, + typename I, + typename V, + typename E = broadcast_vector_extent_type + > + KERNEL_FLOAT_INLINE + void store(const T* ptr, const I& indices, const V& values) { + static constexpr E new_size = {}; + + return detail::store_helper::call( + ptr, + convert_storage(indices, new_size), + convert_storage(values, new_size) + ); + } + + + template < + typename T, + typename V + > + KERNEL_FLOAT_INLINE + void store(const T* ptr, const V& values) { + using E = vector_extent; + return store(ptr, range(), values); + } + + template + KERNEL_FLOAT_INLINE + void store(const T* ptr, const I& indices, const S& length, const V& values) { + using index_type = vector_value_type; + return store(ptr, indices, (indices >= I(0)) & (indices < length), values); + } + + + template + struct aligned_ptr_base { + static_assert(alignof(T) % alignment == 0, "invalid alignment, must be multiple of alignment of `T`"); + + KERNEL_FLOAT_INLINE + aligned_ptr_base(): ptr_(nullptr) {} + + KERNEL_FLOAT_INLINE + explicit aligned_ptr_base(T* ptr): ptr_(ptr) {} + + KERNEL_FLOAT_INLINE + T* get() const { + // TOOD: check if this way is support across all compilers +#if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned) + return __builtin_assume_aligned(ptr_, alignment); +#else + return ptr_; +#endif + } + + KERNEL_FLOAT_INLINE + operator T*() const { + return get(); + } + + KERNEL_FLOAT_INLINE + T& operator*() const { + return *get(); + } + + template + KERNEL_FLOAT_INLINE + T& operator[](I index) const { + return get()[index); + } + + private: + T* ptr_ = nullptr; + }; + + template + struct aligned_ptr; + + template + struct aligned_ptr: aligned_ptr_base { + using base_type = aligned_ptr_base; + + KERNEL_FLOAT_INLINE + aligned_ptr(): base_type(nullptr) {} + + KERNEL_FLOAT_INLINE + explicit aligned_ptr(T* ptr): base_type(ptr) {} + + KERNEL_FLOAT_INLINE + aligned_ptr(aligned_ptr ptr): base_type(ptr.get()) {} + }; + + template + struct aligned_ptr: aligned_ptr_base { + using base_type = aligned_ptr_base; + + KERNEL_FLOAT_INLINE + aligned_ptr(): base_type(nullptr) {} + + KERNEL_FLOAT_INLINE + explicit aligned_ptr(T* ptr): base_type(ptr) {} + + KERNEL_FLOAT_INLINE + explicit aligned_ptr(const T* ptr): base_type(ptr) {} + + KERNEL_FLOAT_INLINE + aligned_ptr(aligned_ptr ptr): base_type(ptr.get()) {} + + KERNEL_FLOAT_INLINE + aligned_ptr(aligned_ptr ptr): base_type(ptr.get()) {} + }; + + + template + KERNEL_FLOAT_INLINE + T* operator+(aligned_ptr ptr, ptrdiff_t index) { + return ptr.get() + index; + } + + template + KERNEL_FLOAT_INLINE + T* operator+(ptrdiff_t index, aligned_ptr ptr) { + return ptr.get() + index; + } + + template + KERNEL_FLOAT_INLINE + ptrdiff_t operator-(aligned_ptr left, aligned_ptr right) { + return left.get() - right.get(); + } + + template + using unaligned_ptr = aligned_ptr; + +} +*/ + +#endif //KERNEL_FLOAT_MEMORY_H diff --git a/include/kernel_float/meta.h b/include/kernel_float/meta.h index e8bba21..9c133a3 100644 --- a/include/kernel_float/meta.h +++ b/include/kernel_float/meta.h @@ -5,15 +5,6 @@ namespace kernel_float { -template -struct const_index { - static constexpr size_t value = I; - - KERNEL_FLOAT_INLINE constexpr operator size_t() const noexcept { - return I; - } -}; - template struct index_sequence { static constexpr size_t size = sizeof...(Is); @@ -21,13 +12,13 @@ struct index_sequence { namespace detail { template -struct make_index_sequence_helper {}; +struct make_index_sequence_impl {}; // Benchmarks show that it is much faster to predefine all possible index sequences instead of doing something // recursive with variadic templates. #define KERNEL_FLOAT_INDEX_SEQ(N, ...) \ template<> \ - struct make_index_sequence_helper { \ + struct make_index_sequence_impl { \ using type = index_sequence<__VA_ARGS__>; \ }; @@ -53,169 +44,207 @@ KERNEL_FLOAT_INDEX_SEQ(17, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, } // namespace detail template -using make_index_sequence = typename detail::make_index_sequence_helper::type; +using make_index_sequence = typename detail::make_index_sequence_impl::type; namespace detail { template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; } // namespace detail template -using decay_t = typename detail::decay_helper::type; +using decay_t = typename detail::decay_impl::type; + +template +struct promote_type; + +template +struct promote_type { + using type = T; +}; -template -struct common_type; +template +struct promote_type { + using type = T; +}; template -struct common_type { +struct promote_type { using type = T; }; -#define KERNEL_FLOAT_DEFINE_COMMON_TYPE(T, U) \ - template<> \ - struct common_type { \ - using type = T; \ - }; \ - template<> \ - struct common_type { \ - using type = T; \ +template<> +struct promote_type { + using type = void; +}; + +#define KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, U) \ + template<> \ + struct promote_type { \ + using type = T; \ + }; \ + template<> \ + struct promote_type { \ + using type = T; \ + }; + +// T + bool becomes T +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed short, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed int, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed long, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed long long, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned short, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned int, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned long, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned long long, bool) + +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, float) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(long double, float) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(long double, double) + +#define KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(T) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, char) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed char) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed short) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed int) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed long long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned char) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned short) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned int) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned long long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, bool) + +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(float) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(double) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(long double) + +#define KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(T, U) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed T, signed U) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned T, unsigned U) + +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(short, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(int, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(int, short) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long, short) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long, int) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, short) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, int) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, long) + +template +struct promote_type { + using type = T*; +}; + +#define KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(I) \ + template \ + struct promote_type { \ + using type = T*; \ + }; \ + template \ + struct promote_type { \ + using type = T*; \ }; -KERNEL_FLOAT_DEFINE_COMMON_TYPE(long double, double) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(long double, float) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, float) -//KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, half) -//KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, half) - -#define KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(T, U) \ - KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed T, signed U) \ - KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned T, unsigned U) - -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, long) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, int) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, short) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, char) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long, int) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long, short) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long, char) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(int, short) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(int, char) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(short, char) - -KERNEL_FLOAT_DEFINE_COMMON_TYPE(long double, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, bool) - -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed long long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed int, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed short, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed char, bool) - -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned long long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned int, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned short, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(char) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed char) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed short) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed int) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed long) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed long long) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned char) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned short) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned int) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned long) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned long long) + +// half precision +// KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(half) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half, bool) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, half) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, half) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(long double, half) namespace detail { template -struct common_type_helper; +struct multi_promote_type; template -struct common_type_helper { +struct multi_promote_type { using type = T; }; -template -struct common_type_helper { - using type = typename common_type::type; -}; +template +struct multi_promote_type: promote_type {}; + +template +struct multi_promote_type: + multi_promote_type::type, C, Rest...> {}; -template -struct common_type_helper: - common_type_helper::type, R, Rest...> {}; } // namespace detail template -using common_t = typename detail::common_type_helper...>::type; +using promote_t = typename detail::multi_promote_type...>::type; namespace detail { -template -struct common_size_helper; - -template<> -struct common_size_helper<> { - static constexpr size_t value = 1; -}; -template -struct common_size_helper { - static constexpr size_t value = N; -}; - -template -struct common_size_helper { - static constexpr size_t value = N; -}; - -template -struct common_size_helper { - static constexpr size_t value = N; -}; - -template -struct common_size_helper<1, N> { - static constexpr size_t value = N; +template +struct is_same_type_impl { + static constexpr bool value = false; }; -template<> -struct common_size_helper<1, 1> { - static constexpr size_t value = 1; +template +struct is_same_type_impl { + static constexpr bool value = true; }; } // namespace detail -template -static constexpr size_t common_size = detail::common_size_helper::value; +template +static constexpr bool is_same_type = detail::is_same_type_impl::value; namespace detail { - template -struct is_implicit_convertible_helper { +struct is_implicit_convertible_impl { static constexpr bool value = false; }; template -struct is_implicit_convertible_helper::type> { +struct is_implicit_convertible_impl::type> { static constexpr bool value = true; }; } // namespace detail template static constexpr bool is_implicit_convertible = - detail::is_implicit_convertible_helper, decay_t>::value; + detail::is_implicit_convertible_impl, decay_t>::value; namespace detail { template @@ -230,17 +259,17 @@ using result_t = decltype((detail::declval())(detail::declval()...)); namespace detail { template -struct enabled_helper {}; +struct enable_if_impl {}; template -struct enabled_helper { +struct enable_if_impl { using type = T; }; } // namespace detail template -using enabled_t = typename detail::enabled_helper::type; +using enable_if_t = typename detail::enable_if_impl::type; } // namespace kernel_float -#endif \ No newline at end of file +#endif diff --git a/include/kernel_float/prelude.h b/include/kernel_float/prelude.h new file mode 100644 index 0000000..2bc06a2 --- /dev/null +++ b/include/kernel_float/prelude.h @@ -0,0 +1,114 @@ +#ifndef KERNEL_FLOAT_PRELUDE_H +#define KERNEL_FLOAT_PRELUDE_H + +#include "bf16.h" +#include "constant.h" +#include "fp16.h" +#include "vector.h" + +namespace kernel_float { +namespace prelude { +namespace kf = ::kernel_float; + +template +using kscalar = vector>; + +template +using kvec = vector>; + +// clang-format off +template using kvec1 = kvec; +template using kvec2 = kvec; +template using kvec3 = kvec; +template using kvec4 = kvec; +template using kvec5 = kvec; +template using kvec6 = kvec; +template using kvec7 = kvec; +template using kvec8 = kvec; +// clang-format on + +#define KERNEL_FLOAT_TYPE_ALIAS(NAME, T) \ + template \ + using k##NAME = vector>; \ + using k##NAME##1 = vec; \ + using k##NAME##2 = vec; \ + using k##NAME##3 = vec; \ + using k##NAME##4 = vec; \ + using k##NAME##5 = vec; \ + using k##NAME##6 = vec; \ + using k##NAME##7 = vec; \ + using k##NAME##8 = vec; + +KERNEL_FLOAT_TYPE_ALIAS(char, char) +KERNEL_FLOAT_TYPE_ALIAS(short, short) +KERNEL_FLOAT_TYPE_ALIAS(int, int) +KERNEL_FLOAT_TYPE_ALIAS(long, long) +KERNEL_FLOAT_TYPE_ALIAS(longlong, long long) + +KERNEL_FLOAT_TYPE_ALIAS(uchar, unsigned char) +KERNEL_FLOAT_TYPE_ALIAS(ushort, unsigned short) +KERNEL_FLOAT_TYPE_ALIAS(uint, unsigned int) +KERNEL_FLOAT_TYPE_ALIAS(ulong, unsigned long) +KERNEL_FLOAT_TYPE_ALIAS(ulonglong, unsigned long long) + +KERNEL_FLOAT_TYPE_ALIAS(float, float) +KERNEL_FLOAT_TYPE_ALIAS(f32x, float) +KERNEL_FLOAT_TYPE_ALIAS(float32x, float) + +KERNEL_FLOAT_TYPE_ALIAS(double, double) +KERNEL_FLOAT_TYPE_ALIAS(f64x, double) +KERNEL_FLOAT_TYPE_ALIAS(float64x, double) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_TYPE_ALIAS(half, __half) +KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) +#endif + +#if KERNEL_FLOAT_BF16_AVAILABLE +KERNEL_FLOAT_TYPE_ALIAS(bfloat16, __nv_bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bf16, __nv_bfloat16) +#endif + +template +static constexpr extent kextent = {}; + +template +KERNEL_FLOAT_INLINE kvec, sizeof...(Args)> make_kvec(Args&&... args) { + return make_vec(std::forward(args)...); +}; + +template +using kconstant = constant; + +template +KERNEL_FLOAT_INLINE constexpr kconstant kconst(T value) { + return value; +} + +KERNEL_FLOAT_INLINE +static constexpr kconstant operator""_c(long double v) { + return static_cast(v); +} + +KERNEL_FLOAT_INLINE +static constexpr kconstant operator""_c(unsigned long long int v) { + return static_cast(v); +} + +// Deduction guides for aliases are only supported from C++20 +#if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201907L +template +kscalar(T&&) -> kscalar>; + +template +kvec(Args&&...) -> kvec, sizeof...(Args)>; + +template +kconstant(T&&) -> kconstant>; +#endif + +} // namespace prelude +} // namespace kernel_float + +#endif diff --git a/include/kernel_float/reduce.h b/include/kernel_float/reduce.h index f3bc520..dfa52c3 100644 --- a/include/kernel_float/reduce.h +++ b/include/kernel_float/reduce.h @@ -5,29 +5,21 @@ namespace kernel_float { namespace detail { -template -struct reduce_helper { - using value_type = vector_value_type; - - KERNEL_FLOAT_INLINE static value_type call(F fun, const V& input) { - return call(fun, input, make_index_sequence> {}); +template +struct reduce_impl { + KERNEL_FLOAT_INLINE static T call(F fun, const T* input) { + return call(fun, input, make_index_sequence {}); } private: template - KERNEL_FLOAT_INLINE static value_type call(F fun, const V& vector, index_sequence<0, Is...>) { - return call(fun, vector, vector_get<0>(vector), index_sequence {}); - } - - template - KERNEL_FLOAT_INLINE static value_type - call(F fun, const V& vector, value_type accum, index_sequence) { - return call(fun, vector, fun(accum, vector_get(vector)), index_sequence {}); - } - - KERNEL_FLOAT_INLINE static value_type - call(F fun, const V& vector, value_type accum, index_sequence<>) { - return accum; + KERNEL_FLOAT_INLINE static T call(F fun, const T* input, index_sequence<0, Is...>) { + T result = input[0]; +#pragma unroll + for (size_t i = 1; i < N; i++) { + result = fun(result, input[i]); + } + return result; } }; } // namespace detail @@ -36,7 +28,7 @@ struct reduce_helper { * Reduce the elements of the given vector ``input`` into a single value using * the function ``fun``. This function should be a binary function that takes * two elements and returns one element. The order in which the elements - * are reduced is not specified and depends on the reduction function and + * are reduced is not specified and depends on both the reduction function and * the vector type. * * Example @@ -48,7 +40,9 @@ struct reduce_helper { */ template KERNEL_FLOAT_INLINE vector_value_type reduce(F fun, const V& input) { - return detail::reduce_helper>::call(fun, into_storage(input)); + return detail::reduce_impl, vector_value_type>::call( + fun, + into_vector_storage(input).data()); } /** @@ -57,7 +51,7 @@ KERNEL_FLOAT_INLINE vector_value_type reduce(F fun, const V& input) { * Example * ======= * ``` - * vec x = {5, 0, 2, 1, 0}; + * vec x = {5, 0, 2, 1, 0}; * int y = min(x); // Returns 0 * ``` */ @@ -72,7 +66,7 @@ KERNEL_FLOAT_INLINE T min(const V& input) { * Example * ======= * ``` - * vec x = {5, 0, 2, 1, 0}; + * vec x = {5, 0, 2, 1, 0}; * int y = max(x); // Returns 5 * ``` */ @@ -87,7 +81,7 @@ KERNEL_FLOAT_INLINE T max(const V& input) { * Example * ======= * ``` - * vec x = {5, 0, 2, 1, 0}; + * vec x = {5, 0, 2, 1, 0}; * int y = sum(x); // Returns 8 * ``` */ @@ -116,7 +110,7 @@ KERNEL_FLOAT_INLINE T product(const V& input) { * non-zero if ``bool(v)==true``. */ template -KERNEL_FLOAT_INLINE bool all(V&& input) { +KERNEL_FLOAT_INLINE bool all(const V& input) { return reduce(ops::bit_and {}, cast(input)); } @@ -125,7 +119,7 @@ KERNEL_FLOAT_INLINE bool all(V&& input) { * non-zero if ``bool(v)==true``. */ template -KERNEL_FLOAT_INLINE bool any(V&& input) { +KERNEL_FLOAT_INLINE bool any(const V& input) { return reduce(ops::bit_or {}, cast(input)); } @@ -140,9 +134,113 @@ KERNEL_FLOAT_INLINE bool any(V&& input) { * int y = count(x); // Returns 3 (5, 2, 1 are non-zero) * ``` */ -template -KERNEL_FLOAT_INLINE int count(V&& input) { - return sum(cast(cast(input))); +template +KERNEL_FLOAT_INLINE T count(const V& input) { + return sum(cast(cast(input))); +} + +namespace detail { +template +struct dot_impl { + KERNEL_FLOAT_INLINE + static T call(const T* left, const T* right) { + vector_storage intermediate; + detail::apply_impl, N, T, T, T>::call( + ops::multiply(), + intermediate.data(), + left, + right); + + return detail::reduce_impl, N, T>::call(ops::add(), intermediate.data()); + } +}; +} // namespace detail + +/** + * Compute the dot product of the given vectors ``left`` and ``right`` + * + * Example + * ======= + * ``` + * vec x = {1, 2, 3}; + * vec y = {4, 5, 6}; + * int y = dot(x, y); // Returns 1*4+2*5+3*6 = 32 + * ``` + */ +template> +KERNEL_FLOAT_INLINE T dot(const L& left, const R& right) { + using E = broadcast_vector_extent_type; + return detail::dot_impl::call( + convert_storage(left, E {}).data(), + convert_storage(right, E {}).data()); +} + +namespace detail { +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return ops::sqrt {}(detail::dot_impl::call(input, input)); + } +}; + +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return T {}; + } +}; + +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return ops::abs {}(input[0]); + } +}; + +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return ops::hypot()(input[0], input[1]); + } +}; + +// The 3-argument overload of hypot is only available on host from C++17 +#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST +template<> +struct magnitude_impl { + static float call(const float* input) { + return ::hypot(input[0], input[1], input[2]); + } +}; + +template<> +struct magnitude_impl { + static double call(const double* input) { + return ::hypot(input[0], input[1], input[2]); + } +}; +#endif + +} // namespace detail + +/** + * Compute the magnitude of the given input vector. This calculates the square root of the sum of squares, also + * known as the Euclidian norm, of a vector. + * + * Example + * ======= + * ``` + * vec x = {2, 3, 6}; + * float y = mag(x); // Returns sqrt(2*2 + 3*3 + 6*6) = 7 + * ``` + */ +template> +KERNEL_FLOAT_INLINE T mag(const V& input) { + return detail::magnitude_impl>::call(into_vector_storage(input).data()); } } // namespace kernel_float diff --git a/include/kernel_float/storage.h b/include/kernel_float/storage.h deleted file mode 100644 index 4d28339..0000000 --- a/include/kernel_float/storage.h +++ /dev/null @@ -1,503 +0,0 @@ -#ifndef KERNEL_FLOAT_STORAGE -#define KERNEL_FLOAT_STORAGE - -#include "meta.h" - -namespace kernel_float { - -template -struct vector_traits { - using value_type = V; - static constexpr size_t size = 1; - - KERNEL_FLOAT_INLINE - static V fill(value_type value) { - return value; - } - - KERNEL_FLOAT_INLINE - static V create(value_type value) { - return value; - } - - KERNEL_FLOAT_INLINE - static value_type get(const V& self, size_t index) { - KERNEL_FLOAT_ASSERT(index == 0); - return self; - } - - KERNEL_FLOAT_INLINE - static void set(V& self, size_t index, value_type value) { - KERNEL_FLOAT_ASSERT(index == 0); - self = value; - } -}; - -template -struct into_storage_traits { - using type = V; - - KERNEL_FLOAT_INLINE - static constexpr type call(V self) { - return self; - } -}; - -template -struct into_storage_traits: into_storage_traits {}; - -template -struct into_storage_traits: into_storage_traits {}; - -template -struct into_storage_traits: into_storage_traits {}; - -template -using into_storage_type = typename into_storage_traits::type; - -template -KERNEL_FLOAT_INLINE into_storage_type into_storage(V&& input) { - return into_storage_traits::call(input); -} - -template -static constexpr size_t vector_size = vector_traits>::size; - -template -using vector_value_type = typename vector_traits>::value_type; - -template -struct vector_index { - using value_type = vector_value_type; - - KERNEL_FLOAT_INLINE - static value_type get(const V& self) { - return vector_traits::get(self, I); - } - - KERNEL_FLOAT_INLINE - static void set(V& self, value_type value) { - return vector_traits::set(self, I, value); - } -}; - -template -KERNEL_FLOAT_INLINE vector_value_type vector_get(const V& self, size_t index) { - return vector_traits::get(self, index); -} - -template -KERNEL_FLOAT_INLINE vector_value_type vector_get(const V& self, const_index = {}) { - return vector_index::get(self); -} - -template -struct vector_swizzle; - -template -struct vector_swizzle> { - KERNEL_FLOAT_INLINE static Output call(const Input& storage) { - return vector_traits::create(vector_get(storage)...); - } -}; - -template -struct vector; - -template -struct alignas(alignment) array { - T items_[N]; - - KERNEL_FLOAT_INLINE - T& operator[](size_t i) { - KERNEL_FLOAT_ASSERT(i < N); - return items_[i]; - } - - KERNEL_FLOAT_INLINE - const T& operator[](size_t i) const { - KERNEL_FLOAT_ASSERT(i < N); - return items_[i]; - } -}; - -template -struct vector_traits> { - using self_type = array; - using value_type = T; - static constexpr size_t size = N; - - template - KERNEL_FLOAT_INLINE static self_type create(Args&&... args) { - return {args...}; - } - - KERNEL_FLOAT_INLINE - static self_type fill(value_type value) { - self_type result; - for (size_t i = 0; i < N; i++) { - result[i] = value; - } - return result; - } - - KERNEL_FLOAT_INLINE - static value_type get(const self_type& self, size_t index) { - KERNEL_FLOAT_ASSERT(index < N); - return self[index]; - } - - KERNEL_FLOAT_INLINE - static void set(self_type& self, size_t index, value_type value) { - KERNEL_FLOAT_ASSERT(index < N); - self[index] = value; - } -}; - -template -struct array {}; - -template -struct vector_traits> { - using self_type = array; - using value_type = T; - static constexpr size_t size = 0; - - KERNEL_FLOAT_INLINE - static self_type create() { - return {}; - } - - KERNEL_FLOAT_INLINE - static self_type fill(value_type value) { - return {}; - } - - KERNEL_FLOAT_INLINE - static value_type get(const self_type& self, size_t index) { - KERNEL_FLOAT_UNREACHABLE; - } - - KERNEL_FLOAT_INLINE - static void set(self_type& self, size_t index, value_type value) { - KERNEL_FLOAT_UNREACHABLE; - } -}; - -enum struct Alignment { - Minimum, - Packed, - Maximum, -}; - -constexpr size_t calculate_alignment(Alignment required, size_t min_alignment, size_t total_size) { - size_t alignment = 1; - - if (required == Alignment::Maximum) { - if (total_size <= 1) { - alignment = 1; - } else if (total_size <= 2) { - alignment = 2; - } else if (total_size <= 4) { - alignment = 4; - } else if (total_size <= 8) { - alignment = 8; - } else { - alignment = 16; - } - } else if (required == Alignment::Packed) { - if (total_size % 16 == 0) { - alignment = 16; - } else if (total_size % 8 == 0) { - alignment = 8; - } else if (total_size % 4 == 0) { - alignment = 4; - } else if (total_size % 2 == 0) { - alignment = 2; - } else { - alignment = 1; - } - } - - if (min_alignment > alignment) { - alignment = min_alignment; - } - - return alignment; -} - -template -struct default_storage { - using type = array; -}; - -template -struct default_storage { - using type = T; -}; - -template -using default_storage_type = typename default_storage::type; - -#define KERNEL_FLOAT_DEFINE_VECTOR_TYPE(T, T1, T2, T3, T4) \ - template<> \ - struct vector_traits { \ - using value_type = T; \ - static constexpr size_t size = 1; \ - \ - KERNEL_FLOAT_INLINE \ - static T1 create(T x) { \ - return {x}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T1 fill(T v) { \ - return {v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T1& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static void set(T1& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - }; \ - \ - template<> \ - struct vector_traits { \ - using value_type = T; \ - static constexpr size_t size = 2; \ - \ - KERNEL_FLOAT_INLINE \ - static T2 create(T x, T y) { \ - return {x, y}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T2 fill(T v) { \ - return {v, v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T2& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - case 1: \ - return self.y; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static void set(T2& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - case 1: \ - self.y = value; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - }; \ - \ - template<> \ - struct vector_traits { \ - using value_type = T; \ - static constexpr size_t size = 3; \ - \ - KERNEL_FLOAT_INLINE \ - static T3 create(T x, T y, T z) { \ - return {x, y, z}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T3 fill(T v) { \ - return {v, v, v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T3& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - case 1: \ - return self.y; \ - case 2: \ - return self.z; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static void set(T3& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - return; \ - case 1: \ - self.y = value; \ - return; \ - case 2: \ - self.z = value; \ - return; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - }; \ - \ - template<> \ - struct vector_traits { \ - using value_type = T; \ - static constexpr size_t size = 4; \ - \ - KERNEL_FLOAT_INLINE \ - static T4 create(T x, T y, T z, T w) { \ - return {x, y, z, w}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T4 fill(T v) { \ - return {v, v, v, v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T4& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - case 1: \ - return self.y; \ - case 2: \ - return self.z; \ - case 3: \ - return self.w; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static void set(T4& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - return; \ - case 1: \ - self.y = value; \ - return; \ - case 2: \ - self.z = value; \ - return; \ - case 3: \ - self.w = value; \ - return; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - }; - -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(char, char1, char2, char3, char4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(short, short1, short2, short3, short4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(int, int1, int2, int3, int4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(long, long1, long2, long3, long4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(long long, longlong1, longlong2, longlong3, longlong4) - -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned char, uchar1, uchar2, uchar3, uchar4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned short, ushort1, ushort2, ushort3, ushort4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned int, uint1, uint2, uint3, uint4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned long, ulong1, ulong2, ulong3, ulong4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned long long, ulonglong1, ulonglong2, ulonglong3, ulonglong4) - -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(float, float1, float2, float3, float4) -KERNEL_FLOAT_DEFINE_VECTOR_TYPE(double, double1, double2, double3, double4) - -template -struct nested_array { - static constexpr size_t num_packets = (N + vector_size - 1) / vector_size; - static_assert(num_packets * vector_size >= N, "internal error"); - - V packets[num_packets]; - - KERNEL_FLOAT_INLINE - V& operator[](size_t i) { - KERNEL_FLOAT_ASSERT(i < num_packets); - return packets[i]; - } - - KERNEL_FLOAT_INLINE - const V& operator[](size_t i) const { - KERNEL_FLOAT_ASSERT(i < num_packets); - return packets[i]; - } -}; - -template -struct vector_traits> { - using self_type = nested_array; - using value_type = vector_value_type; - static constexpr size_t size = N; - - template - KERNEL_FLOAT_INLINE static self_type create(Args&&... args) { - value_type items[N] = {args...}; - self_type output; - - size_t i = 0; - for (; i + vector_size - 1 < N; i += vector_size) { - // How to generalize this? - output.packets[i / vector_size] = vector_traits::create(items[i], items[i + 1]); - } - - for (; i < N; i++) { - vector_traits::set(output.packets[i / vector_size], i % vector_size, items[i]); - } - - return output; - } - - KERNEL_FLOAT_INLINE - static self_type fill(value_type value) { - self_type output; - - for (size_t i = 0; i < self_type::num_packets; i++) { - output.packets[i] = vector_traits::fill(value); - } - - return output; - } - - KERNEL_FLOAT_INLINE - static value_type get(const self_type& self, size_t index) { - KERNEL_FLOAT_ASSERT(index < N); - return vector_traits::get(self.packets[index / vector_size], index % vector_size); - } - - KERNEL_FLOAT_INLINE - static void set(self_type& self, size_t index, value_type value) { - KERNEL_FLOAT_ASSERT(index < N); - vector_traits::set(self.packets[index / vector_size], index % vector_size, value); - } -}; - -}; // namespace kernel_float - -#endif \ No newline at end of file diff --git a/include/kernel_float/swizzle.h b/include/kernel_float/swizzle.h deleted file mode 100644 index 50a023a..0000000 --- a/include/kernel_float/swizzle.h +++ /dev/null @@ -1,218 +0,0 @@ -#ifndef KERNEL_FLOAT_SWIZZLE_H -#define KERNEL_FLOAT_SWIZZLE_H - -#include "storage.h" - -namespace kernel_float { - -/** - * "Swizzles" the vector. Returns a new vector where the elements are provided by the given indices. - * - * # Example - * ``` - * vec x = {0, 1, 2, 3, 4, 5, 6}; - * vec a = swizzle<0, 1, 2>(x); // 0, 1, 2 - * vec b = swizzle<2, 1, 0>(x); // 2, 1, 0 - * vec c = swizzle<1, 1, 1>(x); // 1, 1, 1 - * vec d = swizzle<0, 2, 4, 6>(x); // 0, 2, 4, 6 - * ``` - */ -template< - size_t... Is, - typename V, - typename Output = default_storage_type, sizeof...(Is)>> -KERNEL_FLOAT_INLINE vector swizzle(const V& input, index_sequence _ = {}) { - return vector_swizzle, index_sequence>::call( - into_storage(input)); -} - -/** - * Takes the first ``N`` elements from the given vector and returns a new vector of length ``N``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = first<3>(x); // 1, 2, 3 - * int z = first(x); // 1 - * ``` - */ -template, K>> -KERNEL_FLOAT_INLINE vector first(const V& input) { - static_assert(K <= vector_size, "K cannot exceed vector size"); - using Indices = make_index_sequence; - return vector_swizzle, Indices>::call(into_storage(input)); -} - -namespace detail { -template -struct offset_index_sequence_helper; - -template -struct offset_index_sequence_helper> { - using type = index_sequence; -}; -} // namespace detail - -/** - * Takes the last ``N`` elements from the given vector and returns a new vector of length ``N``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = last<3>(x); // 4, 5, 6 - * int z = last(x); // 6 - * ``` - */ -template, K>> -KERNEL_FLOAT_INLINE vector last(const V& input) { - static_assert(K <= vector_size, "K cannot exceed vector size"); - using Indices = typename detail::offset_index_sequence_helper< // - vector_size - K, - make_index_sequence>::type; - - return vector_swizzle, Indices>::call(into_storage(input)); -} - -namespace detail { -template -struct reverse_index_sequence_helper: reverse_index_sequence_helper {}; - -template -struct reverse_index_sequence_helper<0, Is...> { - using type = index_sequence; -}; -} // namespace detail - -/** - * Reverses the elements in the given vector. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = reversed(x); // 6, 5, 4, 3, 2, 1 - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector reversed(const V& input) { - using Indices = typename detail::reverse_index_sequence_helper>::type; - - return vector_swizzle, Indices>::call(into_storage(input)); -} - -namespace detail { -template -struct concat_index_sequence_helper {}; - -template -struct concat_index_sequence_helper, index_sequence> { - using type = index_sequence; -}; -} // namespace detail - -/** - * Rotate the given vector ``K`` steps to the right. In other words, this move the front element to the back - * ``K`` times. This is the inverse of ``rotate_left``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = rotate_right<2>(x); // 5, 6, 1, 2, 3, 4 - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector rotate_right(const V& input) { - static constexpr size_t N = vector_size; - static constexpr size_t I = (N > 0) ? (K % N) : 0; - - using First = - typename detail::offset_index_sequence_helper>::type; - using Second = make_index_sequence; - using Indices = typename detail::concat_index_sequence_helper::type; - - return vector_swizzle, Indices>::call(into_storage(input)); -} - -/** - * Rotate the given vector ``K`` steps to the left. In other words, this move the back element to the front - * ``K`` times. This is the inverse of ``rotate_right``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = rotate_left<4>(x); // 5, 6, 1, 2, 3, 4 - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector rotate_left(const V& input) { - static constexpr size_t N = vector_size; - static constexpr size_t K_rev = N > 0 ? (N - K % N) : 0; - - return rotate_right(input); -} - -namespace detail { -template< - typename U, - typename V, - typename Is = make_index_sequence>, - typename Js = make_index_sequence>> -struct concat_helper; - -template -struct concat_helper, index_sequence> { - using type = default_storage_type< - common_t, vector_value_type>, - vector_size + vector_size>; - - KERNEL_FLOAT_INLINE static type call(const U& left, const V& right) { - return vector_traits::create(vector_get(left)..., vector_get(right)...); - } -}; - -template -struct recur_concat_helper; - -template -struct recur_concat_helper { - using type = U; - - KERNEL_FLOAT_INLINE static U call(U&& input) { - return input; - } -}; - -template -struct recur_concat_helper { - using recur_helper = recur_concat_helper::type, Rest...>; - using type = typename recur_helper::type; - - KERNEL_FLOAT_INLINE static type call(const U& left, const V& right, const Rest&... rest) { - return recur_helper::call(concat_helper::call(left, right), rest...); - } -}; -} // namespace detail - -template -using concat_type = typename detail::recur_concat_helper...>::type; - -/** - * Concatenate the given vectors into one large vector. For example, given vectors of size 3, size 2 and size 5, - * this function returns a new vector of size 3+2+5=8. If the vectors are not of the same element type, they - * will first be cast into a common data type. - * - * # Examples - * ``` - * vec x = {1, 2, 3}; - * int y = 4; - * vec z = {5, 6, 7, 8}; - * vec xyz = concat(x, y, z); // 1, 2, 3, 4, 5, 6, 7, 8 - * ``` - */ -template -KERNEL_FLOAT_INLINE vector> concat(const Vs&... inputs) { - return detail::recur_concat_helper...>::call(into_storage(inputs)...); -} - -} // namespace kernel_float - -#endif //KERNEL_FLOAT_SWIZZLE_H diff --git a/include/kernel_float/triops.h b/include/kernel_float/triops.h new file mode 100644 index 0000000..44f6db2 --- /dev/null +++ b/include/kernel_float/triops.h @@ -0,0 +1,147 @@ +#ifndef KERNEL_FLOAT_TRIOPS_H +#define KERNEL_FLOAT_TRIOPS_H + +#include "conversion.h" +#include "unops.h" + +namespace kernel_float { + +namespace ops { +template +struct conditional { + KERNEL_FLOAT_INLINE T operator()(bool cond, T true_value, T false_value) { + if (cond) { + return true_value; + } else { + return false_value; + } + } +}; +} // namespace ops + +/** + * Return elements chosen from `true_values` and `false_values` depending on `cond`. + * + * This function broadcasts all arguments to the same size and then promotes the values of `true_values` and + * `false_values` into the same type. Next, it casts the values of `cond` to booleans and returns a vector where + * the values are taken from `true_values` where the condition is true and `false_values` otherwise. + * + * @param cond The condition used for selection. + * @param true_values The vector of values to choose from when the condition is true. + * @param false_values The vector of values to choose from when the condition is false. + * @return A vector containing selected elements as per the condition. + */ +template< + typename C, + typename L, + typename R, + typename T = promoted_vector_value_type, + typename E = broadcast_vector_extent_type> +KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values, const R& false_values) { + using F = ops::conditional; + vector_storage result; + + detail::apply_impl::call( + F {}, + result.data(), + detail::convert_impl, vector_extent_type, bool, E>::call( + into_vector_storage(cond)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(true_values)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(false_values)) + .data()); + + return result; +} + +/** + * Selects elements from `true_values` depending on `cond`. + * + * This function returns a vector where the values are taken from `true_values` where `cond` is `true` and `0` where + * `cond is `false`. + * + * @param cond The condition used for selection. + * @param true_values The vector of values to choose from when the condition is true. + * @return A vector containing selected elements as per the condition. + */ +template< + typename C, + typename L, + typename T = vector_value_type, + typename E = broadcast_vector_extent_type> +KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values) { + vector> false_values = T {}; + return where(cond, true_values, false_values); +} + +/** + * Returns a vector having the value `T(1)` where `cond` is `true` and `T(0)` where `cond` is `false`. + * + * @param cond The condition used for selection. + * @return A vector containing elements as per the condition. + */ +template> +KERNEL_FLOAT_INLINE vector where(const C& cond) { + return cast(cast(cond)); +} + +namespace ops { +template +struct fma { + KERNEL_FLOAT_INLINE T operator()(T a, T b, T c) { + return a * b + c; + } +}; + +#if KERNEL_FLOAT_IS_DEVICE +template<> +struct fma { + KERNEL_FLOAT_INLINE float operator()(float a, float b, float c) { + return __fmaf_rn(a, b, c); + } +}; + +template<> +struct fma { + KERNEL_FLOAT_INLINE double operator()(double a, double b, double c) { + return __fma_rn(a, b, c); + } +}; +#endif +} // namespace ops + +/** + * Computes the result of `a * b + c`. This is done in a single operation if possible. + */ +template< + typename A, + typename B, + typename C, + typename T = promoted_vector_value_type, + typename E = broadcast_vector_extent_type> +KERNEL_FLOAT_INLINE vector fma(const A& a, const B& b, const C& c) { + using F = ops::fma; + vector_storage result; + + detail::apply_impl::call( + F {}, + result.data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(a)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(b)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(c)) + .data()); + + return result; +} + +} // namespace kernel_float + +#endif //KERNEL_FLOAT_TRIOPS_H diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 6f1b4fd..bca3796 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -1,77 +1,69 @@ #ifndef KERNEL_FLOAT_UNOPS_H #define KERNEL_FLOAT_UNOPS_H -#include "cast.h" -#include "storage.h" +#include "base.h" namespace kernel_float { namespace detail { -template -struct map_helper { - KERNEL_FLOAT_INLINE static Output call(F fun, const Input& input) { - return call(fun, input, make_index_sequence> {}); - } - - private: - template - KERNEL_FLOAT_INLINE static Output call(F fun, const Input& input, index_sequence) { - return vector_traits::create(fun(vector_get(input))...); - } -}; - -template -struct map_helper, nested_array> { - KERNEL_FLOAT_INLINE static nested_array call(F fun, const nested_array& input) { - return call(fun, input, make_index_sequence::num_packets> {}); - } - private: - template - KERNEL_FLOAT_INLINE static nested_array - call(F fun, const nested_array& input, index_sequence) { - return {map_helper::call(fun, input[Is])...}; +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + result[i] = fun(inputs[i]...); + } } }; } // namespace detail -template -using map_type = default_storage_type>, vector_size>; +template +using map_type = vector>, vector_extent_type>; /** - * Applies ``fun`` to each element from vector ``input`` and returns a new vector with the results. - * This function is the basis for all unary operators like ``sin`` and ``sqrt``. + * Apply the function `F` to each element from the vector `input` and return the results as a new vector. * - * Example - * ======= + * Examples + * ======== * ``` - * vector v = {1, 2, 3}; - * vector w = map([](auto i) { return i * 2; }); // 2, 4, 6 + * vec input = {1.0f, 2.0f, 3.0f, 4.0f}; + * vec squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f] * ``` */ -template> -KERNEL_FLOAT_INLINE Output map(F fun, const Input& input) { - return detail::map_helper>::call(fun, into_storage(input)); +template +KERNEL_FLOAT_INLINE map_type map(F fun, const V& input) { + using Input = vector_value_type; + using Output = result_t; + vector_storage> result; + + detail::apply_impl, Output, Input>::call( + fun, + result.data(), + into_vector_storage(input).data()); + + return result; } -#define KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ - namespace ops { \ - template \ - struct NAME { \ - KERNEL_FLOAT_INLINE T operator()(T input) { \ - return T(EXPR); \ - } \ - }; \ - } \ - template \ - KERNEL_FLOAT_INLINE vector> NAME(const V& input) { \ - return map>, V, into_storage_type>({}, input); \ +#define KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ + namespace ops { \ + template \ + struct NAME { \ + KERNEL_FLOAT_INLINE T operator()(T input) { \ + return T(EXPR); \ + } \ + }; \ + } \ + template \ + KERNEL_FLOAT_INLINE vector, vector_extent_type> NAME(const V& input) { \ + using F = ops::NAME>; \ + return map(F {}, input); \ } -#define KERNEL_FLOAT_DEFINE_UNARY_OP(NAME, OP, EXPR) \ - KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ - template \ - KERNEL_FLOAT_INLINE vector operator OP(const vector& vec) { \ - return NAME(vec); \ +#define KERNEL_FLOAT_DEFINE_UNARY_OP(NAME, OP, EXPR) \ + KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ + template \ + KERNEL_FLOAT_INLINE vector operator OP(const vector& vec) { \ + return NAME(vec); \ } KERNEL_FLOAT_DEFINE_UNARY_OP(negate, -, -input) @@ -128,6 +120,28 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN(signbit) KERNEL_FLOAT_DEFINE_UNARY_FUN(isinf) KERNEL_FLOAT_DEFINE_UNARY_FUN(isnan) +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_DEFINE_UNARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_UNARY(FUN_NAME, ops::OP_NAME {}(input)) \ + namespace ops { \ + template<> \ + struct OP_NAME { \ + KERNEL_FLOAT_INLINE float operator()(float input) { \ + return FLOAT_FUN(input); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_DEFINE_UNARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_UNARY(FUN_NAME, ops::OP_NAME {}(input)) +#endif + +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_exp, exp, __expf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_log, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_cos, cos, __cosf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_sin, sin, __sinf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_tan, tan, __tanf) + } // namespace kernel_float #endif //KERNEL_FLOAT_UNOPS_H diff --git a/include/kernel_float/vector.h b/include/kernel_float/vector.h new file mode 100644 index 0000000..642541b --- /dev/null +++ b/include/kernel_float/vector.h @@ -0,0 +1,336 @@ +#ifndef KERNEL_FLOAT_VECTOR_H +#define KERNEL_FLOAT_VECTOR_H + +#include "base.h" +#include "conversion.h" +#include "iterate.h" +#include "macros.h" +#include "reduce.h" +#include "unops.h" + +namespace kernel_float { + +/** + * Container that stores ``N`` elements of type ``T``. + * + * It is not recommended to use this class directly, but instead, use the type `vec` which is an alias for + * `vector, vector_storage>`. + * + * @tparam T The type of the values stored within the vector. + * @tparam E The size of this vector. Should be of type `extent`. + * @tparam S The object's storage class. Should be the type `vector_storage` + */ +template +struct vector: public S { + using value_type = T; + using extent_type = E; + using storage_type = S; + + // Copy another `vector` + vector(const vector&) = default; + + // Copy anything of type `storage_type` + KERNEL_FLOAT_INLINE + vector(const storage_type& storage) : storage_type(storage) {} + + // Copy anything of type `storage_type` + KERNEL_FLOAT_INLINE + vector(const value_type& input = {}) : + storage_type(detail::broadcast_impl, E>::call(input)) {} + + // For all other arguments, we convert it using `convert_storage` according to broadcast rules + template, T>, int> = 0> + KERNEL_FLOAT_INLINE vector(U&& input) : + storage_type(convert_storage(input, extent_type {})) {} + + template, T>, int> = 0> + KERNEL_FLOAT_INLINE explicit vector(U&& input) : + storage_type(convert_storage(input, extent_type {})) {} + + // List of `N` (where N >= 2), simply pass forward to the storage + template< + typename A, + typename B, + typename... Rest, + typename = enable_if_t> + KERNEL_FLOAT_INLINE vector(const A& a, const B& b, const Rest&... rest) : + storage_type {T(a), T(b), T(rest)...} {} + + /** + * Returns the number of elements in this vector. + */ + KERNEL_FLOAT_INLINE + static constexpr size_t size() { + return E::size; + } + + KERNEL_FLOAT_INLINE + storage_type& storage() { + return *this; + } + + KERNEL_FLOAT_INLINE + const storage_type& storage() const { + return *this; + } + + /** + * Returns a pointer to the underlying storage data. + */ + KERNEL_FLOAT_INLINE + T* data() { + return storage().data(); + } + + /** + * Returns a pointer to the underlying storage data. + */ + KERNEL_FLOAT_INLINE + const T* data() const { + return storage().data(); + } + + KERNEL_FLOAT_INLINE + const T* cdata() const { + return this->data(); + } + + /** + * Returns a reference to the item at index `i`. + */ + KERNEL_FLOAT_INLINE + T& at(size_t i) { + return *(this->data() + i); + } + + /** + * Returns a constant reference to the item at index `i`. + */ + KERNEL_FLOAT_INLINE + const T& at(size_t i) const { + return *(this->data() + i); + } + + /** + * Returns a reference to the item at index `i`. + */ + KERNEL_FLOAT_INLINE + T& operator[](size_t i) { + return at(i); + } + + /** + * Returns a constant reference to the item at index `i`. + */ + KERNEL_FLOAT_INLINE + const T& operator[](size_t i) const { + return at(i); + } + + KERNEL_FLOAT_INLINE + T& operator()(size_t i) { + return at(i); + } + + KERNEL_FLOAT_INLINE + const T& operator()(size_t i) const { + return at(i); + } + + /** + * Returns a pointer to the first element. + */ + KERNEL_FLOAT_INLINE + T* begin() { + return this->data(); + } + + /** + * Returns a pointer to the first element. + */ + KERNEL_FLOAT_INLINE + const T* begin() const { + return this->data(); + } + + /** + * Returns a pointer to the first element. + */ + KERNEL_FLOAT_INLINE + const T* cbegin() const { + return this->data(); + } + + /** + * Returns a pointer to one past the last element. + */ + KERNEL_FLOAT_INLINE + T* end() { + return this->data() + size(); + } + + /** + * Returns a pointer to one past the last element. + */ + KERNEL_FLOAT_INLINE + const T* end() const { + return this->data() + size(); + } + + /** + * Returns a pointer to one past the last element. + */ + KERNEL_FLOAT_INLINE + const T* cend() const { + return this->data() + size(); + } + + /** + * Copy the element at index `i`. + */ + KERNEL_FLOAT_INLINE + T get(size_t x) const { + return at(x); + } + + /** + * Set the element at index `i`. + */ + KERNEL_FLOAT_INLINE + void set(size_t x, T value) { + at(x) = std::move(value); + } + + /** + * Selects elements from the this vector based on the specified indices. + * + * Example + * ======= + * ``` + * vec input = {0, 10, 20, 30, 40, 50}; + * vec vec1 = select(input, 0, 4, 4, 2); // [0, 40, 40, 20] + * + * vec indices = {0, 4, 4, 2}; + * vec vec2 = select(input, indices); // [0, 40, 40, 20] + * ``` + */ + template + KERNEL_FLOAT_INLINE select_type select(const Is&... indices) { + return kernel_float::select(*this, indices...); + } + + /** + * Cast the elements of this vector to type `R` and returns a new vector. + */ + template + KERNEL_FLOAT_INLINE vector cast() const { + return kernel_float::cast(*this); + } + + /** + * Broadcast this vector into a new size `(Ns...)`. + */ + template + KERNEL_FLOAT_INLINE vector> broadcast(extent new_size = {}) const { + return kernel_float::broadcast(*this, new_size); + } + + /** + * Apply the given function `F` to each element of this vector and returns a new vector with the results. + */ + template + KERNEL_FLOAT_INLINE vector, E> map(F fun) const { + return kernel_float::map(fun, *this); + } + + /** + * Reduce the elements of the given vector input into a single value using the function `F`. + * + * This function should be a binary function that takes two elements and returns one element. The order in which + * the elements are reduced is not specified and depends on the reduction function and the vector type. + */ + template + KERNEL_FLOAT_INLINE T reduce(F fun) const { + return kernel_float::reduce(fun, *this); + } + + /** + * Flattens the elements of this vector. For example, this turns a `vec, 3>` into a `vec`. + */ + KERNEL_FLOAT_INLINE flatten_type flatten() const { + return kernel_float::flatten(*this); + } + + /** + * Apply the given function `F` to each element of this vector. + */ + template + KERNEL_FLOAT_INLINE void for_each(F fun) const { + return kernel_float::for_each(*this, std::move(fun)); + } +}; + +/** + * Convert the given `input` into a vector. This function can perform one of the following actions: + * + * - For vectors `vec`, it simply returns the original vector. + * - For primitive types `T` (e.g., `int`, `float`, `double`), it returns a `vec`. + * - For array-like types (e.g., `std::array`, `T[N]`), it returns `vec`. + * - For vector-like types (e.g., `int2`, `dim3`), it returns `vec`. + */ +template +KERNEL_FLOAT_INLINE into_vector_type into_vector(V&& input) { + return into_vector_impl::call(std::forward(input)); +} + +template +using scalar = vector>; + +template +using vec = vector>; + +// clang-format off +template using vec1 = vec; +template using vec2 = vec; +template using vec3 = vec; +template using vec4 = vec; +template using vec5 = vec; +template using vec6 = vec; +template using vec7 = vec; +template using vec8 = vec; +// clang-format on + +/** + * Create a vector from a variable number of input values. + * + * The resulting vector type is determined by promoting the types of the input values into a common type. + * The number of input values determines the dimension of the resulting vector. + * + * Example + * ======= + * ``` + * auto v1 = make_vec(1.0f, 2.0f, 3.0f); // Creates a vec [1.0f, 2.0f, 3.0f] + * auto v2 = make_vec(1, 2, 3, 4); // Creates a vec [1, 2, 3, 4] + * ``` + */ +template +KERNEL_FLOAT_INLINE vec, sizeof...(Args)> make_vec(Args&&... args) { + using T = promote_t; + return vector_storage {T(args)...}; +}; + +#if defined(__cpp_deduction_guides) +// Deduction guide for `vector` +template +vector(Args&&... args) -> vector, extent>; + +// Deduction guides for aliases are only supported from C++20 +#if __cpp_deduction_guides >= 201907L +template +vec(Args&&... args) -> vec, sizeof...(Args)>; +#endif +#endif + +} // namespace kernel_float + +#endif diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 72edd39..c31082f 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -1,10 +1,25 @@ +/* + * Kernel Float: Header-only library for vector types and reduced precision floating-point math. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2023-03-31 16:04:17.777000 -// git hash: 5a6b682ac483b61ec8a1697bf6adf4e929021574 +// date: 2023-09-21 10:00:11.122069 +// git hash: 227f987d3fc10499e680bb68f00e1c579afeda97 //================================================================================ - #ifndef KERNEL_FLOAT_MACROS_H #define KERNEL_FLOAT_MACROS_H @@ -13,20 +28,20 @@ #ifdef __CUDA_ARCH__ #define KERNEL_FLOAT_INLINE __forceinline__ __device__ -#define KERNEL_FLOAT_ON_DEVICE (1) -#define KERNEL_FLOAT_ON_HOST (0) +#define KERNEL_FLOAT_IS_DEVICE (1) +#define KERNEL_FLOAT_IS_HOST (0) #define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__) #else #define KERNEL_FLOAT_INLINE __forceinline__ __host__ -#define KERNEL_FLOAT_ON_DEVICE (0) -#define KERNEL_FLOAT_ON_HOST (1) +#define KERNEL_FLOAT_IS_DEVICE (0) +#define KERNEL_FLOAT_IS_HOST (1) #define KERNEL_FLOAT_CUDA_ARCH (0) #endif #else #define KERNEL_FLOAT_INLINE inline #define KERNEL_FLOAT_CUDA (0) -#define KERNEL_FLOAT_ON_HOST (1) -#define KERNEL_FLOAT_ON_DEVICE (0) +#define KERNEL_FLOAT_IS_HOST (1) +#define KERNEL_FLOAT_IS_DEVICE (0) #define KERNEL_FLOAT_CUDA_ARCH (0) #endif @@ -55,15 +70,6 @@ namespace kernel_float { -template -struct const_index { - static constexpr size_t value = I; - - KERNEL_FLOAT_INLINE constexpr operator size_t() const noexcept { - return I; - } -}; - template struct index_sequence { static constexpr size_t size = sizeof...(Is); @@ -71,13 +77,13 @@ struct index_sequence { namespace detail { template -struct make_index_sequence_helper {}; +struct make_index_sequence_impl {}; // Benchmarks show that it is much faster to predefine all possible index sequences instead of doing something // recursive with variadic templates. #define KERNEL_FLOAT_INDEX_SEQ(N, ...) \ template<> \ - struct make_index_sequence_helper { \ + struct make_index_sequence_impl { \ using type = index_sequence<__VA_ARGS__>; \ }; @@ -103,169 +109,207 @@ KERNEL_FLOAT_INDEX_SEQ(17, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, } // namespace detail template -using make_index_sequence = typename detail::make_index_sequence_helper::type; +using make_index_sequence = typename detail::make_index_sequence_impl::type; namespace detail { template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; template -struct decay_helper { +struct decay_impl { using type = T; }; } // namespace detail template -using decay_t = typename detail::decay_helper::type; +using decay_t = typename detail::decay_impl::type; + +template +struct promote_type; + +template +struct promote_type { + using type = T; +}; -template -struct common_type; +template +struct promote_type { + using type = T; +}; template -struct common_type { +struct promote_type { using type = T; }; -#define KERNEL_FLOAT_DEFINE_COMMON_TYPE(T, U) \ - template<> \ - struct common_type { \ - using type = T; \ - }; \ - template<> \ - struct common_type { \ - using type = T; \ +template<> +struct promote_type { + using type = void; +}; + +#define KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, U) \ + template<> \ + struct promote_type { \ + using type = T; \ + }; \ + template<> \ + struct promote_type { \ + using type = T; \ + }; + +// T + bool becomes T +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed short, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed int, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed long, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed long long, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned short, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned int, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned long, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned long long, bool) + +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, float) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(long double, float) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(long double, double) + +#define KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(T) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, char) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed char) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed short) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed int) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, signed long long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned char) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned short) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned int) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, unsigned long long) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(T, bool) + +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(float) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(double) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(long double) + +#define KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(T, U) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(signed T, signed U) \ + KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(unsigned T, unsigned U) + +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(short, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(int, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(int, short) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long, short) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long, int) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, char) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, short) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, int) +KERNEL_FLOAT_DEFINE_PROMOTED_INTEGRAL(long long, long) + +template +struct promote_type { + using type = T*; +}; + +#define KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(I) \ + template \ + struct promote_type { \ + using type = T*; \ + }; \ + template \ + struct promote_type { \ + using type = T*; \ }; -KERNEL_FLOAT_DEFINE_COMMON_TYPE(long double, double) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(long double, float) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, float) -//KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, half) -//KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, half) - -#define KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(T, U) \ - KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed T, signed U) \ - KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned T, unsigned U) - -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, long) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, int) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, short) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long long, char) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long, int) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long, short) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(long, char) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(int, short) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(int, char) -KERNEL_FLOAT_DEFINE_COMMON_INTEGRAL(short, char) - -KERNEL_FLOAT_DEFINE_COMMON_TYPE(long double, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, bool) - -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed long long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed int, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed short, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(signed char, bool) - -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned long long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned long, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned int, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned short, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(unsigned char, bool) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(char) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed char) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed short) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed int) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed long) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(signed long long) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned char) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned short) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned int) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned long) +KERNEL_FLOAT_DEFINE_PROMOTED_POINTER(unsigned long long) + +// half precision +// KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(half) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half, bool) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, half) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, half) +// KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(long double, half) namespace detail { template -struct common_type_helper; +struct multi_promote_type; template -struct common_type_helper { +struct multi_promote_type { using type = T; }; -template -struct common_type_helper { - using type = typename common_type::type; -}; +template +struct multi_promote_type: promote_type {}; + +template +struct multi_promote_type: + multi_promote_type::type, C, Rest...> {}; -template -struct common_type_helper: - common_type_helper::type, R, Rest...> {}; } // namespace detail template -using common_t = typename detail::common_type_helper...>::type; +using promote_t = typename detail::multi_promote_type...>::type; namespace detail { -template -struct common_size_helper; - -template<> -struct common_size_helper<> { - static constexpr size_t value = 1; -}; - -template -struct common_size_helper { - static constexpr size_t value = N; -}; -template -struct common_size_helper { - static constexpr size_t value = N; -}; - -template -struct common_size_helper { - static constexpr size_t value = N; -}; - -template -struct common_size_helper<1, N> { - static constexpr size_t value = N; +template +struct is_same_type_impl { + static constexpr bool value = false; }; -template<> -struct common_size_helper<1, 1> { - static constexpr size_t value = 1; +template +struct is_same_type_impl { + static constexpr bool value = true; }; } // namespace detail -template -static constexpr size_t common_size = detail::common_size_helper::value; +template +static constexpr bool is_same_type = detail::is_same_type_impl::value; namespace detail { - template -struct is_implicit_convertible_helper { +struct is_implicit_convertible_impl { static constexpr bool value = false; }; template -struct is_implicit_convertible_helper::type> { +struct is_implicit_convertible_impl::type> { static constexpr bool value = true; }; } // namespace detail template static constexpr bool is_implicit_convertible = - detail::is_implicit_convertible_helper, decay_t>::value; + detail::is_implicit_convertible_impl, decay_t>::value; namespace detail { template @@ -280,434 +324,288 @@ using result_t = decltype((detail::declval())(detail::declval()...)); namespace detail { template -struct enabled_helper {}; +struct enable_if_impl {}; template -struct enabled_helper { +struct enable_if_impl { using type = T; }; } // namespace detail template -using enabled_t = typename detail::enabled_helper::type; +using enable_if_t = typename detail::enable_if_impl::type; } // namespace kernel_float #endif -#ifndef KERNEL_FLOAT_STORAGE -#define KERNEL_FLOAT_STORAGE +#ifndef KERNEL_FLOAT_BASE_H +#define KERNEL_FLOAT_BASE_H -namespace kernel_float { -template -struct vector_traits { - using value_type = V; - static constexpr size_t size = 1; +namespace kernel_float { +template +struct alignas(Alignment) aligned_array { KERNEL_FLOAT_INLINE - static V fill(value_type value) { - return value; + T* data() { + return items_; } KERNEL_FLOAT_INLINE - static V create(value_type value) { - return value; + const T* data() const { + return items_; } + T items_[N] = {}; +}; + +template +struct aligned_array { KERNEL_FLOAT_INLINE - static value_type get(const V& self, size_t index) { - KERNEL_FLOAT_ASSERT(index == 0); - return self; + T* data() { + while (true) + ; } KERNEL_FLOAT_INLINE - static void set(V& self, size_t index, value_type value) { - KERNEL_FLOAT_ASSERT(index == 0); - self = value; + const T* data() const { + while (true) + ; } }; -template -struct into_storage_traits { - using type = V; +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T value = {}) : x(value) {} KERNEL_FLOAT_INLINE - static constexpr type call(V self) { - return self; + operator T() const { + return x; } -}; - -template -struct into_storage_traits: into_storage_traits {}; - -template -struct into_storage_traits: into_storage_traits {}; - -template -struct into_storage_traits: into_storage_traits {}; - -template -using into_storage_type = typename into_storage_traits::type; - -template -KERNEL_FLOAT_INLINE into_storage_type into_storage(V&& input) { - return into_storage_traits::call(input); -} - -template -static constexpr size_t vector_size = vector_traits>::size; - -template -using vector_value_type = typename vector_traits>::value_type; - -template -struct vector_index { - using value_type = vector_value_type; KERNEL_FLOAT_INLINE - static value_type get(const V& self) { - return vector_traits::get(self, I); + T* data() { + return &x; } KERNEL_FLOAT_INLINE - static void set(V& self, value_type value) { - return vector_traits::set(self, I, value); + const T* data() const { + return &x; } + + T x; }; -template -KERNEL_FLOAT_INLINE vector_value_type vector_get(const V& self, size_t index) { - return vector_traits::get(self, index); -} +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T x, T y) : x(x), y(y) {} -template -KERNEL_FLOAT_INLINE vector_value_type vector_get(const V& self, const_index = {}) { - return vector_index::get(self); -} + KERNEL_FLOAT_INLINE + aligned_array() : aligned_array(T {}, T {}) {} -template -struct vector_swizzle; + KERNEL_FLOAT_INLINE + T* data() { + return items; + } -template -struct vector_swizzle> { - KERNEL_FLOAT_INLINE static Output call(const Input& storage) { - return vector_traits::create(vector_get(storage)...); + KERNEL_FLOAT_INLINE + const T* data() const { + return items; } + + union { + T items[2]; + struct { + T x; + T y; + }; + }; }; -template -struct vector; +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T x, T y, T z) : x(x), y(y), z(z) {} -template -struct alignas(alignment) array { - T items_[N]; + KERNEL_FLOAT_INLINE + aligned_array() : aligned_array(T {}, T {}, T {}) {} KERNEL_FLOAT_INLINE - T& operator[](size_t i) { - KERNEL_FLOAT_ASSERT(i < N); - return items_[i]; + T* data() { + return items; } KERNEL_FLOAT_INLINE - const T& operator[](size_t i) const { - KERNEL_FLOAT_ASSERT(i < N); - return items_[i]; + const T* data() const { + return items; } -}; -template -struct vector_traits> { - using self_type = array; - using value_type = T; - static constexpr size_t size = N; + union { + T items[3]; + struct { + T x; + T y; + T z; + }; + }; +}; - template - KERNEL_FLOAT_INLINE static self_type create(Args&&... args) { - return {args...}; - } +template +struct alignas(Alignment) aligned_array { + KERNEL_FLOAT_INLINE + aligned_array(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {} KERNEL_FLOAT_INLINE - static self_type fill(value_type value) { - self_type result; - for (size_t i = 0; i < N; i++) { - result[i] = value; - } - return result; - } + aligned_array() : aligned_array(T {}, T {}, T {}, T {}) {} KERNEL_FLOAT_INLINE - static value_type get(const self_type& self, size_t index) { - KERNEL_FLOAT_ASSERT(index < N); - return self[index]; + T* data() { + return items; } KERNEL_FLOAT_INLINE - static void set(self_type& self, size_t index, value_type value) { - KERNEL_FLOAT_ASSERT(index < N); - self[index] = value; - } + const T* data() const { + return items; + } + + union { + T items[4]; + struct { + T x; + T y; + T z; + T w; + }; + }; }; -template -struct array {}; +KERNEL_FLOAT_INLINE +static constexpr size_t compute_max_alignment(size_t total_size, size_t min_align) { + if (total_size % 32 == 0 || min_align >= 32) { + return 32; + } else if (total_size % 16 == 0 || min_align == 16) { + return 16; + } else if (total_size % 8 == 0 || min_align == 8) { + return 8; + } else if (total_size % 4 == 0 || min_align == 4) { + return 4; + } else if (total_size % 2 == 0 || min_align == 2) { + return 2; + } else { + return 1; + } +} + +template +using vector_storage = aligned_array; + +template +struct extent; + +template +struct extent { + static constexpr size_t value = N; + static constexpr size_t size = N; +}; -template -struct vector_traits> { - using self_type = array; +template +struct into_vector_impl { using value_type = T; - static constexpr size_t size = 0; + using extent_type = extent<1>; KERNEL_FLOAT_INLINE - static self_type create() { - return {}; + static vector_storage call(const T& input) { + return vector_storage {input}; } +}; - KERNEL_FLOAT_INLINE - static self_type fill(value_type value) { - return {}; - } +template +struct into_vector_impl { + using value_type = T; + using extent_type = extent; KERNEL_FLOAT_INLINE - static value_type get(const self_type& self, size_t index) { - KERNEL_FLOAT_UNREACHABLE; + static vector_storage call(const T (&input)[N]) { + return call(input, make_index_sequence()); } - KERNEL_FLOAT_INLINE - static void set(self_type& self, size_t index, value_type value) { - KERNEL_FLOAT_UNREACHABLE; + private: + template + KERNEL_FLOAT_INLINE static vector_storage + call(const T (&input)[N], index_sequence) { + return {input[Is]...}; } }; -enum struct Alignment { - Minimum, - Packed, - Maximum, -}; - -constexpr size_t calculate_alignment(Alignment required, size_t min_alignment, size_t total_size) { - size_t alignment = 1; +template +struct into_vector_impl: into_vector_impl {}; - if (required == Alignment::Maximum) { - if (total_size <= 1) { - alignment = 1; - } else if (total_size <= 2) { - alignment = 2; - } else if (total_size <= 4) { - alignment = 4; - } else if (total_size <= 8) { - alignment = 8; - } else { - alignment = 16; - } - } else if (required == Alignment::Packed) { - if (total_size % 16 == 0) { - alignment = 16; - } else if (total_size % 8 == 0) { - alignment = 8; - } else if (total_size % 4 == 0) { - alignment = 4; - } else if (total_size % 2 == 0) { - alignment = 2; - } else { - alignment = 1; - } - } +template +struct into_vector_impl: into_vector_impl {}; - if (min_alignment > alignment) { - alignment = min_alignment; - } +template +struct into_vector_impl: into_vector_impl {}; - return alignment; -} +template +struct into_vector_impl: into_vector_impl {}; -template -struct default_storage { - using type = array; -}; +template +struct into_vector_impl> { + using value_type = T; + using extent_type = extent; -template -struct default_storage { - using type = T; + KERNEL_FLOAT_INLINE + static vector_storage call(const aligned_array& input) { + return input; + } }; -template -using default_storage_type = typename default_storage::type; - #define KERNEL_FLOAT_DEFINE_VECTOR_TYPE(T, T1, T2, T3, T4) \ template<> \ - struct vector_traits { \ + struct into_vector_impl<::T1> { \ using value_type = T; \ - static constexpr size_t size = 1; \ + using extent_type = extent<1>; \ \ KERNEL_FLOAT_INLINE \ - static T1 create(T x) { \ - return {x}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T1 fill(T v) { \ - return {v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T1& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static void set(T1& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ + static vector_storage call(::T1 v) { \ + return {v.x}; \ } \ }; \ \ template<> \ - struct vector_traits { \ + struct into_vector_impl<::T2> { \ using value_type = T; \ - static constexpr size_t size = 2; \ - \ - KERNEL_FLOAT_INLINE \ - static T2 create(T x, T y) { \ - return {x, y}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T2 fill(T v) { \ - return {v, v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T2& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - case 1: \ - return self.y; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ + using extent_type = extent<2>; \ \ KERNEL_FLOAT_INLINE \ - static void set(T2& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - case 1: \ - self.y = value; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ + static vector_storage call(::T2 v) { \ + return {v.x, v.y}; \ } \ }; \ \ template<> \ - struct vector_traits { \ + struct into_vector_impl<::T3> { \ using value_type = T; \ - static constexpr size_t size = 3; \ - \ - KERNEL_FLOAT_INLINE \ - static T3 create(T x, T y, T z) { \ - return {x, y, z}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T3 fill(T v) { \ - return {v, v, v}; \ - } \ + using extent_type = extent<3>; \ \ KERNEL_FLOAT_INLINE \ - static T get(const T3& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - case 1: \ - return self.y; \ - case 2: \ - return self.z; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static void set(T3& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - return; \ - case 1: \ - self.y = value; \ - return; \ - case 2: \ - self.z = value; \ - return; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ + static vector_storage call(::T3 v) { \ + return {v.x, v.y, v.z}; \ } \ }; \ \ template<> \ - struct vector_traits { \ + struct into_vector_impl<::T4> { \ using value_type = T; \ - static constexpr size_t size = 4; \ - \ - KERNEL_FLOAT_INLINE \ - static T4 create(T x, T y, T z, T w) { \ - return {x, y, z, w}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T4 fill(T v) { \ - return {v, v, v, v}; \ - } \ - \ - KERNEL_FLOAT_INLINE \ - static T get(const T4& self, size_t index) { \ - switch (index) { \ - case 0: \ - return self.x; \ - case 1: \ - return self.y; \ - case 2: \ - return self.z; \ - case 3: \ - return self.w; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ - } \ + using extent_type = extent<4>; \ \ KERNEL_FLOAT_INLINE \ - static void set(T4& self, size_t index, T value) { \ - switch (index) { \ - case 0: \ - self.x = value; \ - return; \ - case 1: \ - self.y = value; \ - return; \ - case 2: \ - self.z = value; \ - return; \ - case 3: \ - self.w = value; \ - return; \ - default: \ - KERNEL_FLOAT_UNREACHABLE; \ - } \ + static vector_storage call(::T4 v) { \ + return {v.x, v.y, v.z, v.w}; \ } \ }; @@ -726,572 +624,372 @@ KERNEL_FLOAT_DEFINE_VECTOR_TYPE(unsigned long long, ulonglong1, ulonglong2, ulon KERNEL_FLOAT_DEFINE_VECTOR_TYPE(float, float1, float2, float3, float4) KERNEL_FLOAT_DEFINE_VECTOR_TYPE(double, double1, double2, double3, double4) -template -struct nested_array { - static constexpr size_t num_packets = (N + vector_size - 1) / vector_size; - static_assert(num_packets * vector_size >= N, "internal error"); - - V packets[num_packets]; +template> +struct vector; - KERNEL_FLOAT_INLINE - V& operator[](size_t i) { - KERNEL_FLOAT_ASSERT(i < num_packets); - return packets[i]; - } +template +struct into_vector_impl> { + using value_type = T; + using extent_type = E; KERNEL_FLOAT_INLINE - const V& operator[](size_t i) const { - KERNEL_FLOAT_ASSERT(i < num_packets); - return packets[i]; + static vector_storage call(const vector& input) { + return input.storage(); } }; -template -struct vector_traits> { - using self_type = nested_array; - using value_type = vector_value_type; - static constexpr size_t size = N; - - template - KERNEL_FLOAT_INLINE static self_type create(Args&&... args) { - value_type items[N] = {args...}; - self_type output; +template +struct vector_traits; - size_t i = 0; - for (; i + vector_size - 1 < N; i += vector_size) { - // How to generalize this? - output.packets[i / vector_size] = vector_traits::create(items[i], items[i + 1]); - } +template +struct vector_traits> { + using value_type = T; + using extent_type = E; + using storage_type = S; + using vector_type = vector; +}; - for (; i < N; i++) { - vector_traits::set(output.packets[i / vector_size], i % vector_size, items[i]); - } +template +using vector_value_type = typename into_vector_impl::value_type; - return output; - } +template +using vector_extent_type = typename into_vector_impl::extent_type; - KERNEL_FLOAT_INLINE - static self_type fill(value_type value) { - self_type output; +template +static constexpr size_t vector_extent = vector_extent_type::value; - for (size_t i = 0; i < self_type::num_packets; i++) { - output.packets[i] = vector_traits::fill(value); - } +template +using into_vector_type = vector, vector_extent_type>; - return output; - } +template +using vector_storage_type = vector_storage, vector_extent>; - KERNEL_FLOAT_INLINE - static value_type get(const self_type& self, size_t index) { - KERNEL_FLOAT_ASSERT(index < N); - return vector_traits::get(self.packets[index / vector_size], index % vector_size); - } +template +using promoted_vector_value_type = promote_t...>; - KERNEL_FLOAT_INLINE - static void set(self_type& self, size_t index, value_type value) { - KERNEL_FLOAT_ASSERT(index < N); - vector_traits::set(self.packets[index / vector_size], index % vector_size, value); - } -}; +template +KERNEL_FLOAT_INLINE vector_storage_type into_vector_storage(V&& input) { + return into_vector_impl::call(std::forward(input)); +} -}; // namespace kernel_float +} // namespace kernel_float #endif -#ifndef KERNEL_FLOAT_CAST_H -#define KERNEL_FLOAT_CAST_H +#ifndef KERNEL_FLOAT_COMPLEX_TYPE_H +#define KERNEL_FLOAT_COMPLEX_TYPE_H + namespace kernel_float { -namespace ops { -template -struct cast { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return R(input); - } -}; template -struct cast { - KERNEL_FLOAT_INLINE T operator()(T input) noexcept { - return input; - } +struct alignas(2 * alignof(T)) complex_type_storage { + T re; + T im; }; -} // namespace ops -namespace detail { +template +struct complex_type: complex_type_storage { + using base_type = complex_type_storage; -// Cast a vector of type `Input` to type `Output`. Vectors must have the same size. -// The input vector has value type `T` -// The output vector has value type `R` -template< - typename Input, - typename Output, - typename T = vector_value_type, - typename R = vector_value_type> -struct cast_helper { - static_assert(vector_size == vector_size, "sizes must match"); - static constexpr size_t N = vector_size; + template + KERNEL_FLOAT_INLINE complex_type(complex_type that) : base_type(that.real(), that.imag()) {} - KERNEL_FLOAT_INLINE static Output call(const Input& input) { - return call(input, make_index_sequence {}); - } + KERNEL_FLOAT_INLINE + complex_type(T real = {}, T imag = {}) : base_type(real, imag) {} - private: - template - KERNEL_FLOAT_INLINE static Output call(const Input& input, index_sequence) { - ops::cast fun; - return vector_traits::create(fun(vector_get(input))...); + KERNEL_FLOAT_INLINE + T real() const { + return this->re; } -}; -// Cast a vector of type `Input` to type `Output`. -// The input vector has value type `T` and size `N`. -// The output vector has value type `R` and size `M`. -template< - typename Input, - typename Output, - typename T = vector_value_type, - size_t N = vector_size, - typename R = vector_value_type, - size_t M = vector_size> -struct broadcast_helper; - -// T[1] => T[1] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Vector call(Vector input) { - return input; + KERNEL_FLOAT_INLINE + T imag() const { + return this->im; } -}; -// T[N] => T[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Vector call(Vector input) { - return input; + KERNEL_FLOAT_INLINE + T norm() const { + return real() * real() + imag() * imag(); } -}; -// T[1] => T[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::fill(vector_get<0>(input)); + KERNEL_FLOAT_INLINE + complex_type conj() const { + return {real(), -imag()}; } }; -// T[1] => T[1], but different vector types -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::create(vector_get<0>(input)); - } -}; +template +KERNEL_FLOAT_INLINE complex_type operator+(complex_type v) { + return v; +} -// T[N] => T[N], but different vector types -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return cast_helper::call(input); - } -}; +template +KERNEL_FLOAT_INLINE complex_type operator+(complex_type a, complex_type b) { + return {a.real() + b.real(), a.imag() + b.imag()}; +} -// T[1] => R[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::fill(ops::cast {}(vector_get<0>(input))); - } -}; +template +KERNEL_FLOAT_INLINE complex_type operator+(T a, complex_type b) { + return {a + b.real(), b.imag()}; +} -// T[1] => R[1] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return vector_traits::create(ops::cast {}(vector_get<0>(input))); - } -}; +template +KERNEL_FLOAT_INLINE complex_type operator+(complex_type a, T b) { + return {a.real() + b, a.imag()}; +} -// T[N] => R[N] -template -struct broadcast_helper { - KERNEL_FLOAT_INLINE static Output call(Input input) { - return cast_helper::call(input); - } -}; -} // namespace detail +template +KERNEL_FLOAT_INLINE complex_type& operator+=(complex_type& a, complex_type b) { + return (a = a + b); +} -/** - * Cast the elements of the given vector ``input`` to the given type ``R`` and then widen the - * vector to length ``N``. The cast may lead to a loss in precision if ``R`` is a smaller data - * type. Widening is only possible if the input vector has size ``1`` or ``N``, other sizes - * will lead to a compilation error. - * - * Example - * ======= - * ``` - * vec x = {6}; - * vec y = broadcast(x); - * vec z = broadcast(y); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector broadcast(Input&& input) { - return detail::broadcast_helper, Output>::call( - into_storage(std::forward(input))); +template +KERNEL_FLOAT_INLINE complex_type& operator+=(complex_type& a, T b) { + return (a = a + b); } -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template< - size_t N, - typename Input, - typename Output = default_storage_type, N>> -KERNEL_FLOAT_INLINE vector broadcast(Input&& input) { - return detail::broadcast_helper, Output>::call( - into_storage(std::forward(input))); +template +KERNEL_FLOAT_INLINE complex_type operator-(complex_type v) { + return {-v.real(), -v.imag()}; } -template -KERNEL_FLOAT_INLINE vector broadcast(Input&& input) { - return detail::broadcast_helper, Output>::call( - into_storage(std::forward(input))); +template +KERNEL_FLOAT_INLINE complex_type operator-(complex_type a, complex_type b) { + return {a.real() - b.real(), a.imag() - b.imag()}; } -#endif -/** - * Widen the given vector ``input`` to length ``N``. Widening is only possible if the input vector - * has size ``1`` or ``N``, other sizes will lead to a compilation error. - * - * Example - * ======= - * ``` - * vec x = {6}; - * vec y = resize<3>(x); - * ``` - */ -template< - size_t N, - typename Input, - typename Output = default_storage_type, N>> -KERNEL_FLOAT_INLINE vector resize(Input&& input) noexcept { - return detail::broadcast_helper::call(std::forward(input)); +template +KERNEL_FLOAT_INLINE complex_type operator-(T a, complex_type b) { + return {a - b.real(), -b.imag()}; } -template -using cast_type = default_storage_type>; +template +KERNEL_FLOAT_INLINE complex_type operator-(complex_type a, T b) { + return {a.real() - b, a.imag()}; +} -/** - * Cast the elements of given vector ``input`` to the given type ``R``. Note that this cast may - * lead to a loss in precision if ``R`` is a smaller data type. - * - * Example - * ======= - * ``` - * vec x = {1.0f, 2.0f, 3.0f}; - * vec y = cast(x); - * vec z = cast(x); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector cast(Input&& input) noexcept { - return detail::broadcast_helper::call(std::forward(input)); +template +KERNEL_FLOAT_INLINE complex_type& operator-=(complex_type& a, complex_type b) { + return (a = a - b); } -} // namespace kernel_float -#endif //KERNEL_FLOAT_CAST_H -#ifndef KERNEL_FLOAT_SWIZZLE_H -#define KERNEL_FLOAT_SWIZZLE_H +template +KERNEL_FLOAT_INLINE complex_type& operator-=(complex_type& a, T b) { + return (a = a - b); +} +template +KERNEL_FLOAT_INLINE complex_type operator*(complex_type a, complex_type b) { + return {a.real() * b.real() - a.imag() * b.imag(), a.real() * b.imag() + a.imag() * b.real()}; +} +template +KERNEL_FLOAT_INLINE complex_type operator*(complex_type a, T b) { + return {a.real() * b, a.imag() * b}; +} -namespace kernel_float { +template +KERNEL_FLOAT_INLINE complex_type* operator*=(complex_type& a, complex_type b) { + return (a = a * b); +} -/** - * "Swizzles" the vector. Returns a new vector where the elements are provided by the given indices. - * - * # Example - * ``` - * vec x = {0, 1, 2, 3, 4, 5, 6}; - * vec a = swizzle<0, 1, 2>(x); // 0, 1, 2 - * vec b = swizzle<2, 1, 0>(x); // 2, 1, 0 - * vec c = swizzle<1, 1, 1>(x); // 1, 1, 1 - * vec d = swizzle<0, 2, 4, 6>(x); // 0, 2, 4, 6 - * ``` - */ -template< - size_t... Is, - typename V, - typename Output = default_storage_type, sizeof...(Is)>> -KERNEL_FLOAT_INLINE vector swizzle(const V& input, index_sequence _ = {}) { - return vector_swizzle, index_sequence>::call( - into_storage(input)); +template +KERNEL_FLOAT_INLINE complex_type& operator*=(complex_type& a, T b) { + return (a = a * b); } -/** - * Takes the first ``N`` elements from the given vector and returns a new vector of length ``N``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = first<3>(x); // 1, 2, 3 - * int z = first(x); // 1 - * ``` - */ -template, K>> -KERNEL_FLOAT_INLINE vector first(const V& input) { - static_assert(K <= vector_size, "K cannot exceed vector size"); - using Indices = make_index_sequence; - return vector_swizzle, Indices>::call(into_storage(input)); +template +KERNEL_FLOAT_INLINE complex_type operator*(T a, complex_type b) { + return {a * b.real(), a * b.imag()}; } -namespace detail { -template -struct offset_index_sequence_helper; +template +KERNEL_FLOAT_INLINE complex_type operator/(complex_type a, complex_type b) { + T normi = T(1) / b.norm(); -template -struct offset_index_sequence_helper> { - using type = index_sequence; -}; -} // namespace detail + return { + (a.real() * b.real() + a.imag() * b.imag()) * normi, + (a.imag() * b.real() - a.real() * b.imag()) * normi}; +} -/** - * Takes the last ``N`` elements from the given vector and returns a new vector of length ``N``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = last<3>(x); // 4, 5, 6 - * int z = last(x); // 6 - * ``` - */ -template, K>> -KERNEL_FLOAT_INLINE vector last(const V& input) { - static_assert(K <= vector_size, "K cannot exceed vector size"); - using Indices = typename detail::offset_index_sequence_helper< // - vector_size - K, - make_index_sequence>::type; +template +KERNEL_FLOAT_INLINE complex_type operator/(complex_type a, T b) { + return a * (T(1) / b); +} + +template +KERNEL_FLOAT_INLINE complex_type operator/(T a, complex_type b) { + T normi = T(1) / b.norm(); - return vector_swizzle, Indices>::call(into_storage(input)); + return {a * b.real() * normi, -a * b.imag() * normi}; } -namespace detail { -template -struct reverse_index_sequence_helper: reverse_index_sequence_helper {}; +template +KERNEL_FLOAT_INLINE complex_type* operator/=(complex_type& a, complex_type b) { + return (a = a / b); +} -template -struct reverse_index_sequence_helper<0, Is...> { - using type = index_sequence; -}; -} // namespace detail +template +KERNEL_FLOAT_INLINE complex_type& operator/=(complex_type& a, T b) { + return (a = a / b); +} -/** - * Reverses the elements in the given vector. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = reversed(x); // 6, 5, 4, 3, 2, 1 - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector reversed(const V& input) { - using Indices = typename detail::reverse_index_sequence_helper>::type; +template +KERNEL_FLOAT_INLINE T real(complex_type v) { + return v.real(); +} - return vector_swizzle, Indices>::call(into_storage(input)); +template +KERNEL_FLOAT_INLINE T imag(complex_type v) { + return v.imag(); } -namespace detail { -template -struct concat_index_sequence_helper {}; +template +KERNEL_FLOAT_INLINE T abs(complex_type v) { + return hypot(v.real(), v.imag()); +} -template -struct concat_index_sequence_helper, index_sequence> { - using type = index_sequence; -}; -} // namespace detail +template +KERNEL_FLOAT_INLINE T arg(complex_type v) { + return atan2(v.imag(), v.real()); +} -/** - * Rotate the given vector ``K`` steps to the right. In other words, this move the front element to the back - * ``K`` times. This is the inverse of ``rotate_left``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = rotate_right<2>(x); // 5, 6, 1, 2, 3, 4 - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector rotate_right(const V& input) { - static constexpr size_t N = vector_size; - static constexpr size_t I = (N > 0) ? (K % N) : 0; +template +KERNEL_FLOAT_INLINE complex_type sqrt(complex_type v) { + T radius = abs(v); + T cosA = v.real() / radius; - using First = - typename detail::offset_index_sequence_helper>::type; - using Second = make_index_sequence; - using Indices = typename detail::concat_index_sequence_helper::type; + complex_type out = { + sqrt(radius * (cosA + T(1)) * T(.5)), + sqrt(radius * (T(1) - cosA) * T(.5))}; - return vector_swizzle, Indices>::call(into_storage(input)); -} + // signbit should be false if x.y is negative + if (v.imag() < 0) { + out = complex_type {out.real, -out.im}; + } -/** - * Rotate the given vector ``K`` steps to the left. In other words, this move the back element to the front - * ``K`` times. This is the inverse of ``rotate_right``. - * - * # Example - * ``` - * vec x = {1, 2, 3, 4, 5, 6}; - * vec y = rotate_left<4>(x); // 5, 6, 1, 2, 3, 4 - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector rotate_left(const V& input) { - static constexpr size_t N = vector_size; - static constexpr size_t K_rev = N > 0 ? (N - K % N) : 0; + return out; +} - return rotate_right(input); +template +KERNEL_FLOAT_INLINE complex_type norm(complex_type v) { + return v.real() * v.real() + v.imag() * v.imag(); } -namespace detail { -template< - typename U, - typename V, - typename Is = make_index_sequence>, - typename Js = make_index_sequence>> -struct concat_helper; +template +KERNEL_FLOAT_INLINE complex_type conj(complex_type v) { + return {v.real(), -v.imag()}; +} -template -struct concat_helper, index_sequence> { - using type = default_storage_type< - common_t, vector_value_type>, - vector_size + vector_size>; +template +KERNEL_FLOAT_INLINE complex_type exp(complex_type v) { + // TODO: Handle nan and inf correctly + T e = exp(v.real()); + T a = v.imag(); + return complex_type(e * cos(a), e * sin(a)); +} - KERNEL_FLOAT_INLINE static type call(const U& left, const V& right) { - return vector_traits::create(vector_get(left)..., vector_get(right)...); - } -}; +template +KERNEL_FLOAT_INLINE complex_type log(complex_type v) { + return {log(abs(v)), arg(v)}; +} -template -struct recur_concat_helper; +template +KERNEL_FLOAT_INLINE complex_type pow(complex_type a, T b) { + return exp(a * log(b)); +} -template -struct recur_concat_helper { - using type = U; +template +KERNEL_FLOAT_INLINE complex_type pow(complex_type a, complex_type b) { + return exp(a * log(b)); +} - KERNEL_FLOAT_INLINE static U call(U&& input) { - return input; - } +template +struct promote_type, complex_type> { + using type = complex_type>; }; -template -struct recur_concat_helper { - using recur_helper = recur_concat_helper::type, Rest...>; - using type = typename recur_helper::type; - - KERNEL_FLOAT_INLINE static type call(const U& left, const V& right, const Rest&... rest) { - return recur_helper::call(concat_helper::call(left, right), rest...); - } +template +struct promote_type, R> { + using type = complex_type>; }; -} // namespace detail - -template -using concat_type = typename detail::recur_concat_helper...>::type; -/** - * Concatenate the given vectors into one large vector. For example, given vectors of size 3, size 2 and size 5, - * this function returns a new vector of size 3+2+5=8. If the vectors are not of the same element type, they - * will first be cast into a common data type. - * - * # Examples - * ``` - * vec x = {1, 2, 3}; - * int y = 4; - * vec z = {5, 6, 7, 8}; - * vec xyz = concat(x, y, z); // 1, 2, 3, 4, 5, 6, 7, 8 - * ``` - */ -template -KERNEL_FLOAT_INLINE vector> concat(const Vs&... inputs) { - return detail::recur_concat_helper...>::call(into_storage(inputs)...); -} +template +struct promote_type> { + using type = complex_type>; +}; } // namespace kernel_float -#endif //KERNEL_FLOAT_SWIZZLE_H +#endif #ifndef KERNEL_FLOAT_UNOPS_H #define KERNEL_FLOAT_UNOPS_H - namespace kernel_float { namespace detail { -template -struct map_helper { - KERNEL_FLOAT_INLINE static Output call(F fun, const Input& input) { - return call(fun, input, make_index_sequence> {}); - } - - private: - template - KERNEL_FLOAT_INLINE static Output call(F fun, const Input& input, index_sequence) { - return vector_traits::create(fun(vector_get(input))...); - } -}; -template -struct map_helper, nested_array> { - KERNEL_FLOAT_INLINE static nested_array call(F fun, const nested_array& input) { - return call(fun, input, make_index_sequence::num_packets> {}); - } - - private: - template - KERNEL_FLOAT_INLINE static nested_array - call(F fun, const nested_array& input, index_sequence) { - return {map_helper::call(fun, input[Is])...}; +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + result[i] = fun(inputs[i]...); + } } }; } // namespace detail -template -using map_type = default_storage_type>, vector_size>; +template +using map_type = vector>, vector_extent_type>; /** - * Applies ``fun`` to each element from vector ``input`` and returns a new vector with the results. - * This function is the basis for all unary operators like ``sin`` and ``sqrt``. + * Apply the function `F` to each element from the vector `input` and return the results as a new vector. * - * Example - * ======= + * Examples + * ======== * ``` - * vector v = {1, 2, 3}; - * vector w = map([](auto i) { return i * 2; }); // 2, 4, 6 + * vec input = {1.0f, 2.0f, 3.0f, 4.0f}; + * vec squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f] * ``` */ -template> -KERNEL_FLOAT_INLINE Output map(F fun, const Input& input) { - return detail::map_helper>::call(fun, into_storage(input)); +template +KERNEL_FLOAT_INLINE map_type map(F fun, const V& input) { + using Input = vector_value_type; + using Output = result_t; + vector_storage> result; + + detail::apply_impl, Output, Input>::call( + fun, + result.data(), + into_vector_storage(input).data()); + + return result; } -#define KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ - namespace ops { \ - template \ - struct NAME { \ - KERNEL_FLOAT_INLINE T operator()(T input) { \ - return T(EXPR); \ - } \ - }; \ - } \ - template \ - KERNEL_FLOAT_INLINE vector> NAME(const V& input) { \ - return map>, V, into_storage_type>({}, input); \ +#define KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ + namespace ops { \ + template \ + struct NAME { \ + KERNEL_FLOAT_INLINE T operator()(T input) { \ + return T(EXPR); \ + } \ + }; \ + } \ + template \ + KERNEL_FLOAT_INLINE vector, vector_extent_type> NAME(const V& input) { \ + using F = ops::NAME>; \ + return map(F {}, input); \ } -#define KERNEL_FLOAT_DEFINE_UNARY_OP(NAME, OP, EXPR) \ - KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ - template \ - KERNEL_FLOAT_INLINE vector operator OP(const vector& vec) { \ - return NAME(vec); \ +#define KERNEL_FLOAT_DEFINE_UNARY_OP(NAME, OP, EXPR) \ + KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \ + template \ + KERNEL_FLOAT_INLINE vector operator OP(const vector& vec) { \ + return NAME(vec); \ } KERNEL_FLOAT_DEFINE_UNARY_OP(negate, -, -input) @@ -1348,222 +1046,603 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN(signbit) KERNEL_FLOAT_DEFINE_UNARY_FUN(isinf) KERNEL_FLOAT_DEFINE_UNARY_FUN(isnan) +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_DEFINE_UNARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_UNARY(FUN_NAME, ops::OP_NAME {}(input)) \ + namespace ops { \ + template<> \ + struct OP_NAME { \ + KERNEL_FLOAT_INLINE float operator()(float input) { \ + return FLOAT_FUN(input); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_DEFINE_UNARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_UNARY(FUN_NAME, ops::OP_NAME {}(input)) +#endif + +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_exp, exp, __expf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_log, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_cos, cos, __cosf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_sin, sin, __sinf) +KERNEL_FLOAT_DEFINE_UNARY_FAST(fast_tan, tan, __tanf) + } // namespace kernel_float #endif //KERNEL_FLOAT_UNOPS_H -#ifndef KERNEL_FLOAT_BINOPS_H -#define KERNEL_FLOAT_BINOPS_H +#ifndef KERNEL_FLOAT_CAST_H +#define KERNEL_FLOAT_CAST_H + namespace kernel_float { -namespace detail { -template -struct zip_helper { - KERNEL_FLOAT_INLINE static Output call(F fun, const Left& left, const Right& right) { - return call_with_indices(fun, left, right, make_index_sequence> {}); - } - private: - template - KERNEL_FLOAT_INLINE static Output - call_with_indices(F fun, const Left& left, const Right& right, index_sequence = {}) { - return vector_traits::create(fun(vector_get(left), vector_get(right))...); +enum struct RoundingMode { ANY, DOWN, UP, NEAREST, TOWARD_ZERO }; + +namespace ops { +template +struct cast; + +template +struct cast { + KERNEL_FLOAT_INLINE R operator()(T input) noexcept { + return R(input); } }; -template -struct zip_helper, nested_array, nested_array> { - KERNEL_FLOAT_INLINE static nested_array - call(F fun, const nested_array& left, const nested_array& right) { - return call(fun, left, right, make_index_sequence::num_packets> {}); +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(T input) noexcept { + return input; } +}; - private: - template - KERNEL_FLOAT_INLINE static nested_array call( - F fun, - const nested_array& left, - const nested_array& right, - index_sequence) { - return {zip_helper::call(fun, left[Is], right[Is])...}; +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(T input) noexcept { + return input; } }; -}; // namespace detail - -template -using common_vector_value_type = common_t...>; - -template -static constexpr size_t common_vector_size = common_size...>; - -template -using zip_type = default_storage_type< - result_t, vector_value_type>, - common_vector_size>; - -/** - * Applies ``fun`` to each pair of two elements from ``left`` and ``right`` and returns a new - * vector with the results. - * - * If ``left`` and ``right`` are not the same size, they will first be broadcast into a - * common size using ``resize``. - * - * Note that this function does **not** cast the input vectors to a common element type. See - * ``zip_common`` for that functionality. - */ -template> -KERNEL_FLOAT_INLINE vector zip(F fun, Left&& left, Right&& right) { - static constexpr size_t N = vector_size; - using LeftInput = default_storage_type, N>; - using RightInput = default_storage_type, N>; - - return detail::zip_helper::call( - fun, - broadcast(std::forward(left)), - broadcast(std::forward(right))); -} - -template -using zip_common_type = default_storage_type< - result_t, common_vector_value_type>, - common_vector_size>; +} // namespace ops /** - * Applies ``fun`` to each pair of two elements from ``left`` and ``right`` and returns a new - * vector with the results. + * Cast the elements of the given vector `input` to a different type `R`. * - * If ``left`` and ``right`` are not the same size, they will first be broadcast into a - * common size using ``resize``. + * This function casts each element of the input vector to a different data type specified by + * template parameter `R`. * - * If ``left`` and ``right`` are not of the same type, they will first be case into a common - * data type. For example, zipping ``float`` and ``double`` first cast vectors to ``double``. + * Optionally, the rounding mode can be set using the `Mode` template parameter. The default mode is `ANY`, which + * uses the fastest rounding mode available. * * Example * ======= * ``` - * vec x = {1, 2, 3, 4}; - * vec = {8}; - * vec = zip_common([](auto a, auto b){ return a + b; }, x, y); // [9, 10, 11, 12] + * vec input {1.2f, 2.7f, 3.5f, 4.9f}; + * auto casted = cast(input); // [1, 2, 3, 4] * ``` */ -template< - typename F, - typename Left, - typename Right, - typename Output = zip_common_type> -KERNEL_FLOAT_INLINE vector zip_common(F fun, Left&& left, Right&& right) { - static constexpr size_t N = vector_size; - using C = common_t, vector_value_type>; - using Input = default_storage_type; - - return detail::zip_helper::call( - fun, - broadcast(std::forward(left)), - broadcast(std::forward(right))); -} - -#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \ - namespace ops { \ - template \ - struct NAME { \ - KERNEL_FLOAT_INLINE T operator()(T left, T right) { \ - return T(EXPR); \ - } \ - }; \ - } \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ - } - -#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \ - KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> operator OP( \ - const vector& left, \ - const vector& right) { \ - return zip_common(ops::NAME {}, left, right); \ - } \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> operator OP( \ - const vector& left, \ - const R& right) { \ - return zip_common(ops::NAME {}, left, right); \ - } \ - template> \ - KERNEL_FLOAT_INLINE vector, L, R>> operator OP( \ - const L& left, \ - const vector& right) { \ - return zip_common(ops::NAME {}, left, right); \ - } +template +KERNEL_FLOAT_INLINE vector> cast(const V& input) { + using F = ops::cast, R, Mode>; + return map(F {}, input); +} -KERNEL_FLOAT_DEFINE_BINARY_OP(add, +) -KERNEL_FLOAT_DEFINE_BINARY_OP(subtract, -) -KERNEL_FLOAT_DEFINE_BINARY_OP(divide, /) -KERNEL_FLOAT_DEFINE_BINARY_OP(multiply, *) -KERNEL_FLOAT_DEFINE_BINARY_OP(modulo, %) +namespace detail { -KERNEL_FLOAT_DEFINE_BINARY_OP(equal_to, ==) -KERNEL_FLOAT_DEFINE_BINARY_OP(not_equal_to, !=) -KERNEL_FLOAT_DEFINE_BINARY_OP(less, <) -KERNEL_FLOAT_DEFINE_BINARY_OP(less_equal, <=) -KERNEL_FLOAT_DEFINE_BINARY_OP(greater, >) -KERNEL_FLOAT_DEFINE_BINARY_OP(greater_equal, >=) +template +struct broadcast_extent_helper; -KERNEL_FLOAT_DEFINE_BINARY_OP(bit_and, &) -KERNEL_FLOAT_DEFINE_BINARY_OP(bit_or, |) -KERNEL_FLOAT_DEFINE_BINARY_OP(bit_xor, ^) +template +struct broadcast_extent_helper { + using type = E; +}; -// clang-format off -template typename F, typename L, typename R> -static constexpr bool vector_assign_allowed = - common_vector_size == vector_size && - is_implicit_convertible< - result_t< - F, vector_value_type>>, - vector_value_type, - vector_value_type - >, - vector_value_type - >; -// clang-format on +template +struct broadcast_extent_helper, extent> { + using type = extent; +}; -#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \ - template< \ - typename L, \ - typename R, \ - typename T = enabled_t, vector_value_type>> \ - KERNEL_FLOAT_INLINE vector& operator OP(vector& lhs, const R& rhs) { \ - using F = ops::NAME; \ - lhs = zip_common(F {}, lhs.storage(), rhs); \ - return lhs; \ - } +template +struct broadcast_extent_helper, extent> { + using type = extent; +}; -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(add, +=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(subtract, -=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(divide, /=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(multiply, *=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(modulo, %=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_and, &=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_or, |=) -KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_xor, ^=) +template +struct broadcast_extent_helper, extent<1>> { + using type = extent; +}; -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) KERNEL_FLOAT_DEFINE_BINARY(NAME, ::NAME(left, right)) +template<> +struct broadcast_extent_helper, extent<1>> { + using type = extent<1>; +}; -KERNEL_FLOAT_DEFINE_BINARY_FUN(min) -KERNEL_FLOAT_DEFINE_BINARY_FUN(max) -KERNEL_FLOAT_DEFINE_BINARY_FUN(copysign) -KERNEL_FLOAT_DEFINE_BINARY_FUN(hypot) -KERNEL_FLOAT_DEFINE_BINARY_FUN(modf) +template +struct broadcast_extent_helper: + broadcast_extent_helper::type, C, Rest...> {}; + +} // namespace detail + +template +using broadcast_extent = typename detail::broadcast_extent_helper::type; + +template +using broadcast_vector_extent_type = broadcast_extent...>; + +template +static constexpr bool is_broadcastable = is_same_type, To>; + +template +static constexpr bool is_vector_broadcastable = is_broadcastable, To>; + +namespace detail { + +template +struct broadcast_impl; + +template +struct broadcast_impl, extent> { + KERNEL_FLOAT_INLINE static vector_storage call(const vector_storage& input) { + vector_storage output; + for (size_t i = 0; i < N; i++) { + output.data()[i] = input.data()[0]; + } + return output; + } +}; + +template +struct broadcast_impl, extent> { + KERNEL_FLOAT_INLINE static vector_storage call(vector_storage input) { + return input; + } +}; + +template +struct broadcast_impl, extent<1>> { + KERNEL_FLOAT_INLINE static vector_storage call(vector_storage input) { + return input; + } +}; + +} // namespace detail + +/** + * Takes the given vector `input` and extends its size to a length of `N`. This is only valid if the size of `input` + * is 1 or `N`. + * + * Example + * ======= + * ``` + * vec a = {1.0f}; + * vec x = broadcast<5>(a); // Returns [1.0f, 1.0f, 1.0f, 1.0f, 1.0f] + * + * vec b = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + * vec y = broadcast<5>(b); // Returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector, extent> +broadcast(const V& input, extent new_size = {}) { + using T = vector_value_type; + return detail::broadcast_impl, extent>::call( + into_vector_storage(input)); +} + +/** + * Takes the given vector `input` and extends its size to the same length as vector `other`. This is only valid if the + * size of `input` is 1 or the same as `other`. + */ +template +KERNEL_FLOAT_INLINE vector, vector_extent_type> +broadcast_like(const V& input, const R& other) { + return broadcast(input, vector_extent_type {}); +} + +namespace detail { +/** + * Convert vector of element type `T` and extent type `E` to vector of element type `T2` and extent type `E2`. + * Specialization exist for the cases where `T==T2` and/or `E==E2`. + */ +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + using F = ops::cast; + vector_storage intermediate; + detail::apply_impl::call(F {}, intermediate.data(), input.data()); + return detail::broadcast_impl::call(intermediate); + } +}; + +// T == T2, E == E2 +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + return input; + } +}; + +// T == T2, E != E2 +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + return detail::broadcast_impl::call(input); + } +}; + +// T != T2, E == E2 +template +struct convert_impl { + KERNEL_FLOAT_INLINE + static vector_storage call(vector_storage input) { + using F = ops::cast; + + vector_storage result; + detail::apply_impl::call(F {}, result.data(), input.data()); + return result; + } +}; +} // namespace detail + +template +KERNEL_FLOAT_INLINE vector_storage convert_storage(const V& input, extent new_size = {}) { + return detail::convert_impl, vector_extent_type, R, extent, M>::call( + into_vector_storage(input)); +} + +/** + * Cast the values of the given input vector to type `R` and then broadcast the result to the given size `N`. + * + * Example + * ======= + * ``` + * int a = 5; + * vec x = convert(a); // returns [5.0f, 5.0f, 5.0f] + * + * float b = 5.0f; + * vec x = convert(b); // returns [5.0f, 5.0f, 5.0f] + * + * vec c = {1, 2, 3}; + * vec x = convert(c); // returns [1.0f, 2.0f, 3.0f] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> convert(const V& input, extent new_size = {}) { + return convert_storage(input); +} + +/** + * Returns a vector containing `N` copies of `value`. + * + * Example + * ======= + * ``` + * vec a = fill<3>(42); // return [42, 42, 42] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> fill(T value = {}, extent = {}) { + vector_storage input = {value}; + return detail::broadcast_impl, extent>::call(input); +} + +/** + * Returns a vector containing `N` copies of `T(0)`. + * + * Example + * ======= + * ``` + * vec a = zeros(); // return [0, 0, 0] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> zeros(extent = {}) { + vector_storage input = {T {}}; + return detail::broadcast_impl, extent>::call(input); +} + +/** + * Returns a vector containing `N` copies of `T(1)`. + * + * Example + * ======= + * ``` + * vec a = ones(); // return [1, 1, 1] + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> ones(extent = {}) { + vector_storage input = {T {1}}; + return detail::broadcast_impl, extent>::call(input); +} + +/** + * Returns a vector filled with `value` having the same type and size as input vector `V`. + * + * Example + * ======= + * ``` + * vec a = {1, 2, 3}; + * vec b = fill_like(a, 42); // return [42, 42, 42] + * ``` + */ +template, typename E = vector_extent_type> +KERNEL_FLOAT_INLINE vector fill_like(const V&, T value) { + return fill(value, E {}); +} + +/** + * Returns a vector filled with zeros having the same type and size as input vector `V`. + * + * Example + * ======= + * ``` + * vec a = {1, 2, 3}; + * vec b = zeros_like(a); // return [0, 0, 0] + * ``` + */ +template, typename E = vector_extent_type> +KERNEL_FLOAT_INLINE vector zeros_like(const V& = {}) { + return zeros(E {}); +} + +/** + * Returns a vector filled with ones having the same type and size as input vector `V`. + * + * Example + * ======= + * ``` + * vec a = {1, 2, 3}; + * vec b = ones_like(a); // return [1, 1, 1] + * ``` + */ +template, typename E = vector_extent_type> +KERNEL_FLOAT_INLINE vector ones_like(const V& = {}) { + return ones(E {}); +} + +} // namespace kernel_float + +#endif +#ifndef KERNEL_FLOAT_BINOPS_H +#define KERNEL_FLOAT_BINOPS_H + + + + +namespace kernel_float { + +template +using zip_type = vector< + result_t, vector_value_type>, + broadcast_vector_extent_type>; + +/** + * Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary + * function (`fun`) to each pair of corresponding elements. + * + * Example + * ======= + * ``` + * vec make_negative = {true, false, true}; + * vec input = {1, 2, 3}; + * vec output = zip([](bool b, int n){ return b ? -n : +n; }, make_negative, input); // returns [-1, 2, -3] + * ``` + */ +template +KERNEL_FLOAT_INLINE zip_type zip(F fun, const L& left, const R& right) { + using A = vector_value_type; + using B = vector_value_type; + using O = result_t; + using E = broadcast_vector_extent_type; + vector_storage result; + + detail::apply_impl::call( + fun, + result.data(), + detail::broadcast_impl, E>::call(into_vector_storage(left)).data(), + detail::broadcast_impl, E>::call(into_vector_storage(right)) + .data()); + + return result; +} + +template +using zip_common_type = vector< + result_t, promoted_vector_value_type>, + broadcast_vector_extent_type>; + +/** + * Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary + * function (`fun`) to each pair of corresponding elements. The elements are promoted to a common type before applying + * the binary function. + * + * Example + * ======= + * ``` + * vec a = {1.0f, 2.0f, 3.0f}; + * vec b = {4, 5, 6}; + * vec c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f] + * ``` + */ +template +KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, const R& right) { + using T = promoted_vector_value_type; + using O = result_t; + using E = broadcast_vector_extent_type; + + vector_storage result; + + detail::apply_impl::call( + fun, + result.data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(left)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(right)) + .data()); + + return result; +} + +#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR) \ + namespace ops { \ + template \ + struct NAME { \ + KERNEL_FLOAT_INLINE T operator()(T left, T right) { \ + return T(EXPR); \ + } \ + }; \ + } \ + template> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ + } + +#define KERNEL_FLOAT_DEFINE_BINARY_OP(NAME, OP) \ + KERNEL_FLOAT_DEFINE_BINARY(NAME, left OP right) \ + template, typename E1, typename E2> \ + KERNEL_FLOAT_INLINE zip_common_type, vector, vector> operator OP( \ + const vector& left, \ + const vector& right) { \ + return zip_common(ops::NAME {}, left, right); \ + } \ + template>, typename E> \ + KERNEL_FLOAT_INLINE zip_common_type, vector, R> operator OP( \ + const vector& left, \ + const R& right) { \ + return zip_common(ops::NAME {}, left, right); \ + } \ + template, R>, typename E> \ + KERNEL_FLOAT_INLINE zip_common_type, L, vector> operator OP( \ + const L& left, \ + const vector& right) { \ + return zip_common(ops::NAME {}, left, right); \ + } + +KERNEL_FLOAT_DEFINE_BINARY_OP(add, +) +KERNEL_FLOAT_DEFINE_BINARY_OP(subtract, -) +KERNEL_FLOAT_DEFINE_BINARY_OP(divide, /) +KERNEL_FLOAT_DEFINE_BINARY_OP(multiply, *) +KERNEL_FLOAT_DEFINE_BINARY_OP(modulo, %) + +KERNEL_FLOAT_DEFINE_BINARY_OP(equal_to, ==) +KERNEL_FLOAT_DEFINE_BINARY_OP(not_equal_to, !=) +KERNEL_FLOAT_DEFINE_BINARY_OP(less, <) +KERNEL_FLOAT_DEFINE_BINARY_OP(less_equal, <=) +KERNEL_FLOAT_DEFINE_BINARY_OP(greater, >) +KERNEL_FLOAT_DEFINE_BINARY_OP(greater_equal, >=) + +KERNEL_FLOAT_DEFINE_BINARY_OP(bit_and, &) +KERNEL_FLOAT_DEFINE_BINARY_OP(bit_or, |) +KERNEL_FLOAT_DEFINE_BINARY_OP(bit_xor, ^) + +// clang-format off +template typename F, typename T, typename E, typename R> +static constexpr bool is_vector_assign_allowed = + is_vector_broadcastable && + is_implicit_convertible< + result_t< + F>>, + T, + vector_value_type + >, + T + >; +// clang-format on + +#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \ + template< \ + typename T, \ + typename E, \ + typename R, \ + typename = enable_if_t>> \ + KERNEL_FLOAT_INLINE vector& operator OP(vector& lhs, const R& rhs) { \ + using F = ops::NAME; \ + lhs = zip_common(F {}, lhs, rhs); \ + return lhs; \ + } + +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(add, +=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(subtract, -=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(divide, /=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(multiply, *=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(modulo, %=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_and, &=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_or, |=) +KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(bit_xor, ^=) + +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) KERNEL_FLOAT_DEFINE_BINARY(NAME, ::NAME(left, right)) + +KERNEL_FLOAT_DEFINE_BINARY_FUN(min) +KERNEL_FLOAT_DEFINE_BINARY_FUN(max) +KERNEL_FLOAT_DEFINE_BINARY_FUN(copysign) +KERNEL_FLOAT_DEFINE_BINARY_FUN(modf) KERNEL_FLOAT_DEFINE_BINARY_FUN(nextafter) KERNEL_FLOAT_DEFINE_BINARY_FUN(pow) KERNEL_FLOAT_DEFINE_BINARY_FUN(remainder) -#if KERNEL_FLOAT_CUDA_DEVICE -KERNEL_FLOAT_DEFINE_BINARY_FUN(rhypot) +KERNEL_FLOAT_DEFINE_BINARY(hypot, (ops::sqrt()(left * left + right * right))) +KERNEL_FLOAT_DEFINE_BINARY(rhypot, (T(1) / ops::hypot()(left, right))) + +namespace ops { +template<> +struct hypot { + KERNEL_FLOAT_INLINE double operator()(double left, double right) { + return ::hypot(left, right); + }; +}; + +template<> +struct hypot { + KERNEL_FLOAT_INLINE float operator()(float left, float right) { + return ::hypotf(left, right); + }; +}; + +// rhypot is only support on the GPU +#if KERNEL_FLOAT_IS_DEVICE +template<> +struct rhypot { + KERNEL_FLOAT_INLINE double operator()(double left, double right) { + return ::rhypot(left, right); + }; +}; + +template<> +struct rhypot { + KERNEL_FLOAT_INLINE float operator()(float left, float right) { + return ::rhypotf(left, right); + }; +}; +#endif +}; // namespace ops + +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_BINARY(FUN_NAME, ops::OP_NAME {}(left, right)) \ + namespace ops { \ + template<> \ + struct OP_NAME { \ + KERNEL_FLOAT_INLINE float operator()(float left, float right) { \ + return FLOAT_FUN(left, right); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \ + KERNEL_FLOAT_DEFINE_BINARY(FUN_NAME, ops::OP_NAME {}(left, right)) #endif +KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_div, divide, __fdividef) +KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_pow, pow, __powf) + namespace ops { template<> struct add { @@ -1622,9 +1701,158 @@ struct bit_xor { }; }; // namespace ops +namespace detail { +template +struct cross_impl { + KERNEL_FLOAT_INLINE + static vector> + call(const vector_storage& av, const vector_storage& bv) { + auto a = av.data(); + auto b = bv.data(); + vector> v0 = {a[1], a[2], a[0], a[2], a[0], a[1]}; + vector> v1 = {b[2], b[0], b[1], b[1], b[2], b[0]}; + vector> rv = v0 * v1; + + auto r = rv.data(); + vector> r0 = {r[0], r[1], r[2]}; + vector> r1 = {r[3], r[4], r[5]}; + return r0 - r1; + } +}; +}; // namespace detail + +/** + * Calculates the cross-product between two vectors of length 3. + */ +template< + typename L, + typename R, + typename T = promoted_vector_value_type, + typename = + enable_if_t> && is_vector_broadcastable>>> +KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { + return detail::cross_impl::call(convert_storage(left), convert_storage(right)); +} + } // namespace kernel_float -#endif //KERNEL_FLOAT_BINOPS_H +#endif +#ifndef KERNEL_FLOAT_CONSTANT +#define KERNEL_FLOAT_CONSTANT + + + + +namespace kernel_float { + +template +struct constant { + template + KERNEL_FLOAT_INLINE explicit constexpr constant(const constant& that) { + auto f = ops::cast(); + value_ = f(that.get()); + } + + KERNEL_FLOAT_INLINE + constexpr constant(T value = {}) : value_(value) {} + + KERNEL_FLOAT_INLINE + constexpr T get() const { + return value_; + } + + KERNEL_FLOAT_INLINE + constexpr operator T() const { + return value_; + } + + private: + T value_; +}; + +// Deduction guide for `constant` +#if defined(__cpp_deduction_guides) +template +constant(T&&) -> constant>; +#endif + +template +KERNEL_FLOAT_INLINE constexpr constant make_constant(T value) { + return value; +} + +template +struct promote_type, constant> { + using type = constant::type>; +}; + +template +struct promote_type, R> { + using type = typename promote_type::type; +}; + +template +struct promote_type> { + using type = typename promote_type::type; +}; + +namespace ops { +template +struct cast, R> { + KERNEL_FLOAT_INLINE R operator()(const T& input) noexcept { + return cast {}(input); + } +}; + +template +struct cast, R, m> { + KERNEL_FLOAT_INLINE R operator()(const T& input) noexcept { + return cast {}(input); + } +}; +} // namespace ops + +#define KERNEL_FLOAT_CONSTANT_DEFINE_OP(OP) \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const constant& left, const R& right) { \ + auto f = ops::cast>(); \ + return f(left.get()) OP right; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const L& left, const constant& right) { \ + auto f = ops::cast>(); \ + return left OP f(right.get()); \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const constant& left, const vector& right) { \ + auto f = ops::cast(); \ + return f(left.get()) OP right; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE auto operator OP(const vector& left, const constant& right) { \ + auto f = ops::cast(); \ + return left OP f(right.get()); \ + } \ + \ + template> \ + KERNEL_FLOAT_INLINE constant operator OP( \ + const constant& left, \ + const constant& right) { \ + return constant(left.get()) OP constant(right.get()); \ + } + +KERNEL_FLOAT_CONSTANT_DEFINE_OP(+) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(-) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(*) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(/) +KERNEL_FLOAT_CONSTANT_DEFINE_OP(%) + +} // namespace kernel_float + +#endif #ifndef KERNEL_FLOAT_ITERATE_H #define KERNEL_FLOAT_ITERATE_H @@ -1633,173 +1861,569 @@ struct bit_xor { namespace kernel_float { +/** + * Apply the function fun for each element from input. + * + * Example + * ======= + * ``` + * for_each(range(), [&](auto i) { + * printf("element: %d\n", i); + * }); + * ``` + */ +template +void for_each(V&& input, F fun) { + auto storage = into_vector_storage(input); + +#pragma unroll + for (size_t i = 0; i < vector_extent; i++) { + fun(storage.data()[i]); + } +} + namespace detail { -template>> -struct range_helper; +template +struct range_impl { + KERNEL_FLOAT_INLINE + static vector_storage call() { + vector_storage result; + +#pragma unroll + for (size_t i = 0; i < N; i++) { + result.data()[i] = T(i); + } -template -struct range_helper> { - KERNEL_FLOAT_INLINE static V call(F fun) { - return vector_traits::create(fun(const_index {})...); + return result; } }; } // namespace detail /** - * Generate vector of length ``N`` by applying the given function ``fun`` to - * each index ``0...N-1``. + * Generate vector consisting of the numbers `0...N-1` of type `T` * * Example * ======= * ``` - * // returns [0, 2, 4] - * vector vec = range<3>([](auto i) { return float(i * 2); }); + * // Returns [0, 1, 2] + * vec vec = range(); * ``` */ -template< - size_t N, - typename F, - typename T = result_t, - typename Output = default_storage_type> -KERNEL_FLOAT_INLINE vector range(F fun) { - return detail::range_helper::call(fun); +template +KERNEL_FLOAT_INLINE vector> range() { + return detail::range_impl::call(); } /** - * Generate vector consisting of the numbers ``0...N-1`` of type ``T``. + * Takes a vector `vec` and returns a new vector consisting of the numbers ``0...N-1`` of type ``T`` * * Example * ======= * ``` - * // Returns [0, 1, 2] - * vector vec = range(); + * auto input = vec(5.0f, 10.0f, -1.0f); + * auto indices = range_like(input); // returns [0.0f, 1.0f, 2.0f] + * ``` + */ +template +KERNEL_FLOAT_INLINE into_vector_type range_like(const V& = {}) { + return detail::range_impl, vector_extent>::call(); +} + +/** + * Takes a vector of size ``N`` and returns a new vector consisting of the numbers ``0...N-1``. The data type used + * for the indices is given by the first template argument, which is `size_t` by default. This function is useful when + * needing to iterate over the indices of a vector. + * + * Example + * ======= + * ``` + * // Returns [0, 1, 2] of type size_t + * vec a = each_index(float3(6, 4, 2)); + * + * // Returns [0, 1, 2] of type int. + * vec b = each_index(float3(6, 4, 2)); + * + * vec input = {1.0f, 2.0f, 3.0f, 4.0f}; + * for (auto index: each_index(input)) { + * printf("%d] %f\n", index, input[index]); + * } + * ``` + */ +template +KERNEL_FLOAT_INLINE vector> each_index(const V& = {}) { + return detail::range_impl>::call(); +} + +namespace detail { +template, size_t N = vector_extent> +struct flatten_impl { + using value_type = typename flatten_impl::value_type; + static constexpr size_t size = N * flatten_impl::size; + + template + KERNEL_FLOAT_INLINE static void call(U* output, const V& input) { + vector_storage storage = into_vector_storage(input); + +#pragma unroll + for (size_t i = 0; i < N; i++) { + flatten_impl::call(output + flatten_impl::size * i, storage.data()[i]); + } + } +}; + +template +struct flatten_impl { + using value_type = T; + static constexpr size_t size = 1; + + KERNEL_FLOAT_INLINE + static void call(T* output, const T& input) { + *output = input; + } + + template + KERNEL_FLOAT_INLINE static void call(U* output, const T& input) { + *output = ops::cast {}(input); + } +}; +} // namespace detail + +template +using flatten_value_type = typename detail::flatten_impl::value_type; + +template +static constexpr size_t flatten_size = detail::flatten_impl::size; + +template +using flatten_type = vector, extent>>; + +/** + * Flattens the elements of this vector. For example, this turns a `vec, 3>` into a `vec`. + * + * Example + * ======= + * ``` + * vec input = {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}; + * vec result = flatten(input); // returns [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f] * ``` */ -template> -KERNEL_FLOAT_INLINE vector range() { - using F = ops::cast; - return detail::range_helper::call(F {}); +template +KERNEL_FLOAT_INLINE flatten_type flatten(const V& input) { + vector_storage, flatten_size> output; + detail::flatten_impl::call(output.data(), input); + return output; } -/** - * Generate vector having same size and type as ``V``, but filled with the numbers ``0..N-1``. - */ -template> -KERNEL_FLOAT_INLINE vector range_like(const Input&) { - using F = ops::cast>; - return detail::range_helper::call(F {}); -} +namespace detail { +template> +struct concat_base_impl { + static constexpr size_t size = vector_extent; + + KERNEL_FLOAT_INLINE static void call(U* output, const V& input) { + vector_storage storage = into_vector_storage(input); + + for (size_t i = 0; i < size; i++) { + output[i] = ops::cast {}(storage.data()[i]); + } + } +}; + +template +struct concat_base_impl { + static constexpr size_t size = 1; + + KERNEL_FLOAT_INLINE static void call(U* output, const T& input) { + *output = ops::cast {}(input); + } +}; + +template +struct concat_base_impl { + static constexpr size_t size = 1; + + KERNEL_FLOAT_INLINE static void call(T* output, const T& input) { + *output = input; + } +}; + +template +struct concat_impl {}; + +template +struct concat_impl { + using value_type = + typename promote_type, typename concat_impl::value_type>::type; + static constexpr size_t size = concat_base_impl::size + concat_impl::size; + + template + KERNEL_FLOAT_INLINE static void call(U* output, const V& input, const Vs&... rest) { + concat_base_impl::call(output, input); + concat_impl::call(output + concat_base_impl::size, rest...); + } +}; + +template<> +struct concat_impl<> { + using value_type = void; + static constexpr size_t size = 1; + + template + KERNEL_FLOAT_INLINE static void call(U* output) {} +}; +} // namespace detail + +template +using concat_value_type = promote_t::value_type>; + +template +static constexpr size_t concat_size = detail::concat_impl::size; + +template +using concat_type = vector, extent>>; + +/** + * Concatenates the provided input values into a single one-dimensional vector. + * + * This function works in three steps: + * - All input values are converted into vectors using the `into_vector` operation. + * - The resulting vectors' elements are then promoted into a shared value type. + * - The resultant vectors are finally concatenated together. + * + * For instance, when invoking this function with arguments of types `float, double2, double`: + * - After the first step: `vec, vec, vec` + * - After the second step: `vec, vec, vec` + * - After the third step: `vec` + * + * Example + * ======= + * ``` + * double vec1 = 1.0; + * double3 vec2 = {3.0, 4.0, 5.0); + * double4 vec3 = {6.0, 7.0, 8.0, 9.0}; + * vec concatenated = concat(vec1, vec2, vec3); // contains [1, 2, 3, 4, 5, 6, 7, 8, 9] + * + * int num1 = 42; + * float num2 = 3.14159; + * int2 num3 = {-10, 10}; + * vec concatenated = concat(num1, num2, num3); // contains [42, 3.14159, -10, 10] + * ``` + */ +template +KERNEL_FLOAT_INLINE concat_type concat(const Vs&... inputs) { + vector_storage, concat_size> output; + detail::concat_impl::call(output.data(), inputs...); + return output; +} + +template +using select_type = vector, extent>>; + +/** + * Selects elements from the this vector based on the specified indices. + * + * Example + * ======= + * ``` + * vec input = {0, 10, 20, 30, 40, 50}; + * vec vec1 = select(input, 0, 4, 4, 2); // [0, 40, 40, 20] + * + * vec indices = {0, 4, 4, 2}; + * vec vec2 = select(input, indices); // [0, 40, 40, 20] + * ``` + */ +template +KERNEL_FLOAT_INLINE select_type select(const V& input, const Is&... indices) { + using T = vector_value_type; + static constexpr size_t N = vector_extent; + static constexpr size_t M = concat_size; + + vector_storage index_set; + detail::concat_impl::call(index_set.data(), indices...); + + vector_storage inputs = into_vector_storage(input); + vector_storage outputs; + for (size_t i = 0; i < M; i++) { + size_t j = index_set.data()[i]; + + if (j < N) { + outputs.data()[i] = inputs.data()[j]; + } + } + + return outputs; +} + +} // namespace kernel_float + +#endif +#ifndef KERNEL_FLOAT_MEMORY_H +#define KERNEL_FLOAT_MEMORY_H + +/* + + + + +namespace kernel_float { + + namespace detail { + template > + struct load_helper; + + template + struct load_helper> { + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets + ) { + return {base[offsets.data()[Is]]...}; + } + + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets, + vector_storage mask + ) { + if (all(mask)) { + return call(base, offsets); + } else { + return { + (mask.data()[Is] ? base[offsets.data()[Is]] : T())... + }; + } + } + }; + } + + template < + typename T, + typename I, + typename M, + typename E = broadcast_vector_extent_type + > + KERNEL_FLOAT_INLINE + vector load(const T* ptr, const I& indices, const M& mask) { + static constexpr E new_size = {}; + + return detail::load_helper::call( + ptr, + convert_storage(indices, new_size), + convert_storage(mask, new_size) + ); + } + + template + KERNEL_FLOAT_INLINE + vector> load(const T* ptr, const I& indices) { + return detail::load_helper::value>::call( + ptr, + cast(indices) + ); + } + + template + KERNEL_FLOAT_INLINE + vector> load(const T* ptr, ptrdiff_t length) { + using index_type = vector_value_type; + return load_masked(ptr, range(), range() < length); + } + + template + KERNEL_FLOAT_INLINE + vector> load(const T* ptr) { + return load(ptr, range()); + } + + namespace detail { + template + struct store_helper { + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets, + vector_storage mask, + vector_storage values + ) { + for (size_t i = 0; i < N; i++) { + if (mask.data()[i]) { + base[offset.data()[i]] = values.data()[i]; + } + } + } + + KERNEL_FLOAT_INLINE + vector_storage call( + T* base, + vector_storage offsets, + vector_storage values + ) { + for (size_t i = 0; i < N; i++) { + base[offset.data()[i]] = values.data()[i]; + } + } + }; + } + + template < + typename T, + typename I, + typename M, + typename V, + typename E = broadcast_extent, broadcast_vector_extent_type>> + > + KERNEL_FLOAT_INLINE + void store(const T* ptr, const I& indices, const M& mask, const V& values) { + static constexpr E new_size = {}; + + return detail::store_helper::call( + ptr, + convert_storage(indices, new_size), + convert_storage(mask, new_size), + convert_storage(values, new_size) + ); + } + + template < + typename T, + typename I, + typename V, + typename E = broadcast_vector_extent_type + > + KERNEL_FLOAT_INLINE + void store(const T* ptr, const I& indices, const V& values) { + static constexpr E new_size = {}; + + return detail::store_helper::call( + ptr, + convert_storage(indices, new_size), + convert_storage(values, new_size) + ); + } + + + template < + typename T, + typename V + > + KERNEL_FLOAT_INLINE + void store(const T* ptr, const V& values) { + using E = vector_extent; + return store(ptr, range(), values); + } + + template + KERNEL_FLOAT_INLINE + void store(const T* ptr, const I& indices, const S& length, const V& values) { + using index_type = vector_value_type; + return store(ptr, indices, (indices >= I(0)) & (indices < length), values); + } + + + template + struct aligned_ptr_base { + static_assert(alignof(T) % alignment == 0, "invalid alignment, must be multiple of alignment of `T`"); + + KERNEL_FLOAT_INLINE + aligned_ptr_base(): ptr_(nullptr) {} + + KERNEL_FLOAT_INLINE + explicit aligned_ptr_base(T* ptr): ptr_(ptr) {} + + KERNEL_FLOAT_INLINE + T* get() const { + // TOOD: check if this way is support across all compilers +#if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned) + return __builtin_assume_aligned(ptr_, alignment); +#else + return ptr_; +#endif + } + + KERNEL_FLOAT_INLINE + operator T*() const { + return get(); + } + + KERNEL_FLOAT_INLINE + T& operator*() const { + return *get(); + } + + template + KERNEL_FLOAT_INLINE + T& operator[](I index) const { + return get()[index); + } + + private: + T* ptr_ = nullptr; + }; -/** - * Generate vector of `N` elements of type `T` - * - * Example - * ======= - * ``` - * // Returns [1.0, 1.0, 1.0] - * vector = fill(1.0f); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector fill(T value) { - return vector_traits::fill(value); -} + template + struct aligned_ptr; -/** - * Generate vector having same size and type as ``V``, but filled with the given ``value``. - */ -template -KERNEL_FLOAT_INLINE vector fill_like(const Output&, vector_value_type value) { - return vector_traits::fill(value); -} + template + struct aligned_ptr: aligned_ptr_base { + using base_type = aligned_ptr_base; -/** - * Generate vector of ``N`` zeros of type ``T`` - * - * Example - * ======= - * ``` - * // Returns [0.0, 0.0, 0.0] - * vector = zeros(); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector zeros() { - return vector_traits::fill(T(0)); -} + KERNEL_FLOAT_INLINE + aligned_ptr(): base_type(nullptr) {} -/** - * Generate vector having same size and type as ``V``, but filled with zeros. - * - */ -template -KERNEL_FLOAT_INLINE vector zeros_like(const Output& output = {}) { - return vector_traits::fill(0); -} + KERNEL_FLOAT_INLINE + explicit aligned_ptr(T* ptr): base_type(ptr) {} -/** - * Generate vector of ``N`` ones of type ``T`` - * - * Example - * ======= - * ``` - * // Returns [1.0, 1.0, 1.0] - * vector = ones(); - * ``` - */ -template> -KERNEL_FLOAT_INLINE vector ones() { - return vector_traits::fill(T(1)); -} + KERNEL_FLOAT_INLINE + aligned_ptr(aligned_ptr ptr): base_type(ptr.get()) {} + }; -/** - * Generate vector having same size and type as ``V``, but filled with ones. - * - */ -template -KERNEL_FLOAT_INLINE vector ones_like(const Output& output = {}) { - return vector_traits::fill(1); -} + template + struct aligned_ptr: aligned_ptr_base { + using base_type = aligned_ptr_base; -namespace detail { -template>> -struct iterate_helper; + KERNEL_FLOAT_INLINE + aligned_ptr(): base_type(nullptr) {} -template -struct iterate_helper> { + KERNEL_FLOAT_INLINE + explicit aligned_ptr(T* ptr): base_type(ptr) {} + + KERNEL_FLOAT_INLINE + explicit aligned_ptr(const T* ptr): base_type(ptr) {} + + KERNEL_FLOAT_INLINE + aligned_ptr(aligned_ptr ptr): base_type(ptr.get()) {} + + KERNEL_FLOAT_INLINE + aligned_ptr(aligned_ptr ptr): base_type(ptr.get()) {} + }; + + + template KERNEL_FLOAT_INLINE - static void call(F fun, const V& input) {} -}; + T* operator+(aligned_ptr ptr, ptrdiff_t index) { + return ptr.get() + index; + } -template -struct iterate_helper> { + template KERNEL_FLOAT_INLINE - static void call(F fun, const V& input) { - fun(vector_get(input)); - iterate_helper>::call(fun, input); + T* operator+(ptrdiff_t index, aligned_ptr ptr) { + return ptr.get() + index; } -}; -} // namespace detail -/** - * Apply the function ``fun`` for each element from ``input``. - * - * Example - * ======= - * ``` - * for_each(range<3>(), [&](auto i) { - * printf("element: %d\n", i); - * }); - * ``` - */ -template -KERNEL_FLOAT_INLINE void for_each(const V& input, F fun) { - detail::iterate_helper>::call(fun, into_storage(input)); -} + template + KERNEL_FLOAT_INLINE + ptrdiff_t operator-(aligned_ptr left, aligned_ptr right) { + return left.get() - right.get(); + } -} // namespace kernel_float + template + using unaligned_ptr = aligned_ptr; + +} +*/ -#endif //KERNEL_FLOAT_ITERATE_H +#endif //KERNEL_FLOAT_MEMORY_H #ifndef KERNEL_FLOAT_REDUCE_H #define KERNEL_FLOAT_REDUCE_H @@ -1807,29 +2431,21 @@ KERNEL_FLOAT_INLINE void for_each(const V& input, F fun) { namespace kernel_float { namespace detail { -template -struct reduce_helper { - using value_type = vector_value_type; - - KERNEL_FLOAT_INLINE static value_type call(F fun, const V& input) { - return call(fun, input, make_index_sequence> {}); +template +struct reduce_impl { + KERNEL_FLOAT_INLINE static T call(F fun, const T* input) { + return call(fun, input, make_index_sequence {}); } private: template - KERNEL_FLOAT_INLINE static value_type call(F fun, const V& vector, index_sequence<0, Is...>) { - return call(fun, vector, vector_get<0>(vector), index_sequence {}); - } - - template - KERNEL_FLOAT_INLINE static value_type - call(F fun, const V& vector, value_type accum, index_sequence) { - return call(fun, vector, fun(accum, vector_get(vector)), index_sequence {}); - } - - KERNEL_FLOAT_INLINE static value_type - call(F fun, const V& vector, value_type accum, index_sequence<>) { - return accum; + KERNEL_FLOAT_INLINE static T call(F fun, const T* input, index_sequence<0, Is...>) { + T result = input[0]; +#pragma unroll + for (size_t i = 1; i < N; i++) { + result = fun(result, input[i]); + } + return result; } }; } // namespace detail @@ -1838,7 +2454,7 @@ struct reduce_helper { * Reduce the elements of the given vector ``input`` into a single value using * the function ``fun``. This function should be a binary function that takes * two elements and returns one element. The order in which the elements - * are reduced is not specified and depends on the reduction function and + * are reduced is not specified and depends on both the reduction function and * the vector type. * * Example @@ -1850,7 +2466,9 @@ struct reduce_helper { */ template KERNEL_FLOAT_INLINE vector_value_type reduce(F fun, const V& input) { - return detail::reduce_helper>::call(fun, into_storage(input)); + return detail::reduce_impl, vector_value_type>::call( + fun, + into_vector_storage(input).data()); } /** @@ -1859,7 +2477,7 @@ KERNEL_FLOAT_INLINE vector_value_type reduce(F fun, const V& input) { * Example * ======= * ``` - * vec x = {5, 0, 2, 1, 0}; + * vec x = {5, 0, 2, 1, 0}; * int y = min(x); // Returns 0 * ``` */ @@ -1874,7 +2492,7 @@ KERNEL_FLOAT_INLINE T min(const V& input) { * Example * ======= * ``` - * vec x = {5, 0, 2, 1, 0}; + * vec x = {5, 0, 2, 1, 0}; * int y = max(x); // Returns 5 * ``` */ @@ -1889,7 +2507,7 @@ KERNEL_FLOAT_INLINE T max(const V& input) { * Example * ======= * ``` - * vec x = {5, 0, 2, 1, 0}; + * vec x = {5, 0, 2, 1, 0}; * int y = sum(x); // Returns 8 * ``` */ @@ -1918,7 +2536,7 @@ KERNEL_FLOAT_INLINE T product(const V& input) { * non-zero if ``bool(v)==true``. */ template -KERNEL_FLOAT_INLINE bool all(V&& input) { +KERNEL_FLOAT_INLINE bool all(const V& input) { return reduce(ops::bit_and {}, cast(input)); } @@ -1927,7 +2545,7 @@ KERNEL_FLOAT_INLINE bool all(V&& input) { * non-zero if ``bool(v)==true``. */ template -KERNEL_FLOAT_INLINE bool any(V&& input) { +KERNEL_FLOAT_INLINE bool any(const V& input) { return reduce(ops::bit_or {}, cast(input)); } @@ -1942,307 +2560,600 @@ KERNEL_FLOAT_INLINE bool any(V&& input) { * int y = count(x); // Returns 3 (5, 2, 1 are non-zero) * ``` */ -template -KERNEL_FLOAT_INLINE int count(V&& input) { - return sum(cast(cast(input))); +template +KERNEL_FLOAT_INLINE T count(const V& input) { + return sum(cast(cast(input))); } -} // namespace kernel_float -#endif //KERNEL_FLOAT_REDUCE_H -#ifndef KERNEL_FLOAT_INTERFACE_H -#define KERNEL_FLOAT_INTERFACE_H +namespace detail { +template +struct dot_impl { + KERNEL_FLOAT_INLINE + static T call(const T* left, const T* right) { + vector_storage intermediate; + detail::apply_impl, N, T, T, T>::call( + ops::multiply(), + intermediate.data(), + left, + right); + + return detail::reduce_impl, N, T>::call(ops::add(), intermediate.data()); + } +}; +} // namespace detail + +/** + * Compute the dot product of the given vectors ``left`` and ``right`` + * + * Example + * ======= + * ``` + * vec x = {1, 2, 3}; + * vec y = {4, 5, 6}; + * int y = dot(x, y); // Returns 1*4+2*5+3*6 = 32 + * ``` + */ +template> +KERNEL_FLOAT_INLINE T dot(const L& left, const R& right) { + using E = broadcast_vector_extent_type; + return detail::dot_impl::call( + convert_storage(left, E {}).data(), + convert_storage(right, E {}).data()); +} +namespace detail { +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return ops::sqrt {}(detail::dot_impl::call(input, input)); + } +}; +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return T {}; + } +}; +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return ops::abs {}(input[0]); + } +}; +template +struct magnitude_impl { + KERNEL_FLOAT_INLINE + static T call(const T* input) { + return ops::hypot()(input[0], input[1]); + } +}; +// The 3-argument overload of hypot is only available on host from C++17 +#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST +template<> +struct magnitude_impl { + static float call(const float* input) { + return ::hypot(input[0], input[1], input[2]); + } +}; +template<> +struct magnitude_impl { + static double call(const double* input) { + return ::hypot(input[0], input[1], input[2]); + } +}; +#endif +} // namespace detail -namespace kernel_float { +/** + * Compute the magnitude of the given input vector. This calculates the square root of the sum of squares, also + * known as the Euclidian norm, of a vector. + * + * Example + * ======= + * ``` + * vec x = {2, 3, 6}; + * float y = mag(x); // Returns sqrt(2*2 + 3*3 + 6*6) = 7 + * ``` + */ +template> +KERNEL_FLOAT_INLINE T mag(const V& input) { + return detail::magnitude_impl>::call(into_vector_storage(input).data()); +} +} // namespace kernel_float -template -KERNEL_FLOAT_INLINE vector broadcast(Input&& input); +#endif //KERNEL_FLOAT_REDUCE_H +#ifndef KERNEL_FLOAT_TRIOPS_H +#define KERNEL_FLOAT_TRIOPS_H -template -struct index_proxy { - using value_type = typename vector_traits::value_type; - KERNEL_FLOAT_INLINE - index_proxy(V& storage, I index) : storage_(storage), index_(index) {} - KERNEL_FLOAT_INLINE - index_proxy& operator=(value_type value) { - vector_traits::set(storage_, index_, value); - return *this; - } - KERNEL_FLOAT_INLINE - operator value_type() const { - return vector_traits::get(storage_, index_); - } +namespace kernel_float { - private: - V& storage_; - I index_; +namespace ops { +template +struct conditional { + KERNEL_FLOAT_INLINE T operator()(bool cond, T true_value, T false_value) { + if (cond) { + return true_value; + } else { + return false_value; + } + } }; +} // namespace ops -template -struct index_proxy> { - using value_type = typename vector_traits::value_type; +/** + * Return elements chosen from `true_values` and `false_values` depending on `cond`. + * + * This function broadcasts all arguments to the same size and then promotes the values of `true_values` and + * `false_values` into the same type. Next, it casts the values of `cond` to booleans and returns a vector where + * the values are taken from `true_values` where the condition is true and `false_values` otherwise. + * + * @param cond The condition used for selection. + * @param true_values The vector of values to choose from when the condition is true. + * @param false_values The vector of values to choose from when the condition is false. + * @return A vector containing selected elements as per the condition. + */ +template< + typename C, + typename L, + typename R, + typename T = promoted_vector_value_type, + typename E = broadcast_vector_extent_type> +KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values, const R& false_values) { + using F = ops::conditional; + vector_storage result; + + detail::apply_impl::call( + F {}, + result.data(), + detail::convert_impl, vector_extent_type, bool, E>::call( + into_vector_storage(cond)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(true_values)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(false_values)) + .data()); + + return result; +} - KERNEL_FLOAT_INLINE - index_proxy(V& storage, const_index) : storage_(storage) {} +/** + * Selects elements from `true_values` depending on `cond`. + * + * This function returns a vector where the values are taken from `true_values` where `cond` is `true` and `0` where + * `cond is `false`. + * + * @param cond The condition used for selection. + * @param true_values The vector of values to choose from when the condition is true. + * @return A vector containing selected elements as per the condition. + */ +template< + typename C, + typename L, + typename T = vector_value_type, + typename E = broadcast_vector_extent_type> +KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values) { + vector> false_values = T {}; + return where(cond, true_values, false_values); +} - KERNEL_FLOAT_INLINE - index_proxy& operator=(value_type value) { - vector_index::set(storage_, value); - return *this; +/** + * Returns a vector having the value `T(1)` where `cond` is `true` and `T(0)` where `cond` is `false`. + * + * @param cond The condition used for selection. + * @return A vector containing elements as per the condition. + */ +template> +KERNEL_FLOAT_INLINE vector where(const C& cond) { + return cast(cast(cond)); +} + +namespace ops { +template +struct fma { + KERNEL_FLOAT_INLINE T operator()(T a, T b, T c) { + return a * b + c; } +}; - KERNEL_FLOAT_INLINE - operator value_type() const { - return vector_index::get(storage_); +#if KERNEL_FLOAT_IS_DEVICE +template<> +struct fma { + KERNEL_FLOAT_INLINE float operator()(float a, float b, float c) { + return __fmaf_rn(a, b, c); } +}; - private: - V& storage_; +template<> +struct fma { + KERNEL_FLOAT_INLINE double operator()(double a, double b, double c) { + return __fma_rn(a, b, c); + } }; +#endif +} // namespace ops + +/** + * Computes the result of `a * b + c`. This is done in a single operation if possible. + */ +template< + typename A, + typename B, + typename C, + typename T = promoted_vector_value_type, + typename E = broadcast_vector_extent_type> +KERNEL_FLOAT_INLINE vector fma(const A& a, const B& b, const C& c) { + using F = ops::fma; + vector_storage result; + + detail::apply_impl::call( + F {}, + result.data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(a)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(b)) + .data(), + detail::convert_impl, vector_extent_type, T, E>::call( + into_vector_storage(c)) + .data()); + + return result; +} + +} // namespace kernel_float + +#endif //KERNEL_FLOAT_TRIOPS_H +#ifndef KERNEL_FLOAT_VECTOR_H +#define KERNEL_FLOAT_VECTOR_H + -template -struct vector { - using storage_type = V; - using traits_type = vector_traits; - using value_type = typename traits_type::value_type; - static constexpr size_t const_size = traits_type::size; - vector(const vector&) = default; - vector(vector&) = default; - vector(vector&&) = default; - vector& operator=(const vector&) = default; - vector& operator=(vector&) = default; - vector& operator=(vector&&) = default; + + + +namespace kernel_float { + +/** + * Container that stores ``N`` elements of type ``T``. + * + * It is not recommended to use this class directly, but instead, use the type `vec` which is an alias for + * `vector, vector_storage>`. + * + * @tparam T The type of the values stored within the vector. + * @tparam E The size of this vector. Should be of type `extent`. + * @tparam S The object's storage class. Should be the type `vector_storage` + */ +template +struct vector: public S { + using value_type = T; + using extent_type = E; + using storage_type = S; + + // Copy another `vector` + vector(const vector&) = default; + + // Copy anything of type `storage_type` KERNEL_FLOAT_INLINE - vector() : storage_(traits_type::fill(value_type {})) {} + vector(const storage_type& storage) : storage_type(storage) {} + // Copy anything of type `storage_type` KERNEL_FLOAT_INLINE - vector(storage_type storage) : storage_(storage) {} + vector(const value_type& input = {}) : + storage_type(detail::broadcast_impl, E>::call(input)) {} - template< - typename U, - enabled_t, value_type>, int> = 0> - KERNEL_FLOAT_INLINE vector(U&& init) : vector(broadcast(std::forward(init))) {} + // For all other arguments, we convert it using `convert_storage` according to broadcast rules + template, T>, int> = 0> + KERNEL_FLOAT_INLINE vector(U&& input) : + storage_type(convert_storage(input, extent_type {})) {} - template = 0> - KERNEL_FLOAT_INLINE vector(Args&&... args) : storage_(traits_type::create(args...)) {} + template, T>, int> = 0> + KERNEL_FLOAT_INLINE explicit vector(U&& input) : + storage_type(convert_storage(input, extent_type {})) {} + // List of `N` (where N >= 2), simply pass forward to the storage + template< + typename A, + typename B, + typename... Rest, + typename = enable_if_t> + KERNEL_FLOAT_INLINE vector(const A& a, const B& b, const Rest&... rest) : + storage_type {T(a), T(b), T(rest)...} {} + + /** + * Returns the number of elements in this vector. + */ KERNEL_FLOAT_INLINE - operator storage_type() const { - return storage_; + static constexpr size_t size() { + return E::size; } KERNEL_FLOAT_INLINE storage_type& storage() { - return storage_; + return *this; } KERNEL_FLOAT_INLINE const storage_type& storage() const { - return storage_; + return *this; } + /** + * Returns a pointer to the underlying storage data. + */ KERNEL_FLOAT_INLINE - value_type get(size_t index) const { - return traits_type::get(storage_, index); + T* data() { + return storage().data(); } + /** + * Returns a pointer to the underlying storage data. + */ KERNEL_FLOAT_INLINE - void set(size_t index, value_type value) { - traits_type::set(storage_, index, value); + const T* data() const { + return storage().data(); } - template - KERNEL_FLOAT_INLINE value_type get(const_index) const { - return vector_index::get(storage_); + KERNEL_FLOAT_INLINE + const T* cdata() const { + return this->data(); } - template - KERNEL_FLOAT_INLINE void set(const_index, value_type value) { - return vector_index::set(storage_, value); + /** + * Returns a reference to the item at index `i`. + */ + KERNEL_FLOAT_INLINE + T& at(size_t i) { + return *(this->data() + i); } + /** + * Returns a constant reference to the item at index `i`. + */ KERNEL_FLOAT_INLINE - value_type operator[](size_t index) const { - return get(index); + const T& at(size_t i) const { + return *(this->data() + i); } - template - KERNEL_FLOAT_INLINE value_type operator[](const_index) const { - return get(const_index {}); + /** + * Returns a reference to the item at index `i`. + */ + KERNEL_FLOAT_INLINE + T& operator[](size_t i) { + return at(i); } + /** + * Returns a constant reference to the item at index `i`. + */ KERNEL_FLOAT_INLINE - index_proxy operator[](size_t index) { - return {storage_, index}; + const T& operator[](size_t i) const { + return at(i); } - template - KERNEL_FLOAT_INLINE index_proxy> operator[](const_index) { - return {storage_, const_index {}}; + KERNEL_FLOAT_INLINE + T& operator()(size_t i) { + return at(i); } KERNEL_FLOAT_INLINE - static constexpr size_t size() { - return const_size; + const T& operator()(size_t i) const { + return at(i); } - private: - storage_type storage_; -}; - -template -struct vector_traits> { - using value_type = vector_value_type; - static constexpr size_t size = vector_size; - + /** + * Returns a pointer to the first element. + */ KERNEL_FLOAT_INLINE - static vector fill(value_type value) { - return vector_traits::fill(value); + T* begin() { + return this->data(); } - template - KERNEL_FLOAT_INLINE static vector create(Args... args) { - return vector_traits::create(args...); + /** + * Returns a pointer to the first element. + */ + KERNEL_FLOAT_INLINE + const T* begin() const { + return this->data(); } + /** + * Returns a pointer to the first element. + */ KERNEL_FLOAT_INLINE - static value_type get(const vector& self, size_t index) { - return vector_traits::get(self.storage(), index); + const T* cbegin() const { + return this->data(); } + /** + * Returns a pointer to one past the last element. + */ KERNEL_FLOAT_INLINE - static void set(vector& self, size_t index, value_type value) { - vector_traits::set(self.storage(), index, value); + T* end() { + return this->data() + size(); } -}; - -template -struct vector_index, I> { - using value_type = vector_value_type; + /** + * Returns a pointer to one past the last element. + */ KERNEL_FLOAT_INLINE - static value_type get(const vector& self) { - return vector_index::get(self.storage()); + const T* end() const { + return this->data() + size(); } + /** + * Returns a pointer to one past the last element. + */ KERNEL_FLOAT_INLINE - static void set(vector& self, value_type value) { - vector_index::set(self.storage(), value); + const T* cend() const { + return this->data() + size(); } -}; - -template -struct into_storage_traits> { - using type = V; + /** + * Copy the element at index `i`. + */ KERNEL_FLOAT_INLINE - static constexpr type call(const vector& self) { - return self.storage(); + T get(size_t x) const { + return at(x); } -}; -template -struct vector_swizzle, index_sequence> { - KERNEL_FLOAT_INLINE static Output call(const vector& self) { - return vector_swizzle>::call(self.storage()); + /** + * Set the element at index `i`. + */ + KERNEL_FLOAT_INLINE + void set(size_t x, T value) { + at(x) = std::move(value); + } + + /** + * Selects elements from the this vector based on the specified indices. + * + * Example + * ======= + * ``` + * vec input = {0, 10, 20, 30, 40, 50}; + * vec vec1 = select(input, 0, 4, 4, 2); // [0, 40, 40, 20] + * + * vec indices = {0, 4, 4, 2}; + * vec vec2 = select(input, indices); // [0, 40, 40, 20] + * ``` + */ + template + KERNEL_FLOAT_INLINE select_type select(const Is&... indices) { + return kernel_float::select(*this, indices...); + } + + /** + * Cast the elements of this vector to type `R` and returns a new vector. + */ + template + KERNEL_FLOAT_INLINE vector cast() const { + return kernel_float::cast(*this); + } + + /** + * Broadcast this vector into a new size `(Ns...)`. + */ + template + KERNEL_FLOAT_INLINE vector> broadcast(extent new_size = {}) const { + return kernel_float::broadcast(*this, new_size); + } + + /** + * Apply the given function `F` to each element of this vector and returns a new vector with the results. + */ + template + KERNEL_FLOAT_INLINE vector, E> map(F fun) const { + return kernel_float::map(fun, *this); + } + + /** + * Reduce the elements of the given vector input into a single value using the function `F`. + * + * This function should be a binary function that takes two elements and returns one element. The order in which + * the elements are reduced is not specified and depends on the reduction function and the vector type. + */ + template + KERNEL_FLOAT_INLINE T reduce(F fun) const { + return kernel_float::reduce(fun, *this); + } + + /** + * Flattens the elements of this vector. For example, this turns a `vec, 3>` into a `vec`. + */ + KERNEL_FLOAT_INLINE flatten_type flatten() const { + return kernel_float::flatten(*this); + } + + /** + * Apply the given function `F` to each element of this vector. + */ + template + KERNEL_FLOAT_INLINE void for_each(F fun) const { + return kernel_float::for_each(*this, std::move(fun)); } }; -template -using vec = vector>; - -template -using unaligned_vec = vector>; - -template -KERNEL_FLOAT_INLINE vec, sizeof...(Args)> make_vec(Args&&... args) { - using value_type = common_t; - using vector_type = default_storage_type; - return vector_traits::create(value_type(args)...); -} - +/** + * Convert the given `input` into a vector. This function can perform one of the following actions: + * + * - For vectors `vec`, it simply returns the original vector. + * - For primitive types `T` (e.g., `int`, `float`, `double`), it returns a `vec`. + * - For array-like types (e.g., `std::array`, `T[N]`), it returns `vec`. + * - For vector-like types (e.g., `int2`, `dim3`), it returns `vec`. + */ template -KERNEL_FLOAT_INLINE vector> into_vec(V&& input) { - return into_storage(input); +KERNEL_FLOAT_INLINE into_vector_type into_vector(V&& input) { + return into_vector_impl::call(std::forward(input)); } -using float32 = float; -using float64 = double; - -template -using vec1 = vec; -template -using vec2 = vec; -template -using vec3 = vec; -template -using vec4 = vec; -template -using vec5 = vec; -template -using vec6 = vec; -template -using vec7 = vec; template -using vec8 = vec; +using scalar = vector>; -#define KERNEL_FLOAT_TYPE_ALIAS(NAME, T) \ - template \ - using NAME##N = vec; \ - using NAME##1 = vec; \ - using NAME##2 = vec; \ - using NAME##3 = vec; \ - using NAME##4 = vec; \ - using NAME##5 = vec; \ - using NAME##6 = vec; \ - using NAME##7 = vec; \ - using NAME##8 = vec; \ - template \ - using unaligned_##NAME##X = unaligned_vec; \ - using unaligned_##NAME##1 = unaligned_vec; \ - using unaligned_##NAME##2 = unaligned_vec; \ - using unaligned_##NAME##3 = unaligned_vec; \ - using unaligned_##NAME##4 = unaligned_vec; \ - using unaligned_##NAME##5 = unaligned_vec; \ - using unaligned_##NAME##6 = unaligned_vec; \ - using unaligned_##NAME##7 = unaligned_vec; \ - using unaligned_##NAME##8 = unaligned_vec; +template +using vec = vector>; -KERNEL_FLOAT_TYPE_ALIAS(char, char) -KERNEL_FLOAT_TYPE_ALIAS(short, short) -KERNEL_FLOAT_TYPE_ALIAS(int, int) -KERNEL_FLOAT_TYPE_ALIAS(long, long) -KERNEL_FLOAT_TYPE_ALIAS(longlong, long long) +// clang-format off +template using vec1 = vec; +template using vec2 = vec; +template using vec3 = vec; +template using vec4 = vec; +template using vec5 = vec; +template using vec6 = vec; +template using vec7 = vec; +template using vec8 = vec; +// clang-format on -KERNEL_FLOAT_TYPE_ALIAS(uchar, unsigned char) -KERNEL_FLOAT_TYPE_ALIAS(ushort, unsigned short) -KERNEL_FLOAT_TYPE_ALIAS(uint, unsigned int) -KERNEL_FLOAT_TYPE_ALIAS(ulong, unsigned long) -KERNEL_FLOAT_TYPE_ALIAS(ulonglong, unsigned long long) +/** + * Create a vector from a variable number of input values. + * + * The resulting vector type is determined by promoting the types of the input values into a common type. + * The number of input values determines the dimension of the resulting vector. + * + * Example + * ======= + * ``` + * auto v1 = make_vec(1.0f, 2.0f, 3.0f); // Creates a vec [1.0f, 2.0f, 3.0f] + * auto v2 = make_vec(1, 2, 3, 4); // Creates a vec [1, 2, 3, 4] + * ``` + */ +template +KERNEL_FLOAT_INLINE vec, sizeof...(Args)> make_vec(Args&&... args) { + using T = promote_t; + return vector_storage {T(args)...}; +}; -KERNEL_FLOAT_TYPE_ALIAS(float, float) -KERNEL_FLOAT_TYPE_ALIAS(f32x, float) -KERNEL_FLOAT_TYPE_ALIAS(float32x, float) +#if defined(__cpp_deduction_guides) +// Deduction guide for `vector` +template +vector(Args&&... args) -> vector, extent>; -KERNEL_FLOAT_TYPE_ALIAS(double, double) -KERNEL_FLOAT_TYPE_ALIAS(f64x, double) -KERNEL_FLOAT_TYPE_ALIAS(float64x, double) +// Deduction guides for aliases are only supported from C++20 +#if __cpp_deduction_guides >= 201907L +template +vec(Args&&... args) -> vec, sizeof...(Args)>; +#endif +#endif } // namespace kernel_float -#endif //KERNEL_FLOAT_INTERFACE_H +#endif #ifndef KERNEL_FLOAT_FP16_H #define KERNEL_FLOAT_FP16_H @@ -2254,71 +3165,119 @@ KERNEL_FLOAT_TYPE_ALIAS(float64x, double) namespace kernel_float { -KERNEL_FLOAT_DEFINE_COMMON_TYPE(__half, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, __half) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half) template<> -struct vector_traits<__half2> { +struct into_vector_impl<__half2> { using value_type = __half; - static constexpr size_t size = 2; + using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static __half2 fill(__half value) { -#if KERNEL_FLOAT_ON_DEVICE - return __half2half2(value); -#else - return {value, value}; -#endif + static vector_storage<__half, 2> call(__half2 input) { + return {input.x, input.y}; } +}; +namespace detail { +template +struct map_halfx2 { KERNEL_FLOAT_INLINE - static __half2 create(__half low, __half high) { -#if KERNEL_FLOAT_ON_DEVICE - return __halves2half2(low, high); -#else - return {low, high}; -#endif + static __half2 call(F fun, __half2 input) { + __half a = fun(input.x); + __half b = fun(input.y); + return {a, b}; } +}; +template +struct zip_halfx2 { KERNEL_FLOAT_INLINE - static __half get(__half2 self, size_t index) { -#if KERNEL_FLOAT_ON_DEVICE - if (index == 0) { - return __low2half(self); - } else { - return __high2half(self); + static __half2 call(F fun, __half2 left, __half2 right) { + __half a = fun(left.x, left.y); + __half b = fun(right.y, right.y); + return {a, b}; + } +}; + +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, __half* result, const __half* input) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __half2 a = {input[2 * i], input[2 * i + 1]}; + __half2 b = map_halfx2::call(fun, a); + result[2 * i + 0] = b.x; + result[2 * i + 1] = b.y; } -#else - if (index == 0) { - return self.x; - } else { - return self.y; + + if (N % 2 != 0) { + result[N - 1] = fun(input[N - 1]); } -#endif } +}; - KERNEL_FLOAT_INLINE - static void set(__half2& self, size_t index, __half value) { - if (index == 0) { - self.x = value; - } else { - self.y = value; +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void + call(F fun, __half* result, const __half* left, const __half* right) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __half2 a = {left[2 * i], left[2 * i + 1]}; + __half2 b = {right[2 * i], right[2 * i + 1]}; + __half2 c = zip_halfx2::call(fun, a, b); + result[2 * i + 0] = c.x; + result[2 * i + 1] = c.y; + } + + if (N % 2 != 0) { + result[N - 1] = fun(left[N - 1], right[N - 1]); } } }; -template -struct default_storage<__half, N, Alignment::Maximum, enabled_t<(N >= 2)>> { - using type = nested_array<__half2, N>; -}; +template +struct reduce_impl= 2)>> { + KERNEL_FLOAT_INLINE static __half call(F fun, const __half* input) { + __half2 accum = {input[0], input[1]}; -template -struct default_storage<__half, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 == 0)>> { - using type = nested_array<__half2, N>; +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __half2 a = {input[2 * i], input[2 * i + 1]}; + accum = zip_halfx2::call(fun, accum, a); + } + + __half result = fun(accum.x, accum.y); + + if (N % 2 != 0) { + result = fun(result, input[N - 1]); + } + + return result; + } }; -#if KERNEL_FLOAT_ON_DEVICE +}; // namespace detail + +#define KERNEL_FLOAT_FP16_UNARY_FORWARD(NAME) \ + namespace ops { \ + template<> \ + struct NAME<__half> { \ + KERNEL_FLOAT_INLINE __half operator()(__half input) { \ + return __half(ops::NAME {}(float(input))); \ + } \ + }; \ + } + +// There operations are not implemented in half precision, so they are forward to single precision +KERNEL_FLOAT_FP16_UNARY_FORWARD(tan) +KERNEL_FLOAT_FP16_UNARY_FORWARD(asin) +KERNEL_FLOAT_FP16_UNARY_FORWARD(acos) +KERNEL_FLOAT_FP16_UNARY_FORWARD(atan) +KERNEL_FLOAT_FP16_UNARY_FORWARD(expm1) + +#if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -2330,28 +3289,37 @@ struct default_storage<__half, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 } \ namespace detail { \ template<> \ - struct map_helper, __half2, __half2> { \ + struct map_halfx2> { \ KERNEL_FLOAT_INLINE static __half2 call(ops::NAME<__half>, __half2 input) { \ return FUN2(input); \ } \ }; \ } +#else +#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) KERNEL_FLOAT_FP16_UNARY_FORWARD(NAME) +#endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2); -KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2); -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil); -KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos); -KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp); -KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10); -KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor); -KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log); -KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2); -KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint); -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt); -KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin); -KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt); -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc); - +KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos) +KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin) +KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) + +KERNEL_FLOAT_FP16_UNARY_FUN(fast_exp, ::hexp, ::h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(fast_log, ::hlog, ::h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(fast_cos, ::hcos, ::h2cos) +KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin) + +#if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -2363,12 +3331,23 @@ KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc); } \ namespace detail { \ template<> \ - struct zip_helper, __half2, __half2, __half2> { \ + struct zip_halfx2> { \ KERNEL_FLOAT_INLINE static __half2 call(ops::NAME<__half>, __half2 left, __half2 right) { \ return FUN2(left, right); \ } \ }; \ } +#else +#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__half> { \ + KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ + return __half(ops::NAME {}(float(left), float(right))); \ + } \ + }; \ + } +#endif KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2) @@ -2376,6 +3355,7 @@ KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2) KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div) KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) +KERNEL_FLOAT_FP16_BINARY_FUN(fast_div, __hdiv, __h2div) KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __heq, __heq2) @@ -2384,8 +3364,6 @@ KERNEL_FLOAT_FP16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_FP16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2) -#endif - #define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ @@ -2410,22 +3388,69 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input)); +KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input)); KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input))); KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input)); +KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); using half = __half; -using float16 = __half; -//KERNEL_FLOAT_TYPE_ALIAS(half, __half) //KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) //KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +#if KERNEL_FLOAT_IS_DEVICE +namespace detail { +template<> +struct dot_impl<__half, 0> { + KERNEL_FLOAT_INLINE + static __half call(const __half* left, const __half* right) { + return __half(0); + } +}; + +template<> +struct dot_impl<__half, 1> { + KERNEL_FLOAT_INLINE + static __half call(const __half* left, const __half* right) { + return __hmul(left[0], right[0]); + } +}; + +template +struct dot_impl<__half, N> { + static_assert(N >= 2, "internal error"); + + KERNEL_FLOAT_INLINE + static __half call(const __half* left, const __half* right) { + __half2 first_a = {left[0], left[1]}; + __half2 first_b = {right[0], right[1]}; + __half2 accum = __hmul2(first_a, first_b); + +#pragma unroll + for (size_t i = 2; i + 2 <= N; i += 2) { + __half2 a = {left[i], left[i + 1]}; + __half2 b = {right[i], right[i + 1]}; + accum = __hfma2(a, b, accum); + } + + __half result = __hadd(accum.x, accum.y); + + if (N % 2 != 0) { + __half a = left[N - 1]; + __half b = right[N - 1]; + result = __hfma(a, b, result); + } + + return result; + } +}; +} // namespace detail +#endif + } // namespace kernel_float #endif @@ -2443,127 +3468,193 @@ using float16 = __half; - - namespace kernel_float { -KERNEL_FLOAT_DEFINE_COMMON_TYPE(__nv_bfloat16, bool) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(float, __nv_bfloat16) -KERNEL_FLOAT_DEFINE_COMMON_TYPE(double, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16) template<> -struct vector_traits<__nv_bfloat162> { +struct into_vector_impl<__nv_bfloat162> { using value_type = __nv_bfloat16; - static constexpr size_t size = 2; + using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static __nv_bfloat162 fill(__nv_bfloat16 value) { -#if KERNEL_FLOAT_ON_DEVICE - return __bfloat162bfloat162(value); -#else - return {value, value}; -#endif + static vector_storage<__nv_bfloat16, 2> call(__nv_bfloat162 input) { + return {input.x, input.y}; } +}; +namespace detail { +template +struct map_bfloat16x2 { KERNEL_FLOAT_INLINE - static __nv_bfloat162 create(__nv_bfloat16 low, __nv_bfloat16 high) { -#if KERNEL_FLOAT_ON_DEVICE - return __halves2bfloat162(low, high); -#else - return {low, high}; -#endif + static __nv_bfloat162 call(F fun, __nv_bfloat162 input) { + __nv_bfloat16 a = fun(input.x); + __nv_bfloat16 b = fun(input.y); + return {a, b}; } +}; +template +struct zip_bfloat16x2 { KERNEL_FLOAT_INLINE - static __nv_bfloat16 get(__nv_bfloat162 self, size_t index) { -#if KERNEL_FLOAT_ON_DEVICE - if (index == 0) { - return __low2bfloat16(self); - } else { - return __high2bfloat16(self); + static __nv_bfloat162 call(F fun, __nv_bfloat162 left, __nv_bfloat162 right) { + __nv_bfloat16 a = fun(left.x, left.y); + __nv_bfloat16 b = fun(right.y, right.y); + return {a, b}; + } +}; + +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, __nv_bfloat16* result, const __nv_bfloat16* input) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __nv_bfloat162 a = {input[2 * i], input[2 * i + 1]}; + __nv_bfloat162 b = map_bfloat16x2::call(fun, a); + result[2 * i + 0] = b.x; + result[2 * i + 1] = b.y; } -#else - if (index == 0) { - return self.x; - } else { - return self.y; + + if (N % 2 != 0) { + result[N - 1] = fun(input[N - 1]); } -#endif } +}; - KERNEL_FLOAT_INLINE - static void set(__nv_bfloat162& self, size_t index, __nv_bfloat16 value) { - if (index == 0) { - self.x = value; - } else { - self.y = value; +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void + call(F fun, __nv_bfloat16* result, const __nv_bfloat16* left, const __nv_bfloat16* right) { +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __nv_bfloat162 a = {left[2 * i], left[2 * i + 1]}; + __nv_bfloat162 b = {right[2 * i], right[2 * i + 1]}; + __nv_bfloat162 c = zip_bfloat16x2::call(fun, a, b); + result[2 * i + 0] = c.x; + result[2 * i + 1] = c.y; + } + + if (N % 2 != 0) { + result[N - 1] = fun(left[N - 1], right[N - 1]); } } }; -template -struct default_storage<__nv_bfloat16, N, Alignment::Maximum, enabled_t<(N >= 2)>> { - using type = nested_array<__nv_bfloat162, N>; -}; +template +struct reduce_impl= 2)>> { + KERNEL_FLOAT_INLINE static __nv_bfloat16 call(F fun, const __nv_bfloat16* input) { + __nv_bfloat162 accum = {input[0], input[1]}; -template -struct default_storage<__nv_bfloat16, N, Alignment::Packed, enabled_t<(N >= 2 && N % 2 == 0)>> { - using type = nested_array<__nv_bfloat162, N>; +#pragma unroll + for (size_t i = 0; 2 * i + 1 < N; i++) { + __nv_bfloat162 a = {input[2 * i], input[2 * i + 1]}; + accum = zip_bfloat16x2::call(fun, accum, a); + } + + __nv_bfloat16 result = fun(accum.x, accum.y); + + if (N % 2 != 0) { + result = fun(result, input[N - 1]); + } + + return result; + } }; +} // namespace detail + +#define KERNEL_FLOAT_BF16_UNARY_FORWARD(NAME) \ + namespace ops { \ + template<> \ + struct NAME<__nv_bfloat16> { \ + KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ + return __nv_bfloat16(ops::NAME {}(float(input))); \ + } \ + }; \ + } + +// There operations are not implemented in half precision, so they are forward to single precision +KERNEL_FLOAT_BF16_UNARY_FORWARD(tan) +KERNEL_FLOAT_BF16_UNARY_FORWARD(asin) +KERNEL_FLOAT_BF16_UNARY_FORWARD(acos) +KERNEL_FLOAT_BF16_UNARY_FORWARD(atan) +KERNEL_FLOAT_BF16_UNARY_FORWARD(expm1) + +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__nv_bfloat16> { \ + KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct map_bfloat16x2> { \ + KERNEL_FLOAT_INLINE static __nv_bfloat162 \ + call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \ + return FUN2(input); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) KERNEL_FLOAT_BF16_UNARY_FORWARD(NAME) +#endif -#if KERNEL_FLOAT_ON_DEVICE -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ +KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) +KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) +KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) +KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) +KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) +KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) +KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) +KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) + +KERNEL_FLOAT_BF16_UNARY_FUN(fast_exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(fast_log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(fast_cos, ::hcos, ::h2cos) +KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin) + +#if KERNEL_FLOAT_IS_DEVICE +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__nv_bfloat16> { \ + KERNEL_FLOAT_INLINE __nv_bfloat16 \ + operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ + return FUN1(left, right); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct zip_bfloat16x2> { \ + KERNEL_FLOAT_INLINE static __nv_bfloat162 \ + call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 left, __nv_bfloat162 right) { \ + return FUN2(left, right); \ + } \ + }; \ + } +#else +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ - return FUN1(input); \ + KERNEL_FLOAT_INLINE __nv_bfloat16 \ + operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ + return __nv_bfloat16(ops::NAME {}(float(left), float(right))); \ } \ }; \ - } \ - namespace detail { \ - template<> \ - struct map_helper, __nv_bfloat162, __nv_bfloat162> { \ - KERNEL_FLOAT_INLINE static __nv_bfloat162 \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 input) { \ - return FUN2(input); \ - } \ - }; \ - } - -KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2); -KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2); -KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil); -KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos); -KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp); -KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10); -KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor); -KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log); -KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2); -KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint); -KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt); -KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin); -KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt); -KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc); - -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 \ - operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct zip_helper, __nv_bfloat162, __nv_bfloat162, __nv_bfloat162> { \ - KERNEL_FLOAT_INLINE static __nv_bfloat162 \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat162 left, __nv_bfloat162 right) { \ - return FUN2(left, right); \ - } \ - }; \ } +#endif KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) @@ -2572,6 +3663,8 @@ KERNEL_FLOAT_BF16_BINARY_FUN(divide, __hdiv, __h2div) KERNEL_FLOAT_BF16_BINARY_FUN(min, __hmin, __hmin2) KERNEL_FLOAT_BF16_BINARY_FUN(max, __hmax, __hmax2) +KERNEL_FLOAT_BF16_BINARY_FUN(fast_div, __hdiv, __h2div) + KERNEL_FLOAT_BF16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_BF16_BINARY_FUN(not_equal_to, __heq, __heq2) KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2) @@ -2579,8 +3672,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) -#endif - #define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ @@ -2600,38 +3691,76 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), double(__bfloat162float(input))); KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input)); +// clang-format off // there are no official char casts. Instead, cast to int and then to char KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST( - signed char, - __int2bfloat16_rn(input), - (signed char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST( - unsigned char, - __int2bfloat16_rn(input), - (unsigned char)__bfloat162int_rz(input)); +KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input)); +KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input)); KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST( - signed long, - __ll2bfloat16_rn(input), - (signed long)(__bfloat162ll_rz(input))); +KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input))); KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input)); KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST( - unsigned long, - __ull2bfloat16_rn(input), - (unsigned long)(__bfloat162ull_rz(input))); +KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); +// clang-format on using bfloat16 = __nv_bfloat16; -//KERNEL_FLOAT_TYPE_ALIAS(half, __nv_bfloat16) //KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16) //KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16) +#if KERNEL_FLOAT_IS_DEVICE +namespace detail { +template<> +struct dot_impl<__nv_bfloat16, 0> { + KERNEL_FLOAT_INLINE + static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) { + return __nv_bfloat16(0); + } +}; + +template<> +struct dot_impl<__nv_bfloat16, 1> { + KERNEL_FLOAT_INLINE + static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) { + return __hmul(left[0], right[0]); + } +}; + +template +struct dot_impl<__nv_bfloat16, N> { + static_assert(N >= 2, "internal error"); + + KERNEL_FLOAT_INLINE + static __nv_bfloat16 call(const __nv_bfloat16* left, const __nv_bfloat16* right) { + __nv_bfloat162 first_a = {left[0], left[1]}; + __nv_bfloat162 first_b = {right[0], right[1]}; + __nv_bfloat162 accum = __hmul2(first_a, first_b); + +#pragma unroll + for (size_t i = 2; i + 1 < N; i += 2) { + __nv_bfloat162 a = {left[i], left[i + 1]}; + __nv_bfloat162 b = {right[i], right[i + 1]}; + accum = __hfma2(a, b, accum); + } + + __nv_bfloat16 result = __hadd(accum.x, accum.y); + + if (N % 2 != 0) { + __nv_bfloat16 a = left[N - 1]; + __nv_bfloat16 b = right[N - 1]; + result = __hfma(a, b, result); + } + + return result; + } +}; +} // namespace detail +#endif + } // namespace kernel_float #if KERNEL_FLOAT_FP16_AVAILABLE @@ -2639,9 +3768,134 @@ using bfloat16 = __nv_bfloat16; namespace kernel_float { KERNEL_FLOAT_BF16_CAST(__half, __float2bfloat16(input), __bfloat162float(input)); -} + +template<> +struct promote_type<__nv_bfloat16, __half> { + using type = float; +}; + +template<> +struct promote_type<__half, __nv_bfloat16> { + using type = float; +}; + +} // namespace kernel_float #endif // KERNEL_FLOAT_FP16_AVAILABLE #endif #endif //KERNEL_FLOAT_BF16_H +#ifndef KERNEL_FLOAT_PRELUDE_H +#define KERNEL_FLOAT_PRELUDE_H + + + + + + +namespace kernel_float { +namespace prelude { +namespace kf = ::kernel_float; + +template +using kscalar = vector>; + +template +using kvec = vector>; + +// clang-format off +template using kvec1 = kvec; +template using kvec2 = kvec; +template using kvec3 = kvec; +template using kvec4 = kvec; +template using kvec5 = kvec; +template using kvec6 = kvec; +template using kvec7 = kvec; +template using kvec8 = kvec; +// clang-format on + +#define KERNEL_FLOAT_TYPE_ALIAS(NAME, T) \ + template \ + using k##NAME = vector>; \ + using k##NAME##1 = vec; \ + using k##NAME##2 = vec; \ + using k##NAME##3 = vec; \ + using k##NAME##4 = vec; \ + using k##NAME##5 = vec; \ + using k##NAME##6 = vec; \ + using k##NAME##7 = vec; \ + using k##NAME##8 = vec; + +KERNEL_FLOAT_TYPE_ALIAS(char, char) +KERNEL_FLOAT_TYPE_ALIAS(short, short) +KERNEL_FLOAT_TYPE_ALIAS(int, int) +KERNEL_FLOAT_TYPE_ALIAS(long, long) +KERNEL_FLOAT_TYPE_ALIAS(longlong, long long) + +KERNEL_FLOAT_TYPE_ALIAS(uchar, unsigned char) +KERNEL_FLOAT_TYPE_ALIAS(ushort, unsigned short) +KERNEL_FLOAT_TYPE_ALIAS(uint, unsigned int) +KERNEL_FLOAT_TYPE_ALIAS(ulong, unsigned long) +KERNEL_FLOAT_TYPE_ALIAS(ulonglong, unsigned long long) + +KERNEL_FLOAT_TYPE_ALIAS(float, float) +KERNEL_FLOAT_TYPE_ALIAS(f32x, float) +KERNEL_FLOAT_TYPE_ALIAS(float32x, float) + +KERNEL_FLOAT_TYPE_ALIAS(double, double) +KERNEL_FLOAT_TYPE_ALIAS(f64x, double) +KERNEL_FLOAT_TYPE_ALIAS(float64x, double) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_TYPE_ALIAS(half, __half) +KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) +#endif + +#if KERNEL_FLOAT_BF16_AVAILABLE +KERNEL_FLOAT_TYPE_ALIAS(bfloat16, __nv_bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bf16, __nv_bfloat16) +#endif + +template +static constexpr extent kextent = {}; + +template +KERNEL_FLOAT_INLINE kvec, sizeof...(Args)> make_kvec(Args&&... args) { + return make_vec(std::forward(args)...); +}; + +template +using kconstant = constant; + +template +KERNEL_FLOAT_INLINE constexpr kconstant kconst(T value) { + return value; +} + +KERNEL_FLOAT_INLINE +static constexpr kconstant operator""_c(long double v) { + return static_cast(v); +} + +KERNEL_FLOAT_INLINE +static constexpr kconstant operator""_c(unsigned long long int v) { + return static_cast(v); +} + +// Deduction guides for aliases are only supported from C++20 +#if defined(__cpp_deduction_guides) && __cpp_deduction_guides >= 201907L +template +kscalar(T&&) -> kscalar>; + +template +kvec(Args&&...) -> kvec, sizeof...(Args)>; + +template +kconstant(T&&) -> kconstant>; +#endif + +} // namespace prelude +} // namespace kernel_float + +#endif diff --git a/tests/basic.cu b/tests/basic.cu deleted file mode 100644 index b76580d..0000000 --- a/tests/basic.cu +++ /dev/null @@ -1,49 +0,0 @@ -#include "common.h" -#include "kernel_float.h" - -namespace kf = kernel_float; - -template> -struct basic_test; - -template -struct basic_test> { - __host__ __device__ void operator()(generator gen) { - T items[N] = {gen.next(Is)...}; - kf::vec a = {items[Is]...}; - - // check if getters work - ASSERT(equals(a.get(Is), items[Is]) && ...); - ASSERT(equals(a.get(kf::const_index {}), items[Is]) && ...); - ASSERT(equals(a[Is], items[Is]) && ...); - ASSERT(equals(a[kf::const_index {}], items[Is]) && ...); - - // check if setter works - T new_items[N] = {gen.next(Is)...}; - (a.set(kf::const_index {}, new_items[Is]), ...); - ASSERT(equals(a.get(Is), new_items[Is]) && ...); - - // check if setter works - T more_new_items[N] = {gen.next(Is)...}; - ((a[Is] = more_new_items[Is]), ...); - ASSERT(equals(a.get(Is), more_new_items[Is]) && ...); - - // check default constructor - kf::vec b; - ASSERT(equals(b.get(Is), T {}) && ...); - - // check broadcast constructor - T value = gen(); - kf::vec c {value}; - ASSERT(equals(c.get(Is), value) && ...); - - // check make_vec - kf::vec d = kf::make_vec(items[Is]...); - ASSERT(equals(d.get(Is), items[Is]) && ...); - } -}; - -TEST_CASE("basic") { - run_on_host_and_device(); - run_on_device(); -} diff --git a/tests/basics.cu b/tests/basics.cu new file mode 100644 index 0000000..9d15b56 --- /dev/null +++ b/tests/basics.cu @@ -0,0 +1,148 @@ +#include "common.h" + +struct basics_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + // default constructor + { + kf::vec x; + ASSERT(equals(x[I], T()) && ...); + } + + // filled with one + { + kf::vec x = {T((gen.next(I), 1))...}; + ASSERT(equals(x[I], T(1)) && ...); + } + + // filled with steps + { + kf::vec x = {T(I)...}; + ASSERT(equals(x[I], T(I)) && ...); + } + + // broadcast constructor + { + T init = gen.next(); + kf::vec x {init}; + ASSERT(equals(x[I], init) && ...); + } + + // Getters + T items[N] = {gen.next(I)...}; + kf::vec a = {items[I]...}; + + ASSERT(equals(a[I], items[I]) && ...); + ASSERT(equals(a.get(I), items[I]) && ...); + ASSERT(equals(a.at(I), items[I]) && ...); + ASSERT(equals(a(I), items[I]) && ...); + + // Data, begin, end + ASSERT(a.size() == N); + ASSERT(&a[0] == a.data()); + ASSERT(&a[0] == a.begin()); + ASSERT(&a[0] + N == a.end()); + ASSERT(&a[0] == a.cdata()); + ASSERT(&a[0] == a.cbegin()); + ASSERT(&a[0] + N == a.cend()); + + // setters + T new_items[N] = {gen.next(I)...}; + (a.set(I, new_items[I]), ...); + ASSERT(equals(a[I], new_items[I]) && ...); + } +}; + +REGISTER_TEST_CASE("basics", basics_tests, int, float) + +struct creation_tests { + __host__ __device__ void operator()(generator gen) { + using kernel_float::into_vector; + using kernel_float::make_vec; + + // into_vector on scalar + { + kf::vec a = into_vector(int(5)); + ASSERT(a[0] == 5); + } + + // into_vector on CUDA vector types + { + kf::vec a = into_vector(make_int1(5)); + kf::vec b = into_vector(make_int2(5, 4)); + kf::vec c = into_vector(make_int3(5, 4, -1)); + kf::vec d = into_vector(make_int4(5, 4, -1, 0)); + + ASSERT(a[0] == 5); + ASSERT(b[0] == 5 && b[1] == 4); + ASSERT(c[0] == 5 && c[1] == 4 && c[2] == -1); + ASSERT(d[0] == 5 && d[1] == 4 && d[2] == -1 && d[3] == 0); + } + + // into_vector on C-style array + { + int items[3] = {1, 2, 3}; + kf::vec a = into_vector(items); + ASSERT(a[0] == 1 && a[1] == 2 && a[2] == 3); + } + + // into_vector on kf array + { + kf::vec items = {1, 2, 3}; + kf::vec a = into_vector(items); + ASSERT(a[0] == 1 && a[1] == 2 && a[2] == 3); + } + + // make_vec + { + kf::vec a = make_vec(true, short(2), int(3)); + ASSERT(a[0] == 1 && a[1] == 2 && a[2] == 3); + } + } + + __host__ __device__ void operator()(generator gen) { + using kernel_float::into_vector; + using kernel_float::make_vec; + + // into_vector on scalar + { + kf::vec a = into_vector(float(5.0f)); + ASSERT(a[0] == 5.0f); + } + + // into_vector on CUDA vector types + { + kf::vec a = into_vector(make_float1(5.0f)); + kf::vec b = into_vector(make_float2(5.0f, 4.0f)); + kf::vec c = into_vector(make_float3(5.0f, 4.0f, -1.0f)); + kf::vec d = into_vector(make_float4(5.0f, 4.0f, -1.0f, 0.0f)); + + ASSERT(a[0] == 5.0f); + ASSERT(b[0] == 5.0f && b[1] == 4.0f); + ASSERT(c[0] == 5.0f && c[1] == 4.0f && c[2] == -1.0f); + ASSERT(d[0] == 5.0f && d[1] == 4.0f && d[2] == -1.0f && d[3] == 0.0f); + } + + // into_vector on C-style array + { + float items[3] = {1.0f, 2.0f, 3.0f}; + kf::vec a = into_vector(items); + ASSERT(a[0] == 1.0f && a[1] == 2.0f && a[2] == 3.0f); + } + + // into_vector on kf array + { + kf::vec items = {1.0f, 2.0f, 3.0f}; + kf::vec a = into_vector(items); + ASSERT(a[0] == 1.0f && a[1] == 2.0f && a[2] == 3.0f); + } + + // make_vec + { + kf::vec a = make_vec(true, int(2), 3.0f); + ASSERT(a[0] == 1.0f && a[1] == 2.0f && a[2] == 3.0f); + } + } +}; + +REGISTER_TEST_CASE("into_vec and make_vec", creation_tests, int, float) \ No newline at end of file diff --git a/tests/binops.cu b/tests/binops.cu index 8409b71..114889f 100644 --- a/tests/binops.cu +++ b/tests/binops.cu @@ -1,149 +1,137 @@ #include "common.h" -#include "kernel_float.h" -namespace kf = kernel_float; +struct binops_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + T x[N] = {gen.next(I)...}; + T y[N] = {gen.next(I)...}; -template> -struct arithmetic_test; - -template -struct arithmetic_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}, b {gen.next(Is)...}, c; + kf::vec a = {x[I]...}; + kf::vec b = {y[I]...}; + kf::vec c; - // binary operator + // Arithmetic c = a + b; - ASSERT(equals(c.get(Is), a.get(Is) + b.get(Is)) && ...); + ASSERT(equals(T(x[I] + y[I]), c[I]) && ...); c = a - b; - ASSERT(equals(c.get(Is), a.get(Is) - b.get(Is)) && ...); + ASSERT(equals(T(x[I] - y[I]), c[I]) && ...); c = a * b; - ASSERT(equals(c.get(Is), a.get(Is) * b.get(Is)) && ...); + ASSERT(equals(T(x[I] * y[I]), c[I]) && ...); - c = a / b; - ASSERT(equals(c.get(Is), a.get(Is) / b.get(Is)) && ...); + // Results in division by zero + // c = a / b; + // ASSERT(equals(T(x[I] / y[I]), c[I]) && ...); - // assignment operator - c = a; - c += b; - ASSERT(equals(c.get(Is), a.get(Is) + b.get(Is)) && ...); + // Results in division by zero + // c = a % b; + // ASSERT(equals(T(x[I] % y[I]), c[I]) && ...); - c = a; - c -= b; - ASSERT(equals(c.get(Is), a.get(Is) - b.get(Is)) && ...); - - c = a; - c *= b; - ASSERT(equals(c.get(Is), a.get(Is) * b.get(Is)) && ...); + // Comparison + c = a < b; + ASSERT(equals(T(x[I] < y[I]), c[I]) && ...); - c = a; - c /= b; - ASSERT(equals(c.get(Is), a.get(Is) / b.get(Is)) && ...); - } -}; + c = a > b; + ASSERT(equals(T(x[I] > y[I]), c[I]) && ...); -template> -struct minmax_test; + c = a <= b; + ASSERT(equals(T(x[I] <= y[I]), c[I]) && ...); -template -struct minmax_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}, b {gen.next(Is)...}, c; + c = a >= b; + ASSERT(equals(T(x[I] >= y[I]), c[I]) && ...); - c = kf::min(a, b); - ASSERT(equals(c.get(Is), a.get(Is) < b.get(Is) ? a.get(Is) : b.get(Is)) && ...); + c = a == b; + ASSERT(equals(T(x[I] == y[I]), c[I]) && ...); - c = kf::max(a, b); - ASSERT(equals(c.get(Is), a.get(Is) > b.get(Is) ? a.get(Is) : b.get(Is)) && ...); - } -}; + c = a != b; + ASSERT(equals(T(x[I] != y[I]), c[I]) && ...); -template -struct minmax_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}, b {gen.next(Is)...}, c; + // Assignment + c = a; + c += b; + ASSERT(equals(T(x[I] + y[I]), c[I]) && ...); - c = kf::min(a, b); - ASSERT(equals(c.get(Is), fminf(a.get(Is), b.get(Is))) && ...); + c = a; + c -= b; + ASSERT(equals(T(x[I] - y[I]), c[I]) && ...); - c = kf::max(a, b); - ASSERT(equals(c.get(Is), fmaxf(a.get(Is), b.get(Is))) && ...); + c = a; + c *= b; + ASSERT(equals(T(x[I] * y[I]), c[I]) && ...); } }; -template -struct minmax_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}, b {gen.next(Is)...}, c; - - c = kf::min(a, b); - ASSERT(equals(c.get(Is), fmin(a.get(Is), b.get(Is))) && ...); - - c = kf::max(a, b); - ASSERT(equals(c.get(Is), fmax(a.get(Is), b.get(Is))) && ...); - } -}; +REGISTER_TEST_CASE("binary operators", binops_tests, bool, int, float, double) +REGISTER_TEST_CASE_GPU("binary operators", binops_tests, __half, __nv_bfloat16) -template> -struct relational_test; +struct binops_float_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + T x[N] = {gen.next(I)...}; + T y[N] = {gen.next(I)...}; -template -struct relational_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec b {gen.next(Is)...}; + kf::vec a = {x[I]...}; + kf::vec b = {y[I]...}; kf::vec c; - c = a == b; - ASSERT(equals(c.get(Is), T(a.get(Is) == b.get(Is))) && ...); - - c = a != b; - ASSERT(equals(c.get(Is), T(a.get(Is) != b.get(Is))) && ...); - - c = a < b; - ASSERT(equals(c.get(Is), T(a.get(Is) < b.get(Is))) && ...); - - c = a <= b; - ASSERT(equals(c.get(Is), T(a.get(Is) <= b.get(Is))) && ...); + c = a / b; + ASSERT(equals(T(x[I] / y[I]), c[I]) && ...); - c = a > b; - ASSERT(equals(c.get(Is), T(a.get(Is) > b.get(Is))) && ...); + // remainder is not support for fp16 + if constexpr (is_none_of) { + // c = a % b; + // ASSERT(equals(T(fmod(x[I], y[I])), c[I]) && ...); + } + } +}; - c = a >= b; - ASSERT(equals(c.get(Is), T(a.get(Is) >= b.get(Is))) && ...); +REGISTER_TEST_CASE("binary float operators", binops_float_tests, float, double) +REGISTER_TEST_CASE_GPU("binary float operators", binops_float_tests, __half, __nv_bfloat16) + +struct minmax_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + T x[N] = {gen.next(I)...}; + T y[N] = {gen.next(I)...}; + + kf::vec a = {x[I]...}; + kf::vec b = {y[I]...}; + + kf::vec lo = min(a, b); + kf::vec hi = max(a, b); + + if constexpr (is_one_of) { + ASSERT(equals(fmin(a[I], b[I]), lo[I]) && ...); + ASSERT(equals(fmax(a[I], b[I]), hi[I]) && ...); + } else if constexpr (is_one_of) { + ASSERT(equals(fminf(a[I], b[I]), lo[I]) && ...); + ASSERT(equals(fmaxf(a[I], b[I]), hi[I]) && ...); + } else if constexpr (is_one_of) { + ASSERT(equals(__hmin(a[I], b[I]), lo[I]) && ...); + ASSERT(equals(__hmax(a[I], b[I]), hi[I]) && ...); + } else { + ASSERT(equals(x[I] < y[I] ? x[I] : y[I], lo[I]) && ...); + ASSERT(equals(x[I] < y[I] ? y[I] : x[I], hi[I]) && ...); + } } }; -template> -struct bitwise_test; +REGISTER_TEST_CASE("min/max functions", minmax_tests, bool, int, float, double) +REGISTER_TEST_CASE_GPU("min/max functions", minmax_tests, __half, __nv_bfloat16) -template -struct bitwise_test> { +struct cross_test { + template __host__ __device__ void operator()(generator gen) { - kf::vec a = {gen.next(Is)...}; - kf::vec b = {gen.next(Is)...}; - - kf::vec c = a | b; - ASSERT(equals(c.get(Is), T(a.get(Is) | b.get(Is))) && ...); - - c = a & b; - ASSERT(equals(c.get(Is), T(a.get(Is) & b.get(Is))) && ...); + kf::vec a = {1, 2, 3}; + kf::vec b = {4, 5, 6}; + kf::vec c = cross(a, b); - c = a ^ b; - ASSERT(equals(c.get(Is), T(a.get(Is) ^ b.get(Is))) && ...); + ASSERT(c[0] == T(-3)); + ASSERT(c[1] == T(6)); + ASSERT(c[2] == T(-3)); } }; -TEST_CASE("binary operators") { - run_on_host_and_device(); - run_on_device(); - - run_on_host_and_device(); - run_on_device(); - - run_on_host_and_device(); - run_on_device(); - - run_on_host_and_device(); -} +REGISTER_TEST_CASE("cross product", cross_test, float, double) +REGISTER_TEST_CASE_GPU("cross product", cross_test, __half, __nv_bfloat16) \ No newline at end of file diff --git a/tests/cast.cu b/tests/cast.cu deleted file mode 100644 index 9f13e7c..0000000 --- a/tests/cast.cu +++ /dev/null @@ -1,156 +0,0 @@ -#include "common.h" -#include "kernel_float.h" - -namespace kf = kernel_float; - -template> -struct cast_test; - -template -struct cast_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec b = kf::cast(a); - - ASSERT(equals(B(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec<__half, N> b = kf::cast<__half>(a); - - for (size_t i = 0; i < N; i++) { - printf("%d/%d] %f %d\n", int(i), int(N), (double)(b.get(i)), int(a[i])); - } - - ASSERT(equals(__half(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test<__half, long, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__half> gen) { - kf::vec<__half, N> a {gen.next(Is)...}; - kf::vec b = kf::cast(a); - ASSERT(equals((long)(long long)a.get(Is), b.get(Is)) && ...); - } -}; - -template -struct cast_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec<__half, N> b = kf::cast<__half>(a); - ASSERT(equals(__half((long long)a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec<__half, N> b = kf::cast<__half>(a); - ASSERT(equals((__half)(unsigned long long)(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test<__half, char, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__half> gen) { - kf::vec<__half, N> a {gen.next(Is)...}; - kf::vec b = kf::cast(a); - ASSERT(equals((char)(int)(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test<__nv_bfloat16, long, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__nv_bfloat16> gen) { - kf::vec<__nv_bfloat16, N> a {gen.next(Is)...}; - kf::vec b = kf::cast(a); - ASSERT(equals((long)(long long)a.get(Is), b.get(Is)) && ...); - } -}; - -template -struct cast_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec<__nv_bfloat16, N> b = kf::cast<__nv_bfloat16>(a); - ASSERT(equals(__nv_bfloat16((long long)a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec<__nv_bfloat16, N> b = kf::cast<__nv_bfloat16>(a); - ASSERT(equals((__nv_bfloat16)(unsigned long long)(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test<__nv_bfloat16, char, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__nv_bfloat16> gen) { - kf::vec<__nv_bfloat16, N> a {gen.next(Is)...}; - kf::vec b = kf::cast(a); - ASSERT(equals((char)(int)(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test<__nv_bfloat16, __half, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__nv_bfloat16> gen) { - kf::vec<__nv_bfloat16, N> a {gen.next(Is)...}; - kf::vec<__half, N> b = kf::cast<__half>(a); - ASSERT(equals((__half)(float)(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_test<__half, __nv_bfloat16, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__half> gen) { - kf::vec<__half, N> a {gen.next(Is)...}; - kf::vec<__nv_bfloat16, N> b = kf::cast<__nv_bfloat16>(a); - ASSERT(equals((__nv_bfloat16)(float)(a.get(Is)), b.get(Is)) && ...); - } -}; - -template -struct cast_to { - template - using type = cast_test; -}; - -TEST_CASE("cast operators") { - auto types = type_sequence< - bool, - char, - short, - int, - unsigned int, - long, - unsigned long, - long long, - float, - double, - __half, - __nv_bfloat16> {}; - - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - run_on_host_and_device::template type>(types); - - //bool, char, short, int, long long, __half, float, double -} diff --git a/tests/common.h b/tests/common.h index 712d945..dac12bb 100644 --- a/tests/common.h +++ b/tests/common.h @@ -4,19 +4,14 @@ #include #include -#include -#include #include "catch2/catch_all.hpp" #include "kernel_float.h" -#define ASSERT(expr) check_assertions((expr), #expr, __FILE__, __LINE__); - -static __host__ __device__ int -check_assertions(bool result, const char* expr, const char* file, int line) { - if (result) - return 0; +namespace kf = kernel_float; +namespace detail { +static __host__ __device__ void __assertion_failed(const char* expr, const char* file, int line) { #ifndef __CUDA_ARCH__ std::string msg = "assertion failed: " + std::string(expr) + " (" + file + ":" + std::to_string(line) + ")"; @@ -28,249 +23,352 @@ check_assertions(bool result, const char* expr, const char* file, int line) { ; #endif } +} // namespace detail + +#define ASSERT(...) \ + do { \ + bool __result = (__VA_ARGS__); \ + if (!__result) { \ + ::detail::__assertion_failed(#__VA_ARGS__, __FILE__, __LINE__); \ + } \ + } while (0) -template -__host__ __device__ void ignore(Ts...) {} +#define ASSERT_EQ(A, B) ASSERT(equals(A, B)) +#define ASSERT_APPROX(A, B) ASSERT(approx(A, B)) +#define ASSERT_ALL(E) ASSERT((E) && ...) +#define ASSERT_EQ_ALL(A, B) ASSERT_ALL(equals(A, B)) +#define ASSERT_APPROX_ALL(A, B) ASSERT_ALL(approx(A, B)) + +namespace detail { template struct equals_helper { - __host__ __device__ static bool call(T left, T right) { + static __host__ __device__ bool call(const T& left, const T& right) { return left == right; } }; template<> struct equals_helper { - __host__ __device__ static bool call(double left, double right) { - return (isnan(left) && isnan(right)) || (isinf(left) && isinf(right)) || left == right; + static __host__ __device__ bool call(const double& left, const double& right) { + return (isnan(left) && isnan(right)) || (left == right); + } +}; + +template<> +struct equals_helper { + static __host__ __device__ bool call(const float& left, const float& right) { + return (isnan(left) && isnan(right)) || (left == right); } }; template<> -struct equals_helper: equals_helper {}; +struct equals_helper<__half> { + static __host__ __device__ bool call(const __half& left, const __half& right) { + return equals_helper::call(float(left), float(right)); + } +}; template<> -struct equals_helper<__half>: equals_helper {}; +struct equals_helper<__nv_bfloat16> { + static __host__ __device__ bool call(const __nv_bfloat16& left, const __nv_bfloat16& right) { + return equals_helper::call(float(left), float(right)); + } +}; + +template +struct equals_helper> { + static __host__ __device__ bool call(const kf::vec& left, const kf::vec& right) { + for (int i = 0; i < N; i++) { + if (!equals_helper::call(left[i], right[i])) { + return false; + } + } + + return true; + } +}; + +} // namespace detail template -__host__ __device__ bool equals(T left, T right) { - return equals_helper::call(left, right); +__host__ __device__ bool equals(const T& left, const T& right) { + return detail::equals_helper::call(left, right); } -template -struct type_sequence {}; +namespace detail { +template +struct approx_helper { + static __host__ __device__ bool call(const T& left, const T& right) { + return equals_helper::call(left, right); + } +}; -template -struct size_sequence {}; +template<> +struct approx_helper { + static __host__ __device__ bool call(double left, double right, double threshold = 1e-8) { + return equals_helper::call(left, right) + || ::fabs(left - right) < threshold * ::fabs(left); + } +}; + +template<> +struct approx_helper { + static __host__ __device__ bool call(float left, float right) { + return approx_helper::call(double(left), double(right), 1e-4); + } +}; + +template<> +struct approx_helper<__half> { + static __host__ __device__ bool call(__half left, __half right) { + return approx_helper::call(double(left), double(right), 0.01); + } +}; + +template<> +struct approx_helper<__nv_bfloat16> { + static __host__ __device__ bool call(__nv_bfloat16 left, __nv_bfloat16 right) { + return approx_helper::call(double(left), double(right), 0.05); + } +}; +} // namespace detail template -struct type_name {}; -#define DEFINE_TYPE_NAME(T) \ - template<> \ - struct type_name { \ - static constexpr const char* value = #T; \ - }; +__host__ __device__ bool approx(const T& left, const T& right) { + return detail::approx_helper::call(left, right); +} -DEFINE_TYPE_NAME(bool) -DEFINE_TYPE_NAME(char) -DEFINE_TYPE_NAME(short) -DEFINE_TYPE_NAME(int) -DEFINE_TYPE_NAME(unsigned int) -DEFINE_TYPE_NAME(long) -DEFINE_TYPE_NAME(unsigned long) -DEFINE_TYPE_NAME(long long) -DEFINE_TYPE_NAME(__half) -DEFINE_TYPE_NAME(__nv_bfloat16) -DEFINE_TYPE_NAME(float) -DEFINE_TYPE_NAME(double) +namespace detail { +template +struct is_one_of_helper; + +template +struct is_one_of_helper: std::false_type {}; + +template +struct is_one_of_helper: std::true_type {}; + +template +struct is_one_of_helper: is_one_of_helper {}; +} // namespace detail + +template +static constexpr bool is_one_of = detail::is_one_of_helper::value; +template +static constexpr bool is_none_of = !detail::is_one_of_helper::value; + +namespace detail { template -struct generate_value; +struct generator_value; template<> -struct generate_value { - __host__ __device__ static bool call(uint64_t value) { - return bool(value & 0x1); +struct generator_value { + static __host__ __device__ bool call(uint64_t bits) { + return bool(bits % 2); } }; template -struct generate_value< - T, - typename std::enable_if::value && !std::is_same::value>::type> { - __host__ __device__ static T call(uint64_t value) { - return T(value); +struct generator_value && !std::is_same_v>> { + static constexpr T min_value = std::numeric_limits::min(); + static constexpr T max_value = std::numeric_limits::max(); + + static __host__ __device__ T call(uint64_t bits) { + if ((bits & 0xf) == 0xa) { + return T(0); + } else if ((bits & 0xf) == 0xb) { + return min_value; + } else if ((bits & 0xf) == 0xc) { + return max_value; + } else { + return T(bits); + } } }; template -struct generate_value::value>::type> { - __host__ __device__ static T call(uint64_t value) { - if ((value & 0xf) == 0) { +struct generator_value>> { + static constexpr T max_value = std::numeric_limits::max(); + + __host__ __device__ static T call(uint64_t bits) { + if ((bits & 0xf) == 0) { return T(0) / T(0); // nan - } else if ((value & 0xf) == 1) { + } else if ((bits & 0xf) == 1) { return T(1) / T(0); // inf - } else if ((value & 0xf) == 2) { - return -T(0) / T(0); // +inf - } else if ((value & 0xf) == 3) { - return 0; + } else if ((bits & 0xf) == 2) { + return -T(1) / T(0); // +inf + } else if ((bits & 0xf) == 3) { + return T(0); } else { - return T(value) / T(UINT64_MAX); + return (T(bits) / T(max_value)) * (bits % 2 ? T(-1) : T(+1)); } } }; template<> -struct generate_value<__half> { +struct generator_value<__half> { __host__ __device__ static __half call(uint64_t seed) { - return __half(generate_value::call(seed)); + return __half(generator_value::call(seed)); } }; template<> -struct generate_value<__nv_bfloat16> { +struct generator_value<__nv_bfloat16> { __host__ __device__ static __nv_bfloat16 call(uint64_t seed) { - return __nv_bfloat16(generate_value::call(seed)); + return __nv_bfloat16(generator_value::call(seed)); } }; +} // namespace detail -template +template struct generator { __host__ __device__ generator(uint64_t seed = 6364136223846793005ULL) : seed_(seed) { next(); } - __host__ __device__ T next(uint64_t ignore = 0) { + template + __host__ __device__ T next(R ignore = {}) { seed_ = 6364136223846793005ULL * seed_ + 1442695040888963407ULL; - return generate_value::call(seed_); - } - - __host__ __device__ T operator()() { - return next(); + return detail::generator_value::call(seed_); } private: uint64_t seed_; }; -template class F, typename T> -void run_sizes(size_sequence<>) { - // empty -} +template +struct type_name { + static constexpr const char* value = "???"; +}; -template class F, typename T, size_t N, size_t... Is, typename... Args> -void run_sizes(size_sequence, Args... args) { - //SECTION("size=" + std::to_string(N)) - { - INFO("N=" << N); - F {}(args...); - } +#define DEFINE_TYPE_NAME(T) \ + template<> \ + struct type_name { \ + static constexpr const char* value = #T; \ + }; - run_sizes(size_sequence {}, args...); -} +DEFINE_TYPE_NAME(bool) +DEFINE_TYPE_NAME(signed char) +DEFINE_TYPE_NAME(char) +DEFINE_TYPE_NAME(short) +DEFINE_TYPE_NAME(int) +DEFINE_TYPE_NAME(long) +DEFINE_TYPE_NAME(long long) +DEFINE_TYPE_NAME(unsigned char) +DEFINE_TYPE_NAME(unsigned short) +DEFINE_TYPE_NAME(unsigned int) +DEFINE_TYPE_NAME(unsigned long) +DEFINE_TYPE_NAME(unsigned long long) +DEFINE_TYPE_NAME(__half) +DEFINE_TYPE_NAME(__nv_bfloat16) +DEFINE_TYPE_NAME(float) +DEFINE_TYPE_NAME(double) -template< - template - class F, - typename T, - typename... Ts, - size_t... Is, - typename... Args> -void run_combinations(type_sequence, size_sequence, Args... args) { - //SECTION(std::string("type=") + type_name::value) - { - INFO("T=" << type_name::value); - run_sizes(size_sequence {}); - } +template +struct type_sequence {}; + +template +struct size_sequence {}; - run_combinations(type_sequence {}, size_sequence {}, args...); +using default_size_sequence = size_sequence<1, 2, 3, 4, 5, 6, 7, 8>; + +namespace detail { +template +void iterate_sizes(F runner, size_sequence) { + runner.template run(); + iterate_sizes(runner, size_sequence {}); } -template class F, typename... Ts, size_t... Is, typename... Args> -void run_combinations(type_sequence<>, size_sequence, Args... args) {} +template +void iterate_sizes(F, size_sequence<>) {} -template class F, typename T, size_t N> +template struct host_runner { - template - void operator()(Args... args) { - for (size_t i = 0; i < 5; i++) { - INFO("seed=" << i); - F {}(generator(i), args...); - } - } -}; + F fun; + + host_runner(F fun) : fun(fun) {} -template class F> -struct host_runner_helper { template - using type = host_runner; + void run() { + for (int seed = 0; seed < 5; seed++) { + INFO("T=" << type_name::value); + INFO("N=" << N); + INFO("seed=" << seed); + + if constexpr (std::is_invocable_v) { + fun(); + } else if constexpr (std::is_invocable_v>) { + fun(generator(seed)); + } else { + fun(generator(seed), std::make_index_sequence {}); + } + } + } }; -template class F, typename... Ts, size_t... Is> -void run_on_host(type_sequence, size_sequence) { - run_combinations::template type>( - type_sequence {}, - size_sequence {}); -} - -template class F, typename... Ts> -void run_on_host(type_sequence = {}) { - run_on_host(type_sequence {}, size_sequence<1, 2, 3, 4, 5, 6, 7, 8> {}); -} - template __global__ void kernel(F fun, Args... args) { fun(args...); } -template class F, typename T, size_t N> +template struct device_runner { - template - void operator()(Args... args) { - static bool gpu_enabled = true; - if (!gpu_enabled) { - return; - } + F fun; - cudaError_t code = cudaSetDevice(0); - if (code != cudaSuccess) { - gpu_enabled = false; - WARN("skipping device code"); - return; + device_runner(F fun) : fun(fun) {} + + template + void run() { + if (cudaSetDevice(0) != cudaSuccess) { + FAIL("failed to initialize CUDA device, does this machine have a GPU?"); } - //SECTION("environment=GPU") - { - for (size_t i = 0; i < 5; i++) { - INFO("seed=" << i); - CHECK(cudaDeviceSynchronize() == cudaSuccess); - kernel<<<1, 1>>>(F {}, generator(i), args...); - CHECK(cudaDeviceSynchronize() == cudaSuccess); + for (int seed = 0; seed < 5; seed++) { + INFO("T=" << type_name::value); + INFO("N=" << N); + INFO("seed=" << seed); + + CHECK(cudaDeviceSynchronize() == cudaSuccess); + + if constexpr (std::is_invocable_v) { + kernel<<<1, 1>>>(fun); + } else if constexpr (std::is_invocable_v>) { + kernel<<<1, 1>>>(fun, generator(seed)); + } else { + kernel<<<1, 1>>>(fun, generator(seed), std::make_index_sequence {}); } + + CHECK(cudaDeviceSynchronize() == cudaSuccess); } } }; +} // namespace detail -template class F> -struct device_runner_helper { - template - using type = device_runner; -}; - -template class F, typename... Ts, size_t... Is> -void run_on_device(type_sequence, size_sequence) { - run_combinations::template type>( - type_sequence {}, - size_sequence {}); +template +void run_tests_host(F fun, type_sequence, size_sequence) { + detail::iterate_sizes(detail::host_runner(fun), size_sequence {}); } -template class F, typename... Ts> -void run_on_device(type_sequence = {}) { - run_on_device(type_sequence {}, size_sequence<1, 2, 3, 4, 5, 6, 7, 8> {}); +template +void run_tests_device(F fun, type_sequence, size_sequence) { + detail::iterate_sizes(detail::device_runner(fun), size_sequence {}); } -template class F, typename... Ts> -void run_on_host_and_device(type_sequence = {}) { - run_on_host(type_sequence {}); - run_on_device(type_sequence {}); -} +#define REGISTER_TEST_CASE_CPU(NAME, F, ...) \ + TEMPLATE_TEST_CASE(NAME " - CPU", "", __VA_ARGS__) { \ + run_tests_host(F {}, type_sequence {}, default_size_sequence {}); \ + CHECK("done"); \ + } + +#define REGISTER_TEST_CASE_GPU(NAME, F, ...) \ + TEMPLATE_TEST_CASE(NAME " - GPU", "[GPU]", __VA_ARGS__) { \ + run_tests_device(F {}, type_sequence {}, default_size_sequence {}); \ + CHECK("done"); \ + } + +#undef REGISTER_TEST_CASE +#define REGISTER_TEST_CASE(NAME, F, ...) \ + REGISTER_TEST_CASE_CPU(NAME, F, __VA_ARGS__) \ + REGISTER_TEST_CASE_GPU(NAME, F, __VA_ARGS__) diff --git a/tests/constant.cu b/tests/constant.cu new file mode 100644 index 0000000..a6e011c --- /dev/null +++ b/tests/constant.cu @@ -0,0 +1,40 @@ +#include "common.h" + +#define ASSERT_TYPE(A, B) ASSERT(std::is_same::value); + +struct constant_tests { + template + __host__ __device__ void operator()(generator gen) { + T value = gen.next(); + kf::vec vector = {gen.next(), gen.next()}; + + ASSERT_EQ(kf::make_constant(5.0) + value, T(5) + value); + ASSERT_EQ(value + kf::make_constant(5.0), value + T(5)); + ASSERT_EQ(kf::make_constant(5.0) + vector, T(5) + vector); + ASSERT_EQ(vector + kf::make_constant(5.0), vector + T(5)); + + ASSERT_EQ(kf::make_constant(5.0) - value, T(5) - value); + ASSERT_EQ(value - kf::make_constant(5.0), value - T(5)); + ASSERT_EQ(kf::make_constant(5.0) - vector, T(5) - vector); + ASSERT_EQ(vector - kf::make_constant(5.0), vector - T(5)); + + ASSERT_EQ(kf::make_constant(5.0) * value, T(5) * value); + ASSERT_EQ(value * kf::make_constant(5.0), value * T(5)); + ASSERT_EQ(kf::make_constant(5.0) * vector, T(5) * vector); + ASSERT_EQ(vector * kf::make_constant(5.0), vector * T(5)); + + // These results in division by zero for integers + // ASSERT_EQ(kf::make_constant(5.0) / value, T(5) / value); + // ASSERT_EQ(value / kf::make_constant(5.0), value / T(5)); + // ASSERT_EQ(kf::make_constant(5.0) / vector, T(5) / vector); + // ASSERT_EQ(vector / kf::make_constant(5.0), vector / T(5)); + // + // ASSERT_EQ(kf::make_constant(5.0) % value, T(5) % value); + // ASSERT_EQ(value % kf::make_constant(5.0), value % T(5)); + // ASSERT_EQ(kf::make_constant(5.0) % vector, T(5) % vector); + // ASSERT_EQ(vector % kf::make_constant(5.0), vector % T(5)); + } +}; + +REGISTER_TEST_CASE("constant tests", constant_tests, int, float, double) +REGISTER_TEST_CASE_GPU("constant tests", constant_tests, __half, __nv_bfloat16) diff --git a/tests/promotion.cu b/tests/promotion.cu new file mode 100644 index 0000000..3367c3a --- /dev/null +++ b/tests/promotion.cu @@ -0,0 +1,117 @@ +#include "common.h" + +// Check if combining type `vec` and `vec` results in `vec` +#define CHECK_PROMOTION(A, B, C) \ + CHECK(std::is_same() + kf::vec()), kf::vec>::value); + +TEST_CASE("type promotion") { + CHECK_PROMOTION(int, int, int); + CHECK_PROMOTION(int, float, float); + CHECK_PROMOTION(int, double, double); + // CHECK_PROMOTION(int, unsigned int, int); + CHECK_PROMOTION(int, bool, int); + CHECK_PROMOTION(int, __half, __half); + CHECK_PROMOTION(int, __nv_bfloat16, __nv_bfloat16); + // CHECK_PROMOTION(int, char, int); + CHECK_PROMOTION(int, signed char, int); + // CHECK_PROMOTION(int, unsigned char, int); + + CHECK_PROMOTION(float, int, float); + CHECK_PROMOTION(float, float, float); + CHECK_PROMOTION(float, double, double); + CHECK_PROMOTION(float, unsigned int, float); + CHECK_PROMOTION(float, bool, float); + CHECK_PROMOTION(float, __half, float); + CHECK_PROMOTION(float, __nv_bfloat16, float); + CHECK_PROMOTION(float, char, float); + CHECK_PROMOTION(float, signed char, float); + CHECK_PROMOTION(float, unsigned char, float); + + CHECK_PROMOTION(double, int, double); + CHECK_PROMOTION(double, float, double); + CHECK_PROMOTION(double, double, double); + CHECK_PROMOTION(double, unsigned int, double); + CHECK_PROMOTION(double, bool, double); + CHECK_PROMOTION(double, __half, double); + CHECK_PROMOTION(double, __nv_bfloat16, double); + CHECK_PROMOTION(double, char, double); + CHECK_PROMOTION(double, signed char, double); + CHECK_PROMOTION(double, unsigned char, double); + + // CHECK_PROMOTION(unsigned int, int, unsigned int); + CHECK_PROMOTION(unsigned int, float, float); + CHECK_PROMOTION(unsigned int, double, double); + CHECK_PROMOTION(unsigned int, unsigned int, unsigned int); + CHECK_PROMOTION(unsigned int, bool, unsigned int); + CHECK_PROMOTION(unsigned int, __half, __half); + CHECK_PROMOTION(unsigned int, __nv_bfloat16, __nv_bfloat16); + // CHECK_PROMOTION(unsigned int, char, unsigned int); + // CHECK_PROMOTION(unsigned int, signed char, unsigned int); + CHECK_PROMOTION(unsigned int, unsigned char, unsigned int); + + CHECK_PROMOTION(bool, int, int); + CHECK_PROMOTION(bool, float, float); + CHECK_PROMOTION(bool, double, double); + CHECK_PROMOTION(bool, unsigned int, unsigned int); + CHECK_PROMOTION(bool, bool, bool); + CHECK_PROMOTION(bool, __half, __half); + CHECK_PROMOTION(bool, __nv_bfloat16, __nv_bfloat16); + CHECK_PROMOTION(bool, char, char); + CHECK_PROMOTION(bool, signed char, signed char); + CHECK_PROMOTION(bool, unsigned char, unsigned char); + + CHECK_PROMOTION(__half, int, __half); + CHECK_PROMOTION(__half, float, float); + CHECK_PROMOTION(__half, double, double); + CHECK_PROMOTION(__half, unsigned int, __half); + CHECK_PROMOTION(__half, bool, __half); + CHECK_PROMOTION(__half, __half, __half); + CHECK_PROMOTION(__half, __nv_bfloat16, float); + CHECK_PROMOTION(__half, char, __half); + CHECK_PROMOTION(__half, signed char, __half); + CHECK_PROMOTION(__half, unsigned char, __half); + + CHECK_PROMOTION(__nv_bfloat16, int, __nv_bfloat16); + CHECK_PROMOTION(__nv_bfloat16, float, float); + CHECK_PROMOTION(__nv_bfloat16, double, double); + CHECK_PROMOTION(__nv_bfloat16, unsigned int, __nv_bfloat16); + CHECK_PROMOTION(__nv_bfloat16, bool, __nv_bfloat16); + CHECK_PROMOTION(__nv_bfloat16, __half, float); + CHECK_PROMOTION(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); + CHECK_PROMOTION(__nv_bfloat16, char, __nv_bfloat16); + CHECK_PROMOTION(__nv_bfloat16, signed char, __nv_bfloat16); + CHECK_PROMOTION(__nv_bfloat16, unsigned char, __nv_bfloat16); + + // CHECK_PROMOTION(char, int, char); + CHECK_PROMOTION(char, float, float); + CHECK_PROMOTION(char, double, double); + // CHECK_PROMOTION(char, unsigned int, char); + CHECK_PROMOTION(char, bool, char); + CHECK_PROMOTION(char, __half, __half); + CHECK_PROMOTION(char, __nv_bfloat16, __nv_bfloat16); + CHECK_PROMOTION(char, char, char); + // CHECK_PROMOTION(char, signed char, char); + // CHECK_PROMOTION(char, unsigned char, char); + + CHECK_PROMOTION(signed char, int, int); + CHECK_PROMOTION(signed char, float, float); + CHECK_PROMOTION(signed char, double, double); + // CHECK_PROMOTION(signed char, unsigned int, signed char); + CHECK_PROMOTION(signed char, bool, signed char); + CHECK_PROMOTION(signed char, __half, __half); + CHECK_PROMOTION(signed char, __nv_bfloat16, __nv_bfloat16); + // CHECK_PROMOTION(signed char, char, signed char); + CHECK_PROMOTION(signed char, signed char, signed char); + // CHECK_PROMOTION(signed char, unsigned char, signed char); + + // CHECK_PROMOTION(unsigned char, int, unsigned char); + CHECK_PROMOTION(unsigned char, float, float); + CHECK_PROMOTION(unsigned char, double, double); + CHECK_PROMOTION(unsigned char, unsigned int, unsigned int); + CHECK_PROMOTION(unsigned char, bool, unsigned char); + CHECK_PROMOTION(unsigned char, __half, __half); + CHECK_PROMOTION(unsigned char, __nv_bfloat16, __nv_bfloat16); + // CHECK_PROMOTION(unsigned char, char, unsigned char); + // CHECK_PROMOTION(unsigned char, signed char, unsigned char); + CHECK_PROMOTION(unsigned char, unsigned char, unsigned char); +} \ No newline at end of file diff --git a/tests/reduce.cu b/tests/reduce.cu index 4190f0e..73a6752 100644 --- a/tests/reduce.cu +++ b/tests/reduce.cu @@ -1,59 +1,170 @@ #include "common.h" -#include "kernel_float.h" -namespace kf = kernel_float; +struct reduction_tests { + template + __host__ __device__ void operator()(generator gen) { + // TODO: these tests do not consider special numbers: NaN, -Inf, +Inf, and -0.0 -__host__ __device__ bool is_close(double a, double b) { - return (isnan(a) && isnan(b)) || (isinf(a) && isinf(b)) || fabs(a - b) < 0.0001; -} + { + kf::vec a; + ASSERT_APPROX(kf::min(a), T(0)); + ASSERT_APPROX(kf::max(a), T(0)); + ASSERT_APPROX(kf::sum(a), T(0)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), false); + ASSERT_EQ(kf::count(a), 0); -__host__ __device__ bool is_close(__half a, __half b) { - return is_close(double(a), double(b)); -} + a = {T(1)}; + ASSERT_APPROX(kf::min(a), T(1)); + ASSERT_APPROX(kf::max(a), T(1)); + ASSERT_APPROX(kf::sum(a), T(1)); + ASSERT_APPROX(kf::product(a), T(1)); + ASSERT_EQ(kf::all(a), true); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 1); -__host__ __device__ bool is_close(long long a, long long b) { - return a == b; -} + a = {T(5)}; + ASSERT_APPROX(kf::min(a), T(5)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(5)); + ASSERT_APPROX(kf::product(a), T(5)); + ASSERT_EQ(kf::all(a), true); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 1); + } -__host__ __device__ bool is_close(int a, int b) { - return a == b; -} + { + kf::vec a = {T(0), T(0)}; + ASSERT_APPROX(kf::min(a), T(0)); + ASSERT_APPROX(kf::max(a), T(0)); + ASSERT_APPROX(kf::sum(a), T(0)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), false); + ASSERT_EQ(kf::count(a), 0); -template> -struct reduction_test; + a = {T(5), T(0)}; + ASSERT_APPROX(kf::min(a), T(0)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(5)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 1); -template -struct reduction_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec v {gen.next(Is)...}; + a = {T(5), T(-3)}; + ASSERT_APPROX(kf::min(a), T(-3)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(2)); + ASSERT_APPROX(kf::product(a), T(-15)); + ASSERT_EQ(kf::all(a), true); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 2); + } + + { + kf::vec a; + ASSERT_APPROX(kf::min(a), T(0)); + ASSERT_APPROX(kf::max(a), T(0)); + ASSERT_APPROX(kf::sum(a), T(0)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), false); + ASSERT_EQ(kf::count(a), 0); + + a = {T(5), T(0), T(-1)}; + ASSERT_APPROX(kf::min(a), T(-1)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(4)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 2); + + a = {T(5), T(-3), T(1)}; + ASSERT_APPROX(kf::min(a), T(-3)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(3)); + ASSERT_APPROX(kf::product(a), T(-15)); + ASSERT_EQ(kf::all(a), true); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 3); + } + + { + kf::vec a; + ASSERT_APPROX(kf::min(a), T(0)); + ASSERT_APPROX(kf::max(a), T(0)); + ASSERT_APPROX(kf::sum(a), T(0)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), false); + ASSERT_EQ(kf::count(a), 0); - bool b = (bool(v.get(Is)) && ...); - ASSERT(kf::all(v) == b); + a = {T(5), T(0), T(-1), T(0)}; + ASSERT_APPROX(kf::min(a), T(-1)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(4)); + ASSERT_APPROX(kf::product(a), T(0)); + ASSERT_EQ(kf::all(a), false); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 2); - b = (bool(v.get(Is)) || ...); - ASSERT(kf::any(v) == b); + a = {T(5), T(-3), T(1), T(-2)}; + ASSERT_APPROX(kf::min(a), T(-3)); + ASSERT_APPROX(kf::max(a), T(5)); + ASSERT_APPROX(kf::sum(a), T(1)); + ASSERT_APPROX(kf::product(a), T(30)); + ASSERT_EQ(kf::all(a), true); + ASSERT_EQ(kf::any(a), true); + ASSERT_EQ(kf::count(a), 4); + } + } +}; + +REGISTER_TEST_CASE("reductions", reduction_tests, int, float, double) +REGISTER_TEST_CASE_GPU("reductions", reduction_tests, __half, __nv_bfloat16) + +struct dot_mag_tests { + template + __host__ __device__ void operator()(generator gen) { + { + kf::vec a = {-1}; + kf::vec b = {2}; + ASSERT_APPROX(kf::dot(a, b), T(-2)); + ASSERT_APPROX(kf::mag(a), T(1)); + } + + { + kf::vec a = {3, -4}; + kf::vec b = {2, 1}; + ASSERT_APPROX(kf::dot(a, b), T(2)); + ASSERT_APPROX(kf::mag(a), T(5)); + } - T sum = v.get(0); - for (int i = 1; i < N; i++) { - sum = sum + v.get(i); + { + kf::vec a = {2, -3, 6}; + kf::vec b = {2, -1, 3}; + ASSERT_APPROX(kf::dot(a, b), T(25)); + ASSERT_APPROX(kf::mag(a), T(7)); } - ASSERT(is_close(kf::sum(v), sum)); - T minimum = v.get(0); - for (int i = 1; i < N; i++) { - minimum = kf::ops::min {}(minimum, v.get(i)); + { + kf::vec a = {2, -4, 5, 6}; + kf::vec b = {2, 1, -3, 1}; + ASSERT_APPROX(kf::dot(a, b), T(-9)); + ASSERT_APPROX(kf::mag(a), T(9)); } - ASSERT(is_close(kf::min(v), minimum)); - T maximum = v.get(0); - for (int i = 1; i < N; i++) { - maximum = kf::ops::max {}(maximum, v.get(i)); + { + kf::vec a = {1, -3, 4, 5, 7}; + kf::vec b = {2, 0, 1, -1, 2}; + ASSERT_APPROX(kf::dot(a, b), T(15)); + ASSERT_APPROX(kf::mag(a), T(10)); } - ASSERT(is_close(kf::max(v), maximum)); } }; -TEST_CASE("reduction operations") { - run_on_host_and_device(); - run_on_device(); -} +REGISTER_TEST_CASE("dot product/magnitude", dot_mag_tests, float, double) +REGISTER_TEST_CASE_GPU("dot product/magnitude", dot_mag_tests, __half, __nv_bfloat16) diff --git a/tests/swizzle.cu b/tests/swizzle.cu deleted file mode 100644 index e44394e..0000000 --- a/tests/swizzle.cu +++ /dev/null @@ -1,42 +0,0 @@ -#include "common.h" -#include "kernel_float.h" - -namespace kf = kernel_float; - -template> -struct swizzle_test; - -template -struct swizzle_test> { - __host__ __device__ void operator()(generator gen) { - T items[N] = {gen.next(Is)...}; - kf::vec a = {items[Is]...}; - - ASSERT(equals(items[0], kf::first(a))); - ASSERT(equals(items[N - 1], kf::last(a))); - - kf::vec b = kf::reversed(a); - ASSERT(equals(b[Is], items[N - Is - 1]) && ...); - - b = kf::rotate_left<1>(a); - ASSERT(equals(b[Is], items[(Is + 1) % N]) && ...); - - b = kf::rotate_right<1>(a); - ASSERT(equals(b[Is], items[(Is + N - 1) % N]) && ...); - - b = kf::rotate_left<2>(a); - ASSERT(equals(b[Is], items[(Is + 2) % N]) && ...); - - b = kf::rotate_right<2>(a); - ASSERT(equals(b[Is], items[(Is + N - 2) % N]) && ...); - - kf::vec c = kf::concat(a, T {}, a); - ASSERT(equals(c[Is], items[Is]) && ...); - ASSERT(equals(c[N], T {})); - ASSERT(equals(c[N + 1 + Is], items[Is]) && ...); - } -}; - -TEST_CASE("swizzle") { - run_on_host_and_device(); -} diff --git a/tests/triops.cu b/tests/triops.cu new file mode 100644 index 0000000..4b899b1 --- /dev/null +++ b/tests/triops.cu @@ -0,0 +1,29 @@ +#include "common.h" + +struct triops_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + T x[N] = {gen.next(I)...}; + T y[N] = {gen.next(I)...}; + T z[N] = {gen.next(I)...}; + + kf::vec a = {x[I]...}; + kf::vec b = {y[I]...}; + kf::vec c = {z[I]...}; + + kf::vec answer = kf::where(a, b, c); + ASSERT_EQ_ALL(answer[I], bool(x[I]) ? y[I] : z[I]); + + answer = kf::where(a, b); + ASSERT_EQ_ALL(answer[I], bool(x[I]) ? y[I] : T()); + + answer = kf::where(a); + ASSERT_EQ_ALL(answer[I], T(bool(x[I]))); + + answer = kf::fma(a, b, c); + ASSERT_EQ_ALL(answer[I], x[I] * y[I] + z[I]); + } +}; + +REGISTER_TEST_CASE("ternary operators", triops_tests, int, float, double) +REGISTER_TEST_CASE_GPU("ternary operators", triops_tests, __half, __nv_bfloat16) diff --git a/tests/unops.cu b/tests/unops.cu index 2075c19..e50d8d3 100644 --- a/tests/unops.cu +++ b/tests/unops.cu @@ -1,90 +1,73 @@ #include "common.h" -#include "kernel_float.h" -namespace kf = kernel_float; - -template> -struct int_test; - -template -struct int_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; +struct unops_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + T items[N] = {gen.next(I)...}; + kf::vec a = {items[I]...}; kf::vec b; b = -a; - ASSERT((b.get(Is) == -(a.get(Is))) && ...); + ASSERT(equals(b[I], T(-items[I])) && ...); b = ~a; - ASSERT((b.get(Is) == ~(a.get(Is))) && ...); + ASSERT(equals(b[I], T(~items[I])) && ...); b = !a; - ASSERT((b.get(Is) == !(a.get(Is))) && ...); + ASSERT(equals(b[I], T(!items[I])) && ...); } }; -template> -struct float_test; +REGISTER_TEST_CASE("unary operators", unops_tests, bool, int) -template -struct float_test> { - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; +struct unops_float_tests { + template + __host__ __device__ void operator()(generator gen, std::index_sequence) { + double items[N] = {gen.next(I)...}; + kf::vec a = {T(items[I])...}; kf::vec b; b = -a; - ASSERT(equals(-a.get(Is), b.get(Is)) && ...); + ASSERT(equals(b[I], T(-items[I])) && ...); - // just some examples - b = kf::cos(a); - ASSERT(equals(cos(a.get(Is)), b.get(Is)) && ...); + b = !a; + ASSERT(equals(b[I], T(!items[I])) && ...); - b = kf::floor(a); - ASSERT(equals(floor(a.get(Is)), b.get(Is)) && ...); + // Ideally, we would test all unary operators, but that would be a lot of work and not that useful since + // all operators are generators by the same macro. Instead, we only check a few of them + if constexpr (is_one_of) { + b = sqrt(a); + ASSERT(equals(b[I], hsqrt(T(items[I]))) && ...); - b = kf::abs(a); - ASSERT(equals(abs(a.get(Is)), b.get(Is)) && ...); + b = sin(a); + ASSERT(equals(b[I], hsin(T(items[I]))) && ...); - b = kf::sqrt(a); - ASSERT(equals(sqrt(a.get(Is)), b.get(Is)) && ...); - } -}; + b = cos(a); + ASSERT(equals(b[I], hcos(T(items[I]))) && ...); -template -struct float_test<__half, N, std::index_sequence> { - template - __host__ __device__ void operator()(generator gen) { - kf::vec a {gen.next(Is)...}; - kf::vec b; + b = log(a); + ASSERT(equals(b[I], hlog(T(items[I]))) && ...); - b = -a; - ASSERT(equals(__hneg(a.get(Is)), b.get(Is)) && ...); + b = exp(a); + ASSERT(equals(b[I], hexp(T(items[I]))) && ...); + } else { + b = sqrt(a); + ASSERT(equals(b[I], sqrt(T(items[I]))) && ...); - // just some examples - b = kf::cos(a); - ASSERT(equals(hcos(a.get(Is)), b.get(Is)) && ...); + b = sin(a); + ASSERT(equals(b[I], sin(T(items[I]))) && ...); - b = kf::floor(a); - ASSERT(equals(hfloor(a.get(Is)), b.get(Is)) && ...); + b = cos(a); + ASSERT(equals(b[I], cos(T(items[I]))) && ...); - b = kf::abs(a); - ASSERT(equals(__habs(a.get(Is)), b.get(Is)) && ...); + b = log(a); + ASSERT(equals(b[I], log(T(items[I]))) && ...); - b = kf::sqrt(a); - ASSERT(equals(hsqrt(a.get(Is)), b.get(Is)) && ...); + b = exp(a); + ASSERT(equals(b[I], exp(T(items[I]))) && ...); + } } }; -template -struct float_test<__nv_bfloat16, N, std::index_sequence> { - __host__ __device__ void operator()(generator<__nv_bfloat16> gen) { - float_test<__half, N> {}(gen); - } -}; - -TEST_CASE("unary operators") { - run_on_host_and_device(); - - run_on_host_and_device(); - run_on_device(); -} +REGISTER_TEST_CASE("unary float operators", unops_float_tests, float, double) +REGISTER_TEST_CASE_GPU("unary float operators", unops_float_tests, __half, __nv_bfloat16) \ No newline at end of file