Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused attention op backward and python layer. #36498

Merged
merged 62 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
f5eee9f
Add fused_attention_op: add impl wrappers.
limin2021 Sep 22, 2021
e16e3b3
Add fused_attention_op: forward.
limin2021 Sep 22, 2021
42f0372
Add fused_attention_op: forward impl.
limin2021 Sep 22, 2021
c6aebef
Remove useless code.
limin2021 Sep 22, 2021
2c0ab6c
Remove useless code.
limin2021 Sep 22, 2021
ece3c08
Remove docs.
limin2021 Sep 22, 2021
b18b405
Minors.
limin2021 Sep 22, 2021
b939159
Minors.
limin2021 Sep 23, 2021
07fd753
Update test_fused_attention_op.py
limin2021 Sep 23, 2021
ef89a94
Merge branch 'PaddlePaddle:develop' into fused_attention_op_2_fw
limin2021 Sep 23, 2021
b44d882
Remove static construction of python api.
limin2021 Sep 23, 2021
ff3df46
Modifications accordding to reviews.
limin2021 Sep 23, 2021
8a4c2a8
Modifications accordding to Xreki's review.
limin2021 Sep 26, 2021
739d9ca
Modifications unittest/cmakefile.txt.
limin2021 Sep 27, 2021
1d9e125
Fetch new fused_dropout_helper.h from #35843.
limin2021 Sep 27, 2021
4dd4260
Remove include fused_attention_op.h.
limin2021 Sep 27, 2021
2e3f4f2
Polish names of variants.
limin2021 Sep 27, 2021
f17c444
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
limin2021 Sep 29, 2021
13d4ff3
Revert "Polish names of variants."
limin2021 Oct 13, 2021
300ec35
Revert "Modifications accordding to Xreki's review."
limin2021 Oct 13, 2021
7b28f7c
Move fused_multi_head_attention from common.py.
limin2021 Oct 13, 2021
30fef54
Modify copyright and names with number.
limin2021 Oct 14, 2021
766ef85
Remove HIP and use OpTest and remove print.
limin2021 Oct 14, 2021
0bc03a6
Minors.
limin2021 Oct 14, 2021
99e36f9
Polish functional.fused_attention_op.
limin2021 Oct 14, 2021
2d9f727
Minors.
limin2021 Oct 14, 2021
f35b3c7
Remove commits of tools/__pycache__/.
limin2021 Oct 14, 2021
1433ba6
Minors.
limin2021 Oct 14, 2021
cf7be13
Add english doc for functional.fused_multi_head_attention
limin2021 Oct 14, 2021
ae875ca
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
limin2021 Oct 14, 2021
af08e6d
Add fused_attention bw and layer defs.
limin2021 Oct 18, 2021
eefe50a
Minors.
limin2021 Oct 18, 2021
10687a6
Add "#require gpu" for sample code in english doc.
limin2021 Oct 21, 2021
7a4c5ca
Merge remote-tracking branch 'origin/fused_attention_op_2_fw' into fu…
limin2021 Oct 21, 2021
56709c0
Add "require gpu" for layer's sample code.
limin2021 Oct 21, 2021
ea15676
Remove skip in unittest.
limin2021 Oct 21, 2021
13d25cd
Improve format of sample code.
limin2021 Oct 21, 2021
0f93775
Improve format of sample code.
limin2021 Oct 21, 2021
b5e01db
Merge remote-tracking branch 'origin/fused_attention_op_2_fw' into fu…
limin2021 Oct 21, 2021
92672c4
Add assert for unsupported parameters.
limin2021 Oct 21, 2021
3c4ce92
Minors.
limin2021 Oct 22, 2021
647ad38
Add assert for qkv_weight's shape and dim.
limin2021 Oct 22, 2021
99e210b
Merge branch 'develop' into fused_attention_bw
limin2021 Oct 22, 2021
b8a802e
Recover cmakefiles for typos.
limin2021 Oct 22, 2021
74338f8
Minors.
limin2021 Oct 22, 2021
6a79464
Improve english doc.
limin2021 Oct 24, 2021
8500174
Fix bugs in english doc.
limin2021 Oct 24, 2021
3cff8fe
Expose FusedMultiHeadAttention api in paddle.nn
limin2021 Oct 24, 2021
c1ce0c2
Add api to all[] and polish english docs.
limin2021 Oct 25, 2021
ffa11b5
Merge branch 'develop' into fused_attention_bw
limin2021 Oct 25, 2021
c65617c
Minors.
limin2021 Oct 25, 2021
43666eb
Minors.
limin2021 Oct 25, 2021
1f94fba
Move fused_attention function api path to incubate.
limin2021 Oct 25, 2021
abc2d20
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
limin2021 Oct 25, 2021
da2e6e8
Move fused_feedforward functional api path to incubate.
limin2021 Oct 25, 2021
6680652
Remove functional api in incubate/__init__.py.
limin2021 Oct 25, 2021
617a647
rm incubate/layer/fused_transformer.py.
limin2021 Oct 25, 2021
3be3a8e
Merge branch 'modify_fused_attention_functional_api_path' of https://…
limin2021 Oct 25, 2021
753abc5
Modify functional api path in sample code.
limin2021 Oct 25, 2021
dd16512
Merge branch 'modify_fused_attention_functional_api_path' of https://…
limin2021 Oct 25, 2021
2b3c379
Add FusedMultiHeadAttention api in incubate/nn/__init__.py
limin2021 Oct 26, 2021
5f54a0f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
limin2021 Oct 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
199 changes: 198 additions & 1 deletion paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
}
};

class FusedAttentionGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<bool>("attn_dropout_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));

OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");

if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));

ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
ctx->GetInputDim("QKTVOut"));
ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
ctx->GetInputDim("TransposeOut2"));
ctx->SetOutputDim(framework::GradVarName("QKOut"),
ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

template <typename T>
class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_attention_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));

// inputs x, parameters and their grad.
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
}
if (this->HasInput("Ln2Bias")) {
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}

op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW"));

// use forward outputs as backward inputs.
op->SetInput("LnOut", this->Output("LnOut"));
op->SetInput("LnMean", this->Output("LnMean"));
op->SetInput("LnVariance", this->Output("LnVariance"));
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut"));
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));

op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut"));

// backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut"));
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"),
this->OutputGrad("TransposeOut2"));
op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut"));
op->SetOutput(framework::GradVarName("SoftmaxOut"),
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));

op->SetAttrMap(this->Attrs());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionOpMaker);
ops::FusedAttentionOpMaker,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);