From c2407f4778f257718e9d6e6c63d45488da62b882 Mon Sep 17 00:00:00 2001 From: zhangbopd <1299246947@qq.com> Date: Fri, 24 Mar 2023 07:06:00 +0000 Subject: [PATCH 1/8] change judgement for DropoutGradGPUKernelDriver --- paddle/phi/kernels/funcs/dropout_impl.cu.h | 60 +++++++++++----------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index de82561c1fefb..6d0ceba2d9096 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -310,7 +310,7 @@ void DropoutFwGPUKernelDriver( auto* mask_data = mask->data(); size_t size = phi::product(mask->dims()); - if (dropout_prob == 1.0f) { + if (abs(dropout_prob / 1.0f - 1) < 5.96e-08) { #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS( hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); @@ -454,41 +454,43 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, // y = factor * x ScaleByDropoutFactor(dev_ctx, grad_y, grad_x, factor); } else { - phi::DenseTensor broadcasted_mask; - if (is_dropout_nd) { - broadcasted_mask.Resize(grad_y.dims()); - dev_ctx.template Alloc(&broadcasted_mask); - - std::vector broadcast_ins = {&mask}; - std::vector broadcast_outs = {&broadcasted_mask}; - phi::funcs::BroadcastKernel(dev_ctx, - broadcast_ins, - &broadcast_outs, - -1, - kps::IdentityFunctor()); - } - - std::vector ins = { - &grad_y, is_dropout_nd ? &broadcasted_mask : &mask}; - std::vector outs = {grad_x}; - if (upscale_in_train) { - if (dropout_prob == 1.0f) { + if (upscale_in_train && abs(dropout_prob / 1.0f - 1) < 5.96e-08) { #ifdef PADDLE_WITH_HIP - hipMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); + hipMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); #else - cudaMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); + cudaMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); #endif + } else { + MT factor = upscale_in_train + ? static_cast(1.0f / (1.0f - dropout_prob)) + : static_cast(1.0f); + if (is_dropout_nd) { + phi::DenseTensor broadcasted_mask; + + broadcasted_mask.Resize(grad_y.dims()); + dev_ctx.template Alloc(&broadcasted_mask); + + std::vector broadcast_ins = {&mask}; + std::vector broadcast_outs = {&broadcasted_mask}; + phi::funcs::BroadcastKernel(dev_ctx, + broadcast_ins, + &broadcast_outs, + -1, + kps::IdentityFunctor()); + + std::vector ins = {&grad_y, &broadcasted_mask}; + std::vector outs = {grad_x}; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); + } else { - MT factor = static_cast(1.0f / (1.0f - dropout_prob)); + std::vector ins = {&grad_y, &mask}; + std::vector outs = {grad_x}; phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } - } else { - MT factor = static_cast(1.0f); - phi::funcs::ElementwiseKernel( - dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } } } From 48ff6cb1d5265222a93c05267692b5c5eaab75b0 Mon Sep 17 00:00:00 2001 From: zhangbopd <1299246947@qq.com> Date: Tue, 28 Mar 2023 05:23:55 +0000 Subject: [PATCH 2/8] add UnrollerWithoutVecSize and after this Loaddata to be refined --- paddle/phi/kernels/funcs/broadcast_function.h | 134 ++++++++++-------- paddle/phi/kernels/funcs/elementwise_base.h | 27 +++- 2 files changed, 97 insertions(+), 64 deletions(-) diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index f96a1764c24a5..054d6cb06aa61 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -31,20 +31,28 @@ namespace funcs { enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; -template +template struct LoaderTypeClassifier { public: int64_t numel{0}; - int vec_size{1}; + int vec_size{4}; int broadcast_num{0}; bool all_elementwise{true}; - phi::Array use_broadcast; + phi::Array use_broadcast; phi::Array ins_data; + // phi::Array ins_data; LoaderTypeClassifier() {} LoaderTypeClassifier(const std::vector &ins, std::vector *outs) { + using Traits = phi::funcs::FunctionTraits; + using ArgsT = typename Traits::ArgsTuple; + ArgsT arg; + // The Arg VecSize=1 is to match the Unroller template. uint64_t out_addr = reinterpret_cast((*outs)[0]->data()); + + UnrollerWithoutVecSize::step(ins, arg, &vec_size); + for (auto i = 1; i < outs->size(); ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), @@ -56,28 +64,27 @@ struct LoaderTypeClassifier { out_addr = (out_addr | reinterpret_cast((*outs)[i]->data())); } - int out_vec_size = - phi::GetVectorizedSize(reinterpret_cast(out_addr)); - uint64_t in_addr = static_cast(0); + vec_size = std::min( + vec_size, + phi::GetVectorizedSize(reinterpret_cast(out_addr))); numel = (*outs)[0]->numel(); + // UnrollerWithoutVecSize::step(ins, &ins_data); + +#pragma unroll for (int i = 0; i < Arity; ++i) { + // FIXME: get type of ins_data find all parts used in_data auto in_data = ins[i]->data(); ins_data[i] = (const _ptr_ InT *)(in_data); - bool is_same_dim = ins[i]->numel() == numel; if (is_same_dim) { use_broadcast[i] = false; - in_addr = (in_addr | reinterpret_cast(in_data)); } else { use_broadcast[i] = true; broadcast_num++; } all_elementwise &= is_same_dim; } - int in_vec_size = std::min( - 4, phi::GetVectorizedSize(reinterpret_cast(in_addr))); - vec_size = std::min(out_vec_size, in_vec_size); } }; @@ -89,7 +96,7 @@ struct BroadcastDataLoader { T args[Arity][VecSize], const phi::Array &ins, const phi::Array &configs, - const phi::Array &use_broadcast, + const phi::Array &use_broadcast, const int block_offset, const int num, const uint32_t numel) { @@ -114,7 +121,7 @@ struct BroadcastDataLoader { T args[Arity][VecSize], const phi::Array &ins, const phi::Array &configs, - const phi::Array &use_broadcast, + const phi::Array &use_broadcast, const int block_offset, const int num, const uint32_t numel) { @@ -140,7 +147,7 @@ struct BroadcastDataLoader { T args[Arity][VecSize], const phi::Array &ins, const phi::Array &configs, - const phi::Array &use_broadcast, + const phi::Array &use_broadcast, const int block_offset, const int num, const uint32_t numel) { @@ -168,7 +175,7 @@ struct BroadcastDataLoader { T args[Arity][VecSize], const phi::Array &ins, const phi::Array &configs, - const phi::Array &use_broadcast, + const phi::Array &use_broadcast, const int block_offset, const int num, const uint32_t numel) { @@ -224,7 +231,7 @@ template &ins, phi::Array<_ptr_ OutT *, NumOuts> outs, - const phi::Array &use_broadcast, + const phi::Array &use_broadcast, const uint32_t numel, const phi::Array &configs, int num, @@ -274,7 +281,7 @@ template ins, phi::Array<_ptr_ OutT *, NumOuts> outs, - phi::Array use_broadcast, + phi::Array use_broadcast, uint32_t numel, phi::Array configs, int main_offset, @@ -363,7 +370,7 @@ __global__ void VectorizedBroadcastKernel( template @@ -371,9 +378,9 @@ void LaunchBroadcastKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, - Func func, + Functor func, const phi::Array &configs, - const LoaderTypeClassifier &loader_classifier) { + const LoaderTypeClassifier &loader_classifier) { phi::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (_ptr_ OutT *)(ctx.Alloc((*outs)[i])); @@ -388,7 +395,7 @@ void LaunchBroadcastKernel( int main_offset = (numel / (read_lens * threads)) * read_lens * threads; int tail_tid = numel % (read_lens * threads); - VectorizedBroadcastKernel + VectorizedBroadcastKernel <<>>(loader_classifier.ins_data, outs_data, loader_classifier.use_broadcast, @@ -409,7 +416,7 @@ void LaunchBroadcastKernel( int tail_tid = numel % (VecSize * threads); if (loader_classifier.all_elementwise) { - VectorizedBroadcastKernel (Arity >> 1)) { constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed; - VectorizedBroadcastKernel + VectorizedBroadcastKernel <<>>(loader_classifier.ins_data, outs_data, loader_classifier.use_broadcast, @@ -438,7 +451,13 @@ void LaunchBroadcastKernel( VecSize, func); } else { - VectorizedBroadcastKernel + VectorizedBroadcastKernel <<>>(loader_classifier.ins_data, outs_data, loader_classifier.use_broadcast, @@ -847,6 +866,7 @@ template void BroadcastKernelForDifferentVecSize( const KPDevice &ctx, @@ -854,43 +874,14 @@ void BroadcastKernelForDifferentVecSize( std::vector *outs, int axis, Functor func) { - using Traits = phi::funcs::FunctionTraits; - const int kArity = - Traits::has_pointer_args ? static_cast(ET) : Traits::arity; - PADDLE_ENFORCE_EQ( - ins.size(), - kArity, - phi::errors::InvalidArgument("The number of inputs is expected to be " - "equal to the " - "arity of functor. But received: the " - "number of inputs " - "is %d, the arity of functor is %d.", - ins.size(), - kArity)); - PADDLE_ENFORCE_LE( - kArity, - 3, - phi::errors::InvalidArgument("Currently only broadcast of ternary is " - "supported " - "and verified, but received %d.", - kArity)); - PADDLE_ENFORCE_EQ( - outs->size(), - NumOuts, - phi::errors::InvalidArgument("Number of outputs shall equal to number " - "of functions, " - "but number of outputs is %d, of " - "functions is %d.", - outs->size(), - NumOuts)); - #ifndef PADDLE_WITH_XPU_KP constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3); bool use_int64_index_kernel = kEnabledInt64IndexKernel && (*outs)[0]->numel() >= std::numeric_limits::max(); if (use_int64_index_kernel) { - auto loader_classifier = LoaderTypeClassifier(ins, outs); + auto loader_classifier = + LoaderTypeClassifier(ins, outs); switch (loader_classifier.vec_size) { case VecSizeL: { LaunchBroadcastKernelWithInt64IndexHelper(); + auto loader_classifier = LoaderTypeClassifier(); const auto dims_simplifier = BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); if (VLOG_IS_ON(6)) { @@ -968,7 +959,8 @@ void BroadcastKernelForDifferentVecSize( bool is_optimize = configs[0].cmp_type != type; int vec_size = is_optimize ? VecSizeL : VecSizeM; #else - auto loader_classifier = LoaderTypeClassifier(ins, outs); + auto loader_classifier = + LoaderTypeClassifier(ins, outs); if (!loader_classifier.all_elementwise) { const auto dims_simplifier = BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); @@ -1013,6 +1005,8 @@ void BroadcastKernelForDifferentVecSize( } } +// FIXME: delete (ElementwiseType ET & typename InT) +// default: axis = 1 template ; + const int kArity = + Traits::has_pointer_args ? static_cast(ET) : Traits::arity; + PADDLE_ENFORCE_EQ( + ins.size(), + kArity, + phi::errors::InvalidArgument("The number of inputs is expected to be " + "equal to the " + "arity of functor. But received: the " + "number of inputs " + "is %d, the arity of functor is %d.", + ins.size(), + kArity)); + PADDLE_ENFORCE_EQ( + outs->size(), + NumOuts, + phi::errors::InvalidArgument("Number of outputs shall equal to number " + "of functions, " + "but number of outputs is %d, of " + "functions is %d.", + outs->size(), + NumOuts)); + int max_rank = 0; int min_rank = phi::DDim::kMaxRank; for (auto *in : ins) { @@ -1037,7 +1055,7 @@ void BroadcastKernel(const KPDevice &ctx, max_rank = std::max(max_rank, (*outs)[0]->dims().size()); } axis = axis == -1 ? max_rank - min_rank : axis; - BroadcastKernelForDifferentVecSize( + BroadcastKernelForDifferentVecSize( ctx, ins, outs, axis, func); } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 1d40d4d8c2957..e2a8d0b569819 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -35,7 +35,7 @@ namespace kps = phi::kps; namespace phi { -enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; +enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template @@ -508,6 +508,22 @@ struct Unroller { static HOSTDEVICE inline void step(Args &&...args) {} }; +// static unroller without VecSize for broadcast +template