Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using CDataType = F16; // C matrix doesn't exsitm this is used for verification
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
Expand Down Expand Up @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}

float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});

std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
Expand All @@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC

if(config.do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<CDataType> c_m_n(HostTensorDescriptor{M, N});

auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"

namespace ck {
namespace tensor_operation {
Expand Down Expand Up @@ -225,43 +226,28 @@ struct AddHardswish
};
};

// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}

template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;

template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;

const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c + d;

e = type_convert<E>(y);
FastGelu{}.template operator()<float, float>(e, x);
}

template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
static_assert(is_valid_param_type_v<D>);
const half_t x = c + d;

e = GetFastGeLU(c + type_convert<float>(d));
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
// Example:
//
// struct ExampleElementwiseOp
// {
Expand All @@ -29,19 +29,6 @@ namespace element_wise {
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };

struct AddReluAdd
{
Expand Down Expand Up @@ -142,7 +129,6 @@ struct AddHardswishAdd
}
};

// C = A * B
// E = C + D0 + D1
struct AddAdd
{
Expand Down Expand Up @@ -171,41 +157,33 @@ struct AddAdd
}
};

// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}

template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;

template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;

template <>
__host__ __device__ constexpr void operator()<float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float x = c + d0 + d1;

const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
FastGelu{}.template operator()<float, float>(e, x);
}

template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t x = c + d0 + d1;

e = type_convert<E>(y);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {

extern "C" __device__ float __ocml_native_recip_f32(float);

struct PassThrough
{
template <typename Y, typename X>
Expand Down Expand Up @@ -195,20 +197,65 @@ struct Relu

// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// device code use lower accuracy "__expf" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ void operator()(Y& y, const X& x) const;

template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;

template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
__host__ void operator()<float, float>(float& y, const float& x) const
{
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);

y = x * cdf;
}

template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;

this->operator()<float, float>(y_f, type_convert<float>(x));

y = type_convert<half_t>(y_f);
}

// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
#if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);

y = x * cdf;
#else
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u);
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);

y = x * cdf;
#endif
}

// device code, use lower precision "__expf" and "rcp"
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;

this->operator()<float, float>(y_f, type_convert<float>(x));

y = type_convert<half_t>(y_f);
}
};

// https://paperswithcode.com/method/gelu
Expand Down