Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/04_gemm_add_add_fastgelu/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct ExecutionConfig final
};

inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config)
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using D1DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = BF16;
using D1DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;

using ALayout = Row;
using BLayout = Col;
Expand All @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;

using ALayout = Row;
using BLayout = Col;
Expand All @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"
Expand All @@ -7,10 +6,11 @@ using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F32;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F32;

using ALayout = Row;
using BLayout = Col;
Expand All @@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ using ADataType = I4;
using BDataType = I4;
using AccDataType = I32;
using CShuffleDataType = I32;
using D0DataType = I4;
using D1DataType = I4;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I4;
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = I4;
using D1DataType = I4;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I4;

using KernelADataType = I8;
using KernelBDataType = I8;
Expand Down Expand Up @@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using D0DataType = I8;
using D1DataType = I8;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I8;
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = I8;
using D1DataType = I8;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I8;

using ALayout = Row;
using BLayout = Col;
Expand All @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC

if(config.do_verification)
{
Tensor<AccDataType> c_m_n({M, N});
Tensor<CDataType> c_m_n({M, N});

auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
Expand Down
3 changes: 3 additions & 0 deletions include/ck/ck.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@
// tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0

// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_383542 1

// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"

namespace ck {
namespace tensor_operation {
Expand Down Expand Up @@ -280,43 +281,42 @@ struct AddHardswish
};
};

// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}

template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;

template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;

template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float x = c + d;

FastGelu{}.template operator()<float, float>(e, x);
}

const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;

e = type_convert<E>(y);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}

template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
static_assert(is_valid_param_type_v<D>);
const float x0_f = c + d;

float x1_f = 0;

ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);

e = GetFastGeLU(c + type_convert<float>(d));
e = type_convert<half_t>(x1_f);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion
//
// Method 1:
// Example:
//
// struct ExampleElementwiseOp
// {
Expand All @@ -30,19 +30,6 @@ namespace element_wise {
// {
// }
// };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };

struct AddReluAdd
{
Expand Down Expand Up @@ -208,41 +195,74 @@ struct AddMultiply
}
};

// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;

template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
const float x = c + d0 + d1;

FastGelu{}.template operator()<float, float>(e, x);
}

template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <>
__host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
{
const half_t x = c + d0 + d1;

template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}

template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float x0_f = c + d0 + d1;

float x1_f = 0;

ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);

e = type_convert<half_t>(x1_f);
}

template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);

float x1_f = 0;

ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);

e = type_convert<bhalf_t>(x1_f);
}

template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
{
const float x0_f =
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);

float x1_f = 0;

const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);

e = type_convert<E>(y);
e = type_convert<int8_t>(x1_f);
}
};

Expand Down
Loading