diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp index 6c7ca414448..6d0abd88038 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -6,7 +6,8 @@ using ADataType = F16; using BDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; +using CDataType = F16; // C matrix doesn't exsitm this is used for verification using D0DataType = F16; using D1DataType = F16; using DsDataType = ck::Tuple; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm c_m_n(HostTensorDescriptor{M, N}); + Tensor c_m_n(HostTensorDescriptor{M, N}); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 9ae3e18ed1a..18f081ba9a2 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" namespace ck { namespace tensor_operation { @@ -225,43 +226,28 @@ struct AddHardswish }; }; -// C = A * B // E = FastGelu(C + D) struct AddFastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) - { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; - } - - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const - { - static_assert(is_valid_param_type_v && is_valid_param_type_v && - is_valid_param_type_v); + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; - const float y = GetFastGeLU(type_convert(c) + type_convert(d)); + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c + d; - e = type_convert(y); + FastGelu{}.template operator()(e, x); } - template - __host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const { - static_assert(is_valid_param_type_v); + const half_t x = c + d; - e = GetFastGeLU(c + type_convert(d)); + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); } }; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 47d018095d2..9f4d43329fd 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -15,7 +15,7 @@ namespace element_wise { // Need to ensure compiler will fail if there is no matching candidate, instead of compiler // siliently do implicit type conversion // -// Method 1: +// Example: // // struct ExampleElementwiseOp // { @@ -29,19 +29,6 @@ namespace element_wise { // { // } // }; -// -// Method 2: -// -// template -// struct ExampleElementwiseOp; -// -// template <> -// struct ExampleElementwiseOp -// { -// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const -// { -// } -// }; struct AddReluAdd { @@ -142,7 +129,6 @@ struct AddHardswishAdd } }; -// C = A * B // E = C + D0 + D1 struct AddAdd { @@ -171,41 +157,33 @@ struct AddAdd } }; -// C = A * B // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) - { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; - } - - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || std::is_same_v -#endif - ; - template __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v && - is_valid_param_type_v && is_valid_param_type_v); + const float x = c + d0 + d1; - const float y = - GetFastGeLU(type_convert(c) + type_convert(d0) + type_convert(d1)); + FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t x = c + d0 + d1; - e = type_convert(y); + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); } }; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 699b05fe3c4..eb4e1eeda28 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -10,6 +10,8 @@ namespace ck { namespace tensor_operation { namespace element_wise { +extern "C" __device__ float __ocml_native_recip_f32(float); + struct PassThrough { template @@ -195,20 +197,65 @@ struct Relu // https://paperswithcode.com/method/gelu // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +// host code use higher accuracy "exp" and "div" +// device code use lower accuracy "__expf" and "rcp" function struct FastGelu { template - __host__ __device__ void operator()(Y& y, const X& x) const; + __host__ void operator()(Y& y, const X& x) const; + + template + __device__ void operator()(Y& y, const X& x) const; template <> - __host__ __device__ void operator()(float& y, const float& x) const + __host__ void operator()(float& y, const float& x) const { - const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float emu = exp(-u); - const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); y = x * cdf; } + + template <> + __host__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + // device code, use lower precision "__expf" and "rcp" + template <> + __device__ void operator()(float& y, const float& x) const + { +#if 0 + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = exp(-u); + const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); + + y = x * cdf; +#else + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = __expf(-u); + const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); + + y = x * cdf; +#endif + } + + // device code, use lower precision "__expf" and "rcp" + template <> + __device__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } }; // https://paperswithcode.com/method/gelu