diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4745ec6fcb8..fcdc1c865ba 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1181,6 +1181,7 @@ tf_gen_op_libs( "function_ops", "functional_ops", "fused_embedding_ops", + "fused_l2_normalize_ops", "hash_ops", "hash_training_ops", "fuserecv_ops", @@ -1439,6 +1440,7 @@ cc_library( ":function_ops_op_lib", ":functional_ops_op_lib", ":fused_embedding_ops_op_lib", + ":fused_l2_normalize_ops_op_lib", ":fuserecv_ops_op_lib", ":hash_ops_op_lib", ":hash_training_ops_op_lib", @@ -1623,6 +1625,7 @@ cc_library( "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:fused_embedding_ops", "//tensorflow/core/kernels/data:parquet_dataset_ops", + "//tensorflow/core/kernels:fused_l2_normalize_ops", "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:hash_ops", "//tensorflow/core/kernels:histogram_op", diff --git a/tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt new file mode 100644 index 00000000000..403e88193c1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "FusedL2Normalize" +} diff --git a/tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt new file mode 100644 index 00000000000..86b652754b3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "FusedL2NormalizeGrad" +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ff8f572c246..804cb641935 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5403,6 +5403,37 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "fused_l2_normalize_ops", + srcs = [ + "fused_l2_normalize/fused_l2_normalize_op.cc", + ], + hdrs = ["fused_l2_normalize/compile_util.h"], + deps = ["//third_party/eigen3"] + DYNAMIC_DEPS + mkl_deps(), +) + +tf_cc_test( + name = "fused_l2_normalize_ops_test", + size = "small", + srcs = ["fused_l2_normalize/fused_l2_normalize_op_test.cc", + "fused_l2_normalize/fused_l2_normalize_grad_op_test.cc"], + deps = [ + ":fused_l2_normalize_ops", + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "run_graph_op", prefix = "run_graph_op", diff --git a/tensorflow/core/kernels/fused_l2_normalize/compile_util.h b/tensorflow/core/kernels/fused_l2_normalize/compile_util.h new file mode 100644 index 00000000000..646b503ce00 --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/compile_util.h @@ -0,0 +1,41 @@ +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +#include + +// A class for forced loop unrolling at compile time +template +struct compile_time_for { + template + inline static void op(const Lambda& function, Args... args) { + compile_time_for::op(function, args...); + function(std::integral_constant{}, args...); + } +}; +template <> +struct compile_time_for<1> { + template + inline static void op(const Lambda& function, Args... args) { + function(std::integral_constant{}, args...); + } +}; +template <> +struct compile_time_for<0> { + // 0 loops, do nothing + template + inline static void op(const Lambda& function, Args... args) { + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_ + + diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc new file mode 100644 index 00000000000..0af5a3c104a --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc @@ -0,0 +1,125 @@ +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace { + +enum class Device { CPU, GPU }; + +class FusedL2NormalizeGradOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, DataType dtype, int axis, float epsilon) { + TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize_grad", "FusedL2NormalizeGrad") + .Attr("T", dtype) + .Attr("T", dtype) + .Attr("axis", axis) + .Attr("epsilon", epsilon) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) { + const int rows = 4; + const int cols = 252; //128+64+32+16+8+4=252 1008 + + MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12); + + // y_grad + float y_grad_array[1008]; + for (int i = 0; i < rows * cols; i++) { + y_grad_array[i] = 1.0; + } + y_grad_array[251] = 2.0; + y_grad_array[503] = 2.0; + y_grad_array[755] = 2.0; + y_grad_array[1007] = 2.0; + AddInputFromArray(TensorShape({rows, cols}), y_grad_array); + + // x + float x_array[1008]; + for (int i = 0; i < rows * cols; i++) { + x_array[i] = 1.0; + } + AddInputFromArray(TensorShape({rows, cols}), x_array); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_output(allocator(), DT_FLOAT, + TensorShape({rows, cols})); + float output_array[1008]; + for (int i = 0; i < rows * cols; i++) { + output_array[i] = - 1.0 / (252 * std::sqrt(252)); + } + output_array[251] = 251.0 / (252 * std::sqrt(252)); + output_array[503] = 251.0 / (252 * std::sqrt(252)); + output_array[755] = 251.0 / (252 * std::sqrt(252)); + output_array[1007] = 251.0 / (252 * std::sqrt(252)); + test::FillValues(&expected_output, output_array); + test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); + } +} + +//----------------------------------------------------------------------------// +// Performance benchmarks // +//----------------------------------------------------------------------------// +static Graph* FusedL2NormalizeGrad(int rows, int cols) { + Graph* g = new Graph(OpRegistry::Global()); + DataType dtype = DT_FLOAT; + + Tensor in1(dtype, TensorShape({rows, cols})); + in1.flat().setRandom(); + Tensor in2(dtype, TensorShape({rows, cols})); + in2.flat().setRandom(); + + Node* input_in1 = test::graph::Constant(g, in1); + Node* input_in2 = test::graph::Constant(g, in2); + auto nodeBuilder = NodeBuilder(g->NewName("n"), "FusedL2NormalizeGrad") + .Input(input_in1) + .Input(input_in2) + .Attr("T", dtype) + .Attr("axis", 0) + .Attr("epsilon", 1e-12); + TF_CHECK_OK(nodeBuilder.Finalize(g, nullptr)); + + return g; +} + +#define BM_FusedL2NormGrad(ROWS, COLS, NTH) \ + static void BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * ROWS * COLS * 5); \ + SessionOptions opts; \ + opts.config.set_intra_op_parallelism_threads(NTH); \ + test::Benchmark("cpu", FusedL2NormalizeGrad(ROWS, COLS), &opts).Run(iters); \ + } \ + BENCHMARK(BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU); \ + +#define BM_FusedL2NormGrad_NTH(ROWS, COLS) \ + BM_FusedL2NormGrad(ROWS, COLS, 1); \ + BM_FusedL2NormGrad(ROWS, COLS, 4); \ + BM_FusedL2NormGrad(ROWS, COLS, 8); \ + +BM_FusedL2NormGrad_NTH(1024, 63); +BM_FusedL2NormGrad_NTH(1024, 127); +BM_FusedL2NormGrad_NTH(1024, 255); +BM_FusedL2NormGrad_NTH(1024, 511); +BM_FusedL2NormGrad_NTH(1024, 1023); + +} +} diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc new file mode 100644 index 00000000000..f0013d7bfd6 --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -0,0 +1,499 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/lib/core/threadpool.h" + +#include "compile_util.h" + +#include + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class FusedL2NormalizeOp : public OpKernel { +public: + explicit FusedL2NormalizeOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("axis", &axis)); + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); + } + + ~FusedL2NormalizeOp() { + } + + void Compute(OpKernelContext* context) override { + // Grab the input + const Tensor *input_tensor = &context->input(0); + const T *input = input_tensor->flat().data(); + + // To check the input + OP_REQUIRES(context, (input_tensor->dims() >= 2), + errors::InvalidArgument("Input dimension should be >= 2")); + + int64 cols = input_tensor->dim_size(input_tensor->dims() - 1); + int64 rows = 1; + for (int64 i = 0; i < input_tensor->dims() - 1; ++i) { + rows *= input_tensor->dim_size(i); + } + + // Create output tensors + Tensor *output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), + &output_tensor)); + T *output = output_tensor->flat().data(); + + // Let every thread compute 16 rows to avoid false sharing + #define BLOCK_SIZE 16 + const int64 total_unit = (rows + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64 unit_cost = BLOCK_SIZE * cols * 50; // assume every element consumes 50 cycles + + auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + thread::ThreadPool *thread_pool = worker_threads.workers; + +#ifdef __AVX512F__ + // Do forward in 128(8*16) block, avx512 handles 16 floats one time + int64 block_num = cols >> 7; // 128-size block nums + int64 remainder_128 = cols & 0x7F; // remainder of 128 + int64 remainder_16 = remainder_128 & 0x0F; // remainder of 16 + int64 remainder_block_num = remainder_128 >> 4; // 16-size block num in 128-remainder + int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; // total 16-size block num in remainder + + thread_pool->ParallelFor(total_unit, unit_cost, [&input, &output, rows, cols, block_num, remainder_block_num, + remainder_block_num_total, remainder_128, remainder_16, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } + forward_avx512<8>(input, output, begin_row, end_row, cols, block_num, remainder_block_num, + remainder_block_num_total, remainder_128, remainder_16); + }); +#else + thread_pool->ParallelFor(total_unit, unit_cost, + [&input, &output, rows, cols, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } + forward<8>(input, output, begin_row, end_row, cols); + }); +#endif + } + +private: + // temp = tf.math.square(inputs) + // temp = tf.math.reduce_sum(temp, reduction_indices=axis, keepdims=True) + // temp = tf.math.maximum(temp, epsilon) + // temp = tf.math.rsqrt(temp) + // outputs = tf.math.multiply(temp, inputs) + template + void forward(const T* input, T* output, int64 begin_row, int64 end_row, + int64 cols) { + int64 remainder = cols % SUM_BLOCK_SIZE; + for (int64 i = begin_row; i < end_row; ++i) { + T row_sum = 0; + // Sum of squares of the inputs + for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { + T data_0 = input[i * cols + j]; + T data_1 = input[i * cols + j + 1]; + T data_2 = input[i * cols + j + 2]; + T data_3 = input[i * cols + j + 3]; + T data_4 = input[i * cols + j + 4]; + T data_5 = input[i * cols + j + 5]; + T data_6 = input[i * cols + j + 6]; + T data_7 = input[i * cols + j + 7]; + row_sum += data_0 * data_0 + data_1 * data_1 + data_2 * data_2 + + data_3 * data_3 + data_4 * data_4 + data_5 * data_5 + + data_6 * data_6 + data_7 * data_7; + } + for (int64 j = cols - remainder; j < cols; j++) { + T data_0 = input[i * cols + j]; + row_sum += data_0 * data_0; + } + + // Square + row_sum += epsilon; + row_sum = 1.0 / std::sqrt(row_sum); + + // Mul + for (int64 j = 0; j < cols; ++j) { + output[i * cols + j] = input[i * cols + j] * row_sum; + } + } + } + +#ifdef __AVX512F__ + template + void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols, int64 block_num, int64 remainder_block_num, + int64 remainder_block_num_total, int64 remainder_128, int64 remainder_16) { + for (int64 i = begin_row; i < end_row; ++i) { + float row_sum = 0.0; + // Sum of squares of the inputs + for (int64 j = 0; j < block_num; ++j) { + __m512 inputs[SUM_BLOCK_SIZE]; + auto load = [&](auto idx) { + inputs[idx] = _mm512_loadu_ps(input + cols * i + + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for::op(load); + __m512 block_sum = reduce_sum_block<8>(inputs); + row_sum += _mm512_reduce_add_ps(block_sum); + } + if (remainder_block_num_total) { + __m512 inputs[remainder_block_num_total]; + + for (int64 idx = 0; idx < remainder_block_num; idx++){ + inputs[idx] = _mm512_loadu_ps(input + cols * i + cols - + remainder_128 + 16 * idx); + inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); + } + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + inputs[remainder_block_num] = _mm512_maskz_loadu_ps( + mask, input + cols * i + cols - remainder_16); + inputs[remainder_block_num] = _mm512_mul_ps(inputs[remainder_block_num], inputs[remainder_block_num]); + } + + __m512 block_sum = reduce_sum_block_ps(inputs, remainder_block_num_total); + row_sum += _mm512_reduce_add_ps(block_sum); + } + + // Square root + row_sum += epsilon; + row_sum = 1.0 / std::sqrt(row_sum); + + // Mul & store + __m512 row_sums = _mm512_set1_ps(row_sum); + for (int64 j = 0; j < cols - 15; j += 16) { + __m512 inputs = _mm512_loadu_ps(input + cols * i + j); + inputs = _mm512_mul_ps(inputs, row_sums); + _mm512_storeu_ps(output + cols * i + j, inputs); + } + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + __m512 inputs = _mm512_maskz_loadu_ps( + mask, input + cols * i + cols - remainder_16); + inputs = _mm512_mul_ps(inputs, row_sums); + _mm512_mask_storeu_ps(output + cols * i + cols - remainder_16, mask, + inputs); + } + } + } + + // data type: FP32, 16 FP32 per __m512 + // v0: v0_0, v0_1, ..., v0_15 + // v1: v1_0, v1_1, ..., v1_15 + // ... + // v7: v7_0, v7_1, ..., v7_15 + // sum: v_0, v_1, ..., v_15 + template + inline __m512 reduce_sum_block(const __m512* v) { + __m512 block_sum = _mm512_setzero_ps(); + auto reduce_sum = [&](auto idx) { + block_sum = _mm512_add_ps(block_sum, v[idx]); + }; + functor::compile_time_for::op(reduce_sum); + return block_sum; + } + + inline __m512 reduce_sum_block_ps(const __m512* v, int64 BLOCK_NUM) { + switch (BLOCK_NUM) + { + case 1: + return reduce_sum_block<1>(v); + case 2: + return reduce_sum_block<2>(v); + case 3: + return reduce_sum_block<3>(v); + case 4: + return reduce_sum_block<4>(v); + case 5: + return reduce_sum_block<5>(v); + case 6: + return reduce_sum_block<6>(v); + case 7: + return reduce_sum_block<7>(v); + case 8: + return reduce_sum_block<8>(v); + } + } +#endif + +private: + float epsilon; + int32 axis; +}; + +REGISTER_KERNEL_BUILDER(Name("FusedL2Normalize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FusedL2NormalizeOp); + + +template +class FusedL2NormalizeGradOp : public OpKernel { +public: + explicit FusedL2NormalizeGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("axis", &axis)); + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); + } + + ~FusedL2NormalizeGradOp() {} + + void Compute(OpKernelContext* context) override { + // Grab the input + const Tensor *y_grad_tensor = &context->input(0); + const Tensor *x_tensor = &context->input(1); + + const T *y_grad = y_grad_tensor->flat().data(); + const T *x = x_tensor->flat().data(); + + int64 cols = x_tensor->dim_size(x_tensor->dims() - 1); + int64 rows = 1; + for (int64 i = 0; i < x_tensor->dims() - 1; ++i) { + rows *= x_tensor->dim_size(i); + } + + // Create output tensors + Tensor *x_grad_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, x_tensor->shape(), + &x_grad_tensor)); + T *x_grad = x_grad_tensor->flat().data(); + + // Do it in parallel + #define BLOCK_SIZE 16 + const int64 total_unit = (rows + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64 unit_cost = BLOCK_SIZE * cols * 50; // assume every element consumes 50 cycles + + auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + thread::ThreadPool *thread_pool = worker_threads.workers; + +#ifdef __AVX512F__ + // Do forward in 128(8*16) block, avx512 handles 16 floats one time + int64 block_num = cols >> 7; // 128-size block nums + int64 remainder_128 = cols & 0x7F; // remainder of 128 + int64 remainder_16 = remainder_128 & 0x0F; // remainder of 16 + int64 remainder_block_num = remainder_128 >> 4; // 16-size block num in 128-remainder + int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; // total 16-size block num in remainder + + thread_pool->ParallelFor(total_unit, unit_cost, [&y_grad, &x, &x_grad, rows, cols, block_num, remainder_block_num, + remainder_block_num_total, remainder_128, remainder_16, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } + backward_avx512<8>(y_grad, x, x_grad, begin_row, end_row, cols, block_num, + remainder_block_num, remainder_block_num_total, remainder_128, remainder_16); + }); +#else + thread_pool->ParallelFor(total_unit, unit_cost, + [&y_grad, &x, &x_grad, rows, cols, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } + backward<8>(y_grad, x, x_grad, begin_row, end_row, cols); + }); +#endif + } + +private: + // rvar = tf.math.rsqrt(tf.math.reduce_sum(x * x, reduction_indices=1, + // keepdims=True) + 1e-12) # rsqrt quickly sum = tf.math.reduce_sum(y_grad * + // x, reduction_indices=1, keepdims=True) grad_x = y_grad * rvar - x * ((sum * + // rvar) * (rvar * rvar)) + template + void backward(const float* y_grad, const float* x, float* x_grad, + int64 begin_row, int64 end_row, int64 cols) { + int64 remainder = cols % SUM_BLOCK_SIZE; + for (int64 i = begin_row; i < end_row; ++i) { + T x_row_sum = 0.0; + T y_grad_row_sum = 0.0; + // sum of squares of x and sum of y_grad * x + for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { + T x_0 = x[i * cols + j]; + T x_1 = x[i * cols + j + 1]; + T x_2 = x[i * cols + j + 2]; + T x_3 = x[i * cols + j + 3]; + T x_4 = x[i * cols + j + 4]; + T x_5 = x[i * cols + j + 5]; + T x_6 = x[i * cols + j + 6]; + T x_7 = x[i * cols + j + 7]; + x_row_sum += x_0 * x_0 + x_1 * x_1 + x_2 * x_2 + x_3 * x_3 + x_4 * x_4 + + x_5 * x_5 + x_6 * x_6 + x_7 * x_7; + + T y_grad_0 = y_grad[i * cols + j]; + T y_grad_1 = y_grad[i * cols + j + 1]; + T y_grad_2 = y_grad[i * cols + j + 2]; + T y_grad_3 = y_grad[i * cols + j + 3]; + T y_grad_4 = y_grad[i * cols + j + 4]; + T y_grad_5 = y_grad[i * cols + j + 5]; + T y_grad_6 = y_grad[i * cols + j + 6]; + T y_grad_7 = y_grad[i * cols + j + 7]; + y_grad_row_sum += x_0 * y_grad_0 + x_1 * y_grad_1 + x_2 * y_grad_2 + + x_3 * y_grad_3 + x_4 * y_grad_4 + x_5 * y_grad_5 + + x_6 * y_grad_6 + x_7 * y_grad_7; + } + for (int64 j = cols - remainder; j < cols; j++) { + T x_0 = x[i * cols + j]; + x_row_sum += x_0 * x_0; + + T y_grad_0 = y_grad[i * cols + j]; + y_grad_row_sum += x_0 * y_grad_0; + } + x_row_sum += epsilon; + x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar + y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + // Calculate x_grad = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) + for (int64 j = 0; j < cols; ++j) { + x_grad[i * cols + j] = + y_grad[i * cols + j] * x_row_sum - x[i * cols + j] * y_grad_row_sum; + } + } + } + + +#ifdef __AVX512F__ + template + void backward_avx512(const float* y_grad, const float* x, float* x_grad, int64 begin_row, int64 end_row, int64 cols, + int64 block_num, int64 remainder_block_num, int64 remainder_block_num_total, int64 remainder_128, + int64 remainder_16) { + for (int64 i = begin_row; i < end_row; ++i) { + T x_row_sum = 0.0; + T y_grad_row_sum = 0.0; + // sum of squares of x and sum of y_grad * x + for (int64 j = 0; j < block_num; ++j) { + __m512 xs[SUM_BLOCK_SIZE]; + auto x_load = [&](auto idx) { + xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + + 16 * idx); + xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for::op(x_load); + __m512 x_block_sum = reduce_sum_block<8>(xs); + x_row_sum += _mm512_reduce_add_ps(x_block_sum); + + __m512 y_grads[SUM_BLOCK_SIZE]; + auto y_grad_load = [&](auto idx) { + y_grads[idx] = _mm512_loadu_ps(y_grad + cols * i + + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + + 16 * idx); + y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for::op(y_grad_load); + __m512 y_grad_block_sum = reduce_sum_block<8>(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); + } + if (remainder_block_num_total) { + __m512 xs[remainder_block_num_total]; + for (int64 idx = 0; idx < remainder_block_num; idx++){ + xs[idx] = _mm512_loadu_ps(x + cols * i + cols - remainder_128 + 16 * idx); + xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); + } + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + xs[remainder_block_num] = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); + xs[remainder_block_num] = _mm512_mul_ps(xs[remainder_block_num], xs[remainder_block_num]); + } + __m512 x_block_sum = reduce_sum_block_ps(xs, remainder_block_num_total); + x_row_sum += _mm512_reduce_add_ps(x_block_sum); + + __m512 y_grads[remainder_block_num_total]; + for (int64 idx = 0; idx < remainder_block_num; idx++){ + y_grads[idx] = _mm512_loadu_ps(y_grad + cols * i + cols - remainder_128 + 16 * idx); + xs[idx] = _mm512_loadu_ps(x + cols * i + cols - remainder_128 + 16 * idx); + y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); + } + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + y_grads[remainder_block_num] = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - remainder_16); + xs[remainder_block_num] = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); + y_grads[remainder_block_num] = _mm512_mul_ps(y_grads[remainder_block_num], xs[remainder_block_num]); + } + __m512 y_grad_block_sum = reduce_sum_block_ps(y_grads, remainder_block_num_total); + y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); + } + + x_row_sum += epsilon; + x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar + y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + // Calculate x_grad = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) + __m512 x_row_sums = _mm512_set1_ps(x_row_sum); + __m512 y_grad_row_sums = _mm512_set1_ps(y_grad_row_sum); + for (int64 j = 0; j < cols - 15; j += 16) { + __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + j); + __m512 xs = _mm512_loadu_ps(x + cols * i + j); + y_grads = _mm512_mul_ps(y_grads, x_row_sums); + xs = _mm512_mul_ps(xs, y_grad_row_sums); + y_grads = _mm512_sub_ps(y_grads, xs); + _mm512_storeu_ps(x_grad + cols * i + j, y_grads); + } + if (remainder_16 > 0) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - remainder_16); + __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); + y_grads = _mm512_mul_ps(y_grads, x_row_sums); + xs = _mm512_mul_ps(xs, y_grad_row_sums); + y_grads = _mm512_sub_ps(y_grads, xs); + _mm512_mask_storeu_ps(x_grad + cols * i + cols - remainder_16, mask, y_grads); + } + } + } + + template + inline __m512 reduce_sum_block(const __m512* v) { + __m512 block_sum = _mm512_setzero_ps(); + auto reduce_sum = [&](auto idx) { + block_sum = _mm512_add_ps(block_sum, v[idx]); + }; + functor::compile_time_for::op(reduce_sum); + return block_sum; + } + + inline __m512 reduce_sum_block_ps(const __m512* v, int64 BLOCK_NUM) { + switch (BLOCK_NUM) + { + case 1: + return reduce_sum_block<1>(v); + case 2: + return reduce_sum_block<2>(v); + case 3: + return reduce_sum_block<3>(v); + case 4: + return reduce_sum_block<4>(v); + case 5: + return reduce_sum_block<5>(v); + case 6: + return reduce_sum_block<6>(v); + case 7: + return reduce_sum_block<7>(v); + case 8: + return reduce_sum_block<8>(v); + } + } +#endif + +private: + float epsilon; + int32 axis; +}; + +REGISTER_KERNEL_BUILDER(Name("FusedL2NormalizeGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FusedL2NormalizeGradOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc new file mode 100644 index 00000000000..98f25942438 --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -0,0 +1,103 @@ +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace { + +enum class Device { CPU, GPU }; + +class FusedL2NormalizeOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, DataType dtype, int axis, float epsilon) { + TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize", "FusedL2Normalize") + .Attr("T", dtype) + .Attr("axis", axis) + .Attr("epsilon", epsilon) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { + const int rows = 4; + const int cols = 252; //128+64+32+16+8+4=252 1008 + + MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12); + + // x + float input_array[1008]; + for (int i = 0; i < sizeof(input_array) / sizeof(float); i++) { + input_array[i] = 1.0; + } + AddInputFromArray(TensorShape({rows, cols}), input_array); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_output(allocator(), DT_FLOAT, + TensorShape({rows, cols})); + float output_array[1008]; + for (int i = 0; i < sizeof(output_array) / sizeof(float); i++) { + output_array[i] = 0.062994122505188; + } + test::FillValues(&expected_output, output_array); + test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); + } +} + +//----------------------------------------------------------------------------// +// Performance benchmarks // +//----------------------------------------------------------------------------// +static Graph* FusedL2Normalize(int rows, int cols) { + Graph* g = new Graph(OpRegistry::Global()); + DataType dtype = DT_FLOAT; + + Tensor in(dtype, TensorShape({rows, cols})); + in.flat().setRandom(); + + Node* input_in = test::graph::Constant(g, in); + auto nodeBuilder = NodeBuilder(g->NewName("n"), "FusedL2Normalize") + .Input(input_in) + .Attr("T", dtype) + .Attr("axis", 0) + .Attr("epsilon", 1e-12); + TF_CHECK_OK(nodeBuilder.Finalize(g, nullptr)); + + return g; +} + +#define BM_FusedL2Norm(ROWS, COLS, NTH) \ + static void BM_FusedL2Norm##_##ROWS##_##COLS##_##NTH##_CPU( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * ROWS * COLS * 3); \ + SessionOptions opts; \ + opts.config.set_intra_op_parallelism_threads(NTH); \ + test::Benchmark("cpu", FusedL2Normalize(ROWS, COLS), &opts).Run(iters); \ + } \ + BENCHMARK(BM_FusedL2Norm##_##ROWS##_##COLS##_##NTH##_CPU); \ + +#define BM_FusedL2Norm_NTH(ROWS, COLS) \ + BM_FusedL2Norm(ROWS, COLS, 1); \ + BM_FusedL2Norm(ROWS, COLS, 4); \ + BM_FusedL2Norm(ROWS, COLS, 8); \ + +BM_FusedL2Norm_NTH(1024, 63); +BM_FusedL2Norm_NTH(1024, 127); +BM_FusedL2Norm_NTH(1024, 255); +BM_FusedL2Norm_NTH(1024, 511); +BM_FusedL2Norm_NTH(1024, 1023); +} +} diff --git a/tensorflow/core/ops/fused_l2_normalize_ops.cc b/tensorflow/core/ops/fused_l2_normalize_ops.cc new file mode 100644 index 00000000000..d2ca1d1414e --- /dev/null +++ b/tensorflow/core/ops/fused_l2_normalize_ops.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("FusedL2Normalize") + .Input("x: T") + .Output("y: T") + .Attr("T: {float}") + .Attr("axis: int = 1") + .Attr("epsilon: float = 1e-12") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); +// .Doc(R"doc( +// FusedL2Normalize ops. +// )doc"); + +REGISTER_OP("FusedL2NormalizeGrad") + .Input("y_grad: T") + .Input("x: T") + .Output("x_grad: T") + .Attr("T: {float}") + .Attr("axis: int = 1") + .Attr("epsilon: float = 1e-12") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); +// .Doc(R"doc( +// FusedL2NormalizeGrad ops. +// )doc"); + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9fabf845128..2af4ef5083b 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2115,6 +2115,16 @@ tf_gen_op_wrapper_private_py( ] ) +tf_gen_op_wrapper_private_py( + name = "fused_l2_normalize_ops_gen", + visibility = [ + "//tensorflow:__subpackages__", + ], + deps = [ + "//tensorflow/core:fused_l2_normalize_ops_op_lib" + ] +) + tf_gen_op_wrapper_private_py( name = "image_ops_gen", visibility = ["//learning/brain/python/ops:__pkg__"], @@ -3536,6 +3546,7 @@ py_library( ":sparse_ops", ":util", ":variables", + ":fused_l2_normalize_ops_gen" ], ) @@ -3553,6 +3564,7 @@ py_library( ":sparse_ops", ":tensor_util", "//tensorflow/python/eager:context", + ":fused_l2_normalize_ops_gen" ], ) diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 83e111fd1a3..34a6f32ea47 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import gen_fused_l2_normalize_ops @ops.RegisterGradient("Conv2DBackpropInput") @@ -1167,3 +1168,15 @@ def _NthElementGrad(op, grad): num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) return [math_ops.div(indicators, num_selected) * grad, None] + + +@ops.RegisterGradient("FusedL2Normalize") +def _FusedL2NormalizeGrad(op, grad): + """Return the gradients for FusedL2Normalize""" + + x = op.inputs[0] # pylint: disable=redefined-builtin + axis = op.get_attr("axis") + epsilon = op.get_attr("epsilon") + + return gen_fused_l2_normalize_ops.fused_l2_normalize_grad( + grad, x, axis=axis, epsilon=epsilon) \ No newline at end of file diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index b3d4952f500..d3a15134832 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import gen_fused_l2_normalize_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import util as losses_util @@ -613,11 +614,11 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): A `Tensor` with the same shape as `x`. """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) - return l2_normalize_v2(x, axis, epsilon, name) + return l2_normalize_v2(x, axis, epsilon, name, False) @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) -def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): +def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None, do_fusion=True): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes @@ -633,16 +634,45 @@ def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): integers. epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. + do_fusion: Whether fuse op when doing l2 norm on last axis, enabled by default. name: A name for this operation (optional). Returns: A `Tensor` with the same shape as `x`. """ - with ops.name_scope(name, "l2_normalize", [x]) as name: + if do_fusion and x.dtype == dtypes.float32 and ( + axis is None or axis== x.shape.rank - 1): + return fused_l2_normalize(x, epsilon=epsilon, name=name) + else: + with ops.name_scope(name, "l2_normalize", [x]) as name: + x = ops.convert_to_tensor(x, name="x") + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) + x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) + return math_ops.multiply(x, x_inv_norm, name=name) + +def fused_l2_normalize(x, epsilon=1e-12, name=None): + """Normalizes along last dimension using an L2 norm. + + For a 1-D tensor, computes + + output = x / sqrt(max(sum(x**2), epsilon)) + + For `x` with more dimensions, independently normalizes each 1-D slice along + lastdimension. + + Args: + x: A `Tensor`. + epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the + divisor if `norm < sqrt(epsilon)`. + name: A name for this operation (optional). + + Returns: + A `Tensor` with the same shape as `x`. + """ + with ops.name_scope(name, "fused_l2_normalize", [x]) as name: x = ops.convert_to_tensor(x, name="x") - square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) - x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) - return math_ops.multiply(x, x_inv_norm, name=name) + return gen_fused_l2_normalize_ops.fused_l2_normalize(x, + epsilon=epsilon, name=name) def _count_nonzero(input_tensor, dtype=dtypes.int64): diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index a3bf2e6b739..44604f883da 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops import gen_fused_l2_normalize_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.ops.nn_impl import _compute_sampled_logits from tensorflow.python.platform import test as test_lib @@ -313,6 +314,31 @@ def testL2NormalizeGradient(self): x_shape) print("L2Normalize gradient err = %g " % err) self.assertLess(err, 1e-4) + + @test_util.run_deprecated_v1 + def testFusedL2Normalize(self): + x_shape = [20, 7, 3] + np.random.seed(0) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in [2]: + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.fused_l2_normalize(x_tf) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_deprecated_v1 + def testFusedL2NormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(0) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in [2]: + with self.cached_session(): + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.fused_l2_normalize(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + print("FusedL2Normalize gradient err = %g " % err) + self.assertLess(err, 1e-4) class DropoutTest(test_lib.TestCase):