Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,7 @@ tf_gen_op_libs(
"function_ops",
"functional_ops",
"fused_embedding_ops",
"fused_l2_normalize_ops",
"hash_ops",
"hash_training_ops",
"fuserecv_ops",
Expand Down Expand Up @@ -1439,6 +1440,7 @@ cc_library(
":function_ops_op_lib",
":functional_ops_op_lib",
":fused_embedding_ops_op_lib",
":fused_l2_normalize_ops_op_lib",
":fuserecv_ops_op_lib",
":hash_ops_op_lib",
":hash_training_ops_op_lib",
Expand Down Expand Up @@ -1623,6 +1625,7 @@ cc_library(
"//tensorflow/core/kernels:functional_ops",
"//tensorflow/core/kernels:fused_embedding_ops",
"//tensorflow/core/kernels/data:parquet_dataset_ops",
"//tensorflow/core/kernels:fused_l2_normalize_ops",
"//tensorflow/core/kernels:grappler",
"//tensorflow/core/kernels:hash_ops",
"//tensorflow/core/kernels:histogram_op",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
op {
graph_op_name: "FusedL2Normalize"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
op {
graph_op_name: "FusedL2NormalizeGrad"
}
31 changes: 31 additions & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5403,6 +5403,37 @@ tf_cc_test(
],
)

tf_kernel_library(
name = "fused_l2_normalize_ops",
srcs = [
"fused_l2_normalize/fused_l2_normalize_op.cc",
],
hdrs = ["fused_l2_normalize/compile_util.h"],
deps = ["//third_party/eigen3"] + DYNAMIC_DEPS + mkl_deps(),
)

tf_cc_test(
name = "fused_l2_normalize_ops_test",
size = "small",
srcs = ["fused_l2_normalize/fused_l2_normalize_op_test.cc",
"fused_l2_normalize/fused_l2_normalize_grad_op_test.cc"],
deps = [
":fused_l2_normalize_ops",
":ops_testutil",
":ops_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_kernel_library(
name = "run_graph_op",
prefix = "run_graph_op",
Expand Down
41 changes: 41 additions & 0 deletions tensorflow/core/kernels/fused_l2_normalize/compile_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_
#define TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"

namespace tensorflow {
namespace functor {

#include <type_traits>

// A class for forced loop unrolling at compile time
template <int i>
struct compile_time_for {
template <typename Lambda, typename... Args>
inline static void op(const Lambda& function, Args... args) {
compile_time_for<i-1>::op(function, args...);
function(std::integral_constant<int, i-1>{}, args...);
}
};
template <>
struct compile_time_for<1> {
template <typename Lambda, typename... Args>
inline static void op(const Lambda& function, Args... args) {
function(std::integral_constant<int, 0>{}, args...);
}
};
template <>
struct compile_time_for<0> {
// 0 loops, do nothing
template <typename Lambda, typename... Args>
inline static void op(const Lambda& function, Args... args) {
}
};

} // namespace functor
} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_


Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/cc/ops/standard_ops.h"

namespace tensorflow {
namespace {

enum class Device { CPU, GPU };

class FusedL2NormalizeGradOpTest : public OpsTestBase {
protected:
void MakeOpAndSetDevice(Device device, DataType dtype, int axis, float epsilon) {
TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize_grad", "FusedL2NormalizeGrad")
.Attr("T", dtype)
.Attr("T", dtype)
.Attr("axis", axis)
.Attr("epsilon", epsilon)
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
}
};

TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) {
const int rows = 4;
const int cols = 252; //128+64+32+16+8+4=252 1008

MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12);

// y_grad
float y_grad_array[1008];
for (int i = 0; i < rows * cols; i++) {
y_grad_array[i] = 1.0;
}
y_grad_array[251] = 2.0;
y_grad_array[503] = 2.0;
y_grad_array[755] = 2.0;
y_grad_array[1007] = 2.0;
AddInputFromArray<float>(TensorShape({rows, cols}), y_grad_array);

// x
float x_array[1008];
for (int i = 0; i < rows * cols; i++) {
x_array[i] = 1.0;
}
AddInputFromArray<float>(TensorShape({rows, cols}), x_array);

TF_ASSERT_OK(RunOpKernel());
TF_EXPECT_OK(device_->Sync());

{
Tensor expected_output(allocator(), DT_FLOAT,
TensorShape({rows, cols}));
float output_array[1008];
for (int i = 0; i < rows * cols; i++) {
output_array[i] = - 1.0 / (252 * std::sqrt(252));
}
output_array[251] = 251.0 / (252 * std::sqrt(252));
output_array[503] = 251.0 / (252 * std::sqrt(252));
output_array[755] = 251.0 / (252 * std::sqrt(252));
output_array[1007] = 251.0 / (252 * std::sqrt(252));
test::FillValues<float>(&expected_output, output_array);
test::ExpectTensorNear<float>(expected_output, *GetOutput(0), 1e-6);
}
}

//----------------------------------------------------------------------------//
// Performance benchmarks //
//----------------------------------------------------------------------------//
static Graph* FusedL2NormalizeGrad(int rows, int cols) {
Graph* g = new Graph(OpRegistry::Global());
DataType dtype = DT_FLOAT;

Tensor in1(dtype, TensorShape({rows, cols}));
in1.flat<float>().setRandom();
Tensor in2(dtype, TensorShape({rows, cols}));
in2.flat<float>().setRandom();

Node* input_in1 = test::graph::Constant(g, in1);
Node* input_in2 = test::graph::Constant(g, in2);
auto nodeBuilder = NodeBuilder(g->NewName("n"), "FusedL2NormalizeGrad")
.Input(input_in1)
.Input(input_in2)
.Attr("T", dtype)
.Attr("axis", 0)
.Attr("epsilon", 1e-12);
TF_CHECK_OK(nodeBuilder.Finalize(g, nullptr));

return g;
}

#define BM_FusedL2NormGrad(ROWS, COLS, NTH) \
static void BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * ROWS * COLS * 5); \
SessionOptions opts; \
opts.config.set_intra_op_parallelism_threads(NTH); \
test::Benchmark("cpu", FusedL2NormalizeGrad(ROWS, COLS), &opts).Run(iters); \
} \
BENCHMARK(BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU); \

#define BM_FusedL2NormGrad_NTH(ROWS, COLS) \
BM_FusedL2NormGrad(ROWS, COLS, 1); \
BM_FusedL2NormGrad(ROWS, COLS, 4); \
BM_FusedL2NormGrad(ROWS, COLS, 8); \

BM_FusedL2NormGrad_NTH(1024, 63);
BM_FusedL2NormGrad_NTH(1024, 127);
BM_FusedL2NormGrad_NTH(1024, 255);
BM_FusedL2NormGrad_NTH(1024, 511);
BM_FusedL2NormGrad_NTH(1024, 1023);

}
}
Loading