diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc index 334cdc6736f1e..a84c53e33a106 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc @@ -572,15 +572,17 @@ fused_attention_dygraph_function( egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut); grad_node->SetGradOutMeta(SoftmaxOut, 19); - auto AttnDropoutOut_accumulation_node = - std::make_shared( - p_autograd_AttnDropoutOut); - egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0); - egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut, - AttnDropoutOut_accumulation_node); - AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0); - egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut); - grad_node->SetGradOutMeta(AttnDropoutOut, 20); + if (AttnDropoutOut.initialized()) { + auto AttnDropoutOut_accumulation_node = + std::make_shared( + p_autograd_AttnDropoutOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0); + egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut, + AttnDropoutOut_accumulation_node); + AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0); + egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut); + grad_node->SetGradOutMeta(AttnDropoutOut, 20); + } auto FMHAOut_accumulation_node = std::make_shared(p_autograd_FMHAOut); diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h index 7e0d679689c4a..82e1b24447987 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h @@ -476,7 +476,7 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase { SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); } void SetTensorWrapperSrcMask(const paddle::experimental::Tensor& SrcMask) { - SrcMask_ = egr::TensorWrapper(SrcMask, false); + SrcMask_ = egr::TensorWrapper(SrcMask, true); } void SetTensorWrapperSrcMaskOut( const paddle::experimental::Tensor& SrcMaskOut) { diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index fc5f9cf71d349..9a518e7bcecd5 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -104,7 +104,6 @@ class FMHARef { T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); T* fmha_out_data = fmha_out_tensor->data(); auto out_seq_len = seq_len_; @@ -221,6 +220,7 @@ class FMHARef { dropout_mask_out_tensor, dropout_out_tensor, false); + T* dropout_out_data = dropout_out_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -464,8 +464,6 @@ class FMHARef { const T* softmax_out_data = softmax_out_tensor.data(); T* softmax_out_grad_data = softmax_out_grad_tensor->data(); - const T* dropout_out_data = dropout_out_tensor.data(); - T* dropout_out_grad_data = dropout_out_grad_tensor->data(); T* qktv_out_grad_data = qktv_out_grad_tensor->data(); // transpose bw @@ -487,6 +485,7 @@ class FMHARef { int64_t stride_b = gemm_k * gemm_n; // bw: dy = x^t * dout if (dropout_param_.dropout_prob_) { + const T* dropout_out_data = dropout_out_tensor.data(); blas.BatchedGEMM(transA, transB, gemm_m, @@ -524,6 +523,7 @@ class FMHARef { stride_a = gemm_m * gemm_k; stride_b = gemm_k * gemm_n; if (dropout_param_.dropout_prob_) { + T* dropout_out_grad_data = dropout_out_grad_tensor->data(); blas.BatchedGEMM(transA, transB, gemm_m, diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 03c97ec345fb8..6a8289cd968c8 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -547,8 +547,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("QKOut")); ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), ctx->GetInputDim("SoftmaxOut")); - ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), - ctx->GetInputDim("AttnDropoutOut")); + if (ctx->HasOutput(framework::GradVarName("AttnDropoutOut"))) { + ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), + ctx->GetInputDim("AttnDropoutOut")); + } if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) { ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), @@ -709,7 +711,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer, "QKVOut", "QKOut", "QKTVOut", - "OutLinearOut"); + "OutLinearOut", + "SrcMask"); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index ef5087f0534e1..14f7408f12fd0 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -123,6 +123,10 @@ class FusedAttentionOpKernel : public framework::OpKernel { const float ln_epsilon = ctx.Attr("ln_epsilon"); float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); + const bool has_attn_dropout = (attn_dropout_rate != 0.0f); + DropoutParam dropout_param2(ctx, 0); + const bool has_dropout = (dropout_param2.dropout_prob != 0.0f); + bool is_test_1 = ctx.Attr("is_test"); auto &dropout_implementation_1 = ctx.Attr("attn_dropout_implementation"); @@ -171,11 +175,16 @@ class FusedAttentionOpKernel : public framework::OpKernel { src_mask_out->numel() * sizeof(T)); auto *softmax_out_data = dev_ctx.template Alloc( softmax_out, softmax_out->numel() * sizeof(T)); - auto *attn_dropout_mask_out_data = dev_ctx.template Alloc( - attn_dropout_mask_out, - attn_dropout_mask_out->numel() * sizeof(uint8_t)); - auto *attn_dropout_out_data = dev_ctx.template Alloc( - attn_dropout_out, attn_dropout_out->numel() * sizeof(T)); + auto *attn_dropout_mask_out_data = + has_attn_dropout ? dev_ctx.template Alloc( + attn_dropout_mask_out, + attn_dropout_mask_out->numel() * sizeof(uint8_t)) + : nullptr; + auto *attn_dropout_out_data = + has_attn_dropout + ? dev_ctx.template Alloc(attn_dropout_out, + attn_dropout_out->numel() * sizeof(T)) + : nullptr; auto *fmha_out_data = dev_ctx.template Alloc(fmha_out, fmha_out->numel() * sizeof(T)); @@ -187,8 +196,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { out_linear_out, out_linear_out->numel() * sizeof(T)); // get data ptr for bias+dropout+residual+layernorm - auto *dropout_mask_out_data = dev_ctx.template Alloc( - dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); + auto *dropout_mask_out_data = + has_dropout + ? dev_ctx.template Alloc( + dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)) + : nullptr; auto *final_out_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); @@ -248,7 +260,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { input_size, output_size, false); - DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, @@ -369,7 +380,11 @@ class FusedAttentionGradKernel : public framework::OpKernel { const float epsilon = ctx.Attr("epsilon"); const float ln2epsilon = ctx.Attr("ln_epsilon"); - float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); + const float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); + const bool has_attn_dropout = (attn_dropout_prob != 0.0f); + DropoutParam dropout_param2(ctx, 0); + const bool has_dropout = (dropout_param2.dropout_prob != 0.0f); + auto &dev_ctx = ctx.template device_context(); bool is_test_1 = ctx.Attr("is_test"); auto &dropout_implementation_1 = @@ -400,7 +415,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *qkv_bias = ctx.Input("QKVBias"); auto *out_linear_weight = ctx.Input("OutLinearW"); auto *out_linear_bias = ctx.Input("OutLinearBias"); - auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *out_linear_weight_data = out_linear_weight->data(); @@ -426,7 +440,8 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *softmax_out_data = softmax_out->data(); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->data(); - auto *dropout_mask_out_data = dropout_mask_out->data(); + auto *dropout_mask_out_data = + has_dropout ? dropout_mask_out->data() : nullptr; // output's grad auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -472,8 +487,11 @@ class FusedAttentionGradKernel : public framework::OpKernel { dev_ctx.template Alloc(d_qk_out, d_qk_out->numel() * sizeof(T)); auto *d_softmax_out_data = dev_ctx.template Alloc( d_softmax_out, d_softmax_out->numel() * sizeof(T)); - auto *d_attn_dropout_out_data = dev_ctx.template Alloc( - d_attn_dropout_out, d_attn_dropout_out->numel() * sizeof(T)); + auto *d_attn_dropout_out_data = + has_attn_dropout + ? dev_ctx.template Alloc(d_attn_dropout_out, + d_attn_dropout_out->numel() * sizeof(T)) + : nullptr; auto *d_src_mask_out_data = (src_mask == nullptr) ? nullptr @@ -573,7 +591,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { input_size, output_size, compute_bias); - DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, @@ -633,7 +650,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { if (qkv_bias != nullptr) { fmha_ref_compute.ComputeBackward(*transpose_out_2, - src_mask, + has_attn_dropout ? src_mask : nullptr, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, @@ -650,7 +667,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_qkv_bias_out); } else { fmha_ref_compute.ComputeBackward(*transpose_out_2, - src_mask, + has_attn_dropout ? src_mask : nullptr, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index a529271250e5d..c65364d2818d1 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -109,7 +109,8 @@ template + bool ScaleBiasWithSameTypeX = false, + bool HasDropout = true> __global__ void FusedLayernormResidualDropoutBias( const size_t rows, const size_t cols, @@ -133,7 +134,9 @@ __global__ void FusedLayernormResidualDropoutBias( int row_id = blockIdx.x; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); + if (HasDropout) { + curand_init(seed, idx, increment, &state); + } T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); @@ -151,21 +154,24 @@ __global__ void FusedLayernormResidualDropoutBias( VecSize, true, false, - phi::funcs::ReluFunctor>(row_id, - i, - cols, - &state, - dropout_prob, - factor, - src, - residual, - bias, - dst, - mask, - is_test, - &mean_val, - &var_val, - relu); + phi::funcs::ReluFunctor, + T, + T, + HasDropout>(row_id, + i, + cols, + &state, + dropout_prob, + factor, + src, + residual, + bias, + dst, + mask, + is_test, + &mean_val, + &var_val, + relu); } mean_val = BlockReduceSum(mean_val, shared_mean); @@ -197,6 +203,86 @@ __global__ void FusedLayernormResidualDropoutBias( invvar); } +template +void LaunchFusedLayernormResidualDropoutBiasCUDAKernel( + int grid_dim, + int block_dim, + gpuStream_t stream, + const size_t rows, + const size_t cols, + uint64_t seed, + const float dropout_prob, + const bool is_upscale_in_train, + const bool is_test, + const uint64_t increment, + const float epsilon, + const T *src, + const T *residual, + const T *bias, + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *layernorm_bias, + MaskType *mask, + T *dst, + T *layernorm_dst, + LayerNormParamType *mean, + LayerNormParamType *var) { + if (dropout_prob != 0.0f) { + FusedLayernormResidualDropoutBias + <<>>(rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + src, + residual, + bias, + scale, + layernorm_bias, + mask, + dst, + layernorm_dst, + mean, + var); + } else { + FusedLayernormResidualDropoutBias + <<>>(rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + src, + residual, + bias, + scale, + layernorm_bias, + mask, + dst, + layernorm_dst, + mean, + var); + } +} + /** * @brief layernorm(residual + dropout(src + bias)); * @param @@ -328,29 +414,32 @@ struct FusedLayernormResidualDropoutBiasFunctor { cudaStream_t stream) { int blockDim = GetDesiredBlockDim(cols / VecSize); if (mean != nullptr && var != nullptr) { - FusedLayernormResidualDropoutBias - <<>>(rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - is_test, - increment, - epsilon, - src, - residual, - bias, - scale, - layernorm_bias, - mask, - dst, - layernorm_dst, - mean, - var); + LaunchFusedLayernormResidualDropoutBiasCUDAKernel( + rows, + blockDim, + stream, + rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + src, + residual, + bias, + scale, + layernorm_bias, + mask, + dst, + layernorm_dst, + mean, + var); } else { FusedLayernormResidualDropoutBiasInfer(dropout_prob, is_upscale_in_train, is_test); @@ -510,7 +602,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( } MaskStoreT mask_vec[LDGS]; - if (!is_test) { + if (!is_test && HasDropout) { #pragma unroll for (int it = 0; it < LDGS; it++) { float rand[VecSize]; @@ -585,7 +677,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize); col += THREADS_PER_ROW; } - if (!is_test) { + if (!is_test && HasDropout) { #pragma unroll for (int it = 0, col = c; it < LDGS; it++) { phi::Store( @@ -787,62 +879,109 @@ void LaunchLayernormResidualDropoutBias( return; } -#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \ - case (cols): { \ - constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \ - constexpr int WARPS_M = 4 / WARPS_N; \ - const int THREADS_PER_WARP = 32; \ - const int BYTES_PER_LDG = 16; \ - const int VecSize = BYTES_PER_LDG / sizeof(T); \ - const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \ - const int ROWS_PER_CTA = WARPS_M; \ - const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; \ - const int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW * VecSize; \ - const int LDGS = cols / ELTS_PER_ROW_PER_CTA; \ - const int grid = \ - static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); \ - fused_fast_ln_fwd_kernel< \ - T, \ - U, \ - LayerNormScaleBiasT, \ - uint8_t, \ - VecSize, \ - WARPS_M, \ - WARPS_N, \ - BYTES_PER_LDG, \ - cols, \ - THREADS_PER_WARP, \ - THREADS_PER_ROW, \ - THREADS_PER_CTA, \ - ROWS_PER_CTA, \ - ELTS_PER_ROW_PER_CTA, \ - LDGS, \ - InType, \ - OutType> \ - <<>>(rows, \ - cols, \ - seed, \ - dropout_prob, \ - is_upscale_in_train, \ - is_test, \ - increment, \ - epsilon, \ - src, \ - residual, \ - bias, \ - scale, \ - layernorm_bias, \ - mask_data, \ - mean, \ - var, \ - dst, \ - layernorm_dst, \ - quant_last_in_scale, \ - dequant_out_scale_data, \ - quant_next_in_scale, \ - quant_round_type, \ - quant_max_bound, \ - quant_min_bound); \ +#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \ + case (cols): { \ + constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \ + constexpr int WARPS_M = 4 / WARPS_N; \ + const int THREADS_PER_WARP = 32; \ + const int BYTES_PER_LDG = 16; \ + const int VecSize = BYTES_PER_LDG / sizeof(T); \ + const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \ + const int ROWS_PER_CTA = WARPS_M; \ + const int THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP; \ + const int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW * VecSize; \ + const int LDGS = cols / ELTS_PER_ROW_PER_CTA; \ + const int grid = \ + static_cast(std::ceil(rows / static_cast(ROWS_PER_CTA))); \ + if (dropout_prob != 0.0f) { \ + fused_fast_ln_fwd_kernel< \ + true, \ + T, \ + U, \ + LayerNormScaleBiasT, \ + uint8_t, \ + VecSize, \ + WARPS_M, \ + WARPS_N, \ + BYTES_PER_LDG, \ + cols, \ + THREADS_PER_WARP, \ + THREADS_PER_ROW, \ + THREADS_PER_CTA, \ + ROWS_PER_CTA, \ + ELTS_PER_ROW_PER_CTA, \ + LDGS, \ + InType, \ + OutType> \ + <<>>(rows, \ + cols, \ + seed, \ + dropout_prob, \ + is_upscale_in_train, \ + is_test, \ + increment, \ + epsilon, \ + src, \ + residual, \ + bias, \ + scale, \ + layernorm_bias, \ + mask_data, \ + mean, \ + var, \ + dst, \ + layernorm_dst, \ + quant_last_in_scale, \ + dequant_out_scale_data, \ + quant_next_in_scale, \ + quant_round_type, \ + quant_max_bound, \ + quant_min_bound); \ + } else { \ + fused_fast_ln_fwd_kernel< \ + false, \ + T, \ + U, \ + LayerNormScaleBiasT, \ + uint8_t, \ + VecSize, \ + WARPS_M, \ + WARPS_N, \ + BYTES_PER_LDG, \ + cols, \ + THREADS_PER_WARP, \ + THREADS_PER_ROW, \ + THREADS_PER_CTA, \ + ROWS_PER_CTA, \ + ELTS_PER_ROW_PER_CTA, \ + LDGS, \ + InType, \ + OutType> \ + <<>>(rows, \ + cols, \ + seed, \ + dropout_prob, \ + is_upscale_in_train, \ + is_test, \ + increment, \ + epsilon, \ + src, \ + residual, \ + bias, \ + scale, \ + layernorm_bias, \ + mask_data, \ + mean, \ + var, \ + dst, \ + layernorm_dst, \ + quant_last_in_scale, \ + dequant_out_scale_data, \ + quant_next_in_scale, \ + quant_round_type, \ + quant_max_bound, \ + quant_min_bound); \ + } \ } break #define LAUNCH_FUSED_FAST_LN_KERNEL \ @@ -866,26 +1005,32 @@ void LaunchLayernormResidualDropoutBias( const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (cols % VecSize != 0) { int blockDim = GetDesiredBlockDim(cols); - FusedLayernormResidualDropoutBias - <<>>( - rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - is_test, - increment, - epsilon, - reinterpret_cast(src), - residual, - bias, - scale, - layernorm_bias, - mask_data, - dst, - reinterpret_cast(layernorm_dst), - mean, - var); + LaunchFusedLayernormResidualDropoutBiasCUDAKernel( + rows, + blockDim, + ctx.stream(), + rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + reinterpret_cast(src), + residual, + bias, + scale, + layernorm_bias, + mask_data, + dst, + reinterpret_cast(layernorm_dst), + mean, + var); } else { if (can_call_fast_ln_kernel) { switch (cols) { @@ -898,30 +1043,32 @@ void LaunchLayernormResidualDropoutBias( } } else { int blockDim = GetDesiredBlockDim(cols / VecSize); - FusedLayernormResidualDropoutBias - <<>>( - rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - is_test, - increment, - epsilon, - reinterpret_cast(src), - residual, - bias, - scale, - layernorm_bias, - mask_data, - dst, - reinterpret_cast(layernorm_dst), - mean, - var); + LaunchFusedLayernormResidualDropoutBiasCUDAKernel( + rows, + blockDim, + ctx.stream(), + rows, + cols, + seed, + dropout_prob, + is_upscale_in_train, + is_test, + increment, + epsilon, + reinterpret_cast(src), + residual, + bias, + scale, + layernorm_bias, + mask_data, + dst, + reinterpret_cast(layernorm_dst), + mean, + var); } } } diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu index f383d6846f946..f4353dd9bd4d7 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -290,7 +290,7 @@ struct TestFusedLayernormResidualDropoutBias { framework::TensorToVector(layernorm_out, *ctx, &_layernorm_out); framework::TensorToVector(means, *ctx, &_means); framework::TensorToVector(vars, *ctx, &_vars); - if (!is_test) { + if (!is_test && dropout_prob != 0.0f) { framework::TensorToVector(mask, *ctx, &_mask); } ctx->Wait(); @@ -298,7 +298,9 @@ struct TestFusedLayernormResidualDropoutBias { for (int i = 0; i < n; i++) { EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); EXPECT_LT(std::abs(_layernorm_out[i] - correct_layernorm_out[i]), diff); - if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); + if (!is_test && dropout_prob != 0.0f) { + EXPECT_EQ(_mask[i], correct_mask[i]); + } } for (int i = 0; i < rows; i++) { EXPECT_LT(std::abs(_means[i] - correct_means[i]), static_cast(diff)); diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 972bbe3326a5d..43c25bcaf10cb 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -30,7 +30,8 @@ template + typename OutType = T, + bool HasDropout = true> __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, @@ -84,7 +85,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( } MaskStoreT mask_vec; - if (!is_test) { + if (!is_test && HasDropout) { float rand[VecSize]; RandVec(state, rand); #pragma unroll @@ -114,8 +115,12 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( if (Activation) { tmp = act_func(tmp); } - dest_vec[ii] = - tmp * static_cast(mask_vec[ii]) * factor + residual_vec[ii]; + if (HasDropout) { + dest_vec[ii] = + tmp * static_cast(mask_vec[ii]) * factor + residual_vec[ii]; + } else { + dest_vec[ii] = tmp * factor + residual_vec[ii]; + } if (ComputeLayerNorm) { U tmp = static_cast(dest_vec[ii]); *mean_val += tmp; @@ -138,7 +143,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( phi::Store(dest_vec, reinterpret_cast(&dst[row_id * cols + col_id])); } - if (!is_test) { + if (!is_test && HasDropout) { phi::Store(mask_vec, &mask[row_id * cols + col_id]); } } @@ -154,7 +159,8 @@ template + typename OutType = T, + bool HasDropout = true> __global__ void FusedResidualDropoutBias( const size_t rows, const size_t cols, @@ -175,8 +181,15 @@ __global__ void FusedResidualDropoutBias( int row_id = blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); - const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); + if (HasDropout) { + curand_init(seed, idx, increment, &state); + } + T factor; + if (HasDropout) { + factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); + } else { + factor = static_cast(1); + } phi::funcs::ReluFunctor relu; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; @@ -188,24 +201,25 @@ __global__ void FusedResidualDropoutBias( false, phi::funcs::ReluFunctor, InType, - OutType>(r, - i, - cols, - &state, - dropout_prob, - factor, - src, - residual, - bias, - dst, - mask, - is_test, - nullptr, - nullptr, - relu, - quant_last_in_scale, - dequant_out_scale_data, - quant_next_in_scale); + OutType, + HasDropout>(r, + i, + cols, + &state, + dropout_prob, + factor, + src, + residual, + bias, + dst, + mask, + is_test, + nullptr, + nullptr, + relu, + quant_last_in_scale, + dequant_out_scale_data, + quant_next_in_scale); } } } @@ -256,43 +270,64 @@ void LaunchResidualDropoutBias(const uint32_t rows, const int VecSize = MAX_CACHE_BYTES / sizeof(T); const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); - if (cols % VecSize == 0) { - FusedResidualDropoutBias - <<>>( - rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - src, - residual, - bias, - mask_data, - dst, - increment, - is_test, - quant_last_in_scale, - dequant_out_scale_data, - quant_next_in_scale); + +#define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(__has_dropout) \ + do { \ + if (cols % VecSize == 0) { \ + FusedResidualDropoutBias \ + <<>>(rows, \ + cols, \ + seed, \ + dropout_prob, \ + is_upscale_in_train, \ + src, \ + residual, \ + bias, \ + mask_data, \ + dst, \ + increment, \ + is_test, \ + quant_last_in_scale, \ + dequant_out_scale_data, \ + quant_next_in_scale); \ + } else { \ + FusedResidualDropoutBias \ + <<>>(rows, \ + cols, \ + seed, \ + dropout_prob, \ + is_upscale_in_train, \ + src, \ + residual, \ + bias, \ + mask_data, \ + dst, \ + increment, \ + is_test, \ + quant_last_in_scale, \ + dequant_out_scale_data, \ + quant_next_in_scale); \ + } \ + } while (0) + + if (dropout_prob != 0.0f) { + PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(true); } else { - FusedResidualDropoutBias - <<>>( - rows, - cols, - seed, - dropout_prob, - is_upscale_in_train, - src, - residual, - bias, - mask_data, - dst, - increment, - is_test, - quant_last_in_scale, - dequant_out_scale_data, - quant_next_in_scale); + PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL(false); } + +#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_KERNEL } /* @@ -334,7 +369,8 @@ template + int VecSize, + bool HasDropout> __global__ void FusedResidualDropoutBiasGrad(const T *dout, const MaskType *mask, const T factor, @@ -350,6 +386,9 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, T tmp_sum[VecSize] = {static_cast(0)}; // calculate the dx and temporary sum + const bool not_need_dx = (dx == nullptr) || (dx == dout && !HasDropout && + factor == static_cast(1.0)); + if (col_id * VecSize < cols) { for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { int index = row_id * cols + col_id * VecSize; @@ -357,15 +396,27 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, MaskLoadT mask_vec; StoreT dx_vec; phi::Load(&dout[index], &out_vec); - phi::Load(&mask[index], &mask_vec); + if (HasDropout) { + phi::Load(&mask[index], &mask_vec); + } + if (not_need_dx) { +#pragma unroll + for (int i = 0; i < VecSize; i++) { + tmp_sum[i] += out_vec[i]; + } + } else { #pragma unroll - for (int i = 0; i < VecSize; i++) { - dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]) * factor; - tmp_sum[i] += out_vec[i]; + for (int i = 0; i < VecSize; i++) { + if (HasDropout) { + dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]) * factor; + } else { + dx_vec[i] = out_vec[i] * factor; + } + tmp_sum[i] += out_vec[i]; + } + phi::Store(dx_vec, &dx[index]); } - - phi::Store(dx_vec, &dx[index]); } } @@ -395,35 +446,68 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const int VecSize = MAX_CACHE_BYTES / sizeof(T); int real_vec_size = cols % VecSize == 0 ? VecSize : 1; - if (dbias != nullptr) { - const auto threads = 8; - auto blocks = std::max(static_cast(1), - (cols / real_vec_size + threads - 1) / threads); - dim3 block_dim(threads, 128, 1); - dim3 grid_dim(blocks, 1, 1); - if (cols % VecSize == 0) { - FusedResidualDropoutBiasGrad - <<>>( - dout, mask, factor, rows, cols, dx, dbias); - } else { - FusedResidualDropoutBiasGrad - <<>>( - dout, mask, factor, rows, cols, dx, dbias); - } + +#define PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(__has_dropout) \ + do { \ + if (dbias != nullptr) { \ + const auto threads = 8; \ + auto blocks = std::max(static_cast(1), \ + (cols / real_vec_size + threads - 1) / threads); \ + dim3 block_dim(threads, 128, 1); \ + dim3 grid_dim(blocks, 1, 1); \ + if (cols % VecSize == 0) { \ + FusedResidualDropoutBiasGrad \ + <<>>( \ + dout, mask, factor, rows, cols, dx, dbias); \ + } else { \ + FusedResidualDropoutBiasGrad \ + <<>>( \ + dout, mask, factor, rows, cols, dx, dbias); \ + } \ + } else { \ + if (dropout_prob == 0.0f) { \ + if (dx == nullptr || dx == dout) { \ + return; \ + } \ + memory::Copy(ctx.GetPlace(), \ + dx, \ + ctx.GetPlace(), \ + dout, \ + rows *cols * sizeof(T), \ + ctx.stream()); \ + } else { \ + const uint64_t n = rows * cols; \ + platform::GpuLaunchConfig config = \ + platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); \ + if (n % VecSize == 0) { \ + FusedResidualDropoutGrad \ + <<>>(dout, mask, factor, n, dx); \ + } else { \ + FusedResidualDropoutGrad \ + <<>>(dout, mask, factor, n, dx); \ + } \ + } \ + } \ + } while (0) + + if (dropout_prob != 0.0f) { + PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(true); } else { - const uint64_t n = rows * cols; - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); - if (n % VecSize == 0) { - FusedResidualDropoutGrad - <<>>( - dout, mask, factor, n, dx); - } else { - FusedResidualDropoutGrad - <<>>( - dout, mask, factor, n, dx); - } + PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL(false); } + +#undef PD_LAUNCH_FUSED_RESIDUAL_DROPOUT_BIAS_GRAD_KERNEL } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 79eb5f64cf0ec..d74bd9bfe1750 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -258,14 +258,14 @@ struct FusedResidualDropoutBiasTester { std::vector fused_out(n); std::vector fused_mask(n); framework::TensorToVector(out, *ctx, &fused_out); - if (!is_test) { + if (!is_test && dropout_prob != 0.0f) { framework::TensorToVector(mask, *ctx, &fused_mask); } ctx->Wait(); for (int i = 0; i < n; i++) { EXPECT_LT(std::abs(fused_out[i] - correct_out[i]), diff); - if (!is_test) { + if (!is_test && dropout_prob != 0.0f) { EXPECT_EQ(fused_mask[i], correct_mask[i]); } } diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index 3d1bd7490795d..30aa74978354e 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -502,7 +502,8 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block, } #ifdef PADDLE_WITH_CUDA -template (0), T *d_dropout_src_ptr = nullptr) { + static_assert( + !IsFusedDropoutResidualLn || NeedDDropoutSrcPtr, + "When IsFusedDropoutResidualLn = true, NeedDDropoutSrcPtr must be true."); + using Vec = phi::AlignedVector; using Vec_scale = phi::AlignedVector; using MaskLoadT = phi::AlignedVector; @@ -586,7 +591,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( phi::Load(dout_ptr + row * ELTS_PER_ROW + col * VecSize, &dout[it]); phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); - if (isFusedDropoutResidualLn) { + if (IsFusedDropoutResidualLn) { phi::Load( mask_ptr + row * ELTS_PER_ROW + col * VecSize, &mask_vec[it]); } @@ -672,7 +677,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( U dx_tmp = var_cur_row * (dy_tmp - sum_loss2 * y_tmp - sum_loss1); // Note: reuse x and dout vec register to store dx and d_dropout_src. x[it][jt] = static_cast(dx_tmp); - if (isFusedDropoutResidualLn) { + if (IsFusedDropoutResidualLn) { dout[it][jt] = x[it][jt] * static_cast(mask_vec[it][jt]) * factor; } } @@ -684,9 +689,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel( for (int it = 0; it < LDGS; it++) { phi::Store(x[it], dx_ptr + row * ELTS_PER_ROW + col * VecSize); - if (isFusedDropoutResidualLn) { + if (IsFusedDropoutResidualLn) { phi::Store( dout[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize); + } else if (NeedDDropoutSrcPtr) { + phi::Store( + x[it], d_dropout_src_ptr + row * ELTS_PER_ROW + col * VecSize); } col += THREADS_PER_ROW; } @@ -956,6 +964,7 @@ void ln_bwd_fast_kernel_driver(const phi::GPUContext &dev_ctx, } #define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \ fused_ln_bwd_fast_kernel