diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 6f165b19eb5230..43ebe4e1196caf 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -277,6 +277,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative"; csinfo_.fused_matmul = "_FusedMatMul"; csinfo_.fused_matmul_grad = "_FusedMatMulGrad"; + csinfo_.gelu = "Gelu"; + csinfo_.gelu_grad = "GeluGrad"; csinfo_.identity = "Identity"; csinfo_.leakyrelu = "LeakyRelu"; csinfo_.leakyrelu_grad = "LeakyReluGrad"; @@ -500,6 +502,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.gelu, mkl_op_registry::GetMklOpName(csinfo_.gelu), + CopyAttrsAll, GeluRewrite, kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.gelu_grad, + mkl_op_registry::GetMklOpName(csinfo_.gelu_grad), + CopyAttrsAll, GeluRewrite, kRewriteForLayoutPropagation}); rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsAll, RewriteIfAtleastOneMklInput, @@ -947,6 +954,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string fused_depthwise_conv2d; string fused_matmul; string fused_matmul_grad; + string gelu; + string gelu_grad; string identity; string leakyrelu; string leakyrelu_grad; @@ -1551,6 +1560,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + // MKL-DNN's Gelu only support approximate version, + // so we only rewrite Gelu to MKL OP when approximate is true + static bool GeluRewrite(const Node* n) { + DCHECK(n); + + bool approximate = false; + bool has_attr = TryGetNodeAttr(n->def(), "approximate", &approximate); + DCHECK(has_attr); + + // If approximate is true, rewrite the node. + // Otherwise eigen node is used instead. + if (approximate) { + return true; + } + VLOG(1) << "GeluRewrite: The model sets approximate is false " + << "which case is not optimized by Intel MKL, thus using Eigen op" + << "for Gelu "; + + return false; + } + // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized // path. The unoptimized path is slow. Thus we dont rewrite the node // and use default Eigen. But for depth_radius=2, MKL DNN optimized diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index cd34969aefe78f..59f9964724e1b1 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -2876,6 +2876,71 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Grad_Positive) { "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); } +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph("node { name: 'A' op: '" #INPUT \ + "'}" \ + "node { name: 'B' op: 'Gelu'" \ + " attr { key: 'T' value { type: " #T \ + " } }" \ + " attr { key: 'approximate' value { b: true } }" \ + " input: ['A']}" \ + "node { name: 'C' op: 'Zeta'" \ + "attr { key: 'T' value { type: " #T \ + " } }" \ + " input: ['A', 'B'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT \ + ");B(_MklGelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" \ + "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_Gelu_Positive); +#undef REGISTER_TEST + +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph("node { name: 'A' op: '" #INPUT \ + "'}" \ + "node { name: 'B' op: 'Gelu'" \ + " attr { key: 'T' value { type: " #T \ + " } }" \ + " attr { key: 'approximate' value { b: false } }" \ + " input: ['A']}" \ + "node { name: 'C' op: 'Zeta'" \ + "attr { key: 'T' value { type: " #T \ + " } }" \ + " input: ['A', 'B'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(Gelu);C(Zeta)|A->B;A->C;B->C:1"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_Gelu_Negative); +#undef REGISTER_TEST + +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph("node { name: 'A' op: '" #INPUT \ + "'}" \ + "node { name: 'B' op: '" #INPUT \ + "'}" \ + "node { name: 'C' op: 'GeluGrad'" \ + " attr { key: 'T' value { type: " #T \ + " } }" \ + " attr { key: 'approximate' value { b: true } }" \ + " input: ['A', 'B']}" \ + "node { name: 'D' op: 'Zeta'" \ + "attr { key: 'T' value { type: " #T \ + " } }" \ + " input: ['A', 'C'] }"); \ + EXPECT_EQ( \ + DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT \ + ");C(_MklGeluGrad);D(Zeta);DMT/_0(Const);" \ + "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \ + "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_GeluGrad_Positive); +#undef REGISTER_TEST + TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Relu6Grad_Positive) { InitGraph( "node { name: 'A' op: 'Input'}" diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index cafe5a1da04061..a233ce3e5d4415 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -8165,6 +8165,7 @@ tf_mkl_kernel_library( prefix = "mkl_relu", deps = [ ":bounds_check", + ":no_op", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index ffbc1e28355806..52a2d40808a280 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -19,14 +19,15 @@ limitations under the License. #include #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::algorithm; using mkldnn::eltwise_forward; @@ -1163,6 +1164,101 @@ class MklLeakyReluGradOp } }; +template +class MklGeluOp : public MklReluOpBase { + public: + ~MklGeluOp() {} + + explicit MklGeluOp(OpKernelConstruction* context) + : MklReluOpBase(context, 0.0f, 0.0f) { + bool approximate; + OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate)); + OP_REQUIRES( + context, approximate, + errors::InvalidArgument("MKL Gelu only supports approximate is true. " + "approximate is: ", + approximate)); + } + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t src_index = 0; // index of src input tensor + const size_t dst_index = 0; // index of dst output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + MklDnnShape dnn_shape_src; + GetMklShape(context, src_index, &dnn_shape_src); + + Tensor* dst_tensor = nullptr; + T* user_i = const_cast(src_tensor.flat().data()); + MklDnnShape dnn_shape_dst; + dnn_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, + src_tensor.shape(), dnn_shape_dst); + + T* out_o = dst_tensor->flat().data(); + T features = user_i[0]; + out_o[0] = + static_cast(0.5) * features * + (static_cast(1) + + std::tanh(static_cast(M_2_SQRTPI * M_SQRT1_2) * + (features + static_cast(0.044715) * + std::pow(features, static_cast(3))))); + return; + } +}; + +template +class MklGeluGradOp + : public MklReluGradOpBase { + public: + ~MklGeluGradOp() {} + + explicit MklGeluGradOp(OpKernelConstruction* context) + : MklReluGradOpBase(context, 0.0f, + 0.0f) { + bool approximate; + OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate)); + OP_REQUIRES( + context, approximate, + errors::InvalidArgument("MKL Gelu only supports approximate is true. " + "approximate is: ", + approximate)); + } + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t diff_dst_index = 0; // index of diff_dst input tensor + const size_t src_index = 1; // index of src input tensor + const size_t diff_src_index = 0; // index of diff_src output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); + Tensor* diff_src_tensor = nullptr; + + MklDnnShape dnn_shape_diff_dst; + GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); + + MklDnnShape dnn_shape_diff_src; + dnn_shape_diff_src.SetMklTensor(false); + AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + diff_dst_tensor.shape(), dnn_shape_diff_src); + T* out_o = diff_src_tensor->flat().data(); + T* user_i = const_cast(src_tensor.flat().data()); + T* user_g = const_cast(diff_dst_tensor.flat().data()); + + T features = user_i[0]; + const T kAlpha = static_cast(M_2_SQRTPI * M_SQRT1_2); + const T kBeta = kAlpha * static_cast(0.044715) * static_cast(3); + const auto y = std::tanh( + (kAlpha * + ((static_cast(0.044715) * std::pow(features, static_cast(3))) + + features))); + out_o[0] = user_g[0] * static_cast(0.5) * + ((-features * y * y + features) * + (kBeta * features * features + kAlpha) + + static_cast(1) + y); + + return; + } +}; + // register dnn kernels for supported operations and supported types #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER( \ @@ -1245,6 +1341,27 @@ TF_CALL_bfloat16(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_float(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); +#define REGISTER_GELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklGelu") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklGeluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklGeluGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklGeluGradOp); +TF_CALL_float(REGISTER_GELU_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_GELU_MKL_SUPPORTED_KERNELS_TYPES); + +REGISTER_KERNEL_BUILDER( + Name("Gelu").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); +REGISTER_KERNEL_BUILDER( + Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint("T"), NoOp); + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index a32323f2369252..9a67fe92dc5661 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1077,7 +1077,7 @@ REGISTER_OP("Dilation2DBackpropFilter") REGISTER_OP("Gelu") .Input("features: T") .Output("activations: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, float, double, bfloat16}") .Attr("approximate: bool = true") .SetShapeFn(shape_inference::UnchangedShape); @@ -1085,7 +1085,7 @@ REGISTER_OP("GeluGrad") .Input("gradients: T") .Input("features: T") .Output("backprops: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, float, double, bfloat16}") .Attr("approximate: bool = true") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); @@ -2209,6 +2209,40 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklGelu") + .Input("features: T") + .Input("mkl_features: uint8") + .Output("activations: T") + .Output("mkl_activations: uint8") + .Attr("T: {float, bfloat16} = DT_FLOAT") + .Attr("approximate: bool = true") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +MKL version of Gelu operator. Uses MKL DNN APIs to implement +Gelu operator. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklGeluGrad") + .Input("gradients: T") + .Input("features: T") + .Input("mkl_gradients: uint8") + .Input("mkl_features: uint8") + .Output("backprops: T") + .Output("mkl_backprops: uint8") + .Attr("T: {float, bfloat16} = DT_FLOAT") + .Attr("approximate: bool = true") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) + .Doc(R"doc( +MKL version of GeluGrad operator. Uses MKL DNN APIs to compute the +gradients for GeluGrad operation. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + REGISTER_OP("_MklElu") .Input("features: T") .Input("mkl_features: uint8")