diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 9e3b401154a766..4d0e2c9101f1bd 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -282,6 +282,8 @@ bool IsGather(const NodeDef& node) { return op == "Gather" || op == "GatherV2"; } +bool IsGelu(const NodeDef& node) { return node.op() == "Gelu"; } + bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; } bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index b1624ac70c647d..ba88ff249631d0 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -83,6 +83,7 @@ bool IsFusedBatchNorm(const NodeDef& node); bool IsFusedBatchNormEx(const NodeDef& node); bool IsFusedBatchNormGrad(const NodeDef& node); bool IsGather(const NodeDef& node); +bool IsGelu(const NodeDef& node); bool IsGreater(const NodeDef& node); bool IsGreaterEqual(const NodeDef& node); bool IsHistogramSummary(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 80ebe424256070..37ea5a0c2acbef 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -393,7 +393,11 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) { } bool IsSupportedActivation(const NodeDef& node) { +#ifdef INTEL_MKL + return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsGelu(node); +#else return IsRelu(node) || IsRelu6(node) || IsElu(node); +#endif } inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) { diff --git a/tensorflow/core/kernels/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl_fused_ops_test.cc index edd1201a09c4f8..5432bcf1a592d9 100644 --- a/tensorflow/core/kernels/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl_fused_ops_test.cc @@ -878,6 +878,12 @@ class MklFusedMatMulOpTest : public OpsTestBase { next_op = ops::Elu(root.WithOpName(last_op), next_op); } + if (std::find(fused_ops.begin(), fused_ops.end(), "Gelu") != + fused_ops.end()) { + last_op = "with_gelu"; + next_op = ops::Gelu(root.WithOpName(last_op), next_op); + } + CommonTestUtilities::RunAndFetch(root, last_op, output); }; @@ -965,11 +971,21 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndElu) { {"BiasAdd", "Elu"}); } +TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndGelu) { + const int batch = 3; + const int input_channel = 4; + const int output_channel = 5; + + this->VerifyFusedMatMul(batch, input_channel, output_channel, + {"BiasAdd", "Gelu"}); +} + REGISTER_TYPED_TEST_CASE_P(MklFusedMatMulOpTest, // WithBias, // WithBiasAndRelu, // WithBiasAndRelu6, // - WithBiasAndElu); + WithBiasAndElu, // + WithBiasAndGelu); using MklFusedMatMulDataTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedMatMulOpTest, diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index 95a4f41a5af450..5804f884f0cca7 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -221,6 +221,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}}); } else if (post_op == "Elu") { params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}}); + } else if (post_op == "Gelu") { + params.post_op_params.push_back({"gelu", {1.0, 1.0, 0.0}}); } else { OP_REQUIRES_OK( ctx, errors::InvalidArgument( diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index c746e9c5036c61..442ba78918e83e 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -259,6 +259,13 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { float op_beta = post_op_param.param[2]; post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_elu, op_alpha, op_beta); + } else if (post_op_param.name == "gelu") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_gelu, op_alpha, + op_beta); } else if (post_op_param.name == "output_scale") { DCHECK_EQ(post_op_param.param.size(), 1); std::vector scales; @@ -268,6 +275,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { DCHECK((post_op_param.name == "relu") || (post_op_param.name == "relu6") || (post_op_param.name == "elu") || + (post_op_param.name == "gelu") || (post_op_param.name == "output_scale")); } } @@ -372,11 +380,12 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims); key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims); key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format); // Generate keys for post-ops for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { if (post_op_param.name == "relu" || post_op_param.name == "relu6" || - post_op_param.name == "elu") { + post_op_param.name == "elu" || post_op_param.name == "gelu") { DCHECK_EQ(post_op_param.param.size(), 3); key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.param[0]);