diff --git a/example/19_gemm_softmax/CMakeLists.txt b/example/19_gemm_softmax/CMakeLists.txt new file mode 100644 index 00000000000..740f931e5ab --- /dev/null +++ b/example/19_gemm_softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_softmax_xdl_fp16 gemm_softmax_xdl_fp16.cpp) diff --git a/example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp b/example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp new file mode 100644 index 00000000000..13518c1a6e4 --- /dev/null +++ b/example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp @@ -0,0 +1,547 @@ +#include +#include +#include +#include +#include +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_reduce_util.hpp" +#include "host_reduction.hpp" + +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +#include "device_reduce_blockwise.hpp" +#include "reduction_enums.hpp" +#include "reduction_operator_mapping.hpp" +#include "device_binary_elementwise_2d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using AccDataType = F32; +using EltwiseComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + CDataType, // CShuffleDataType + PassThrough, // AElementwiseOperation + PassThrough, // BElementwiseOperation + PassThrough, // CElementwiseOperation + GemmDefault, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; +constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX; +constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD; +constexpr bool ReducePropagateNan = false; +using ReduceMaxOp = typename ck::reduce_binary_operator::opType; +using ReduceSumOp = typename ck::reduce_binary_operator::opType; +using ReduceMaxInElementwiseOperation = + typename ck::reduce_unary_operator::InElementwiseOperation; +using ReduceMaxAccElementwiseOperation = + typename ck::reduce_unary_operator::AccElementwiseOperation; +using ReduceSumInElementwiseOperation = typename ck:: + reduce_unary_operator::InElementwiseOperation; +using ReduceSumAccElementwiseOperation = typename ck:: + reduce_unary_operator::AccElementwiseOperation; + +using DeviceReduceMaxInstance = + ck::tensor_operation::device::DeviceReduceBlockWise; + +using DeviceReduceSumInstance = + ck::tensor_operation::device::DeviceReduceBlockWise; + +struct SubExp +{ + __host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst, + const EltwiseComputeDataType& src1, + const EltwiseComputeDataType& src2) const + { + dst = exp(src1 - src2); + } +}; + +struct Div +{ + __host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst, + const EltwiseComputeDataType& src1, + const EltwiseComputeDataType& src2) const + { + dst = src1 / src2; + } +}; + +using DeviceElementwiseSubExpInstance = + ck::tensor_operation::device::DeviceBinaryElementwise_2D; + +using DeviceElementwiseDivInstance = + ck::tensor_operation::device::DeviceBinaryElementwise_2D; + +using HostGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using HostReduceMaxInstance = ReductionHost; + +using HostReduceSumInstance = ReductionHost; + +template +void host_broadcast2D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor) +{ + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + ComputeDataType Amn = static_cast(A(m, n)); + ComputeDataType Cmn = 0; + if constexpr(broadcastDim == 0) + { + ComputeDataType Bn = static_cast(B(n)); + functor(Cmn, Amn, Bn); + } + else + { + ComputeDataType Bm = static_cast(B(m)); + functor(Cmn, Amn, Bm); + } + C(m, n) = static_cast(Cmn); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + const std::vector reduceDims{0}; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_n_max(std::vector({static_cast(N)}), + std::vector({1})); + Tensor exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor exp_n_sum(std::vector({static_cast(N)}), + std::vector({1})); + Tensor softmax_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + const auto c_m_n_shape = ck::to_int_vector(c_m_n.mDesc.GetLengths()); + const auto c_m_n_stride = ck::to_int_vector(c_m_n.mDesc.GetStrides()); + const auto reduce_n_shape = ck::to_int_vector(c_n_max.mDesc.GetLengths()); + const auto reduce_n_stride = ck::to_int_vector(c_n_max.mDesc.GetStrides()); + + size_t reduce_total_length = c_m_n.mDesc.GetElementSize() / c_n_max.mDesc.GetElementSize(); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n.mDesc << std::endl; + std::cout << "c_n_max: " << c_n_max.mDesc << std::endl; + std::cout << "exp_m_n: " << exp_m_n.mDesc << std::endl; + std::cout << "exp_n_sum: " << exp_n_sum.mDesc << std::endl; + std::cout << "softmax_m_n: " << softmax_m_n.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem c_n_max_device_buf(sizeof(CDataType) * c_n_max.mDesc.GetElementSpace()); + DeviceMem indices_device_buf(0); + DeviceMem exp_m_n_device_buf(sizeof(CDataType) * exp_m_n.mDesc.GetElementSpace()); + DeviceMem exp_n_sum_device_buf(sizeof(CDataType) * exp_n_sum.mDesc.GetElementSpace()); + DeviceMem softmax_m_n_device_buf(sizeof(CDataType) * softmax_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto gemm_invoker = gemm.MakeInvoker(); + auto gemm_argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + if(!gemm.IsSupportedArgument(gemm_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm_invoker.Run(gemm_argument, nrepeat); + + // do reduce max + auto reduce_max = DeviceReduceMaxInstance{}; + auto reduce_max_workspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims); + DeviceMem reduce_max_workaspace_device_buf(reduce_max_workspace_size); + + auto reduce_max_argument_ptr = reduce_max.MakeArgumentPointer( + c_m_n_shape, + c_m_n_stride, + reduce_n_shape, + reduce_n_stride, + reduceDims, + 1, + 0, + c_m_n_device_buf.GetDeviceBuffer(), + c_n_max_device_buf.GetDeviceBuffer(), + indices_device_buf.GetDeviceBuffer(), + reduce_max_workaspace_device_buf.GetDeviceBuffer(), + ReduceMaxInElementwiseOperation{static_cast(reduce_total_length)}, + ReduceMaxAccElementwiseOperation{static_cast(reduce_total_length)}); + + if(!reduce_max.IsSupportedArgument(reduce_max_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"); + }; + + auto reduce_max_invoker_ptr = reduce_max.MakeInvokerPointer(); + reduce_max_invoker_ptr->Run(reduce_max_argument_ptr.get(), nrepeat); + + // do broadcast sub and exp + auto broadcastSubExp = DeviceElementwiseSubExpInstance{}; + auto broadcastSubExp_argument_ptr = + broadcastSubExp.MakeArgumentPointer(c_m_n_device_buf.GetDeviceBuffer(), + c_n_max_device_buf.GetDeviceBuffer(), + exp_m_n_device_buf.GetDeviceBuffer(), + {M, N}, + {StrideC, 1}, + {0, 1}, + {StrideC, 1}, + SubExp{}, + 256); + + if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastSubExp_invoker_ptr = broadcastSubExp.MakeInvokerPointer(); + broadcastSubExp_invoker_ptr->Run(broadcastSubExp_argument_ptr.get(), nrepeat); + + // do reduce sum - denominator of softmax + auto reduce_sum = DeviceReduceSumInstance{}; + auto reduce_sum_workaspace_size = reduce_sum.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims); + DeviceMem reduce_sum_workaspace_device_buf(reduce_sum_workaspace_size); + + auto reduce_sum_argument_ptr = reduce_sum.MakeArgumentPointer( + c_m_n_shape, + c_m_n_stride, + reduce_n_shape, + reduce_n_stride, + reduceDims, + 1, // alpha + 0, // beta + exp_m_n_device_buf.GetDeviceBuffer(), + exp_n_sum_device_buf.GetDeviceBuffer(), + indices_device_buf.GetDeviceBuffer(), + reduce_sum_workaspace_device_buf.GetDeviceBuffer(), + ReduceSumInElementwiseOperation{static_cast(reduce_total_length)}, + ReduceSumAccElementwiseOperation{static_cast(reduce_total_length)}); + + if(!reduce_sum.IsSupportedArgument(reduce_sum_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"); + }; + + auto reduce_sum_invoker_ptr = reduce_sum.MakeInvokerPointer(); + reduce_sum_invoker_ptr->Run(reduce_sum_argument_ptr.get(), nrepeat); + + // do broadcast div + auto broadcastDiv = DeviceElementwiseDivInstance{}; + auto broadcastDiv_argument_ptr = + broadcastDiv.MakeArgumentPointer(exp_m_n_device_buf.GetDeviceBuffer(), + exp_n_sum_device_buf.GetDeviceBuffer(), + softmax_m_n_device_buf.GetDeviceBuffer(), + {M, N}, + {StrideC, 1}, + {0, 1}, + {StrideC, 1}, + Div{}, + 256); + + if(!broadcastDiv.IsSupportedArgument(broadcastDiv_argument_ptr.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise_2D instance, exiting!"); + }; + + auto broadcastDiv_invoker_ptr = broadcastDiv.MakeInvokerPointer(); + broadcastDiv_invoker_ptr->Run(broadcastDiv_argument_ptr.get(), nrepeat); + + if(do_verification) + { + std::cout << "verification..." << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); + c_n_max_device_buf.FromDevice(c_n_max.mData.data()); + exp_m_n_device_buf.FromDevice(exp_m_n.mData.data()); + exp_n_sum_device_buf.FromDevice(exp_n_sum.mData.data()); + softmax_m_n_device_buf.FromDevice(softmax_m_n.mData.data()); + + const std::vector reduceInvariantDims{1}; + Tensor host_c_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor host_c_n_max(std::vector({static_cast(N)}), + std::vector({1})); + Tensor host_indices(host_c_n_max.mDesc.GetLengths()); + Tensor host_exp_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor host_exp_n_sum(std::vector({static_cast(N)}), + std::vector({1})); + Tensor host_softmax_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + auto host_gemm = HostGemmInstance{}; + auto host_gemm_invoker = host_gemm.MakeInvoker(); + auto host_gemm_argument = host_gemm.MakeArgument( + a_m_k, b_k_n, host_c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + auto host_reduce_max = HostReduceMaxInstance{ + host_c_m_n.mDesc, host_c_n_max.mDesc, reduceInvariantDims, reduceDims}; + + auto host_reduce_sum = HostReduceSumInstance{ + host_exp_m_n.mDesc, host_exp_n_sum.mDesc, reduceInvariantDims, reduceDims}; + + host_gemm_invoker.Run(host_gemm_argument); + host_reduce_max.Run(1, // alpha + c_m_n.mData.data(), + 0, // beta + host_c_n_max.mData.data(), + host_indices.mData.data()); + + host_broadcast2D, + Tensor, + Tensor, + EltwiseComputeDataType, + SubExp, + 0>(host_exp_m_n, c_m_n, c_n_max, M, N, SubExp{}); + + host_reduce_sum.Run(1, // alpha + exp_m_n.mData.data(), + 0, // beta + host_exp_n_sum.mData.data(), + host_indices.mData.data()); + + host_broadcast2D, + Tensor, + Tensor, + EltwiseComputeDataType, + Div, + 0>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{}); + + bool result = true; + if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData)) + std::cout << "[PASS] - gemm" << std::endl; + if(result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData)) + std::cout << "[PASS] - reduce max" << std::endl; + if(result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData)) + std::cout << "[PASS] - broadcast sub + exp" << std::endl; + if(result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData)) + std::cout << "[PASS] - reduce sum" << std::endl; + if(result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData)) + std::cout << "[PASS] - broadcast div" << std::endl; + } + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5f041253056..6c43e9d948a 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -42,3 +42,4 @@ add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(15_grouped_gemm) add_subdirectory(16_gemm_reduce) add_subdirectory(18_batched_gemm_reduce) +add_subdirectory(19_gemm_softmax) diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp new file mode 100644 index 00000000000..eba2d7979f2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -0,0 +1,30 @@ +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBinaryElementwise : public BaseOperator +{ + + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + std::vector shape_a, + std::vector stride_a, + std::vector shape_b, + std::vector stride_b, + ElementwiseFunctor functor, + index_t threadPerBlock) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp new file mode 100644 index 00000000000..9d37a1d6439 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp @@ -0,0 +1,206 @@ +#pragma once +#include +#include + +#include "device.hpp" +#include "device_binary_elementwise.hpp" +#include "gridwise_binary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise +{ + static constexpr auto I0 = Number<0>{}; + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + const int m = shape[0]; + const int n = shape[1]; + + // 2d desc - [m, n] + const auto desc_m_n = + make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1])); + + // 1d desc - [m * n] + const auto desc_m0 = + transform_tensor_descriptor(desc_m_n, + make_tuple(make_merge_transform(make_tuple(m, n))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + // pad + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseBinEltwise = GridwiseBinaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + const std::vector& stride_c, + ElementwiseFunctor functor, + index_t threadPerBlock) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + functor_(functor), + threadPerBlock_(threadPerBlock), + gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_); + c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + CDataType* p_c_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + GridDesc_M0 c_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t threadPerBlock_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, int nrepeat = 1) + { + (void)arg; + const auto kernel = kernel_elementwise_1d; + float avgTime = 0; + if(nrepeat == 0) + { + launch_kernel(kernel, + dim3(arg.gridSize_), + dim3(arg.threadPerBlock_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.functor_); + } + else + { + avgTime = launch_and_time_kernel(kernel, + nrepeat, + dim3(arg.gridSize_), + dim3(arg.threadPerBlock_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.functor_); + } + return avgTime; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + // m * n + const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0); + + if(m0 % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + std::vector stride_c, + ElementwiseFunctor functor, + index_t threadPerBlock) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + shape, + stride_a, + stride_b, + stride_c, + functor, + threadPerBlock); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise_2D" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp new file mode 100644 index 00000000000..aea54ff53c4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -0,0 +1,150 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const GridDesc_M0 c_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseBinEltwise::Run(p_a_global, + p_b_global, + p_c_global, + a_grid_desc_m0, + b_grid_desc_m0, + c_grid_desc_m0, + functor); +} + +template +struct GridwiseBinaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ __host__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const GridDesc_M0 c_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + + const auto thread_to_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_to_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m0, thread_to_global_offset}; + + auto c_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + c_grid_desc_m0, thread_to_global_offset, PassThrough{}}; + + const index_t threadPerBlock = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = c_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * threadPerBlock * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(c_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{})); + }); + + c_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + c_thread_buf, + c_grid_desc_m0, + c_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); + c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index f742512d400..d2a689c1cca 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -7,10 +7,14 @@ __device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + __device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); } __device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_grid_size() { return gridDim.x; } +__device__ index_t get_block_size() { return blockDim.x; } + } // namespace ck