From 1306a1870f2b2cb45666e8b3eeb184a601189f6c Mon Sep 17 00:00:00 2001 From: "Li, Changqing" Date: Mon, 11 Jul 2022 16:54:22 +0800 Subject: [PATCH 01/17] [Op] Add fused l2 normalize op and grad op. --- tensorflow/core/kernels/BUILD | 30 ++ .../kernels/fused_l2_normalize/compile_util.h | 41 +++ .../fused_l2_normalize_op.cc | 319 ++++++++++++++++++ .../fused_l2_normalize_op_test.cc | 70 ++++ tensorflow/core/ops/fused_l2_normalize_ops.cc | 39 +++ 5 files changed, 499 insertions(+) create mode 100644 tensorflow/core/kernels/fused_l2_normalize/compile_util.h create mode 100644 tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc create mode 100644 tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc create mode 100644 tensorflow/core/ops/fused_l2_normalize_ops.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ff8f572c246..a292b6f87f6 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5403,6 +5403,36 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "fused_l2_normalize_ops", + srcs = [ + "fused_l2_normalize/fused_l2_normalize_op.cc", + ], + hdrs = ["fused_l2_normalize/compile_util.h"], + deps = ["//third_party/eigen3"] + DYNAMIC_DEPS + mkl_deps(), +) + +tf_cc_test( + name = "fused_l2_normalize_ops_test", + size = "small", + srcs = ["fused_l2_normalize/fused_l2_normalize_op_test.cc"], + deps = [ + ":fused_l2_normalize_ops", + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "run_graph_op", prefix = "run_graph_op", diff --git a/tensorflow/core/kernels/fused_l2_normalize/compile_util.h b/tensorflow/core/kernels/fused_l2_normalize/compile_util.h new file mode 100644 index 00000000000..646b503ce00 --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/compile_util.h @@ -0,0 +1,41 @@ +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +#include + +// A class for forced loop unrolling at compile time +template +struct compile_time_for { + template + inline static void op(const Lambda& function, Args... args) { + compile_time_for::op(function, args...); + function(std::integral_constant{}, args...); + } +}; +template <> +struct compile_time_for<1> { + template + inline static void op(const Lambda& function, Args... args) { + function(std::integral_constant{}, args...); + } +}; +template <> +struct compile_time_for<0> { + // 0 loops, do nothing + template + inline static void op(const Lambda& function, Args... args) { + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_L2_NORMALIZE_COMPILE_UTIL_OP_H_ + + diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc new file mode 100644 index 00000000000..12899927da0 --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -0,0 +1,319 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/lib/core/threadpool.h" + +#include "ln_util.h" + +#include + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class FusedL2NormalizeOp : public OpKernel { +public: + explicit FusedL2NormalizeOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("axis", &axis)); + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); + } + + ~FusedL2NormalizeOp() {} + + void Compute(OpKernelContext* context) override { + // Grab the input + const Tensor *input_tensor = &context->input(0); + const T *input = input_tensor->flat().data(); + + // To check the input + OP_REQUIRES(context, (input_tensor->dims() >= 2), + errors::InvalidArgument("Input dimension should be >= 2")); + + int64 cols = input_tensor->dim_size(input_tensor->dims() - 1); + int64 rows = 1; + for (int64 i = 0; i < input_tensor->dims() - 1; ++i) { + rows *= input_tensor->dim_size(i); + } + + // Create output tensors + Tensor *output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), + &output_tensor)); + T *output = output_tensor->flat().data(); + + // Let every thread compute 16 rows to avoid false sharing + #define BLOCK_SIZE 16 + const int64 total_unit = (rows + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64 unit_cost = BLOCK_SIZE * cols * 50; // assume every element consumes 50 cycles + + auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + thread::ThreadPool *thread_pool = worker_threads.workers; + + thread_pool->ParallelFor(total_unit, unit_cost, + [&input, &output, rows, cols, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } + forward<8>(input, output, begin_row, end_row, cols); + }); + } + +private: + // temp = tf.math.square(inputs) + // temp = tf.math.reduce_sum(temp, reduction_indices=axis, keepdims=True) + // temp = tf.math.maximum(temp, epsilon) + // temp = tf.math.rsqrt(temp) + // outputs = tf.math.multiply(temp, inputs) + template + void ref_forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { + for (int64 i = begin_row; i < end_row; ++i) { + T row_sum = 0; + // must be SUM_BLOCK_SIZE block !!! + for (int64 j = 0; j < cols; j += SUM_BLOCK_SIZE) { + T data_0 = input[i * cols + j]; + T data_1 = input[i * cols + j + 1]; + T data_2 = input[i * cols + j + 2]; + T data_3 = input[i * cols + j + 3]; + T data_4 = input[i * cols + j + 4]; + T data_5 = input[i * cols + j + 5]; + T data_6 = input[i * cols + j + 6]; + T data_7 = input[i * cols + j + 7]; + row_sum += data_0 * data_0 + data_1 * data_1 + + data_2 * data_2 + data_3 * data_3 + + data_4 * data_4 + data_5 * data_5 + + data_6 * data_6 + data_7 * data_7; + } + row_sum += epsilon; + row_sum = 1.0 / std::sqrt(row_sum); + for (int64 j = 0; j < cols; ++j) { + output[i * cols + j] = input[i * cols + j] * row_sum; + } + } + } + + template + void forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { + int64 avx3_block_num = cols >> 7; // cols / 128 + // printf("cols: %d, avx3_block_num: %d\n", cols, avx3_block_num); + for (int64 i = begin_row; i < end_row; ++i) { + float row_sum = 0.0; + for (int64 j = 0; j < avx3_block_num; ++j) { + __m512 inputs[SUM_BLOCK_SIZE]; + auto load = [&](auto idx) { + inputs[idx] = _mm512_loadu_ps(input + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for::op(load); + __m512 block_sum = reduce_sum_block8_ps(inputs); + row_sum += _mm512_reduce_add_ps(block_sum); + } + row_sum += epsilon; + row_sum = 1.0 / std::sqrt(row_sum); + __m512 row_sums = _mm512_set1_ps(row_sum); + for (int64 j = 0; j < cols; j += 16) { + __m512 inputs = _mm512_loadu_ps(input + cols * i + j); + inputs = _mm512_mul_ps(inputs, row_sums); + _mm512_storeu_ps(output + cols * i + j, inputs); + } + } + } + + // data type: FP32, 16 FP32 per __m512 + // v0: v0_0, v0_1, ..., v0_15 + // v1: v1_0, v1_1, ..., v1_15 + // ... + // v7: v7_0, v7_1, ..., v7_15 + // sum: v_0, v_1, ..., v_15 + inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { + __m512 block_sum = _mm512_add_ps(v[0], v[1]); + block_sum = _mm512_add_ps(block_sum, v[2]); + block_sum = _mm512_add_ps(block_sum, v[3]); + block_sum = _mm512_add_ps(block_sum, v[4]); + block_sum = _mm512_add_ps(block_sum, v[5]); + block_sum = _mm512_add_ps(block_sum, v[6]); + block_sum = _mm512_add_ps(block_sum, v[7]); + return block_sum; + } + +private: + float epsilon; + int32 axis; +}; + +REGISTER_KERNEL_BUILDER(Name("FusedL2Normalize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FusedL2NormalizeOp); + + +template +class FusedL2NormalizeGradOp : public OpKernel { +public: + explicit FusedL2NormalizeGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("axis", &axis)); + OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); + } + + ~FusedL2NormalizeGradOp() {} + + void Compute(OpKernelContext* context) override { + // Grab the input + const Tensor *y_grad_tensor = &context->input(0); + const Tensor *x_tensor = &context->input(1); + + const T *y_grad = y_grad_tensor->flat().data(); + const T *x = x_tensor->flat().data(); + + int64 cols = x_tensor->dim_size(x_tensor->dims() - 1); + int64 rows = 1; + for (int64 i = 0; i < x_tensor->dims() - 1; ++i) { + rows *= x_tensor->dim_size(i); + } + + // Create output tensors + Tensor *x_grad_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, x_tensor->shape(), + &x_grad_tensor)); + T *x_grad = x_grad_tensor->flat().data(); + + // Do it in parallel + #define BLOCK_SIZE 16 + const int64 total_unit = (rows + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64 unit_cost = BLOCK_SIZE * cols * 50; // assume every element consumes 50 cycles + + auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + thread::ThreadPool *thread_pool = worker_threads.workers; + + thread_pool->ParallelFor(total_unit, unit_cost, + [&y_grad, &x, &x_grad, rows, cols, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } + backward<8>(y_grad, x, x_grad, begin_row, end_row, cols); + }); + } + +private: + // rvar = tf.math.rsqrt(tf.math.reduce_sum(x * x, reduction_indices=1, keepdims=True) + 1e-12) # rsqrt quickly + // sum = tf.math.reduce_sum(y_grad * x, reduction_indices=1, keepdims=True) + // grad_x = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) + template + void ref_backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { + for (int64 i = begin_row; i < end_row; ++i) { + int64 new_row = i - begin_row; + T x_row_sum = 0.0; + T y_grad_row_sum = 0.0; + for (int64 j = cols - 1; j > 0; j -= SUM_BLOCK_SIZE) { + T x_0 = x[i * cols + j]; + T x_1 = x[i * cols + j - 1]; + T x_2 = x[i * cols + j - 2]; + T x_3 = x[i * cols + j - 3]; + T x_4 = x[i * cols + j - 4]; + T x_5 = x[i * cols + j - 5]; + T x_6 = x[i * cols + j - 6]; + T x_7 = x[i * cols + j - 7]; + x_row_sum += x_0 * x_0 + x_1 * x_1 + + x_2 * x_2 + x_3 * x_3 + + x_4 * x_4 + x_5 * x_5 + + x_6 * x_6 + x_7 * x_7; + + T y_grad_0 = y_grad[i * cols + j]; + T y_grad_1 = y_grad[i * cols + j - 1]; + T y_grad_2 = y_grad[i * cols + j - 2]; + T y_grad_3 = y_grad[i * cols + j - 3]; + T y_grad_4 = y_grad[i * cols + j - 4]; + T y_grad_5 = y_grad[i * cols + j - 5]; + T y_grad_6 = y_grad[i * cols + j - 6]; + T y_grad_7 = y_grad[i * cols + j - 7]; + y_grad_row_sum += x_0 * y_grad_0 + x_1 * y_grad_1 + + x_2 * y_grad_2 + x_3 * y_grad_3 + + x_4 * y_grad_4 + x_5 * y_grad_5 + + x_6 * y_grad_6 + x_7 * y_grad_7; + } + x_row_sum += epsilon; + x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar + y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + for (int64 j = 0; j < cols; ++j) { + x_grad[i * cols + j] = y_grad[i * cols + j] * x_row_sum + - x[i * cols + j] * y_grad_row_sum; + } + } + } + + template + void backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { + int64 avx3_block_num = cols >> 7; // cols / 128 + // printf("backward cols: %d, avx3_block_num: %d\n", cols, avx3_block_num); + for (int64 i = begin_row; i < end_row; ++i) { + T x_row_sum = 0.0; + T y_grad_row_sum = 0.0; + for (int64 j = 0; j < avx3_block_num; ++j) { + __m512 xs[SUM_BLOCK_SIZE]; + auto x_load = [&](auto idx) { + xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for::op(x_load); + __m512 x_block_sum = reduce_sum_block8_ps(xs); + x_row_sum += _mm512_reduce_add_ps(x_block_sum); + + __m512 y_grads[SUM_BLOCK_SIZE]; + auto y_grad_load = [&](auto idx) { + y_grads[idx] = _mm512_loadu_ps(y_grad + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for::op(y_grad_load); + __m512 y_grad_block_sum = reduce_sum_block8_ps(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); + } + x_row_sum += epsilon; + x_row_sum = 1.0 / std::sqrt(x_row_sum); + y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + __m512 x_row_sums = _mm512_set1_ps(x_row_sum); + __m512 y_grad_row_sums = _mm512_set1_ps(y_grad_row_sum); + for (int64 j = 0; j < cols; j += 16) { + __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + j); + __m512 xs = _mm512_loadu_ps(x + cols * i + j); + y_grads = _mm512_mul_ps(y_grads, x_row_sums); + xs = _mm512_mul_ps(xs, y_grad_row_sums); + y_grads = _mm512_sub_ps(y_grads, xs); + _mm512_storeu_ps(x_grad + cols * i + j, y_grads); + } + } + } + + inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { + __m512 block_sum = _mm512_add_ps(v[0], v[1]); + block_sum = _mm512_add_ps(block_sum, v[2]); + block_sum = _mm512_add_ps(block_sum, v[3]); + block_sum = _mm512_add_ps(block_sum, v[4]); + block_sum = _mm512_add_ps(block_sum, v[5]); + block_sum = _mm512_add_ps(block_sum, v[6]); + block_sum = _mm512_add_ps(block_sum, v[7]); + return block_sum; + } + +private: + float epsilon; + int32 axis; +}; + +REGISTER_KERNEL_BUILDER(Name("FusedL2NormalizeGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FusedL2NormalizeGradOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc new file mode 100644 index 00000000000..11265b057c6 --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -0,0 +1,70 @@ +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace { + +class FusedL2NormalizeOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, DataType dtype, int axis, int epsilon) { + TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize", + "FusedL2Normalize") + .Attr("T", dtype) + .Attr("axis", axis) + .Attr("epsilon", epsilon) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { + const int rows = 4; + const int cols = 16; + + MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 1, 1e-12); + + // emb_shards + AddInputFromArray(TensorShape({rows, cols}), { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_output(allocator(), DT_FLOAT, + TensorShape({rows, cols})); + test::FillValues(&expected_output, { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); + } +} + +//----------------------------------------------------------------------------// +// Performance benchmarks // +//----------------------------------------------------------------------------// + diff --git a/tensorflow/core/ops/fused_l2_normalize_ops.cc b/tensorflow/core/ops/fused_l2_normalize_ops.cc new file mode 100644 index 00000000000..82ca2406056 --- /dev/null +++ b/tensorflow/core/ops/fused_l2_normalize_ops.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("FusedL2Normalize") + .Input("x: T") + .Output("y: T") + .Attr("T: {float}") + .Attr("axis: int = 1") + .Attr("epsilon: float = 1e-12") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +FusedL2Normalize ops. + )doc"); + +REGISTER_OP("FusedL2NormalizeGrad") + .Input("y_grad: T") + .Input("x: T") + .Output("x_grad: T") + .Attr("T: {float}") + .Attr("axis: int = 1") + .Attr("epsilon: float = 1e-12") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +FusedL2NormalizeGrad ops. + )doc"); + +} // namespace tensorflow From 5bfa60383079f285fce299f6c9d7bafdc126d698 Mon Sep 17 00:00:00 2001 From: marvinYu Date: Fri, 22 Apr 2022 16:12:00 +0800 Subject: [PATCH 02/17] [Op] fix compile issue. bazel test --action_env=TF_CPP_MIN_VLOG_LEVEL=1 --action_env=TF_CPP_MIN_LOG_LEVEL=0 --flaky_test_attempts 1 --test_output=all --nocache_test_results --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-march=skylake-avx512 -- //tensorflow/core/kernels:fused_l2_normalize_ops_test --- tensorflow/core/BUILD | 3 +++ .../fused_l2_normalize/fused_l2_normalize_op.cc | 6 +++--- .../fused_l2_normalize/fused_l2_normalize_op_test.cc | 12 +++++++----- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4745ec6fcb8..fcdc1c865ba 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1181,6 +1181,7 @@ tf_gen_op_libs( "function_ops", "functional_ops", "fused_embedding_ops", + "fused_l2_normalize_ops", "hash_ops", "hash_training_ops", "fuserecv_ops", @@ -1439,6 +1440,7 @@ cc_library( ":function_ops_op_lib", ":functional_ops_op_lib", ":fused_embedding_ops_op_lib", + ":fused_l2_normalize_ops_op_lib", ":fuserecv_ops_op_lib", ":hash_ops_op_lib", ":hash_training_ops_op_lib", @@ -1623,6 +1625,7 @@ cc_library( "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:fused_embedding_ops", "//tensorflow/core/kernels/data:parquet_dataset_ops", + "//tensorflow/core/kernels:fused_l2_normalize_ops", "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:hash_ops", "//tensorflow/core/kernels:histogram_op", diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index 12899927da0..fb050bd7707 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -8,7 +8,7 @@ #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "ln_util.h" +#include "compile_util.h" #include @@ -16,7 +16,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -template +template class FusedL2NormalizeOp : public OpKernel { public: explicit FusedL2NormalizeOp(OpKernelConstruction* context) @@ -155,7 +155,7 @@ REGISTER_KERNEL_BUILDER(Name("FusedL2Normalize") \ FusedL2NormalizeOp); -template +template class FusedL2NormalizeGradOp : public OpKernel { public: explicit FusedL2NormalizeGradOp(OpKernelConstruction* context) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc index 11265b057c6..f204866f32c 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -14,11 +14,12 @@ namespace tensorflow { namespace { +enum class Device { CPU, GPU }; + class FusedL2NormalizeOpTest : public OpsTestBase { protected: - void MakeOpAndSetDevice(Device device, DataType dtype, int axis, int epsilon) { - TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize", - "FusedL2Normalize") + void MakeOpAndSetDevice(Device device, DataType dtype, int axis, float epsilon) { + TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize", "FusedL2Normalize") .Attr("T", dtype) .Attr("axis", axis) .Attr("epsilon", epsilon) @@ -32,7 +33,7 @@ TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { const int rows = 4; const int cols = 16; - MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 1, 1e-12); + MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12); // emb_shards AddInputFromArray(TensorShape({rows, cols}), { @@ -67,4 +68,5 @@ TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { //----------------------------------------------------------------------------// // Performance benchmarks // //----------------------------------------------------------------------------// - +} +} From 3f32097288894a39b645c7fda3ebff53b9a41ce6 Mon Sep 17 00:00:00 2001 From: marvinYu Date: Sun, 24 Apr 2022 17:05:38 +0800 Subject: [PATCH 03/17] [UT] python API implement. bazel test --action_env=TF_CPP_MIN_VLOG_LEVEL=1 --action_env=TF_CPP_MIN_LOG_LEVEL=0 --flaky_test_attempts 1 --test_output=all --nocache_test_results --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-march=skylake-avx512 -- //tensorflow/python:nn_test --- tensorflow/python/BUILD | 12 + tensorflow/python/ops/nn_grad.py | 13 + tensorflow/python/ops/nn_impl.py | 29 + tensorflow/python/ops/nn_test.py | 3152 +++++++++++++++--------------- 4 files changed, 1643 insertions(+), 1563 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9fabf845128..2af4ef5083b 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2115,6 +2115,16 @@ tf_gen_op_wrapper_private_py( ] ) +tf_gen_op_wrapper_private_py( + name = "fused_l2_normalize_ops_gen", + visibility = [ + "//tensorflow:__subpackages__", + ], + deps = [ + "//tensorflow/core:fused_l2_normalize_ops_op_lib" + ] +) + tf_gen_op_wrapper_private_py( name = "image_ops_gen", visibility = ["//learning/brain/python/ops:__pkg__"], @@ -3536,6 +3546,7 @@ py_library( ":sparse_ops", ":util", ":variables", + ":fused_l2_normalize_ops_gen" ], ) @@ -3553,6 +3564,7 @@ py_library( ":sparse_ops", ":tensor_util", "//tensorflow/python/eager:context", + ":fused_l2_normalize_ops_gen" ], ) diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 83e111fd1a3..34a6f32ea47 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import gen_fused_l2_normalize_ops @ops.RegisterGradient("Conv2DBackpropInput") @@ -1167,3 +1168,15 @@ def _NthElementGrad(op, grad): num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) return [math_ops.div(indicators, num_selected) * grad, None] + + +@ops.RegisterGradient("FusedL2Normalize") +def _FusedL2NormalizeGrad(op, grad): + """Return the gradients for FusedL2Normalize""" + + x = op.inputs[0] # pylint: disable=redefined-builtin + axis = op.get_attr("axis") + epsilon = op.get_attr("epsilon") + + return gen_fused_l2_normalize_ops.fused_l2_normalize_grad( + grad, x, axis=axis, epsilon=epsilon) \ No newline at end of file diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index b3d4952f500..d4bfc676621 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import gen_fused_l2_normalize_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import util as losses_util @@ -645,6 +646,34 @@ def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): return math_ops.multiply(x, x_inv_norm, name=name) +@tf_export(v1=["math.fused_l2_normalize", "linalg.fused_l2_normalize", "nn.fused_l2_normalize"]) +def fused_l2_normalize(x, axis=None, epsilon=1e-12, name=None): + """Normalizes along dimension `axis` using an L2 norm. + + For a 1-D tensor with `axis = 0`, computes + + output = x / sqrt(max(sum(x**2), epsilon)) + + For `x` with more dimensions, independently normalizes each 1-D slice along + dimension `axis`. + + Args: + x: A `Tensor`. + axis: Dimension along which to normalize. A scalar or a vector of + integers. + epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the + divisor if `norm < sqrt(epsilon)`. + name: A name for this operation (optional). + + Returns: + A `Tensor` with the same shape as `x`. + """ + with ops.name_scope(name, "fused_l2_normalize", [x]) as name: + x = ops.convert_to_tensor(x, name="x") + return gen_fused_l2_normalize_ops.fused_l2_normalize( + x, axis=axis, epsilon=epsilon, name=name) + + def _count_nonzero(input_tensor, dtype=dtypes.int64): """Same as math_ops.count_nonzero. diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index a3bf2e6b739..6c8bfd77abe 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -39,231 +39,232 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops import gen_fused_l2_normalize_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.ops.nn_impl import _compute_sampled_logits from tensorflow.python.platform import test as test_lib -class ZeroFractionTest(test_lib.TestCase): - - def _ZeroFraction(self, x): - assert x.shape - total_elements = np.prod(x.shape) - nonzeros = np.count_nonzero(x.flatten()) - return 1.0 - nonzeros / total_elements - - @test_util.run_deprecated_v1 - def testZeroFraction(self): - x_shape = [5, 17] - x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) - y_np = self._ZeroFraction(x_np) - - x_tf = constant_op.constant(x_np) - x_tf.set_shape(x_shape) - y_tf = nn_impl.zero_fraction(x_tf) - y_tf_np = self.evaluate(y_tf) - - eps = 1e-8 - self.assertAllClose(y_tf_np, y_np, eps) - - @test_util.run_deprecated_v1 - def testZeroFractionEmpty(self): - x = np.zeros(0) - y = self.evaluate(nn_impl.zero_fraction(x)) - self.assertTrue(np.isnan(y)) - - @test_util.run_deprecated_v1 - def testZeroFraction2_27Zeros(self): - sparsity = nn_impl.zero_fraction( - array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8)) - self.assertAllClose(1.0, self.evaluate(sparsity)) - - @test_util.run_deprecated_v1 - def testZeroFraction2_27Ones(self): - sparsity = nn_impl.zero_fraction( - array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8)) - self.assertAllClose(0.0, self.evaluate(sparsity)) - - @test_util.run_deprecated_v1 - def testUnknownSize(self): - value = array_ops.placeholder(dtype=dtypes.float32) - sparsity = nn_impl.zero_fraction(value) - with self.cached_session() as sess: - self.assertAllClose( - 0.25, - sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})) - - -class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): - - def _softmax(self, x): - assert len(x.shape) == 2 - m = x.max(1)[:, np.newaxis] - u = np.exp(x - m) - z = u.sum(1)[:, np.newaxis] - return u / z - - @test_util.run_in_graph_and_eager_modes - def testSoftmax(self): - x_shape = [5, 10] - x_np = np.random.randn(*x_shape).astype(np.float32) - y_np = self._softmax(x_np) - x_tf = constant_op.constant(x_np) - y_tf = nn_ops.softmax_v2(x_tf) - y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1) - y_tf_np = self.evaluate(y_tf) - y_tf_last_dim_np = self.evaluate(y_tf_last_dim) - eps = 1e-3 - self.assertAllClose(y_tf_np, y_np, eps) - self.assertAllClose(y_tf_last_dim_np, y_np, eps) - - def testSoftmaxAxes(self): - arr = np.linspace(0., 1, 12).reshape(3, 4) - x_neg_axis = nn_ops.softmax_v2(arr, axis=-2) - y_pos_axis = nn_ops.softmax_v2(arr, axis=0) - z_gt_axis = nn_ops.softmax_v2(arr, axis=0) - x_neg_axis_tf = self.evaluate(x_neg_axis) - y_pos_axis_tf = self.evaluate(y_pos_axis) - z_gt_axis_tf = self.evaluate(z_gt_axis) - eps = 1e-3 - self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) - self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) - - def testSoftmaxExtendType(self): - x_shape = [5, 10] - x_np = np.random.randn(*x_shape).astype(np.float32) - - x_f32_tf = constant_op.constant(x_np) - x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16) - y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf)) - y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf)) - expected = math_ops.cast(y_f32_tf, dtypes.bfloat16) - tol = x_shape[1] * 1e-3 - self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol) - - @parameterized.parameters(((5, 10),), ((2, 3, 4),)) - @test_util.run_deprecated_v1 - def testGradient(self, x_shape): - x_np = np.random.randn(*x_shape).astype(np.float64) - with self.cached_session(): - x_tf = constant_op.constant(x_np) - y_tf = nn_ops.softmax_v2(x_tf) - err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, - x_shape) - eps = 2e-8 - self.assertLess(err, eps) - - -class LogPoissonLossTest(test_lib.TestCase): - - def _log_poisson_loss(self, x, z, compute_full_loss=False): - lpl = np.exp(x) - z * x - if compute_full_loss: - stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) - lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) - return lpl - - @test_util.run_in_graph_and_eager_modes - def testLogPoissonLoss(self): - x_shape = [5, 10] - x_np = np.random.randn(*x_shape).astype(np.float32) - z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32) - y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) - y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) - y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) - y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) - y_tf_np = self.evaluate(y_tf) - y_tf_np_stirling = self.evaluate(y_tf_stirling) - eps = 1e-3 - self.assertAllClose(y_tf_np, y_np, eps) - self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps) - - @test_util.run_deprecated_v1 - def testGradient(self): - x_shape = [5, 10] - x_np = np.random.randn(*x_shape).astype(np.float64) - z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64) - with self.cached_session(): - x_tf = constant_op.constant(x_np) - y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False) - y_tf_stirling = nn_impl.log_poisson_loss( - z_np, x_tf, compute_full_loss=True) - err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, - x_shape) - err_stirling = gradient_checker.compute_gradient_error( - x_tf, x_shape, y_tf_stirling, x_shape) - eps = 1e-6 - self.assertLess(err, eps) - self.assertLess(err_stirling, eps) - - -class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): - - def _log_softmax(self, x): - assert len(x.shape) == 2 - m = x.max(1)[:, np.newaxis] - u = x - m - return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) - - @test_util.run_in_graph_and_eager_modes - def testLogSoftmax(self): - x_shape = [5, 10] - x_np = np.random.randn(*x_shape).astype(np.float32) - y_np = self._log_softmax(x_np) - x_tf = constant_op.constant(x_np) - y_tf = nn_ops.log_softmax_v2(x_tf) - y_tf_np = self.evaluate(y_tf) - eps = 1e-3 - self.assertAllClose(y_tf_np, y_np, eps) - - def testLogSoftmaxAxes(self): - arr = np.linspace(0., 1, 12).reshape(3, 4) - x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2) - y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0) - z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0) - x_neg_axis_tf = self.evaluate(x_neg_axis) - y_pos_axis_tf = self.evaluate(y_pos_axis) - z_gt_axis_tf = self.evaluate(z_gt_axis) - eps = 1e-3 - self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) - self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) - - @parameterized.parameters(((5, 10),), ((2, 3, 4),)) - @test_util.run_deprecated_v1 - def testGradient(self, x_shape): - x_np = np.random.randn(*x_shape).astype(np.float64) - with self.cached_session(): - x_tf = constant_op.constant(x_np) - y_tf = nn_ops.log_softmax_v2(x_tf) - err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, - x_shape) - eps = 1e-7 - self.assertLess(err, eps) - - -class L2LossTest(test_lib.TestCase): - - @test_util.run_in_graph_and_eager_modes - def testL2Loss(self): - for dtype in [dtypes.float32, dtypes.float64]: - x = constant_op.constant( - [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype) - l2loss = nn_ops.l2_loss(x) - value = self.evaluate(l2loss) - self.assertAllClose(7.0, value) - - @test_util.run_deprecated_v1 - def testGradient(self): - x_shape = [20, 7, 3] - np.random.seed(1) # Make it reproducible. - x_val = np.random.random_sample(x_shape).astype(np.float64) - with self.cached_session(): - x = constant_op.constant(x_val, name="x") - output = nn_ops.l2_loss(x) - err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) - print("L2Loss gradient err = %g " % err) - err_tolerance = 1e-10 - self.assertLess(err, err_tolerance) +# class ZeroFractionTest(test_lib.TestCase): + +# def _ZeroFraction(self, x): +# assert x.shape +# total_elements = np.prod(x.shape) +# nonzeros = np.count_nonzero(x.flatten()) +# return 1.0 - nonzeros / total_elements + +# @test_util.run_deprecated_v1 +# def testZeroFraction(self): +# x_shape = [5, 17] +# x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) +# y_np = self._ZeroFraction(x_np) + +# x_tf = constant_op.constant(x_np) +# x_tf.set_shape(x_shape) +# y_tf = nn_impl.zero_fraction(x_tf) +# y_tf_np = self.evaluate(y_tf) + +# eps = 1e-8 +# self.assertAllClose(y_tf_np, y_np, eps) + +# @test_util.run_deprecated_v1 +# def testZeroFractionEmpty(self): +# x = np.zeros(0) +# y = self.evaluate(nn_impl.zero_fraction(x)) +# self.assertTrue(np.isnan(y)) + +# @test_util.run_deprecated_v1 +# def testZeroFraction2_27Zeros(self): +# sparsity = nn_impl.zero_fraction( +# array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8)) +# self.assertAllClose(1.0, self.evaluate(sparsity)) + +# @test_util.run_deprecated_v1 +# def testZeroFraction2_27Ones(self): +# sparsity = nn_impl.zero_fraction( +# array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8)) +# self.assertAllClose(0.0, self.evaluate(sparsity)) + +# @test_util.run_deprecated_v1 +# def testUnknownSize(self): +# value = array_ops.placeholder(dtype=dtypes.float32) +# sparsity = nn_impl.zero_fraction(value) +# with self.cached_session() as sess: +# self.assertAllClose( +# 0.25, +# sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})) + + +# class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): + +# def _softmax(self, x): +# assert len(x.shape) == 2 +# m = x.max(1)[:, np.newaxis] +# u = np.exp(x - m) +# z = u.sum(1)[:, np.newaxis] +# return u / z + +# @test_util.run_in_graph_and_eager_modes +# def testSoftmax(self): +# x_shape = [5, 10] +# x_np = np.random.randn(*x_shape).astype(np.float32) +# y_np = self._softmax(x_np) +# x_tf = constant_op.constant(x_np) +# y_tf = nn_ops.softmax_v2(x_tf) +# y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1) +# y_tf_np = self.evaluate(y_tf) +# y_tf_last_dim_np = self.evaluate(y_tf_last_dim) +# eps = 1e-3 +# self.assertAllClose(y_tf_np, y_np, eps) +# self.assertAllClose(y_tf_last_dim_np, y_np, eps) + +# def testSoftmaxAxes(self): +# arr = np.linspace(0., 1, 12).reshape(3, 4) +# x_neg_axis = nn_ops.softmax_v2(arr, axis=-2) +# y_pos_axis = nn_ops.softmax_v2(arr, axis=0) +# z_gt_axis = nn_ops.softmax_v2(arr, axis=0) +# x_neg_axis_tf = self.evaluate(x_neg_axis) +# y_pos_axis_tf = self.evaluate(y_pos_axis) +# z_gt_axis_tf = self.evaluate(z_gt_axis) +# eps = 1e-3 +# self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) +# self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + +# def testSoftmaxExtendType(self): +# x_shape = [5, 10] +# x_np = np.random.randn(*x_shape).astype(np.float32) + +# x_f32_tf = constant_op.constant(x_np) +# x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16) +# y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf)) +# y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf)) +# expected = math_ops.cast(y_f32_tf, dtypes.bfloat16) +# tol = x_shape[1] * 1e-3 +# self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol) + +# @parameterized.parameters(((5, 10),), ((2, 3, 4),)) +# @test_util.run_deprecated_v1 +# def testGradient(self, x_shape): +# x_np = np.random.randn(*x_shape).astype(np.float64) +# with self.cached_session(): +# x_tf = constant_op.constant(x_np) +# y_tf = nn_ops.softmax_v2(x_tf) +# err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, +# x_shape) +# eps = 2e-8 +# self.assertLess(err, eps) + + +# class LogPoissonLossTest(test_lib.TestCase): + +# def _log_poisson_loss(self, x, z, compute_full_loss=False): +# lpl = np.exp(x) - z * x +# if compute_full_loss: +# stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) +# lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) +# return lpl + +# @test_util.run_in_graph_and_eager_modes +# def testLogPoissonLoss(self): +# x_shape = [5, 10] +# x_np = np.random.randn(*x_shape).astype(np.float32) +# z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32) +# y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) +# y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) +# y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) +# y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) +# y_tf_np = self.evaluate(y_tf) +# y_tf_np_stirling = self.evaluate(y_tf_stirling) +# eps = 1e-3 +# self.assertAllClose(y_tf_np, y_np, eps) +# self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps) + +# @test_util.run_deprecated_v1 +# def testGradient(self): +# x_shape = [5, 10] +# x_np = np.random.randn(*x_shape).astype(np.float64) +# z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64) +# with self.cached_session(): +# x_tf = constant_op.constant(x_np) +# y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False) +# y_tf_stirling = nn_impl.log_poisson_loss( +# z_np, x_tf, compute_full_loss=True) +# err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, +# x_shape) +# err_stirling = gradient_checker.compute_gradient_error( +# x_tf, x_shape, y_tf_stirling, x_shape) +# eps = 1e-6 +# self.assertLess(err, eps) +# self.assertLess(err_stirling, eps) + + +# class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): + +# def _log_softmax(self, x): +# assert len(x.shape) == 2 +# m = x.max(1)[:, np.newaxis] +# u = x - m +# return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) + +# @test_util.run_in_graph_and_eager_modes +# def testLogSoftmax(self): +# x_shape = [5, 10] +# x_np = np.random.randn(*x_shape).astype(np.float32) +# y_np = self._log_softmax(x_np) +# x_tf = constant_op.constant(x_np) +# y_tf = nn_ops.log_softmax_v2(x_tf) +# y_tf_np = self.evaluate(y_tf) +# eps = 1e-3 +# self.assertAllClose(y_tf_np, y_np, eps) + +# def testLogSoftmaxAxes(self): +# arr = np.linspace(0., 1, 12).reshape(3, 4) +# x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2) +# y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0) +# z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0) +# x_neg_axis_tf = self.evaluate(x_neg_axis) +# y_pos_axis_tf = self.evaluate(y_pos_axis) +# z_gt_axis_tf = self.evaluate(z_gt_axis) +# eps = 1e-3 +# self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) +# self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + +# @parameterized.parameters(((5, 10),), ((2, 3, 4),)) +# @test_util.run_deprecated_v1 +# def testGradient(self, x_shape): +# x_np = np.random.randn(*x_shape).astype(np.float64) +# with self.cached_session(): +# x_tf = constant_op.constant(x_np) +# y_tf = nn_ops.log_softmax_v2(x_tf) +# err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, +# x_shape) +# eps = 1e-7 +# self.assertLess(err, eps) + + +# class L2LossTest(test_lib.TestCase): + +# @test_util.run_in_graph_and_eager_modes +# def testL2Loss(self): +# for dtype in [dtypes.float32, dtypes.float64]: +# x = constant_op.constant( +# [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype) +# l2loss = nn_ops.l2_loss(x) +# value = self.evaluate(l2loss) +# self.assertAllClose(7.0, value) + +# @test_util.run_deprecated_v1 +# def testGradient(self): +# x_shape = [20, 7, 3] +# np.random.seed(1) # Make it reproducible. +# x_val = np.random.random_sample(x_shape).astype(np.float64) +# with self.cached_session(): +# x = constant_op.constant(x_val, name="x") +# output = nn_ops.l2_loss(x) +# err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) +# print("L2Loss gradient err = %g " % err) +# err_tolerance = 1e-10 +# self.assertLess(err, err_tolerance) class L2NormalizeTest(test_lib.TestCase): @@ -278,1402 +279,1427 @@ def _l2Normalize(self, x, dim): norm = np.apply_along_axis(np.linalg.norm, dim, x) return x / np.expand_dims(norm, dim) - @test_util.run_in_graph_and_eager_modes - def testL2Normalize(self): + # @test_util.run_in_graph_and_eager_modes + # def testL2Normalize(self): + # x_shape = [20, 7, 3] + # np.random.seed(1) + # x_np = np.random.random_sample(x_shape).astype(np.float32) + # for dim in range(len(x_shape)): + # y_np = self._l2Normalize(x_np, dim) + # x_tf = constant_op.constant(x_np, name="x") + # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + # self.assertAllClose(y_np, self.evaluate(y_tf)) + + # @test_util.run_in_graph_and_eager_modes + # def testL2NormalizeDimArray(self): + # x_shape = [20, 7, 3] + # np.random.seed(1) + # x_np = np.random.random_sample(x_shape).astype(np.float32) + # dim = [1, 2] + # y_np = self._l2Normalize(x_np, dim) + # x_tf = constant_op.constant(x_np, name="x") + # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + # self.assertAllClose(y_np, self.evaluate(y_tf)) + + # @test_util.run_deprecated_v1 + # def testL2NormalizeGradient(self): + # x_shape = [20, 7, 3] + # np.random.seed(1) + # x_np = np.random.random_sample(x_shape).astype(np.float64) + # for dim in range(len(x_shape)): + # with self.cached_session(): + # x_tf = constant_op.constant(x_np, name="x") + # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + # err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + # x_shape) + # print("L2Normalize gradient err = %g " % err) + # self.assertLess(err, 1e-4) + + @test_util.run_deprecated_v1 + def testFusedL2Normalize(self): x_shape = [20, 7, 3] np.random.seed(1) x_np = np.random.random_sample(x_shape).astype(np.float32) - for dim in range(len(x_shape)): + for dim in [0]: y_np = self._l2Normalize(x_np, dim) x_tf = constant_op.constant(x_np, name="x") - y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + y_tf = nn_impl.fused_l2_normalize(x_tf, dim) self.assertAllClose(y_np, self.evaluate(y_tf)) - @test_util.run_in_graph_and_eager_modes - def testL2NormalizeDimArray(self): - x_shape = [20, 7, 3] - np.random.seed(1) - x_np = np.random.random_sample(x_shape).astype(np.float32) - dim = [1, 2] - y_np = self._l2Normalize(x_np, dim) - x_tf = constant_op.constant(x_np, name="x") - y_tf = nn_impl.l2_normalize_v2(x_tf, dim) - self.assertAllClose(y_np, self.evaluate(y_tf)) - - @test_util.run_deprecated_v1 - def testL2NormalizeGradient(self): - x_shape = [20, 7, 3] - np.random.seed(1) - x_np = np.random.random_sample(x_shape).astype(np.float64) - for dim in range(len(x_shape)): - with self.cached_session(): - x_tf = constant_op.constant(x_np, name="x") - y_tf = nn_impl.l2_normalize_v2(x_tf, dim) - err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, - x_shape) - print("L2Normalize gradient err = %g " % err) - self.assertLess(err, 1e-4) - - -class DropoutTest(test_lib.TestCase): - - def testDropout(self): - # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate - # that it is producing approximately the right number of ones over a large - # number of samples, based on the keep probability. - x_dim = 40 - y_dim = 30 - num_iter = 10 - for keep_prob in [0.1, 0.5, 0.8]: - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - dropout = nn_ops.dropout(t, rate=(1 - keep_prob)) - final_count = 0 - self.assertEqual([x_dim, y_dim], dropout.get_shape()) - for _ in xrange(0, num_iter): - value = self.evaluate(dropout) - final_count += np.count_nonzero(value) - # Verifies that there are only two values: 0 and 1/keep_prob. - sorted_value = np.unique(np.sort(value)) - self.assertEqual(0, sorted_value[0]) - self.assertAllClose(1 / keep_prob, sorted_value[1]) - - # Check that we are in the 15% error range - expected_count = x_dim * y_dim * keep_prob * num_iter - rel_error = math.fabs(final_count - expected_count) / expected_count - print(rel_error) - self.assertTrue(rel_error < 0.15) - - def testShapedDropout(self): - # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate - # that it is producing approximately the right number of ones over a large - # number of samples, based on the keep probability. This time with shaped - # noise. - x_dim = 40 * 30 - y_dim = 3 - num_iter = 10 - for keep_prob in [0.1, 0.5, 0.8]: - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) - self.assertEqual([x_dim, y_dim], dropout.get_shape()) - final_count = 0 - for _ in xrange(0, num_iter): - value = self.evaluate(dropout) - final_count += np.count_nonzero(value) - # Verifies that there are only two values: 0 and 1/keep_prob. - sorted_value = np.unique(np.sort(value)) - self.assertEqual(0, sorted_value[0]) - self.assertAllClose(1 / keep_prob, sorted_value[1]) - - # Check that we are in the 15% error range - expected_count = x_dim * y_dim * keep_prob * num_iter - rel_error = math.fabs(final_count - expected_count) / expected_count - print(rel_error) - self.assertTrue(rel_error < 0.15) - - def testShapedDropoutCorrelation(self): - # Runs a shaped dropout and tests that the correlations are correct. - x_dim = 40 - y_dim = 30 - num_iter = 10 - for keep_prob in [0.1, 0.5, 0.8]: - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) - self.assertEqual([x_dim, y_dim], dropout.get_shape()) - for _ in xrange(0, num_iter): - value = self.evaluate(dropout) - # Verifies that each y column as only one type of activation. - for i in xrange(x_dim): - sorted_value = np.unique(np.sort(value[i, :])) - self.assertEqual(sorted_value.size, 1) - - @test_util.run_deprecated_v1 - def testDropoutPlaceholderKeepProb(self): - # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate - # that it is producing approximately the right number of ones over a large - # number of samples, based on the keep probability. - x_dim = 40 - y_dim = 30 - num_iter = 10 - for keep_prob in [0.1, 0.5, 0.8]: - with self.cached_session(): - t = constant_op.constant( - 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - keep_prob_placeholder = array_ops.placeholder(dtypes.float32) - dropout = nn_ops.dropout(t, keep_prob_placeholder) - final_count = 0 - self.assertEqual([x_dim, y_dim], dropout.get_shape()) - for _ in xrange(0, num_iter): - value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) - final_count += np.count_nonzero(value) - # Verifies that there are only two values: 0 and 1/keep_prob. - sorted_value = np.unique(np.sort(value)) - self.assertEqual(0, sorted_value[0]) - self.assertAllClose(1 / keep_prob, sorted_value[1]) - # Check that we are in the 15% error range - expected_count = x_dim * y_dim * keep_prob * num_iter - rel_error = math.fabs(final_count - expected_count) / expected_count - print(rel_error) - self.assertTrue(rel_error < 0.15) - - @test_util.run_deprecated_v1 - def testShapedDropoutUnknownShape(self): - x_dim = 40 - y_dim = 30 - keep_prob = 0.5 - x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - dropout_x = nn_ops.dropout( - x, - rate=(1 - keep_prob), - noise_shape=array_ops.placeholder(dtypes.int32)) - self.assertEqual(x.get_shape(), dropout_x.get_shape()) - - def testPartialShapedDropout(self): - x_dim = 40 * 30 - y_dim = 3 - num_iter = 10 - for keep_prob in [0.1, 0.5, 0.8]: - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - # Set noise_shape=[None, 1] which means [x_dim, 1]. - dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[None, 1]) - self.assertEqual([x_dim, y_dim], dropout.get_shape()) - final_count = 0 - for _ in xrange(0, num_iter): - value = self.evaluate(dropout) - final_count += np.count_nonzero(value) - # Verifies that there are only two values: 0 and 1/keep_prob. - sorted_value = np.unique(np.sort(value)) - self.assertEqual(0, sorted_value[0]) - self.assertAllClose(1 / keep_prob, sorted_value[1]) - - # Check that we are in the 15% error range - expected_count = x_dim * y_dim * keep_prob * num_iter - rel_error = math.fabs(final_count - expected_count) / expected_count - print(rel_error) - self.assertTrue(rel_error < 0.15) - - @test_util.run_deprecated_v1 - def testInvalidKeepProb(self): - x_dim = 40 - y_dim = 30 - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - with self.assertRaises(ValueError): - nn_ops.dropout(t, -1.0) - with self.assertRaises(ValueError): - nn_ops.dropout(t, 1.1) - with self.assertRaises(ValueError): - nn_ops.dropout(t, [0.0, 1.0]) - with self.assertRaises(ValueError): - nn_ops.dropout(t, array_ops.placeholder(dtypes.float64)) - with self.assertRaises(ValueError): - nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) - - @test_util.run_deprecated_v1 - def testInvalidRate(self): - x_dim = 40 - y_dim = 30 - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - with self.assertRaises(ValueError): - nn_ops.dropout_v2(t, -1.0) - with self.assertRaises(ValueError): - nn_ops.dropout_v2(t, 1.1) - with self.assertRaises(ValueError): - nn_ops.dropout_v2(t, [0.0, 1.0]) - - def testLargeRate(self): - x_dim = 40 - y_dim = 30 - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - _ = nn_ops.dropout_v2(t, 0.9) - - @test_util.run_deprecated_v1 - def testShapedDropoutShapeError(self): - # Runs shaped dropout and verifies an error is thrown on misshapen noise. - x_dim = 40 - y_dim = 30 - keep_prob = 0.5 - t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) - with self.assertRaises(ValueError): - _ = nn_ops.dropout( - t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim + 10]) - with self.assertRaises(ValueError): - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim, 5]) - with self.assertRaises(ValueError): - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim + 3]) - with self.assertRaises(ValueError): - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim]) - # test that broadcasting proceeds - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[y_dim]) - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, y_dim]) - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) - _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, 1]) - - def testNoDropoutFast(self): - x = array_ops.zeros((5,)) - y = nn_ops.dropout(x, rate=0) - self.assertTrue(x is y) - - y = nn_ops.dropout_v2(x, rate=0) - self.assertTrue(x is y) - - def testDropoutWithIntegerInputs(self): - x = constant_op.constant([1, 1, 1, 1, 1]) - with self.assertRaises(ValueError): - _ = nn_ops.dropout(x, 0.5) - - -class ComputeSampledLogitsTest(test_lib.TestCase): - - def setUp(self): - self._eps = 1e-3 - - def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels, - sampled, subtract_log_q): - """Randomly generates input/output data for a single test case. - - This function returns numpy constants for use in a test case. - - Args: - num_classes: An int. The number of embedding classes in the test case. - dim: An int. The dimension of the embedding. - batch_size: An int. The batch size. - num_true: An int. The number of target classes per training example. - labels: A list of batch_size * num_true ints. The target classes. - sampled: A list of indices in [0, num_classes). - subtract_log_q: A bool corresponding to the parameter in - _compute_sampled_logits(). - - Returns: - weights: Embedding weights to use as test input. It is a numpy array - of shape [num_classes, dim] - biases: Embedding biases to use as test input. It is a numpy array - of shape [num_classes]. - hidden_acts: Forward activations of the network to use as test input. - It is a numpy array of shape [batch_size, dim]. - sampled_vals: A tuple based on `sampled` to use as test input in the - format returned by a *_candidate_sampler function. - exp_logits: The output logits expected from _compute_sampled_logits(). - It is a numpy array of shape [batch_size, num_true + len(sampled)]. - exp_labels: The output labels expected from _compute_sampled_logits(). - It is a numpy array of shape [batch_size, num_true + len(sampled)]. - """ - weights = np.random.randn(num_classes, dim).astype(np.float32) - biases = np.random.randn(num_classes).astype(np.float32) - hidden_acts = np.random.randn(batch_size, dim).astype(np.float32) - - true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) - sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) - sampled_vals = (sampled, true_exp, sampled_exp) - - sampled_w, sampled_b = weights[sampled], biases[sampled] - true_w, true_b = weights[labels], biases[labels] - - true_logits = np.sum( - hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( - (batch_size, num_true, dim)), - axis=2) - true_b = true_b.reshape((batch_size, num_true)) - true_logits += true_b - sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b - - if subtract_log_q: - true_logits -= np.log(true_exp) - sampled_logits -= np.log(sampled_exp[np.newaxis, :]) - - exp_logits = np.concatenate([true_logits, sampled_logits], axis=1) - exp_labels = np.hstack((np.ones_like(true_logits) / num_true, - np.zeros_like(sampled_logits))) - - return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels - - def _ShardTestEmbeddings(self, weights, biases, num_shards): - """Shards the weights and biases returned by _GenerateTestData. - - Args: - weights: The weights returned by _GenerateTestData. - biases: The biases returned by _GenerateTestData. - num_shards: The number of shards to create. - - Returns: - sharded_weights: A list of size `num_shards` containing all the weights. - sharded_biases: A list of size `num_shards` containing all the biases. - """ - with ops.Graph().as_default() as g: - sharded_weights = variable_scope.get_variable( - "w", - partitioner=partitioned_variables.fixed_size_partitioner(num_shards), - initializer=constant_op.constant(weights)) - sharded_biases = variable_scope.get_variable( - "b", - partitioner=partitioned_variables.fixed_size_partitioner(num_shards), - initializer=constant_op.constant(biases)) - with self.session(graph=g) as sess: - variables.global_variables_initializer().run() - return self.evaluate([list(sharded_weights), list(sharded_biases)]) - - def testShapes(self): - np.random.seed(0) - num_classes = 5 - batch_size = 3 - - for num_true in range(1, 5): - labels = np.random.randint( - low=0, high=num_classes, size=batch_size * num_true) - (weights, biases, hidden_acts, sampled_vals, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=num_true, - labels=labels, - sampled=[1, 0, 2, 3], - subtract_log_q=False) - logits_tensor, labels_tensor = _compute_sampled_logits( - weights=constant_op.constant(weights), - biases=constant_op.constant(biases), - labels=constant_op.constant( - labels, dtype=dtypes.int64, shape=(batch_size, num_true)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=num_true, - sampled_values=sampled_vals, - subtract_log_q=False, - remove_accidental_hits=False, - partition_strategy="div", - name="sampled_logits_basic_num_true_%d" % num_true) - got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) - self.assertEqual(exp_logits.shape, got_logits.shape, self._eps) - self.assertEqual(exp_labels.shape, got_labels.shape, self._eps) - - def testBasic(self): - """Without accidental hit removal or subtract_log_q.""" - np.random.seed(0) - num_classes = 5 - batch_size = 3 - - for num_true in range(1, 5): - labels = np.random.randint( - low=0, high=num_classes, size=batch_size * num_true) - (weights, biases, hidden_acts, sampled_vals, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=num_true, - labels=labels, - sampled=[1, 0, 2, 3], - subtract_log_q=False) - logits_tensor, labels_tensor = _compute_sampled_logits( - weights=constant_op.constant(weights), - biases=constant_op.constant(biases), - labels=constant_op.constant( - labels, dtype=dtypes.int64, shape=(batch_size, num_true)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=num_true, - sampled_values=sampled_vals, - subtract_log_q=False, - remove_accidental_hits=False, - partition_strategy="div", - name="sampled_logits_basic_num_true_%d" % num_true) - got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) - self.assertAllClose(exp_logits, got_logits, self._eps) - self.assertAllClose(exp_labels, got_labels, self._eps) - - def testAccidentalHitRemoval(self): - """With accidental hit removal, no subtract_log_q.""" - np.random.seed(0) - num_classes = 5 - batch_size = 3 - sampled = [1, 0, 2, 3] - - for num_true in range(1, 5): - labels = np.random.randint( - low=0, high=num_classes, size=batch_size * num_true) - (weights, biases, hidden_acts, sampled_vals, _, - _) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=num_true, - labels=labels, - sampled=sampled, - subtract_log_q=False) - logits_tensor, _ = _compute_sampled_logits( - weights=constant_op.constant(weights), - biases=constant_op.constant(biases), - labels=constant_op.constant( - labels, dtype=dtypes.int64, shape=(batch_size, num_true)), - inputs=constant_op.constant(hidden_acts), - num_sampled=len(sampled), - num_classes=num_classes, - num_true=num_true, - sampled_values=sampled_vals, - subtract_log_q=False, - remove_accidental_hits=True, - partition_strategy="div", - name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true) - # Test that the exponentiated logits of accidental hits are near 0. - # First we need to find the hits in this random test run: - labels_reshape = labels.reshape((batch_size, num_true)) - got_logits = self.evaluate(logits_tensor) - for row in xrange(batch_size): - row_labels = labels_reshape[row, :] - for col in xrange(len(sampled)): - if sampled[col] in row_labels: - # We need to add the num_true_test offset into logits_* - self.assertNear( - np.exp(got_logits[row, col + num_true]), 0., self._eps) - - def testSubtractLogQ(self): - """With subtract_log_q, no accidental hit removal.""" - np.random.seed(0) - num_classes = 5 - batch_size = 3 - - for num_true in range(1, 5): - labels = np.random.randint( - low=0, high=num_classes, size=batch_size * num_true) - (weights, biases, hidden_acts, sampled_vals, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=num_true, - labels=labels, - sampled=[1, 0, 2, 3], - subtract_log_q=True) - logits_tensor, labels_tensor = _compute_sampled_logits( - weights=constant_op.constant(weights), - biases=constant_op.constant(biases), - labels=constant_op.constant( - labels, dtype=dtypes.int64, shape=(batch_size, num_true)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=num_true, - sampled_values=sampled_vals, - subtract_log_q=True, - remove_accidental_hits=False, - partition_strategy="div", - name="sampled_logits_subtract_log_q_num_true_%d" % num_true) - got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) - self.assertAllClose(exp_logits, got_logits, self._eps) - self.assertAllClose(exp_labels, got_labels, self._eps) - - def testSharded(self): - """With sharded weights and sharded biases.""" - np.random.seed(0) - num_classes = 5 - batch_size = 3 - - for num_true in range(1, 5): - labels = np.random.randint( - low=0, high=num_classes, size=batch_size * num_true) - (weights, biases, hidden_acts, sampled_vals, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=num_true, - labels=labels, - sampled=[1, 0, 2, 3], - subtract_log_q=False) - weight_shards, bias_shards = self._ShardTestEmbeddings( - weights, biases, num_shards=3) - logits_tensor, labels_tensor = _compute_sampled_logits( - weights=[constant_op.constant(shard) for shard in weight_shards], - biases=[constant_op.constant(shard) for shard in bias_shards], - labels=constant_op.constant( - labels, dtype=dtypes.int64, shape=(batch_size, num_true)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=num_true, - sampled_values=sampled_vals, - subtract_log_q=False, - remove_accidental_hits=False, - partition_strategy="div", - name="sampled_logits_sharded_num_true_%d" % num_true) - got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) - self.assertAllClose(exp_logits, got_logits, self._eps) - self.assertAllClose(exp_labels, got_labels, self._eps) - - def testNCELoss(self): - # A simple test to verify the numerics. - - def _SigmoidCrossEntropyWithLogits(logits, targets): - # logits, targets: float arrays of the same shape. - assert logits.shape == targets.shape - pred = 1. / (1. + np.exp(-logits)) - eps = 0.0001 - pred = np.minimum(np.maximum(pred, eps), 1 - eps) - return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) - - np.random.seed(0) - num_classes = 5 - batch_size = 3 - labels = [0, 1, 2] - (weights, biases, hidden_acts, sampled_vals, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=1, - labels=labels, - sampled=[1, 0, 2, 3], - subtract_log_q=True) - exp_nce_loss = np.sum( - _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1) - - got_nce_loss = nn_impl.nce_loss_v2( - weights=constant_op.constant(weights), - biases=constant_op.constant(biases), - labels=constant_op.constant(labels, shape=(batch_size, 1)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=1, - sampled_values=sampled_vals) - - self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) - - # Test with sharded weights and sharded biases. - weight_shards, bias_shards = self._ShardTestEmbeddings( - weights, biases, num_shards=3) - got_nce_loss = nn_impl.nce_loss_v2( - weights=[constant_op.constant(shard) for shard in weight_shards], - biases=[constant_op.constant(shard) for shard in bias_shards], - labels=constant_op.constant(labels, shape=(batch_size, 1)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=1, - sampled_values=sampled_vals) - - self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) - - def testSampledSoftmaxLoss(self): - # A simple test to verify the numerics. - - def _SoftmaxCrossEntropyWithLogits(logits, targets): - # logits, targets: float arrays of the same shape. - assert logits.shape == targets.shape - stable_exp_logits = np.exp( - logits - np.amax(logits, axis=1, keepdims=True)) - pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) - return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) - - np.random.seed(0) - num_classes = 5 - batch_size = 3 - labels = [0, 1, 2] - (weights, biases, hidden_acts, sampled_vals, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=1, - labels=labels, - sampled=[1, 0, 2, 3], - subtract_log_q=True) - exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( - exp_logits, exp_labels) - - got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( - weights=constant_op.constant(weights), - biases=constant_op.constant(biases), - labels=constant_op.constant(labels, shape=(batch_size, 1)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=1, - sampled_values=sampled_vals, - remove_accidental_hits=False) - - self.assertAllClose(exp_sampled_softmax_loss, - self.evaluate(got_sampled_softmax_loss), 1e-4) - - # Test with sharded weights and sharded biases. - weight_shards, bias_shards = self._ShardTestEmbeddings( - weights, biases, num_shards=3) - got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( - weights=[constant_op.constant(shard) for shard in weight_shards], - biases=[constant_op.constant(shard) for shard in bias_shards], - labels=constant_op.constant(labels, shape=(batch_size, 1)), - inputs=constant_op.constant(hidden_acts), - num_sampled=4, - num_classes=num_classes, - num_true=1, - sampled_values=sampled_vals, - remove_accidental_hits=False) - - self.assertAllClose(exp_sampled_softmax_loss, - self.evaluate(got_sampled_softmax_loss), 1e-4) - - def testSampledSoftmaxLossBf16(self): - # A simple test to verify the numerics for bfloat16. - def _SoftmaxCrossEntropyWithLogits(logits, targets): - # logits, targets: float arrays of the same shape. - assert logits.shape == targets.shape - stable_exp_logits = np.exp( - logits - np.amax(logits, axis=1, keepdims=True)) - pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) - return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) - - np.random.seed(0) - num_classes = 5 - batch_size = 3 - labels = [0, 1, 2] - sampled = [1, 0, 2, 3] - (weights, biases, hidden_acts, _, exp_logits, - exp_labels) = self._GenerateTestData( - num_classes=num_classes, - dim=10, - batch_size=batch_size, - num_true=1, - labels=labels, - sampled=sampled, - subtract_log_q=True) - exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( - exp_logits, exp_labels) - - true_exp_bf16 = np.full([batch_size, 1], - fill_value=0.5, - dtype=dtypes.bfloat16.as_numpy_dtype) - sampled_exp_bf16 = np.full([len(sampled)], - fill_value=0.5, - dtype=dtypes.bfloat16.as_numpy_dtype) - sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) - - got_sampled_softmax_loss = math_ops.cast( - nn_impl.sampled_softmax_loss_v2( - weights=constant_op.constant(weights, dtype=dtypes.bfloat16), - biases=constant_op.constant(biases, dtype=dtypes.bfloat16), - labels=constant_op.constant( - labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), - inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), - num_sampled=4, - num_classes=num_classes, - num_true=1, - sampled_values=sampled_vals_bf16, - remove_accidental_hits=False), dtypes.float32) - - self.assertAllClose(exp_sampled_softmax_loss, - self.evaluate(got_sampled_softmax_loss), 1e-1) - - -class GeluTest(test_lib.TestCase): - - def test(self): - - def gelu(x, approximate=False): - if approximate: - return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * - (x + 0.044715 * np.power(x, 3)))) - else: - from scipy.stats import norm # pylint: disable=g-import-not-at-top - return x * norm.cdf(x) - - np.random.seed(1) # Make it reproducible. - x = np.random.randn(3, 4).astype(np.float32) - y = gelu(x) - z = self.evaluate(nn_ops.gelu(constant_op.constant(x))) - self.assertAllClose(y, z) - - y = gelu(x, True) - z = self.evaluate(nn_ops.gelu(constant_op.constant(x), True)) - self.assertAllClose(y, z) - - -class CReluTest(test_lib.TestCase): - - def test(self): - np.random.seed(1) # Make it reproducible. - x = np.random.randn(3, 4).astype(np.float32) - y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) - - z = self.evaluate(nn_ops.crelu(constant_op.constant(x))) - self.assertAllClose(y, z, 1e-4) - - -class ReluTest(test_lib.TestCase): - - def test(self): - np.random.seed(1) # Make it reproducible. - x = np.random.randn(3, 4).astype(np.float32) - y = np.maximum(x, 0.0) - - z = self.evaluate(nn_ops.relu(constant_op.constant(x))) - self.assertAllEqual(y, z) - - @test_util.run_deprecated_v1 - def testNaNs(self): - # Test that relu(nan) = nan for various sizes. - for i in range(18): - x = np.zeros(i) + np.nan - with self.cached_session(): - z = nn_ops.relu(constant_op.constant(x)).eval() - self.assertTrue(np.isnan(z).all()) - - -class LeakyReluTest(test_lib.TestCase): - - def testRange(self): - batch_size = 3 - height, width = 4, 4 - np.random.seed(1) # Make it reproducible. - inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( - np.float32) - inputs = constant_op.constant(inputs) - - outputs = nn_ops.leaky_relu(inputs) - self.assertEquals(inputs.shape, outputs.shape) - - inputs, outputs = self.evaluate([inputs, outputs]) - - self.assertGreaterEqual(outputs.min(), 0.0) - self.assertLessEqual(outputs.max(), 1.0) - self.assertAllClose(inputs, outputs) - - @test_util.run_deprecated_v1 - def testValues(self): - for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: - np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) - outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) - - outputs = self.evaluate(outputs) - - tol = 2e-3 if dtype == np.float16 else 1e-6 - self.assertAllClose( - outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) - - @test_util.run_deprecated_v1 - def testName(self): - np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64) - outputs_with_name_set = nn_ops.leaky_relu( - constant_op.constant(np_values), - name='test_relu_op') - self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0') - outputs_without_name_set = nn_ops.leaky_relu( - constant_op.constant(np_values)) - self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0') - - -class SwishTest(test_lib.TestCase): - - @test_util.run_deprecated_v1 - def testValues(self): - np_values = np.array( - [np.linspace(-7.0, 0.0, 100), - np.linspace(0.0, 7.0, 100)], - dtype=np.float32) - tf_values = constant_op.constant(np_values) - actual_tf_outputs = nn_impl.swish(tf_values) - expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values) - - actual_outputs, expected_outputs = self.evaluate( - [actual_tf_outputs, expected_tf_outputs]) - - self.assertAllClose(actual_outputs, expected_outputs) - - @test_util.run_deprecated_v1 - def testGradients(self): - shape = [5, 3, 4] - sigma = 5 - input_values = np.random.randn(*shape) * sigma - x_tf = constant_op.constant(input_values) - y_tf = nn_impl.swish(x_tf) - with self.cached_session(): - err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape) - self.assertLess(err, 1e-4) - - -class MomentsTest(test_lib.TestCase): - - def doOutputTest(self, - input_shape, - moments_axes, - tol=1e-4, - check_gradients=False): - for mu in [0.0, 1.0, 1e3]: - for sigma in [1.0, 0.1]: - for keep_dims in [True, False]: - input_values = np.random.rand(*input_shape) * sigma + mu - expected_mean = np.mean( - input_values, axis=moments_axes, keepdims=keep_dims) - expected_var = np.var( - input_values, axis=moments_axes, keepdims=keep_dims) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - inputs = constant_op.constant( - input_values, shape=input_shape, dtype=dtypes.float32) - mean, variance = nn_impl.moments_v2( - inputs, moments_axes, keepdims=keep_dims) - - if check_gradients: - err = gradient_checker.compute_gradient_error( - inputs, input_shape, mean, mean.shape.as_list()) - self.assertLess(err, 1e-3) - err = gradient_checker.compute_gradient_error( - inputs, input_shape, variance, variance.shape.as_list()) - self.assertLess(err, 1e-3) - - # Evaluate. - [mean, variance] = self.evaluate([mean, variance]) - # Make sure that there are no NaNs - self.assertFalse(np.isnan(mean).any()) - self.assertFalse(np.isnan(variance).any()) - self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) - self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) - - def testOutputAndGradient2DInput0(self): - self.doOutputTest((10, 10), (0,), check_gradients=True) - - def testOutputAndGradient2DInput01(self): - self.doOutputTest((10, 10), (0, 1), check_gradients=True) - - def testOutput2DInput0(self): - self.doOutputTest((10, 300), (0,)) - - def testOutput2DInput1(self): - self.doOutputTest((10, 300), (1,)) - - def testOutput2DInput01(self): - self.doOutputTest((10, 300), (0, 1)) - - def testOutput4DInput0(self): - self.doOutputTest((10, 10, 10, 30), (0,)) - - def testOutput4DInput1(self): - self.doOutputTest((10, 10, 10, 30), (1,)) - - def testOutput4DInput3(self): - self.doOutputTest((10, 10, 10, 30), (3,)) - - def testOutput4DInput012(self): - self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) - - def testOutput4DInput123(self): - self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) - - -class DataFormatDimMapTest(test_lib.TestCase): - - def _test(self, x_val, y_val_expected): - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x) - - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def test(self): - self._test(0, 0) - self._test(1, 2) - self._test(2, 3) - self._test(3, 1) - self._test(-1, 1) - self._test(-2, 3) - self._test(-3, 2) - self._test(-4, 0) - self._test([1, 3], [2, 1]) - self._test([1, 3, -2], [2, 1, 3]) - self._test([1, -3, -2], [2, 2, 3]) - self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) - - def testNHWCtoNCHW(self): - x_val = [1, -3, -2] - y_val_expected = [2, 2, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def testNHWCtoHWNC(self): - x_val = [-4, -3, -2, -1, 0, 1, 2, 3] - y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def testNHWCtoWHCN(self): - x_val = [-4, -3, -2, -1, 0, 1, 2, 3] - y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def testNDHWCtoNCDHW(self): - x_val = [1, -4, -3, -2] - y_val_expected = [2, 2, 3, 4] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def testNDHWCtoDHWNC(self): - x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] - y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def testDNHWCtoWHDCN(self): - x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] - y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - def testArbitraryASCII(self): - x_val = [-4, -3, -2, -1, 0, 1, 2, 3] - y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] - x = constant_op.constant(x_val) - y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, y_val_expected) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testInvalidLength(self): - x = [-4, -3, -2, -1, 0, 1, 2, 3] - with self.assertRaisesRegex(errors.InvalidArgumentError, - "Source format must be of length 4 or 5"): - op = nn_ops.data_format_dim_map( - x, src_format="12345678", dst_format="87654321") - with test_util.use_gpu(): - self.evaluate(op) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testDuplicateSrc(self): - x = [-4, -3, -2, -1, 0, 1, 2, 3] - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Destination and source format must determine a permutation"): - op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321") - with test_util.use_gpu(): - self.evaluate(op) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testDuplicateDst(self): - x = [-4, -3, -2, -1, 0, 1, 2, 3] - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Destination and source format must determine a permutation"): - op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321") - with test_util.use_gpu(): - self.evaluate(op) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testExtraSpecifiers(self): - x = [-4, -3, -2, -1, 0, 1, 2, 3] - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Destination and source format must determine a permutation"): - op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321") - with test_util.use_gpu(): - self.evaluate(op) - - -class DataFormatVectorPermuteTest(test_lib.TestCase): - - def testNHWCToNCHW(self): - x_val = [7, 4, 9, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x) - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [7, 3, 4, 9]) - - def testNCHWToNHWC(self): - x_val = [7, 4, 9, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [7, 9, 3, 4]) - - def testNHWCToHWNC(self): - x_val = [7, 4, 9, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [4, 9, 7, 3]) - - def testHWNCToNHWC(self): - x_val = [7, 4, 9, 3] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [9, 7, 4, 3]) - - def testNHWCToNCHW2D(self): - x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x) - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) - - def testNHWCToHWNC2D(self): - x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) - - def testHWNCToNHWC2D(self): - x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) - - def testNCHWToNHWC2D(self): - x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] - x = constant_op.constant(x_val) - y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") - with test_util.use_gpu(): - y_val = self.evaluate(y) - self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testInvalidLength(self): - x = [0, 1, 2, 3] - with self.assertRaisesRegex(errors.InvalidArgumentError, - "Source format must be of length 4 or 5"): - op = nn_ops.data_format_vec_permute( - x, src_format="12345678", dst_format="87654321") - with test_util.use_gpu(): - self.evaluate(op) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testDuplicateSrc(self): - x = [0, 1, 2, 3] - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Destination and source format must determine a permutation"): - op = nn_ops.data_format_vec_permute( - x, src_format="1233", dst_format="4321") - with test_util.use_gpu(): - self.evaluate(op) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testDuplicateDst(self): - x = [0, 1, 2, 3] - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Destination and source format must determine a permutation"): - op = nn_ops.data_format_vec_permute( - x, src_format="1234", dst_format="3321") - with test_util.use_gpu(): - self.evaluate(op) - - @test_util.disable_xla("XLA catches the error and rethrows as different one") - def testExtraSpecifiers(self): - x = [0, 1, 2, 3] - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Destination and source format must determine a permutation"): - op = nn_ops.data_format_vec_permute( - x, src_format="1234", dst_format="5321") - with test_util.use_gpu(): - self.evaluate(op) - - -@test_util.run_all_in_graph_and_eager_modes -class AvgPoolTest(test_lib.TestCase): - - def test1DTensor(self): - x = array_ops.ones([3, 6, 5]) - ksize = 2 - strides = 2 - - y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - def test1DNumpy(self): - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - # np.ones defaults to using float64 when dtype is not explicitly specified - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = np.ones([3, 6, 5], dtype=dtype) - ksize = 2 - strides = 2 - - y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - def test1DNumpyWithGolden(self): - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = np.array([[[3], [6], [5]], - [[1], [0], [1]]], dtype=dtype) - ksize = 2 - strides = 1 - y = nn_ops.avg_pool1d(x, ksize, strides, "SAME") - expected_y = np.array([[[4.5], [5.5], [5.0]], - [[0.5], [0.5], [1.0]]], dtype=dtype) - self.assertAllEqual(self.evaluate(y), expected_y) - - def test2DTensor(self): - x = array_ops.ones([3, 6, 6, 5]) - ksize = 2 - strides = 2 - - y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - def test2DNumpy(self): - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - # np.ones defaults to using float64 when dtype is not explicitly specified - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = np.ones([3, 6, 6, 5], dtype=dtype) - ksize = 2 - strides = 2 - - y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - def test3DTensor(self): - if test_lib.is_built_with_rocm(): - self.skipTest("Pooling with 3D tensors is not supported in ROCm") - x = array_ops.ones([3, 7, 6, 6, 5]) - ksize = 2 - strides = 2 - - y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - def test3DNumpy(self): - if test_lib.is_built_with_rocm(): - self.skipTest("Pooling with 3D tensors is not supported in ROCm") - x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) - ksize = 2 - strides = 2 - - y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - -@test_util.run_all_in_graph_and_eager_modes -class MaxPoolTest(test_lib.TestCase): - - def test1DTensor(self): - x = array_ops.ones([3, 6, 5]) - ksize = 2 - strides = 2 - - y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") - - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + # @test_util.run_deprecated_v1 + # def testFusedL2NormalizeGradient(self): + # x_shape = [20, 7, 3] + # np.random.seed(1) + # x_np = np.random.random_sample(x_shape).astype(np.float64) + # for dim in range(len(x_shape)): + # with self.cached_session(): + # x_tf = constant_op.constant(x_np, name="x") + # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + # err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + # x_shape) + # print("L2Normalize gradient err = %g " % err) + # self.assertLess(err, 1e-4) + + +# class DropoutTest(test_lib.TestCase): + +# def testDropout(self): +# # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate +# # that it is producing approximately the right number of ones over a large +# # number of samples, based on the keep probability. +# x_dim = 40 +# y_dim = 30 +# num_iter = 10 +# for keep_prob in [0.1, 0.5, 0.8]: +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# dropout = nn_ops.dropout(t, rate=(1 - keep_prob)) +# final_count = 0 +# self.assertEqual([x_dim, y_dim], dropout.get_shape()) +# for _ in xrange(0, num_iter): +# value = self.evaluate(dropout) +# final_count += np.count_nonzero(value) +# # Verifies that there are only two values: 0 and 1/keep_prob. +# sorted_value = np.unique(np.sort(value)) +# self.assertEqual(0, sorted_value[0]) +# self.assertAllClose(1 / keep_prob, sorted_value[1]) + +# # Check that we are in the 15% error range +# expected_count = x_dim * y_dim * keep_prob * num_iter +# rel_error = math.fabs(final_count - expected_count) / expected_count +# print(rel_error) +# self.assertTrue(rel_error < 0.15) + +# def testShapedDropout(self): +# # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate +# # that it is producing approximately the right number of ones over a large +# # number of samples, based on the keep probability. This time with shaped +# # noise. +# x_dim = 40 * 30 +# y_dim = 3 +# num_iter = 10 +# for keep_prob in [0.1, 0.5, 0.8]: +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) +# self.assertEqual([x_dim, y_dim], dropout.get_shape()) +# final_count = 0 +# for _ in xrange(0, num_iter): +# value = self.evaluate(dropout) +# final_count += np.count_nonzero(value) +# # Verifies that there are only two values: 0 and 1/keep_prob. +# sorted_value = np.unique(np.sort(value)) +# self.assertEqual(0, sorted_value[0]) +# self.assertAllClose(1 / keep_prob, sorted_value[1]) + +# # Check that we are in the 15% error range +# expected_count = x_dim * y_dim * keep_prob * num_iter +# rel_error = math.fabs(final_count - expected_count) / expected_count +# print(rel_error) +# self.assertTrue(rel_error < 0.15) + +# def testShapedDropoutCorrelation(self): +# # Runs a shaped dropout and tests that the correlations are correct. +# x_dim = 40 +# y_dim = 30 +# num_iter = 10 +# for keep_prob in [0.1, 0.5, 0.8]: +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) +# self.assertEqual([x_dim, y_dim], dropout.get_shape()) +# for _ in xrange(0, num_iter): +# value = self.evaluate(dropout) +# # Verifies that each y column as only one type of activation. +# for i in xrange(x_dim): +# sorted_value = np.unique(np.sort(value[i, :])) +# self.assertEqual(sorted_value.size, 1) + +# @test_util.run_deprecated_v1 +# def testDropoutPlaceholderKeepProb(self): +# # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate +# # that it is producing approximately the right number of ones over a large +# # number of samples, based on the keep probability. +# x_dim = 40 +# y_dim = 30 +# num_iter = 10 +# for keep_prob in [0.1, 0.5, 0.8]: +# with self.cached_session(): +# t = constant_op.constant( +# 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# keep_prob_placeholder = array_ops.placeholder(dtypes.float32) +# dropout = nn_ops.dropout(t, keep_prob_placeholder) +# final_count = 0 +# self.assertEqual([x_dim, y_dim], dropout.get_shape()) +# for _ in xrange(0, num_iter): +# value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) +# final_count += np.count_nonzero(value) +# # Verifies that there are only two values: 0 and 1/keep_prob. +# sorted_value = np.unique(np.sort(value)) +# self.assertEqual(0, sorted_value[0]) +# self.assertAllClose(1 / keep_prob, sorted_value[1]) +# # Check that we are in the 15% error range +# expected_count = x_dim * y_dim * keep_prob * num_iter +# rel_error = math.fabs(final_count - expected_count) / expected_count +# print(rel_error) +# self.assertTrue(rel_error < 0.15) + +# @test_util.run_deprecated_v1 +# def testShapedDropoutUnknownShape(self): +# x_dim = 40 +# y_dim = 30 +# keep_prob = 0.5 +# x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# dropout_x = nn_ops.dropout( +# x, +# rate=(1 - keep_prob), +# noise_shape=array_ops.placeholder(dtypes.int32)) +# self.assertEqual(x.get_shape(), dropout_x.get_shape()) + +# def testPartialShapedDropout(self): +# x_dim = 40 * 30 +# y_dim = 3 +# num_iter = 10 +# for keep_prob in [0.1, 0.5, 0.8]: +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# # Set noise_shape=[None, 1] which means [x_dim, 1]. +# dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[None, 1]) +# self.assertEqual([x_dim, y_dim], dropout.get_shape()) +# final_count = 0 +# for _ in xrange(0, num_iter): +# value = self.evaluate(dropout) +# final_count += np.count_nonzero(value) +# # Verifies that there are only two values: 0 and 1/keep_prob. +# sorted_value = np.unique(np.sort(value)) +# self.assertEqual(0, sorted_value[0]) +# self.assertAllClose(1 / keep_prob, sorted_value[1]) + +# # Check that we are in the 15% error range +# expected_count = x_dim * y_dim * keep_prob * num_iter +# rel_error = math.fabs(final_count - expected_count) / expected_count +# print(rel_error) +# self.assertTrue(rel_error < 0.15) + +# @test_util.run_deprecated_v1 +# def testInvalidKeepProb(self): +# x_dim = 40 +# y_dim = 30 +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# with self.assertRaises(ValueError): +# nn_ops.dropout(t, -1.0) +# with self.assertRaises(ValueError): +# nn_ops.dropout(t, 1.1) +# with self.assertRaises(ValueError): +# nn_ops.dropout(t, [0.0, 1.0]) +# with self.assertRaises(ValueError): +# nn_ops.dropout(t, array_ops.placeholder(dtypes.float64)) +# with self.assertRaises(ValueError): +# nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) + +# @test_util.run_deprecated_v1 +# def testInvalidRate(self): +# x_dim = 40 +# y_dim = 30 +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# with self.assertRaises(ValueError): +# nn_ops.dropout_v2(t, -1.0) +# with self.assertRaises(ValueError): +# nn_ops.dropout_v2(t, 1.1) +# with self.assertRaises(ValueError): +# nn_ops.dropout_v2(t, [0.0, 1.0]) + +# def testLargeRate(self): +# x_dim = 40 +# y_dim = 30 +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# _ = nn_ops.dropout_v2(t, 0.9) + +# @test_util.run_deprecated_v1 +# def testShapedDropoutShapeError(self): +# # Runs shaped dropout and verifies an error is thrown on misshapen noise. +# x_dim = 40 +# y_dim = 30 +# keep_prob = 0.5 +# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) +# with self.assertRaises(ValueError): +# _ = nn_ops.dropout( +# t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim + 10]) +# with self.assertRaises(ValueError): +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim, 5]) +# with self.assertRaises(ValueError): +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim + 3]) +# with self.assertRaises(ValueError): +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim]) +# # test that broadcasting proceeds +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[y_dim]) +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, y_dim]) +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) +# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, 1]) + +# def testNoDropoutFast(self): +# x = array_ops.zeros((5,)) +# y = nn_ops.dropout(x, rate=0) +# self.assertTrue(x is y) + +# y = nn_ops.dropout_v2(x, rate=0) +# self.assertTrue(x is y) + +# def testDropoutWithIntegerInputs(self): +# x = constant_op.constant([1, 1, 1, 1, 1]) +# with self.assertRaises(ValueError): +# _ = nn_ops.dropout(x, 0.5) + + +# class ComputeSampledLogitsTest(test_lib.TestCase): + +# def setUp(self): +# self._eps = 1e-3 + +# def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels, +# sampled, subtract_log_q): +# """Randomly generates input/output data for a single test case. + +# This function returns numpy constants for use in a test case. + +# Args: +# num_classes: An int. The number of embedding classes in the test case. +# dim: An int. The dimension of the embedding. +# batch_size: An int. The batch size. +# num_true: An int. The number of target classes per training example. +# labels: A list of batch_size * num_true ints. The target classes. +# sampled: A list of indices in [0, num_classes). +# subtract_log_q: A bool corresponding to the parameter in +# _compute_sampled_logits(). + +# Returns: +# weights: Embedding weights to use as test input. It is a numpy array +# of shape [num_classes, dim] +# biases: Embedding biases to use as test input. It is a numpy array +# of shape [num_classes]. +# hidden_acts: Forward activations of the network to use as test input. +# It is a numpy array of shape [batch_size, dim]. +# sampled_vals: A tuple based on `sampled` to use as test input in the +# format returned by a *_candidate_sampler function. +# exp_logits: The output logits expected from _compute_sampled_logits(). +# It is a numpy array of shape [batch_size, num_true + len(sampled)]. +# exp_labels: The output labels expected from _compute_sampled_logits(). +# It is a numpy array of shape [batch_size, num_true + len(sampled)]. +# """ +# weights = np.random.randn(num_classes, dim).astype(np.float32) +# biases = np.random.randn(num_classes).astype(np.float32) +# hidden_acts = np.random.randn(batch_size, dim).astype(np.float32) + +# true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) +# sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) +# sampled_vals = (sampled, true_exp, sampled_exp) + +# sampled_w, sampled_b = weights[sampled], biases[sampled] +# true_w, true_b = weights[labels], biases[labels] + +# true_logits = np.sum( +# hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( +# (batch_size, num_true, dim)), +# axis=2) +# true_b = true_b.reshape((batch_size, num_true)) +# true_logits += true_b +# sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b + +# if subtract_log_q: +# true_logits -= np.log(true_exp) +# sampled_logits -= np.log(sampled_exp[np.newaxis, :]) + +# exp_logits = np.concatenate([true_logits, sampled_logits], axis=1) +# exp_labels = np.hstack((np.ones_like(true_logits) / num_true, +# np.zeros_like(sampled_logits))) + +# return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels + +# def _ShardTestEmbeddings(self, weights, biases, num_shards): +# """Shards the weights and biases returned by _GenerateTestData. + +# Args: +# weights: The weights returned by _GenerateTestData. +# biases: The biases returned by _GenerateTestData. +# num_shards: The number of shards to create. + +# Returns: +# sharded_weights: A list of size `num_shards` containing all the weights. +# sharded_biases: A list of size `num_shards` containing all the biases. +# """ +# with ops.Graph().as_default() as g: +# sharded_weights = variable_scope.get_variable( +# "w", +# partitioner=partitioned_variables.fixed_size_partitioner(num_shards), +# initializer=constant_op.constant(weights)) +# sharded_biases = variable_scope.get_variable( +# "b", +# partitioner=partitioned_variables.fixed_size_partitioner(num_shards), +# initializer=constant_op.constant(biases)) +# with self.session(graph=g) as sess: +# variables.global_variables_initializer().run() +# return self.evaluate([list(sharded_weights), list(sharded_biases)]) + +# def testShapes(self): +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 + +# for num_true in range(1, 5): +# labels = np.random.randint( +# low=0, high=num_classes, size=batch_size * num_true) +# (weights, biases, hidden_acts, sampled_vals, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=num_true, +# labels=labels, +# sampled=[1, 0, 2, 3], +# subtract_log_q=False) +# logits_tensor, labels_tensor = _compute_sampled_logits( +# weights=constant_op.constant(weights), +# biases=constant_op.constant(biases), +# labels=constant_op.constant( +# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=num_true, +# sampled_values=sampled_vals, +# subtract_log_q=False, +# remove_accidental_hits=False, +# partition_strategy="div", +# name="sampled_logits_basic_num_true_%d" % num_true) +# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) +# self.assertEqual(exp_logits.shape, got_logits.shape, self._eps) +# self.assertEqual(exp_labels.shape, got_labels.shape, self._eps) + +# def testBasic(self): +# """Without accidental hit removal or subtract_log_q.""" +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 + +# for num_true in range(1, 5): +# labels = np.random.randint( +# low=0, high=num_classes, size=batch_size * num_true) +# (weights, biases, hidden_acts, sampled_vals, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=num_true, +# labels=labels, +# sampled=[1, 0, 2, 3], +# subtract_log_q=False) +# logits_tensor, labels_tensor = _compute_sampled_logits( +# weights=constant_op.constant(weights), +# biases=constant_op.constant(biases), +# labels=constant_op.constant( +# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=num_true, +# sampled_values=sampled_vals, +# subtract_log_q=False, +# remove_accidental_hits=False, +# partition_strategy="div", +# name="sampled_logits_basic_num_true_%d" % num_true) +# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) +# self.assertAllClose(exp_logits, got_logits, self._eps) +# self.assertAllClose(exp_labels, got_labels, self._eps) + +# def testAccidentalHitRemoval(self): +# """With accidental hit removal, no subtract_log_q.""" +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 +# sampled = [1, 0, 2, 3] + +# for num_true in range(1, 5): +# labels = np.random.randint( +# low=0, high=num_classes, size=batch_size * num_true) +# (weights, biases, hidden_acts, sampled_vals, _, +# _) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=num_true, +# labels=labels, +# sampled=sampled, +# subtract_log_q=False) +# logits_tensor, _ = _compute_sampled_logits( +# weights=constant_op.constant(weights), +# biases=constant_op.constant(biases), +# labels=constant_op.constant( +# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=len(sampled), +# num_classes=num_classes, +# num_true=num_true, +# sampled_values=sampled_vals, +# subtract_log_q=False, +# remove_accidental_hits=True, +# partition_strategy="div", +# name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true) +# # Test that the exponentiated logits of accidental hits are near 0. +# # First we need to find the hits in this random test run: +# labels_reshape = labels.reshape((batch_size, num_true)) +# got_logits = self.evaluate(logits_tensor) +# for row in xrange(batch_size): +# row_labels = labels_reshape[row, :] +# for col in xrange(len(sampled)): +# if sampled[col] in row_labels: +# # We need to add the num_true_test offset into logits_* +# self.assertNear( +# np.exp(got_logits[row, col + num_true]), 0., self._eps) + +# def testSubtractLogQ(self): +# """With subtract_log_q, no accidental hit removal.""" +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 + +# for num_true in range(1, 5): +# labels = np.random.randint( +# low=0, high=num_classes, size=batch_size * num_true) +# (weights, biases, hidden_acts, sampled_vals, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=num_true, +# labels=labels, +# sampled=[1, 0, 2, 3], +# subtract_log_q=True) +# logits_tensor, labels_tensor = _compute_sampled_logits( +# weights=constant_op.constant(weights), +# biases=constant_op.constant(biases), +# labels=constant_op.constant( +# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=num_true, +# sampled_values=sampled_vals, +# subtract_log_q=True, +# remove_accidental_hits=False, +# partition_strategy="div", +# name="sampled_logits_subtract_log_q_num_true_%d" % num_true) +# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) +# self.assertAllClose(exp_logits, got_logits, self._eps) +# self.assertAllClose(exp_labels, got_labels, self._eps) + +# def testSharded(self): +# """With sharded weights and sharded biases.""" +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 + +# for num_true in range(1, 5): +# labels = np.random.randint( +# low=0, high=num_classes, size=batch_size * num_true) +# (weights, biases, hidden_acts, sampled_vals, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=num_true, +# labels=labels, +# sampled=[1, 0, 2, 3], +# subtract_log_q=False) +# weight_shards, bias_shards = self._ShardTestEmbeddings( +# weights, biases, num_shards=3) +# logits_tensor, labels_tensor = _compute_sampled_logits( +# weights=[constant_op.constant(shard) for shard in weight_shards], +# biases=[constant_op.constant(shard) for shard in bias_shards], +# labels=constant_op.constant( +# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=num_true, +# sampled_values=sampled_vals, +# subtract_log_q=False, +# remove_accidental_hits=False, +# partition_strategy="div", +# name="sampled_logits_sharded_num_true_%d" % num_true) +# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) +# self.assertAllClose(exp_logits, got_logits, self._eps) +# self.assertAllClose(exp_labels, got_labels, self._eps) + +# def testNCELoss(self): +# # A simple test to verify the numerics. + +# def _SigmoidCrossEntropyWithLogits(logits, targets): +# # logits, targets: float arrays of the same shape. +# assert logits.shape == targets.shape +# pred = 1. / (1. + np.exp(-logits)) +# eps = 0.0001 +# pred = np.minimum(np.maximum(pred, eps), 1 - eps) +# return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) + +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 +# labels = [0, 1, 2] +# (weights, biases, hidden_acts, sampled_vals, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=1, +# labels=labels, +# sampled=[1, 0, 2, 3], +# subtract_log_q=True) +# exp_nce_loss = np.sum( +# _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1) + +# got_nce_loss = nn_impl.nce_loss_v2( +# weights=constant_op.constant(weights), +# biases=constant_op.constant(biases), +# labels=constant_op.constant(labels, shape=(batch_size, 1)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=1, +# sampled_values=sampled_vals) + +# self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + +# # Test with sharded weights and sharded biases. +# weight_shards, bias_shards = self._ShardTestEmbeddings( +# weights, biases, num_shards=3) +# got_nce_loss = nn_impl.nce_loss_v2( +# weights=[constant_op.constant(shard) for shard in weight_shards], +# biases=[constant_op.constant(shard) for shard in bias_shards], +# labels=constant_op.constant(labels, shape=(batch_size, 1)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=1, +# sampled_values=sampled_vals) + +# self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + +# def testSampledSoftmaxLoss(self): +# # A simple test to verify the numerics. + +# def _SoftmaxCrossEntropyWithLogits(logits, targets): +# # logits, targets: float arrays of the same shape. +# assert logits.shape == targets.shape +# stable_exp_logits = np.exp( +# logits - np.amax(logits, axis=1, keepdims=True)) +# pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) +# return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 +# labels = [0, 1, 2] +# (weights, biases, hidden_acts, sampled_vals, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=1, +# labels=labels, +# sampled=[1, 0, 2, 3], +# subtract_log_q=True) +# exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( +# exp_logits, exp_labels) + +# got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( +# weights=constant_op.constant(weights), +# biases=constant_op.constant(biases), +# labels=constant_op.constant(labels, shape=(batch_size, 1)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=1, +# sampled_values=sampled_vals, +# remove_accidental_hits=False) + +# self.assertAllClose(exp_sampled_softmax_loss, +# self.evaluate(got_sampled_softmax_loss), 1e-4) + +# # Test with sharded weights and sharded biases. +# weight_shards, bias_shards = self._ShardTestEmbeddings( +# weights, biases, num_shards=3) +# got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( +# weights=[constant_op.constant(shard) for shard in weight_shards], +# biases=[constant_op.constant(shard) for shard in bias_shards], +# labels=constant_op.constant(labels, shape=(batch_size, 1)), +# inputs=constant_op.constant(hidden_acts), +# num_sampled=4, +# num_classes=num_classes, +# num_true=1, +# sampled_values=sampled_vals, +# remove_accidental_hits=False) + +# self.assertAllClose(exp_sampled_softmax_loss, +# self.evaluate(got_sampled_softmax_loss), 1e-4) + +# def testSampledSoftmaxLossBf16(self): +# # A simple test to verify the numerics for bfloat16. +# def _SoftmaxCrossEntropyWithLogits(logits, targets): +# # logits, targets: float arrays of the same shape. +# assert logits.shape == targets.shape +# stable_exp_logits = np.exp( +# logits - np.amax(logits, axis=1, keepdims=True)) +# pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) +# return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + +# np.random.seed(0) +# num_classes = 5 +# batch_size = 3 +# labels = [0, 1, 2] +# sampled = [1, 0, 2, 3] +# (weights, biases, hidden_acts, _, exp_logits, +# exp_labels) = self._GenerateTestData( +# num_classes=num_classes, +# dim=10, +# batch_size=batch_size, +# num_true=1, +# labels=labels, +# sampled=sampled, +# subtract_log_q=True) +# exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( +# exp_logits, exp_labels) + +# true_exp_bf16 = np.full([batch_size, 1], +# fill_value=0.5, +# dtype=dtypes.bfloat16.as_numpy_dtype) +# sampled_exp_bf16 = np.full([len(sampled)], +# fill_value=0.5, +# dtype=dtypes.bfloat16.as_numpy_dtype) +# sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) + +# got_sampled_softmax_loss = math_ops.cast( +# nn_impl.sampled_softmax_loss_v2( +# weights=constant_op.constant(weights, dtype=dtypes.bfloat16), +# biases=constant_op.constant(biases, dtype=dtypes.bfloat16), +# labels=constant_op.constant( +# labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), +# inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), +# num_sampled=4, +# num_classes=num_classes, +# num_true=1, +# sampled_values=sampled_vals_bf16, +# remove_accidental_hits=False), dtypes.float32) + +# self.assertAllClose(exp_sampled_softmax_loss, +# self.evaluate(got_sampled_softmax_loss), 1e-1) + + +# class GeluTest(test_lib.TestCase): + +# def test(self): + +# def gelu(x, approximate=False): +# if approximate: +# return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * +# (x + 0.044715 * np.power(x, 3)))) +# else: +# from scipy.stats import norm # pylint: disable=g-import-not-at-top +# return x * norm.cdf(x) + +# np.random.seed(1) # Make it reproducible. +# x = np.random.randn(3, 4).astype(np.float32) +# y = gelu(x) +# z = self.evaluate(nn_ops.gelu(constant_op.constant(x))) +# self.assertAllClose(y, z) + +# y = gelu(x, True) +# z = self.evaluate(nn_ops.gelu(constant_op.constant(x), True)) +# self.assertAllClose(y, z) + + +# class CReluTest(test_lib.TestCase): + +# def test(self): +# np.random.seed(1) # Make it reproducible. +# x = np.random.randn(3, 4).astype(np.float32) +# y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) + +# z = self.evaluate(nn_ops.crelu(constant_op.constant(x))) +# self.assertAllClose(y, z, 1e-4) + + +# class ReluTest(test_lib.TestCase): + +# def test(self): +# np.random.seed(1) # Make it reproducible. +# x = np.random.randn(3, 4).astype(np.float32) +# y = np.maximum(x, 0.0) + +# z = self.evaluate(nn_ops.relu(constant_op.constant(x))) +# self.assertAllEqual(y, z) + +# @test_util.run_deprecated_v1 +# def testNaNs(self): +# # Test that relu(nan) = nan for various sizes. +# for i in range(18): +# x = np.zeros(i) + np.nan +# with self.cached_session(): +# z = nn_ops.relu(constant_op.constant(x)).eval() +# self.assertTrue(np.isnan(z).all()) + + +# class LeakyReluTest(test_lib.TestCase): + +# def testRange(self): +# batch_size = 3 +# height, width = 4, 4 +# np.random.seed(1) # Make it reproducible. +# inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( +# np.float32) +# inputs = constant_op.constant(inputs) + +# outputs = nn_ops.leaky_relu(inputs) +# self.assertEquals(inputs.shape, outputs.shape) + +# inputs, outputs = self.evaluate([inputs, outputs]) + +# self.assertGreaterEqual(outputs.min(), 0.0) +# self.assertLessEqual(outputs.max(), 1.0) +# self.assertAllClose(inputs, outputs) + +# @test_util.run_deprecated_v1 +# def testValues(self): +# for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: +# np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) +# outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) + +# outputs = self.evaluate(outputs) + +# tol = 2e-3 if dtype == np.float16 else 1e-6 +# self.assertAllClose( +# outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) + +# @test_util.run_deprecated_v1 +# def testName(self): +# np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64) +# outputs_with_name_set = nn_ops.leaky_relu( +# constant_op.constant(np_values), +# name='test_relu_op') +# self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0') +# outputs_without_name_set = nn_ops.leaky_relu( +# constant_op.constant(np_values)) +# self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0') + + +# class SwishTest(test_lib.TestCase): + +# @test_util.run_deprecated_v1 +# def testValues(self): +# np_values = np.array( +# [np.linspace(-7.0, 0.0, 100), +# np.linspace(0.0, 7.0, 100)], +# dtype=np.float32) +# tf_values = constant_op.constant(np_values) +# actual_tf_outputs = nn_impl.swish(tf_values) +# expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values) + +# actual_outputs, expected_outputs = self.evaluate( +# [actual_tf_outputs, expected_tf_outputs]) + +# self.assertAllClose(actual_outputs, expected_outputs) + +# @test_util.run_deprecated_v1 +# def testGradients(self): +# shape = [5, 3, 4] +# sigma = 5 +# input_values = np.random.randn(*shape) * sigma +# x_tf = constant_op.constant(input_values) +# y_tf = nn_impl.swish(x_tf) +# with self.cached_session(): +# err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape) +# self.assertLess(err, 1e-4) + + +# class MomentsTest(test_lib.TestCase): + +# def doOutputTest(self, +# input_shape, +# moments_axes, +# tol=1e-4, +# check_gradients=False): +# for mu in [0.0, 1.0, 1e3]: +# for sigma in [1.0, 0.1]: +# for keep_dims in [True, False]: +# input_values = np.random.rand(*input_shape) * sigma + mu +# expected_mean = np.mean( +# input_values, axis=moments_axes, keepdims=keep_dims) +# expected_var = np.var( +# input_values, axis=moments_axes, keepdims=keep_dims) +# with ops.Graph().as_default() as g: +# with self.session(graph=g) as sess: +# inputs = constant_op.constant( +# input_values, shape=input_shape, dtype=dtypes.float32) +# mean, variance = nn_impl.moments_v2( +# inputs, moments_axes, keepdims=keep_dims) + +# if check_gradients: +# err = gradient_checker.compute_gradient_error( +# inputs, input_shape, mean, mean.shape.as_list()) +# self.assertLess(err, 1e-3) +# err = gradient_checker.compute_gradient_error( +# inputs, input_shape, variance, variance.shape.as_list()) +# self.assertLess(err, 1e-3) + +# # Evaluate. +# [mean, variance] = self.evaluate([mean, variance]) +# # Make sure that there are no NaNs +# self.assertFalse(np.isnan(mean).any()) +# self.assertFalse(np.isnan(variance).any()) +# self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) +# self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) + +# def testOutputAndGradient2DInput0(self): +# self.doOutputTest((10, 10), (0,), check_gradients=True) + +# def testOutputAndGradient2DInput01(self): +# self.doOutputTest((10, 10), (0, 1), check_gradients=True) + +# def testOutput2DInput0(self): +# self.doOutputTest((10, 300), (0,)) + +# def testOutput2DInput1(self): +# self.doOutputTest((10, 300), (1,)) + +# def testOutput2DInput01(self): +# self.doOutputTest((10, 300), (0, 1)) + +# def testOutput4DInput0(self): +# self.doOutputTest((10, 10, 10, 30), (0,)) + +# def testOutput4DInput1(self): +# self.doOutputTest((10, 10, 10, 30), (1,)) + +# def testOutput4DInput3(self): +# self.doOutputTest((10, 10, 10, 30), (3,)) + +# def testOutput4DInput012(self): +# self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) + +# def testOutput4DInput123(self): +# self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) + + +# class DataFormatDimMapTest(test_lib.TestCase): + +# def _test(self, x_val, y_val_expected): +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x) + +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def test(self): +# self._test(0, 0) +# self._test(1, 2) +# self._test(2, 3) +# self._test(3, 1) +# self._test(-1, 1) +# self._test(-2, 3) +# self._test(-3, 2) +# self._test(-4, 0) +# self._test([1, 3], [2, 1]) +# self._test([1, 3, -2], [2, 1, 3]) +# self._test([1, -3, -2], [2, 2, 3]) +# self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) + +# def testNHWCtoNCHW(self): +# x_val = [1, -3, -2] +# y_val_expected = [2, 2, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def testNHWCtoHWNC(self): +# x_val = [-4, -3, -2, -1, 0, 1, 2, 3] +# y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def testNHWCtoWHCN(self): +# x_val = [-4, -3, -2, -1, 0, 1, 2, 3] +# y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def testNDHWCtoNCDHW(self): +# x_val = [1, -4, -3, -2] +# y_val_expected = [2, 2, 3, 4] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def testNDHWCtoDHWNC(self): +# x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] +# y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def testDNHWCtoWHDCN(self): +# x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] +# y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# def testArbitraryASCII(self): +# x_val = [-4, -3, -2, -1, 0, 1, 2, 3] +# y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, y_val_expected) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testInvalidLength(self): +# x = [-4, -3, -2, -1, 0, 1, 2, 3] +# with self.assertRaisesRegex(errors.InvalidArgumentError, +# "Source format must be of length 4 or 5"): +# op = nn_ops.data_format_dim_map( +# x, src_format="12345678", dst_format="87654321") +# with test_util.use_gpu(): +# self.evaluate(op) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testDuplicateSrc(self): +# x = [-4, -3, -2, -1, 0, 1, 2, 3] +# with self.assertRaisesRegex( +# errors.InvalidArgumentError, +# "Destination and source format must determine a permutation"): +# op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321") +# with test_util.use_gpu(): +# self.evaluate(op) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testDuplicateDst(self): +# x = [-4, -3, -2, -1, 0, 1, 2, 3] +# with self.assertRaisesRegex( +# errors.InvalidArgumentError, +# "Destination and source format must determine a permutation"): +# op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321") +# with test_util.use_gpu(): +# self.evaluate(op) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testExtraSpecifiers(self): +# x = [-4, -3, -2, -1, 0, 1, 2, 3] +# with self.assertRaisesRegex( +# errors.InvalidArgumentError, +# "Destination and source format must determine a permutation"): +# op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321") +# with test_util.use_gpu(): +# self.evaluate(op) + + +# class DataFormatVectorPermuteTest(test_lib.TestCase): + +# def testNHWCToNCHW(self): +# x_val = [7, 4, 9, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x) +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [7, 3, 4, 9]) + +# def testNCHWToNHWC(self): +# x_val = [7, 4, 9, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [7, 9, 3, 4]) + +# def testNHWCToHWNC(self): +# x_val = [7, 4, 9, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [4, 9, 7, 3]) + +# def testHWNCToNHWC(self): +# x_val = [7, 4, 9, 3] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [9, 7, 4, 3]) + +# def testNHWCToNCHW2D(self): +# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x) +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) + +# def testNHWCToHWNC2D(self): +# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) + +# def testHWNCToNHWC2D(self): +# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) + +# def testNCHWToNHWC2D(self): +# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] +# x = constant_op.constant(x_val) +# y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") +# with test_util.use_gpu(): +# y_val = self.evaluate(y) +# self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testInvalidLength(self): +# x = [0, 1, 2, 3] +# with self.assertRaisesRegex(errors.InvalidArgumentError, +# "Source format must be of length 4 or 5"): +# op = nn_ops.data_format_vec_permute( +# x, src_format="12345678", dst_format="87654321") +# with test_util.use_gpu(): +# self.evaluate(op) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testDuplicateSrc(self): +# x = [0, 1, 2, 3] +# with self.assertRaisesRegex( +# errors.InvalidArgumentError, +# "Destination and source format must determine a permutation"): +# op = nn_ops.data_format_vec_permute( +# x, src_format="1233", dst_format="4321") +# with test_util.use_gpu(): +# self.evaluate(op) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testDuplicateDst(self): +# x = [0, 1, 2, 3] +# with self.assertRaisesRegex( +# errors.InvalidArgumentError, +# "Destination and source format must determine a permutation"): +# op = nn_ops.data_format_vec_permute( +# x, src_format="1234", dst_format="3321") +# with test_util.use_gpu(): +# self.evaluate(op) + +# @test_util.disable_xla("XLA catches the error and rethrows as different one") +# def testExtraSpecifiers(self): +# x = [0, 1, 2, 3] +# with self.assertRaisesRegex( +# errors.InvalidArgumentError, +# "Destination and source format must determine a permutation"): +# op = nn_ops.data_format_vec_permute( +# x, src_format="1234", dst_format="5321") +# with test_util.use_gpu(): +# self.evaluate(op) + + +# @test_util.run_all_in_graph_and_eager_modes +# class AvgPoolTest(test_lib.TestCase): + +# def test1DTensor(self): +# x = array_ops.ones([3, 6, 5]) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + +# def test1DNumpy(self): +# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 +# # np.ones defaults to using float64 when dtype is not explicitly specified +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = np.ones([3, 6, 5], dtype=dtype) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + +# def test1DNumpyWithGolden(self): +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = np.array([[[3], [6], [5]], +# [[1], [0], [1]]], dtype=dtype) +# ksize = 2 +# strides = 1 +# y = nn_ops.avg_pool1d(x, ksize, strides, "SAME") +# expected_y = np.array([[[4.5], [5.5], [5.0]], +# [[0.5], [0.5], [1.0]]], dtype=dtype) +# self.assertAllEqual(self.evaluate(y), expected_y) + +# def test2DTensor(self): +# x = array_ops.ones([3, 6, 6, 5]) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + +# def test2DNumpy(self): +# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 +# # np.ones defaults to using float64 when dtype is not explicitly specified +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = np.ones([3, 6, 6, 5], dtype=dtype) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + +# def test3DTensor(self): +# if test_lib.is_built_with_rocm(): +# self.skipTest("Pooling with 3D tensors is not supported in ROCm") +# x = array_ops.ones([3, 7, 6, 6, 5]) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + +# def test3DNumpy(self): +# if test_lib.is_built_with_rocm(): +# self.skipTest("Pooling with 3D tensors is not supported in ROCm") +# x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + +# @test_util.run_all_in_graph_and_eager_modes +# class MaxPoolTest(test_lib.TestCase): + +# def test1DTensor(self): +# x = array_ops.ones([3, 6, 5]) +# ksize = 2 +# strides = 2 + +# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") + +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test1DNumpy(self): - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - # np.ones defaults to using float64 when dtype is not explicitly specified - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = np.ones([3, 6, 5], dtype=dtype) - ksize = 2 - strides = 2 +# def test1DNumpy(self): +# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 +# # np.ones defaults to using float64 when dtype is not explicitly specified +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = np.ones([3, 6, 5], dtype=dtype) +# ksize = 2 +# strides = 2 - y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") +# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test1DNumpyWithGolden(self): - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = np.array([[[3], [6], [5]], - [[1], [0], [1]]], dtype=dtype) - ksize = 2 - strides = 1 - y = nn_ops.max_pool1d(x, ksize, strides, "SAME") - expected_y = np.array([[[6], [6], [5]], - [[1], [1], [1]]], dtype=dtype) - self.assertAllEqual(self.evaluate(y), expected_y) +# def test1DNumpyWithGolden(self): +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = np.array([[[3], [6], [5]], +# [[1], [0], [1]]], dtype=dtype) +# ksize = 2 +# strides = 1 +# y = nn_ops.max_pool1d(x, ksize, strides, "SAME") +# expected_y = np.array([[[6], [6], [5]], +# [[1], [1], [1]]], dtype=dtype) +# self.assertAllEqual(self.evaluate(y), expected_y) - def test2DTensor(self): - x = array_ops.ones([3, 6, 6, 5]) - ksize = 2 - strides = 2 +# def test2DTensor(self): +# x = array_ops.ones([3, 6, 6, 5]) +# ksize = 2 +# strides = 2 - y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.max_pool(x, ksize, strides, "SAME") +# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.max_pool(x, ksize, strides, "SAME") - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test2DNumpy(self): - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - # np.ones defaults to using float64 when dtype is not explicitly specified - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = np.ones([3, 6, 6, 5], dtype=dtype) - ksize = 2 - strides = 2 +# def test2DNumpy(self): +# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 +# # np.ones defaults to using float64 when dtype is not explicitly specified +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = np.ones([3, 6, 6, 5], dtype=dtype) +# ksize = 2 +# strides = 2 - y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.max_pool(x, ksize, strides, "SAME") +# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.max_pool(x, ksize, strides, "SAME") - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test3DTensor(self): - if test_lib.is_built_with_rocm(): - self.skipTest("Pooling with 3D tensors is not supported in ROCm") - x = array_ops.ones([3, 7, 6, 6, 5]) - ksize = 2 - strides = 2 +# def test3DTensor(self): +# if test_lib.is_built_with_rocm(): +# self.skipTest("Pooling with 3D tensors is not supported in ROCm") +# x = array_ops.ones([3, 7, 6, 6, 5]) +# ksize = 2 +# strides = 2 - y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") +# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test3DNumpy(self): - if test_lib.is_built_with_rocm(): - self.skipTest("Pooling with 3D tensors is not supported in ROCm") - x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) - ksize = 2 - strides = 2 +# def test3DNumpy(self): +# if test_lib.is_built_with_rocm(): +# self.skipTest("Pooling with 3D tensors is not supported in ROCm") +# x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) +# ksize = 2 +# strides = 2 - y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") - y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") +# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") +# y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def testIncorrectSizeInputSmall(self): - x = array_ops.ones([3, 4]) - with self.assertRaisesRegex( - ValueError, "Input tensor must be of rank 3, 4 or 5 but was 2."): - nn_ops.max_pool_v2(x, 2, 2, "SAME") +# def testIncorrectSizeInputSmall(self): +# x = array_ops.ones([3, 4]) +# with self.assertRaisesRegex( +# ValueError, "Input tensor must be of rank 3, 4 or 5 but was 2."): +# nn_ops.max_pool_v2(x, 2, 2, "SAME") - def testIncorrectSizeInput(self): - x = array_ops.ones([3, 4, 1, 2, 1, 2]) - with self.assertRaisesRegex( - ValueError, "Input tensor must be of rank 3, 4 or 5 but was 6."): - nn_ops.max_pool_v2(x, 2, 2, "SAME") +# def testIncorrectSizeInput(self): +# x = array_ops.ones([3, 4, 1, 2, 1, 2]) +# with self.assertRaisesRegex( +# ValueError, "Input tensor must be of rank 3, 4 or 5 but was 6."): +# nn_ops.max_pool_v2(x, 2, 2, "SAME") -@test_util.run_all_in_graph_and_eager_modes -class ConvolutionTest(test_lib.TestCase): +# @test_util.run_all_in_graph_and_eager_modes +# class ConvolutionTest(test_lib.TestCase): - def testUnknownSize(self): - # explicilty use float32 for ROCm, as MIOpen does not yet support float64 - # np.ones defaults to using float64 when dtype is not explicitly specified - dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 - x = tensor_spec.TensorSpec(None, dtypes.float32, name="x") - k = np.ones([3, 6, 6, 5], dtype=dtype) +# def testUnknownSize(self): +# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 +# # np.ones defaults to using float64 when dtype is not explicitly specified +# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 +# x = tensor_spec.TensorSpec(None, dtypes.float32, name="x") +# k = np.ones([3, 6, 6, 5], dtype=dtype) - @def_function.function - def F(value): - return nn_ops.convolution(value, k, "SAME") +# @def_function.function +# def F(value): +# return nn_ops.convolution(value, k, "SAME") - F.get_concrete_function(x) +# F.get_concrete_function(x) -class ConvTransposeTest(test_lib.TestCase): +# class ConvTransposeTest(test_lib.TestCase): - def test1D(self): - t = array_ops.ones([2, 4, 3]) - v = array_ops.ones([2, 5, 3]) - strides = 2 +# def test1D(self): +# t = array_ops.ones([2, 4, 3]) +# v = array_ops.ones([2, 5, 3]) +# strides = 2 - y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) - y2 = nn_ops.conv_transpose(t, v, [2, 8, 5], strides) +# y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) +# y2 = nn_ops.conv_transpose(t, v, [2, 8, 5], strides) - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test1DTensor(self): - t = array_ops.ones([2, 4, 3]) - v = array_ops.ones([2, 5, 3]) - strides = 2 +# def test1DTensor(self): +# t = array_ops.ones([2, 4, 3]) +# v = array_ops.ones([2, 5, 3]) +# strides = 2 - y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) - y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 5]), strides) +# y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) +# y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 5]), strides) - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test2D(self): - t = array_ops.ones([2, 4, 4, 3]) - v = array_ops.ones([2, 2, 5, 3]) - strides = 2 +# def test2D(self): +# t = array_ops.ones([2, 4, 4, 3]) +# v = array_ops.ones([2, 2, 5, 3]) +# strides = 2 - y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) - y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 5], strides) +# y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) +# y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 5], strides) - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test2DTensor(self): - t = array_ops.ones([2, 4, 4, 3]) - v = array_ops.ones([2, 2, 5, 3]) - strides = 2 +# def test2DTensor(self): +# t = array_ops.ones([2, 4, 4, 3]) +# v = array_ops.ones([2, 2, 5, 3]) +# strides = 2 - y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) - y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 5]), - strides) +# y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) +# y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 5]), +# strides) - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test3D(self): - t = array_ops.ones([2, 4, 4, 4, 3]) - v = array_ops.ones([2, 2, 2, 5, 3]) - strides = 2 +# def test3D(self): +# t = array_ops.ones([2, 4, 4, 4, 3]) +# v = array_ops.ones([2, 2, 2, 5, 3]) +# strides = 2 - y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) - y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 8, 5], strides) +# y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) +# y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 8, 5], strides) - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def test3DTensor(self): - t = array_ops.ones([2, 4, 4, 4, 3]) - v = array_ops.ones([2, 2, 2, 5, 3]) - strides = 2 +# def test3DTensor(self): +# t = array_ops.ones([2, 4, 4, 4, 3]) +# v = array_ops.ones([2, 2, 2, 5, 3]) +# strides = 2 - y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) - y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 8, 5]), - strides) +# y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) +# y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 8, 5]), +# strides) - self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - def testIncorrectSizeInputSmall(self): - with self.assertRaisesRegex( - ValueError, "output_shape must be of length 3, 4 or 5 but was 2."): - nn_ops.conv_transpose(None, 2, [2, 3], "SAME") +# def testIncorrectSizeInputSmall(self): +# with self.assertRaisesRegex( +# ValueError, "output_shape must be of length 3, 4 or 5 but was 2."): +# nn_ops.conv_transpose(None, 2, [2, 3], "SAME") - def testIncorrectSizeInput(self): - with self.assertRaisesRegex( - ValueError, "output_shape must be of length 3, 4 or 5 but was 6."): - nn_ops.conv_transpose(None, 2, [2, 3, 4, 2, 5, 1], "SAME") +# def testIncorrectSizeInput(self): +# with self.assertRaisesRegex( +# ValueError, "output_shape must be of length 3, 4 or 5 but was 6."): +# nn_ops.conv_transpose(None, 2, [2, 3, 4, 2, 5, 1], "SAME") - def testTensorsNoShape(self): - with self.assertRaisesRegex( - ValueError, - "output_shape must be a tensor or sized collection."): - nn_ops.conv_transpose(None, None, None, None) +# def testTensorsNoShape(self): +# with self.assertRaisesRegex( +# ValueError, +# "output_shape must be a tensor or sized collection."): +# nn_ops.conv_transpose(None, None, None, None) if __name__ == "__main__": From 0caf45fde7567355cff9c11471bd24c0f031c86d Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Thu, 19 May 2022 14:10:23 +0800 Subject: [PATCH 04/17] [Op] Add handling of 128 remainders in fused l2 norm. --- .../fused_l2_normalize_op.cc | 153 +++++++++++++++++- .../fused_l2_normalize_op_test.cc | 37 +++-- 2 files changed, 169 insertions(+), 21 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index fb050bd7707..cd45b203e14 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -103,8 +103,11 @@ class FusedL2NormalizeOp : public OpKernel { template void forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { int64 avx3_block_num = cols >> 7; // cols / 128 - // printf("cols: %d, avx3_block_num: %d\n", cols, avx3_block_num); + // handle remainder of 128 + int64 remainder = cols - (avx3_block_num << 7); + // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); for (int64 i = begin_row; i < end_row; ++i) { + int64 tmp_remainder = remainder; float row_sum = 0.0; for (int64 j = 0; j < avx3_block_num; ++j) { __m512 inputs[SUM_BLOCK_SIZE]; @@ -116,6 +119,43 @@ class FusedL2NormalizeOp : public OpKernel { __m512 block_sum = reduce_sum_block8_ps(inputs); row_sum += _mm512_reduce_add_ps(block_sum); } + if (tmp_remainder > 0) { + if (tmp_remainder >= 64) { + __m256 inputs[8]; + auto load_256 = [&](auto idx) { + inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); + inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for<8>::op(load_256); + __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(inputs); + row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + tmp_remainder -= 64; + } + if (tmp_remainder > 32) { + __m256 inputs[4]; + auto load_256 = [&](auto idx) { + inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); + inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for<4>::op(load_256); + __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(inputs); + row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + tmp_remainder -= 32; + } + if (tmp_remainder >= 16) { + __m512 inputs = _mm512_loadu_ps(input + cols * i + cols - tmp_remainder); + inputs = _mm512_mul_ps(inputs, inputs); + row_sum += _mm512_reduce_add_ps(inputs); + tmp_remainder -= 16; + } + if (tmp_remainder > 0) { + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - tmp_remainder); + inputs = _mm512_mul_ps(inputs, inputs); + row_sum += _mm512_reduce_add_ps(inputs); + } + } + row_sum += epsilon; row_sum = 1.0 / std::sqrt(row_sum); __m512 row_sums = _mm512_set1_ps(row_sum); @@ -124,6 +164,12 @@ class FusedL2NormalizeOp : public OpKernel { inputs = _mm512_mul_ps(inputs, row_sums); _mm512_storeu_ps(output + cols * i + j, inputs); } + if (remainder > 0){ + __mmask16 mask = 0xFFFF >> (16 - remainder); + __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - remainder); + inputs = _mm512_mul_ps(inputs, row_sums); + _mm512_mask_storeu_ps(output + cols * i + cols - remainder, mask, inputs); + } } } @@ -143,6 +189,22 @@ class FusedL2NormalizeOp : public OpKernel { block_sum = _mm512_add_ps(block_sum, v[7]); return block_sum; } + inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + block_sum = _mm256_add_ps(block_sum, v[4]); + block_sum = _mm256_add_ps(block_sum, v[5]); + block_sum = _mm256_add_ps(block_sum, v[6]); + block_sum = _mm256_add_ps(block_sum, v[7]); + return block_sum; + } + inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + return block_sum; + } private: float epsilon; @@ -255,10 +317,13 @@ class FusedL2NormalizeGradOp : public OpKernel { template void backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { int64 avx3_block_num = cols >> 7; // cols / 128 - // printf("backward cols: %d, avx3_block_num: %d\n", cols, avx3_block_num); + // handle remainder of 128 + int64 remainder = cols - (avx3_block_num << 7); + // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); for (int64 i = begin_row; i < end_row; ++i) { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; + int64 tmp_remainder = remainder; for (int64 j = 0; j < avx3_block_num; ++j) { __m512 xs[SUM_BLOCK_SIZE]; auto x_load = [&](auto idx) { @@ -279,6 +344,65 @@ class FusedL2NormalizeGradOp : public OpKernel { __m512 y_grad_block_sum = reduce_sum_block8_ps(y_grads); y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); } + if (tmp_remainder > 0) { + if (tmp_remainder >= 64) { + __m256 xs[8]; + auto x_load_256 = [&](auto idx) { + xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for<8>::op(x_load_256); + __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(xs); + x_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + + __m256 y_grads[8]; + auto y_grad_load_256 = [&](auto idx) { + y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - tmp_remainder + 8 * idx); + xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for<8>::op(y_grad_load_256); + __m256 y_grad_block_sum_remainder = reduce_sum_block8_mm256_ps(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(y_grad_block_sum_remainder)); + tmp_remainder -= 64; + } + if (tmp_remainder > 32) { + __m256 xs[4]; + auto x_load_256 = [&](auto idx) { + xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for<4>::op(x_load_256); + __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(xs); + x_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + + __m256 y_grads[4]; + auto y_grad_load_256 = [&](auto idx) { + y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - tmp_remainder + 8 * idx); + xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for<4>::op(y_grad_load_256); + __m256 y_grad_block_sum_remainder = reduce_sum_block4_mm256_ps(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(y_grad_block_sum_remainder)); + tmp_remainder -= 32; + } + if (tmp_remainder >= 16) { + __m512 xs = _mm512_loadu_ps(x + cols * i + cols - tmp_remainder); + __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + cols - tmp_remainder); + x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); + y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); + tmp_remainder -= 16; + } + if (tmp_remainder > 0) { + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); + __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - tmp_remainder); + x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); + y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); + } + } + x_row_sum += epsilon; x_row_sum = 1.0 / std::sqrt(x_row_sum); y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); @@ -292,6 +416,15 @@ class FusedL2NormalizeGradOp : public OpKernel { y_grads = _mm512_sub_ps(y_grads, xs); _mm512_storeu_ps(x_grad + cols * i + j, y_grads); } + if (remainder > 0){ + __mmask16 mask = 0xFFFF >> (16 - remainder); + __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - remainder); + __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder); + y_grads = _mm512_mul_ps(y_grads, x_row_sums); + xs = _mm512_mul_ps(xs, y_grad_row_sums); + y_grads = _mm512_sub_ps(y_grads, xs); + _mm512_mask_storeu_ps(x_grad + cols * i + cols - remainder, mask, y_grads); + } } } @@ -305,6 +438,22 @@ class FusedL2NormalizeGradOp : public OpKernel { block_sum = _mm512_add_ps(block_sum, v[7]); return block_sum; } + inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + block_sum = _mm256_add_ps(block_sum, v[4]); + block_sum = _mm256_add_ps(block_sum, v[5]); + block_sum = _mm256_add_ps(block_sum, v[6]); + block_sum = _mm256_add_ps(block_sum, v[7]); + return block_sum; + } + inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + return block_sum; + } private: float epsilon; diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc index f204866f32c..4a74340e5b5 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -31,20 +31,19 @@ class FusedL2NormalizeOpTest : public OpsTestBase { TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { const int rows = 4; - const int cols = 16; + const int cols = 252; //128+64+32+16+8=252 1008 MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12); // emb_shards - AddInputFromArray(TensorShape({rows, cols}), { - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + // AddInputFromArray(TensorShape({rows, cols}), { + // 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + // 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + float input_array[1008]; + for (int i = 0; i < sizeof(input_array) / sizeof(float); i++) { + input_array[i] = 1.0; + } + AddInputFromArray(TensorShape({rows, cols}), input_array); TF_ASSERT_OK(RunOpKernel()); TF_EXPECT_OK(device_->Sync()); @@ -52,15 +51,15 @@ TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { { Tensor expected_output(allocator(), DT_FLOAT, TensorShape({rows, cols})); - test::FillValues(&expected_output, { - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + // test::FillValues(&expected_output, { + // 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, + // 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}); + float output_array[1008]; + float output_value = 1.0 / std::sqrt(cols); + for (int i = 0; i < sizeof(output_array) / sizeof(float); i++) { + output_array[i] = output_value; + } + test::FillValues(&expected_output, output_array); test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); } } From 82d73a278848323780c72d678adb93deb75d556e Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Tue, 24 May 2022 17:01:59 +0800 Subject: [PATCH 05/17] [Ops] Fix bug in store remainder output. --- .../fused_l2_normalize_op.cc | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index cd45b203e14..1b406b3f10a 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -25,7 +25,9 @@ class FusedL2NormalizeOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); } - ~FusedL2NormalizeOp() {} + ~FusedL2NormalizeOp() { + printf("RUN ~FusedL2NormalizeOp().\n"); + } void Compute(OpKernelContext* context) override { // Grab the input @@ -56,6 +58,7 @@ class FusedL2NormalizeOp : public OpKernel { auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); thread::ThreadPool *thread_pool = worker_threads.workers; + printf("[LOG] Start to calculate l2 norm in parallel.\n"); thread_pool->ParallelFor(total_unit, unit_cost, [&input, &output, rows, cols, this](int64 begin_unit, int64 end_unit) { auto begin_row = begin_unit * BLOCK_SIZE; @@ -65,6 +68,7 @@ class FusedL2NormalizeOp : public OpKernel { } forward<8>(input, output, begin_row, end_row, cols); }); + printf("[LOG] Complete calculate l2 norm in parallel.\n"); } private: @@ -105,8 +109,10 @@ class FusedL2NormalizeOp : public OpKernel { int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); - // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); + printf("[LOG] cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); + printf("[LOG] begin_row: %d, end_row:%d\n", begin_row, end_row); for (int64 i = begin_row; i < end_row; ++i) { + printf("[LOG] \tstart to hande %d row\n", i); int64 tmp_remainder = remainder; float row_sum = 0.0; for (int64 j = 0; j < avx3_block_num; ++j) { @@ -120,7 +126,9 @@ class FusedL2NormalizeOp : public OpKernel { row_sum += _mm512_reduce_add_ps(block_sum); } if (tmp_remainder > 0) { + printf("[LOG] \tStart to handle remainder and remainder is %d\n", tmp_remainder); if (tmp_remainder >= 64) { + printf("[LOG] \tHandle 64 remainer and remainder is %d\n", tmp_remainder); __m256 inputs[8]; auto load_256 = [&](auto idx) { inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); @@ -132,6 +140,7 @@ class FusedL2NormalizeOp : public OpKernel { tmp_remainder -= 64; } if (tmp_remainder > 32) { + printf("[LOG] \tHandle 32 remainer and remainder is %d\n", tmp_remainder); __m256 inputs[4]; auto load_256 = [&](auto idx) { inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); @@ -143,12 +152,14 @@ class FusedL2NormalizeOp : public OpKernel { tmp_remainder -= 32; } if (tmp_remainder >= 16) { + printf("[LOG] \tHandle 16 remainer and remainder is %d\n", tmp_remainder); __m512 inputs = _mm512_loadu_ps(input + cols * i + cols - tmp_remainder); inputs = _mm512_mul_ps(inputs, inputs); row_sum += _mm512_reduce_add_ps(inputs); tmp_remainder -= 16; } if (tmp_remainder > 0) { + printf("[LOG] \tHandle 0 remainer and remainder is %d\n", tmp_remainder); __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - tmp_remainder); inputs = _mm512_mul_ps(inputs, inputs); @@ -164,13 +175,14 @@ class FusedL2NormalizeOp : public OpKernel { inputs = _mm512_mul_ps(inputs, row_sums); _mm512_storeu_ps(output + cols * i + j, inputs); } - if (remainder > 0){ - __mmask16 mask = 0xFFFF >> (16 - remainder); - __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - remainder); + if (tmp_remainder > 0){ + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - tmp_remainder); inputs = _mm512_mul_ps(inputs, row_sums); - _mm512_mask_storeu_ps(output + cols * i + cols - remainder, mask, inputs); + _mm512_mask_storeu_ps(output + cols * i + cols - tmp_remainder, mask, inputs); } } + printf("[LOG] Complete row %d~%d\n", begin_row, end_row); } // data type: FP32, 16 FP32 per __m512 @@ -319,7 +331,7 @@ class FusedL2NormalizeGradOp : public OpKernel { int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); - // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); + // printf("[LOG] cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); for (int64 i = begin_row; i < end_row; ++i) { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; @@ -416,14 +428,14 @@ class FusedL2NormalizeGradOp : public OpKernel { y_grads = _mm512_sub_ps(y_grads, xs); _mm512_storeu_ps(x_grad + cols * i + j, y_grads); } - if (remainder > 0){ - __mmask16 mask = 0xFFFF >> (16 - remainder); - __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - remainder); - __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder); + if (tmp_remainder > 0){ + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - tmp_remainder); + __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); y_grads = _mm512_mul_ps(y_grads, x_row_sums); xs = _mm512_mul_ps(xs, y_grad_row_sums); y_grads = _mm512_sub_ps(y_grads, xs); - _mm512_mask_storeu_ps(x_grad + cols * i + cols - remainder, mask, y_grads); + _mm512_mask_storeu_ps(x_grad + cols * i + cols - tmp_remainder, mask, y_grads); } } } From 9df47ff160e1743ad8458e2dbf70b8bbce8d9103 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 25 May 2022 14:06:40 +0800 Subject: [PATCH 06/17] [Op] Fix array out of bounds in store output. --- .../fused_l2_normalize_op.cc | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index 1b406b3f10a..78f5cd9b733 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -58,7 +58,6 @@ class FusedL2NormalizeOp : public OpKernel { auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); thread::ThreadPool *thread_pool = worker_threads.workers; - printf("[LOG] Start to calculate l2 norm in parallel.\n"); thread_pool->ParallelFor(total_unit, unit_cost, [&input, &output, rows, cols, this](int64 begin_unit, int64 end_unit) { auto begin_row = begin_unit * BLOCK_SIZE; @@ -68,7 +67,6 @@ class FusedL2NormalizeOp : public OpKernel { } forward<8>(input, output, begin_row, end_row, cols); }); - printf("[LOG] Complete calculate l2 norm in parallel.\n"); } private: @@ -109,10 +107,8 @@ class FusedL2NormalizeOp : public OpKernel { int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); - printf("[LOG] cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); - printf("[LOG] begin_row: %d, end_row:%d\n", begin_row, end_row); + // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); for (int64 i = begin_row; i < end_row; ++i) { - printf("[LOG] \tstart to hande %d row\n", i); int64 tmp_remainder = remainder; float row_sum = 0.0; for (int64 j = 0; j < avx3_block_num; ++j) { @@ -126,9 +122,7 @@ class FusedL2NormalizeOp : public OpKernel { row_sum += _mm512_reduce_add_ps(block_sum); } if (tmp_remainder > 0) { - printf("[LOG] \tStart to handle remainder and remainder is %d\n", tmp_remainder); if (tmp_remainder >= 64) { - printf("[LOG] \tHandle 64 remainer and remainder is %d\n", tmp_remainder); __m256 inputs[8]; auto load_256 = [&](auto idx) { inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); @@ -140,7 +134,6 @@ class FusedL2NormalizeOp : public OpKernel { tmp_remainder -= 64; } if (tmp_remainder > 32) { - printf("[LOG] \tHandle 32 remainer and remainder is %d\n", tmp_remainder); __m256 inputs[4]; auto load_256 = [&](auto idx) { inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); @@ -152,14 +145,12 @@ class FusedL2NormalizeOp : public OpKernel { tmp_remainder -= 32; } if (tmp_remainder >= 16) { - printf("[LOG] \tHandle 16 remainer and remainder is %d\n", tmp_remainder); __m512 inputs = _mm512_loadu_ps(input + cols * i + cols - tmp_remainder); inputs = _mm512_mul_ps(inputs, inputs); row_sum += _mm512_reduce_add_ps(inputs); tmp_remainder -= 16; } if (tmp_remainder > 0) { - printf("[LOG] \tHandle 0 remainer and remainder is %d\n", tmp_remainder); __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - tmp_remainder); inputs = _mm512_mul_ps(inputs, inputs); @@ -170,7 +161,7 @@ class FusedL2NormalizeOp : public OpKernel { row_sum += epsilon; row_sum = 1.0 / std::sqrt(row_sum); __m512 row_sums = _mm512_set1_ps(row_sum); - for (int64 j = 0; j < cols; j += 16) { + for (int64 j = 0; j < cols - 15; j += 16) { __m512 inputs = _mm512_loadu_ps(input + cols * i + j); inputs = _mm512_mul_ps(inputs, row_sums); _mm512_storeu_ps(output + cols * i + j, inputs); @@ -182,7 +173,6 @@ class FusedL2NormalizeOp : public OpKernel { _mm512_mask_storeu_ps(output + cols * i + cols - tmp_remainder, mask, inputs); } } - printf("[LOG] Complete row %d~%d\n", begin_row, end_row); } // data type: FP32, 16 FP32 per __m512 @@ -331,7 +321,7 @@ class FusedL2NormalizeGradOp : public OpKernel { int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); - // printf("[LOG] cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); + // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); for (int64 i = begin_row; i < end_row; ++i) { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; @@ -420,7 +410,7 @@ class FusedL2NormalizeGradOp : public OpKernel { y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); __m512 x_row_sums = _mm512_set1_ps(x_row_sum); __m512 y_grad_row_sums = _mm512_set1_ps(y_grad_row_sum); - for (int64 j = 0; j < cols; j += 16) { + for (int64 j = 0; j < cols - 15; j += 16) { __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + j); __m512 xs = _mm512_loadu_ps(x + cols * i + j); y_grads = _mm512_mul_ps(y_grads, x_row_sums); From 82bcbcbaca19cfb558193ad963db94c176a55439 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 25 May 2022 17:00:03 +0800 Subject: [PATCH 07/17] [Op] Enable fused l2 norm in tf.nn.l2_normalize. --- tensorflow/python/ops/nn_impl.py | 38 +++++++++++++++++--------------- tensorflow/python/ops/nn_test.py | 30 ++++++++++++------------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index d4bfc676621..e8fb43483d2 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -591,7 +591,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None): @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) @deprecated_args(None, "dim is deprecated, use axis instead", "dim") -def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): +def l2_normalize(x, axis=None, epsilon=1e-12, do_fusion=False, name=None, dim=None): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes @@ -608,17 +608,18 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. name: A name for this operation (optional). + do_fusion: Whether fuse op when doing l2 norm on last axis. dim: Deprecated alias for axis. Returns: A `Tensor` with the same shape as `x`. """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) - return l2_normalize_v2(x, axis, epsilon, name) + return l2_normalize_v2(x, axis, epsilon, do_fusion, name) @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) -def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): +def l2_normalize_v2(x, axis=None, epsilon=1e-12, do_fusion=False, name=None): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes @@ -634,33 +635,34 @@ def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): integers. epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. + do_fusion: Whether fuse op when doing l2 norm on last axis. name: A name for this operation (optional). Returns: A `Tensor` with the same shape as `x`. """ - with ops.name_scope(name, "l2_normalize", [x]) as name: - x = ops.convert_to_tensor(x, name="x") - square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) - x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) - return math_ops.multiply(x, x_inv_norm, name=name) - + if do_fusion and x.dtype == dtypes.float32 and ( + axis is None or axis== x.shape.rank - 1): + return fused_l2_normalize(x, epsilon=epsilon, name=name) + else: + with ops.name_scope(name, "l2_normalize", [x]) as name: + x = ops.convert_to_tensor(x, name="x") + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) + x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) + return math_ops.multiply(x, x_inv_norm, name=name) -@tf_export(v1=["math.fused_l2_normalize", "linalg.fused_l2_normalize", "nn.fused_l2_normalize"]) -def fused_l2_normalize(x, axis=None, epsilon=1e-12, name=None): - """Normalizes along dimension `axis` using an L2 norm. +def fused_l2_normalize(x, epsilon=1e-12, name=None): + """Normalizes along last dimension using an L2 norm. - For a 1-D tensor with `axis = 0`, computes + For a 1-D tensor, computes output = x / sqrt(max(sum(x**2), epsilon)) For `x` with more dimensions, independently normalizes each 1-D slice along - dimension `axis`. + lastdimension. Args: x: A `Tensor`. - axis: Dimension along which to normalize. A scalar or a vector of - integers. epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. name: A name for this operation (optional). @@ -670,8 +672,8 @@ def fused_l2_normalize(x, axis=None, epsilon=1e-12, name=None): """ with ops.name_scope(name, "fused_l2_normalize", [x]) as name: x = ops.convert_to_tensor(x, name="x") - return gen_fused_l2_normalize_ops.fused_l2_normalize( - x, axis=axis, epsilon=epsilon, name=name) + return gen_fused_l2_normalize_ops.fused_l2_normalize(x, + epsilon=epsilon, name=name) def _count_nonzero(input_tensor, dtype=dtypes.int64): diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 6c8bfd77abe..8112dc6512b 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -320,25 +320,25 @@ def testFusedL2Normalize(self): x_shape = [20, 7, 3] np.random.seed(1) x_np = np.random.random_sample(x_shape).astype(np.float32) - for dim in [0]: + for dim in [2]: y_np = self._l2Normalize(x_np, dim) x_tf = constant_op.constant(x_np, name="x") - y_tf = nn_impl.fused_l2_normalize(x_tf, dim) + y_tf = nn_impl.fused_l2_normalize(x_tf) self.assertAllClose(y_np, self.evaluate(y_tf)) - # @test_util.run_deprecated_v1 - # def testFusedL2NormalizeGradient(self): - # x_shape = [20, 7, 3] - # np.random.seed(1) - # x_np = np.random.random_sample(x_shape).astype(np.float64) - # for dim in range(len(x_shape)): - # with self.cached_session(): - # x_tf = constant_op.constant(x_np, name="x") - # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) - # err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, - # x_shape) - # print("L2Normalize gradient err = %g " % err) - # self.assertLess(err, 1e-4) + @test_util.run_deprecated_v1 + def testFusedL2NormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in [2]: + with self.cached_session(): + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.fused_l2_normalize(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + print("L2Normalize gradient err = %g " % err) + self.assertLess(err, 1e-4) # class DropoutTest(test_lib.TestCase): From 9592612da32923619137db014a98c6d57aac1b49 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Fri, 27 May 2022 10:26:22 +0800 Subject: [PATCH 08/17] [Op] Add l2 norm grad test. --- tensorflow/core/kernels/BUILD | 3 +- .../fused_l2_normalize_grad_op_test.cc | 73 +++++++++++++++++++ .../fused_l2_normalize_op.cc | 41 +++++++++-- .../fused_l2_normalize_op_test.cc | 13 +--- tensorflow/python/ops/nn_test.py | 6 +- 5 files changed, 115 insertions(+), 21 deletions(-) create mode 100644 tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a292b6f87f6..804cb641935 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5415,7 +5415,8 @@ tf_kernel_library( tf_cc_test( name = "fused_l2_normalize_ops_test", size = "small", - srcs = ["fused_l2_normalize/fused_l2_normalize_op_test.cc"], + srcs = ["fused_l2_normalize/fused_l2_normalize_op_test.cc", + "fused_l2_normalize/fused_l2_normalize_grad_op_test.cc"], deps = [ ":fused_l2_normalize_ops", ":ops_testutil", diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc new file mode 100644 index 00000000000..32086fada3f --- /dev/null +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc @@ -0,0 +1,73 @@ +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace { + +enum class Device { CPU, GPU }; + +class FusedL2NormalizeGradOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, DataType dtype, int axis, float epsilon) { + TF_EXPECT_OK(NodeDefBuilder("fused_l2_normalize_grad", "FusedL2NormalizeGrad") + .Attr("T", dtype) + .Attr("T", dtype) + .Attr("axis", axis) + .Attr("epsilon", epsilon) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) { + const int rows = 4; + const int cols = 252; //128+64+32+16+8+4=252 1008 + + MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12); + + // y_grad + float y_grad_array[1008]; + for (int i = 0; i < sizeof(y_grad_array) / sizeof(float); i++) { + y_grad_array[i] = 1.0; + } + AddInputFromArray(TensorShape({rows, cols}), y_grad_array); + + // x + float x_array[1008]; + for (int i = 0; i < sizeof(x_array) / sizeof(float); i++) { + x_array[i] = 1.0; + } + AddInputFromArray(TensorShape({rows, cols}), x_array); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_output(allocator(), DT_FLOAT, + TensorShape({rows, cols})); + float output_array[1008]; + for (int i = 0; i < sizeof(output_array) / sizeof(float); i++) { + output_array[i] = 0; + } + test::FillValues(&expected_output, output_array); + test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); + } +} + +//----------------------------------------------------------------------------// +// Performance benchmarks // +//----------------------------------------------------------------------------// +} +} diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index 78f5cd9b733..e96ce96d075 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -65,7 +65,11 @@ class FusedL2NormalizeOp : public OpKernel { if (end_row > rows) { end_row = rows; } +#ifdef __AVX512F__ + forward_avx512<8>(input, output, begin_row, end_row, cols); +#else forward<8>(input, output, begin_row, end_row, cols); +#endif }); } @@ -76,11 +80,12 @@ class FusedL2NormalizeOp : public OpKernel { // temp = tf.math.rsqrt(temp) // outputs = tf.math.multiply(temp, inputs) template - void ref_forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { + void forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { + int64 remainder = cols % SUM_BLOCK_SIZE; + // printf("Cols is %d, block size is %d, remainder is %d.\n", cols, SUM_BLOCK_SIZE, remainder); for (int64 i = begin_row; i < end_row; ++i) { T row_sum = 0; - // must be SUM_BLOCK_SIZE block !!! - for (int64 j = 0; j < cols; j += SUM_BLOCK_SIZE) { + for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { T data_0 = input[i * cols + j]; T data_1 = input[i * cols + j + 1]; T data_2 = input[i * cols + j + 2]; @@ -94,6 +99,10 @@ class FusedL2NormalizeOp : public OpKernel { + data_4 * data_4 + data_5 * data_5 + data_6 * data_6 + data_7 * data_7; } + for (int64 j = cols - remainder; j < cols; j++) { + T data_0 = input[i * cols + j]; + row_sum += data_0 * data_0; + } row_sum += epsilon; row_sum = 1.0 / std::sqrt(row_sum); for (int64 j = 0; j < cols; ++j) { @@ -102,8 +111,10 @@ class FusedL2NormalizeOp : public OpKernel { } } +#ifdef __AVX512F__ template - void forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { + void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { + // printf("Fused L2 norm by AVX512."); int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); @@ -207,6 +218,7 @@ class FusedL2NormalizeOp : public OpKernel { block_sum = _mm256_add_ps(block_sum, v[3]); return block_sum; } +#endif private: float epsilon; @@ -265,7 +277,11 @@ class FusedL2NormalizeGradOp : public OpKernel { if (end_row > rows) { end_row = rows; } +#ifdef __AVX512F__ + backward_avx512<8>(y_grad, x, x_grad, begin_row, end_row, cols); +#else backward<8>(y_grad, x, x_grad, begin_row, end_row, cols); +#endif }); } @@ -274,12 +290,13 @@ class FusedL2NormalizeGradOp : public OpKernel { // sum = tf.math.reduce_sum(y_grad * x, reduction_indices=1, keepdims=True) // grad_x = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) template - void ref_backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { + void backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { + int64 remainder = cols % SUM_BLOCK_SIZE; for (int64 i = begin_row; i < end_row; ++i) { int64 new_row = i - begin_row; T x_row_sum = 0.0; T y_grad_row_sum = 0.0; - for (int64 j = cols - 1; j > 0; j -= SUM_BLOCK_SIZE) { + for (int64 j = cols - 1; j > remainder; j -= SUM_BLOCK_SIZE) { T x_0 = x[i * cols + j]; T x_1 = x[i * cols + j - 1]; T x_2 = x[i * cols + j - 2]; @@ -306,6 +323,13 @@ class FusedL2NormalizeGradOp : public OpKernel { + x_4 * y_grad_4 + x_5 * y_grad_5 + x_6 * y_grad_6 + x_7 * y_grad_7; } + for (int64 j = remainder; j > 0; j--) { + T x_0 = x[i * cols + j]; + x_row_sum += x_0 * x_0; + + T y_grad_0 = y_grad[i * cols + j]; + y_grad_row_sum += x_0 * y_grad_0; + } x_row_sum += epsilon; x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); @@ -316,8 +340,10 @@ class FusedL2NormalizeGradOp : public OpKernel { } } +#ifdef __AVX512F__ template - void backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { + void backward_avx512(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { + // printf("Fused L2 norm grad by AVX512."); int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); @@ -456,6 +482,7 @@ class FusedL2NormalizeGradOp : public OpKernel { block_sum = _mm256_add_ps(block_sum, v[3]); return block_sum; } +#endif private: float epsilon; diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc index 4a74340e5b5..e121ddbbb53 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -31,14 +31,11 @@ class FusedL2NormalizeOpTest : public OpsTestBase { TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { const int rows = 4; - const int cols = 252; //128+64+32+16+8=252 1008 + const int cols = 252; //128+64+32+16+8+4=252 1008 MakeOpAndSetDevice(Device::CPU, DT_FLOAT, 0, 1e-12); - // emb_shards - // AddInputFromArray(TensorShape({rows, cols}), { - // 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - // 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + // x float input_array[1008]; for (int i = 0; i < sizeof(input_array) / sizeof(float); i++) { input_array[i] = 1.0; @@ -51,13 +48,9 @@ TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { { Tensor expected_output(allocator(), DT_FLOAT, TensorShape({rows, cols})); - // test::FillValues(&expected_output, { - // 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, - // 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}); float output_array[1008]; - float output_value = 1.0 / std::sqrt(cols); for (int i = 0; i < sizeof(output_array) / sizeof(float); i++) { - output_array[i] = output_value; + output_array[i] = 0.062994122505188; } test::FillValues(&expected_output, output_array); test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 8112dc6512b..c3928770790 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -318,7 +318,7 @@ def _l2Normalize(self, x, dim): @test_util.run_deprecated_v1 def testFusedL2Normalize(self): x_shape = [20, 7, 3] - np.random.seed(1) + np.random.seed(0) x_np = np.random.random_sample(x_shape).astype(np.float32) for dim in [2]: y_np = self._l2Normalize(x_np, dim) @@ -329,7 +329,7 @@ def testFusedL2Normalize(self): @test_util.run_deprecated_v1 def testFusedL2NormalizeGradient(self): x_shape = [20, 7, 3] - np.random.seed(1) + np.random.seed(0) x_np = np.random.random_sample(x_shape).astype(np.float32) for dim in [2]: with self.cached_session(): @@ -337,7 +337,7 @@ def testFusedL2NormalizeGradient(self): y_tf = nn_impl.fused_l2_normalize(x_tf) err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) - print("L2Normalize gradient err = %g " % err) + print("FusedL2Normalize gradient err = %g " % err) self.assertLess(err, 1e-4) From d9b0f8ae15782f51998b697172fb1ea69a41b62a Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 30 May 2022 14:11:12 +0800 Subject: [PATCH 09/17] [Op] Add l2 norm called log info. --- .../kernels/fused_l2_normalize/fused_l2_normalize_op.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index e96ce96d075..2b5bc5a5cd3 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -26,7 +26,6 @@ class FusedL2NormalizeOp : public OpKernel { } ~FusedL2NormalizeOp() { - printf("RUN ~FusedL2NormalizeOp().\n"); } void Compute(OpKernelContext* context) override { @@ -82,6 +81,7 @@ class FusedL2NormalizeOp : public OpKernel { template void forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { int64 remainder = cols % SUM_BLOCK_SIZE; + printf("Fused l2 norm called.\n"); // printf("Cols is %d, block size is %d, remainder is %d.\n", cols, SUM_BLOCK_SIZE, remainder); for (int64 i = begin_row; i < end_row; ++i) { T row_sum = 0; @@ -114,7 +114,7 @@ class FusedL2NormalizeOp : public OpKernel { #ifdef __AVX512F__ template void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { - // printf("Fused L2 norm by AVX512."); + printf("AVX512 fused l2 norm called.\n"); int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); @@ -292,8 +292,8 @@ class FusedL2NormalizeGradOp : public OpKernel { template void backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { int64 remainder = cols % SUM_BLOCK_SIZE; + printf("Fused l2 norm grad called.\n"); for (int64 i = begin_row; i < end_row; ++i) { - int64 new_row = i - begin_row; T x_row_sum = 0.0; T y_grad_row_sum = 0.0; for (int64 j = cols - 1; j > remainder; j -= SUM_BLOCK_SIZE) { @@ -343,7 +343,7 @@ class FusedL2NormalizeGradOp : public OpKernel { #ifdef __AVX512F__ template void backward_avx512(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { - // printf("Fused L2 norm grad by AVX512."); + printf("AVX512 fused l2 norm grad called.\n"); int64 avx3_block_num = cols >> 7; // cols / 128 // handle remainder of 128 int64 remainder = cols - (avx3_block_num << 7); From e4408dec8e517b4d72725cf406c03a7c24319f59 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Thu, 2 Jun 2022 16:43:15 +0800 Subject: [PATCH 10/17] [Op] Add performance benchmarks. --- .../fused_l2_normalize_grad_op_test.cc | 43 + .../fused_l2_normalize_op.cc | 690 ++-- .../fused_l2_normalize_op_test.cc | 38 + tensorflow/python/ops/nn_test.py | 3136 ++++++++--------- 4 files changed, 2006 insertions(+), 1901 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc index 32086fada3f..b6762b35589 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc @@ -69,5 +69,48 @@ TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) { //----------------------------------------------------------------------------// // Performance benchmarks // //----------------------------------------------------------------------------// +static Graph* FusedL2NormalizeGrad(int rows, int cols) { + Graph* g = new Graph(OpRegistry::Global()); + DataType dtype = DT_FLOAT; + + Tensor in1(dtype, TensorShape({rows, cols})); + in1.flat().setRandom(); + Tensor in2(dtype, TensorShape({rows, cols})); + in2.flat().setRandom(); + + Node* input_in1 = test::graph::Constant(g, in1); + Node* input_in2 = test::graph::Constant(g, in2); + auto nodeBuilder = NodeBuilder(g->NewName("n"), "FusedL2NormalizeGrad") + .Input(input_in1) + .Input(input_in2) + .Attr("T", dtype) + .Attr("axis", 0) + .Attr("epsilon", 1e-12); + TF_CHECK_OK(nodeBuilder.Finalize(g, nullptr)); + + return g; +} + +#define BM_FusedL2NormGrad(ROWS, COLS, NTH) \ + static void BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * ROWS * COLS * 5); \ + SessionOptions opts; \ + opts.config.set_intra_op_parallelism_threads(NTH); \ + test::Benchmark("cpu", FusedL2NormalizeGrad(ROWS, COLS), &opts).Run(iters); \ + } \ + BENCHMARK(BM_FusedL2NormGrad##_##ROWS##_##COLS##_##NTH##_CPU); \ + +#define BM_FusedL2NormGrad_NTH(ROWS, COLS) \ + BM_FusedL2NormGrad(ROWS, COLS, 1); \ + BM_FusedL2NormGrad(ROWS, COLS, 4); \ + BM_FusedL2NormGrad(ROWS, COLS, 8); \ + +BM_FusedL2NormGrad_NTH(1024, 63); +BM_FusedL2NormGrad_NTH(1024, 255); +BM_FusedL2NormGrad_NTH(1024, 511); +BM_FusedL2NormGrad_NTH(1024, 1023); + } } diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index 2b5bc5a5cd3..e9993db57a7 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -73,156 +73,162 @@ class FusedL2NormalizeOp : public OpKernel { } private: - // temp = tf.math.square(inputs) - // temp = tf.math.reduce_sum(temp, reduction_indices=axis, keepdims=True) - // temp = tf.math.maximum(temp, epsilon) - // temp = tf.math.rsqrt(temp) - // outputs = tf.math.multiply(temp, inputs) - template - void forward(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { - int64 remainder = cols % SUM_BLOCK_SIZE; - printf("Fused l2 norm called.\n"); - // printf("Cols is %d, block size is %d, remainder is %d.\n", cols, SUM_BLOCK_SIZE, remainder); - for (int64 i = begin_row; i < end_row; ++i) { - T row_sum = 0; - for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { - T data_0 = input[i * cols + j]; - T data_1 = input[i * cols + j + 1]; - T data_2 = input[i * cols + j + 2]; - T data_3 = input[i * cols + j + 3]; - T data_4 = input[i * cols + j + 4]; - T data_5 = input[i * cols + j + 5]; - T data_6 = input[i * cols + j + 6]; - T data_7 = input[i * cols + j + 7]; - row_sum += data_0 * data_0 + data_1 * data_1 - + data_2 * data_2 + data_3 * data_3 - + data_4 * data_4 + data_5 * data_5 - + data_6 * data_6 + data_7 * data_7; - } - for (int64 j = cols - remainder; j < cols; j++) { - T data_0 = input[i * cols + j]; - row_sum += data_0 * data_0; - } - row_sum += epsilon; - row_sum = 1.0 / std::sqrt(row_sum); - for (int64 j = 0; j < cols; ++j) { - output[i * cols + j] = input[i * cols + j] * row_sum; - } - } + // temp = tf.math.square(inputs) + // temp = tf.math.reduce_sum(temp, reduction_indices=axis, keepdims=True) + // temp = tf.math.maximum(temp, epsilon) + // temp = tf.math.rsqrt(temp) + // outputs = tf.math.multiply(temp, inputs) + template + void forward(const T* input, T* output, int64 begin_row, int64 end_row, + int64 cols) { + int64 remainder = cols % SUM_BLOCK_SIZE; + for (int64 i = begin_row; i < end_row; ++i) { + T row_sum = 0; + for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { + T data_0 = input[i * cols + j]; + T data_1 = input[i * cols + j + 1]; + T data_2 = input[i * cols + j + 2]; + T data_3 = input[i * cols + j + 3]; + T data_4 = input[i * cols + j + 4]; + T data_5 = input[i * cols + j + 5]; + T data_6 = input[i * cols + j + 6]; + T data_7 = input[i * cols + j + 7]; + row_sum += data_0 * data_0 + data_1 * data_1 + data_2 * data_2 + + data_3 * data_3 + data_4 * data_4 + data_5 * data_5 + + data_6 * data_6 + data_7 * data_7; + } + for (int64 j = cols - remainder; j < cols; j++) { + T data_0 = input[i * cols + j]; + row_sum += data_0 * data_0; + } + row_sum += epsilon; + row_sum = 1.0 / std::sqrt(row_sum); + for (int64 j = 0; j < cols; ++j) { + output[i * cols + j] = input[i * cols + j] * row_sum; + } } + } #ifdef __AVX512F__ - template - void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols) { - printf("AVX512 fused l2 norm called.\n"); - int64 avx3_block_num = cols >> 7; // cols / 128 - // handle remainder of 128 - int64 remainder = cols - (avx3_block_num << 7); - // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); - for (int64 i = begin_row; i < end_row; ++i) { - int64 tmp_remainder = remainder; - float row_sum = 0.0; - for (int64 j = 0; j < avx3_block_num; ++j) { - __m512 inputs[SUM_BLOCK_SIZE]; - auto load = [&](auto idx) { - inputs[idx] = _mm512_loadu_ps(input + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); - inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); - }; - functor::compile_time_for::op(load); - __m512 block_sum = reduce_sum_block8_ps(inputs); - row_sum += _mm512_reduce_add_ps(block_sum); - } - if (tmp_remainder > 0) { - if (tmp_remainder >= 64) { - __m256 inputs[8]; - auto load_256 = [&](auto idx) { - inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); - inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); - }; - functor::compile_time_for<8>::op(load_256); - __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(inputs); - row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - tmp_remainder -= 64; - } - if (tmp_remainder > 32) { - __m256 inputs[4]; - auto load_256 = [&](auto idx) { - inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - tmp_remainder + 8 * idx); - inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); - }; - functor::compile_time_for<4>::op(load_256); - __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(inputs); - row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - tmp_remainder -= 32; - } - if (tmp_remainder >= 16) { - __m512 inputs = _mm512_loadu_ps(input + cols * i + cols - tmp_remainder); - inputs = _mm512_mul_ps(inputs, inputs); - row_sum += _mm512_reduce_add_ps(inputs); - tmp_remainder -= 16; - } - if (tmp_remainder > 0) { - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - tmp_remainder); - inputs = _mm512_mul_ps(inputs, inputs); - row_sum += _mm512_reduce_add_ps(inputs); - } - } - - row_sum += epsilon; - row_sum = 1.0 / std::sqrt(row_sum); - __m512 row_sums = _mm512_set1_ps(row_sum); - for (int64 j = 0; j < cols - 15; j += 16) { - __m512 inputs = _mm512_loadu_ps(input + cols * i + j); - inputs = _mm512_mul_ps(inputs, row_sums); - _mm512_storeu_ps(output + cols * i + j, inputs); - } - if (tmp_remainder > 0){ - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 inputs = _mm512_maskz_loadu_ps(mask, input + cols * i + cols - tmp_remainder); - inputs = _mm512_mul_ps(inputs, row_sums); - _mm512_mask_storeu_ps(output + cols * i + cols - tmp_remainder, mask, inputs); - } + template + void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, + int64 cols) { + int64 avx3_block_num = cols >> 7; // cols / 128 + // handle remainder of 128 + int64 remainder = cols - (avx3_block_num << 7); + for (int64 i = begin_row; i < end_row; ++i) { + int64 tmp_remainder = remainder; + float row_sum = 0.0; + for (int64 j = 0; j < avx3_block_num; ++j) { + __m512 inputs[SUM_BLOCK_SIZE]; + auto load = [&](auto idx) { + inputs[idx] = _mm512_loadu_ps(input + cols * i + + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for::op(load); + __m512 block_sum = reduce_sum_block8_ps(inputs); + row_sum += _mm512_reduce_add_ps(block_sum); + } + if (tmp_remainder > 0) { + if (tmp_remainder >= 64) { + __m256 inputs[8]; + auto load_256 = [&](auto idx) { + inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - + tmp_remainder + 8 * idx); + inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for<8>::op(load_256); + __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(inputs); + row_sum += + _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + tmp_remainder -= 64; + } + if (tmp_remainder > 32) { + __m256 inputs[4]; + auto load_256 = [&](auto idx) { + inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - + tmp_remainder + 8 * idx); + inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); + }; + functor::compile_time_for<4>::op(load_256); + __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(inputs); + row_sum += + _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + tmp_remainder -= 32; } + if (tmp_remainder >= 16) { + __m512 inputs = + _mm512_loadu_ps(input + cols * i + cols - tmp_remainder); + inputs = _mm512_mul_ps(inputs, inputs); + row_sum += _mm512_reduce_add_ps(inputs); + tmp_remainder -= 16; + } + if (tmp_remainder > 0) { + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 inputs = _mm512_maskz_loadu_ps( + mask, input + cols * i + cols - tmp_remainder); + inputs = _mm512_mul_ps(inputs, inputs); + row_sum += _mm512_reduce_add_ps(inputs); + } + } + + row_sum += epsilon; + row_sum = 1.0 / std::sqrt(row_sum); + __m512 row_sums = _mm512_set1_ps(row_sum); + for (int64 j = 0; j < cols - 15; j += 16) { + __m512 inputs = _mm512_loadu_ps(input + cols * i + j); + inputs = _mm512_mul_ps(inputs, row_sums); + _mm512_storeu_ps(output + cols * i + j, inputs); + } + if (tmp_remainder > 0) { + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 inputs = _mm512_maskz_loadu_ps( + mask, input + cols * i + cols - tmp_remainder); + inputs = _mm512_mul_ps(inputs, row_sums); + _mm512_mask_storeu_ps(output + cols * i + cols - tmp_remainder, mask, + inputs); + } } + } - // data type: FP32, 16 FP32 per __m512 - // v0: v0_0, v0_1, ..., v0_15 - // v1: v1_0, v1_1, ..., v1_15 - // ... - // v7: v7_0, v7_1, ..., v7_15 - // sum: v_0, v_1, ..., v_15 - inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { - __m512 block_sum = _mm512_add_ps(v[0], v[1]); - block_sum = _mm512_add_ps(block_sum, v[2]); - block_sum = _mm512_add_ps(block_sum, v[3]); - block_sum = _mm512_add_ps(block_sum, v[4]); - block_sum = _mm512_add_ps(block_sum, v[5]); - block_sum = _mm512_add_ps(block_sum, v[6]); - block_sum = _mm512_add_ps(block_sum, v[7]); - return block_sum; - } - inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - block_sum = _mm256_add_ps(block_sum, v[4]); - block_sum = _mm256_add_ps(block_sum, v[5]); - block_sum = _mm256_add_ps(block_sum, v[6]); - block_sum = _mm256_add_ps(block_sum, v[7]); - return block_sum; - } - inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - return block_sum; - } + // data type: FP32, 16 FP32 per __m512 + // v0: v0_0, v0_1, ..., v0_15 + // v1: v1_0, v1_1, ..., v1_15 + // ... + // v7: v7_0, v7_1, ..., v7_15 + // sum: v_0, v_1, ..., v_15 + inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { + __m512 block_sum = _mm512_add_ps(v[0], v[1]); + block_sum = _mm512_add_ps(block_sum, v[2]); + block_sum = _mm512_add_ps(block_sum, v[3]); + block_sum = _mm512_add_ps(block_sum, v[4]); + block_sum = _mm512_add_ps(block_sum, v[5]); + block_sum = _mm512_add_ps(block_sum, v[6]); + block_sum = _mm512_add_ps(block_sum, v[7]); + return block_sum; + } + inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + block_sum = _mm256_add_ps(block_sum, v[4]); + block_sum = _mm256_add_ps(block_sum, v[5]); + block_sum = _mm256_add_ps(block_sum, v[6]); + block_sum = _mm256_add_ps(block_sum, v[7]); + return block_sum; + } + inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + return block_sum; + } #endif private: - float epsilon; - int32 axis; + float epsilon; + int32 axis; }; REGISTER_KERNEL_BUILDER(Name("FusedL2Normalize") \ @@ -286,207 +292,225 @@ class FusedL2NormalizeGradOp : public OpKernel { } private: - // rvar = tf.math.rsqrt(tf.math.reduce_sum(x * x, reduction_indices=1, keepdims=True) + 1e-12) # rsqrt quickly - // sum = tf.math.reduce_sum(y_grad * x, reduction_indices=1, keepdims=True) - // grad_x = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) - template - void backward(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { - int64 remainder = cols % SUM_BLOCK_SIZE; - printf("Fused l2 norm grad called.\n"); - for (int64 i = begin_row; i < end_row; ++i) { - T x_row_sum = 0.0; - T y_grad_row_sum = 0.0; - for (int64 j = cols - 1; j > remainder; j -= SUM_BLOCK_SIZE) { - T x_0 = x[i * cols + j]; - T x_1 = x[i * cols + j - 1]; - T x_2 = x[i * cols + j - 2]; - T x_3 = x[i * cols + j - 3]; - T x_4 = x[i * cols + j - 4]; - T x_5 = x[i * cols + j - 5]; - T x_6 = x[i * cols + j - 6]; - T x_7 = x[i * cols + j - 7]; - x_row_sum += x_0 * x_0 + x_1 * x_1 - + x_2 * x_2 + x_3 * x_3 - + x_4 * x_4 + x_5 * x_5 - + x_6 * x_6 + x_7 * x_7; - - T y_grad_0 = y_grad[i * cols + j]; - T y_grad_1 = y_grad[i * cols + j - 1]; - T y_grad_2 = y_grad[i * cols + j - 2]; - T y_grad_3 = y_grad[i * cols + j - 3]; - T y_grad_4 = y_grad[i * cols + j - 4]; - T y_grad_5 = y_grad[i * cols + j - 5]; - T y_grad_6 = y_grad[i * cols + j - 6]; - T y_grad_7 = y_grad[i * cols + j - 7]; - y_grad_row_sum += x_0 * y_grad_0 + x_1 * y_grad_1 - + x_2 * y_grad_2 + x_3 * y_grad_3 - + x_4 * y_grad_4 + x_5 * y_grad_5 - + x_6 * y_grad_6 + x_7 * y_grad_7; - } - for (int64 j = remainder; j > 0; j--) { - T x_0 = x[i * cols + j]; - x_row_sum += x_0 * x_0; - - T y_grad_0 = y_grad[i * cols + j]; - y_grad_row_sum += x_0 * y_grad_0; - } - x_row_sum += epsilon; - x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar - y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); - for (int64 j = 0; j < cols; ++j) { - x_grad[i * cols + j] = y_grad[i * cols + j] * x_row_sum - - x[i * cols + j] * y_grad_row_sum; - } - } + // rvar = tf.math.rsqrt(tf.math.reduce_sum(x * x, reduction_indices=1, + // keepdims=True) + 1e-12) # rsqrt quickly sum = tf.math.reduce_sum(y_grad * + // x, reduction_indices=1, keepdims=True) grad_x = y_grad * rvar - x * ((sum * + // rvar) * (rvar * rvar)) + template + void backward(const float* y_grad, const float* x, float* x_grad, + int64 begin_row, int64 end_row, int64 cols) { + int64 remainder = cols % SUM_BLOCK_SIZE; + for (int64 i = begin_row; i < end_row; ++i) { + T x_row_sum = 0.0; + T y_grad_row_sum = 0.0; + for (int64 j = cols - 1; j > remainder; j -= SUM_BLOCK_SIZE) { + T x_0 = x[i * cols + j]; + T x_1 = x[i * cols + j - 1]; + T x_2 = x[i * cols + j - 2]; + T x_3 = x[i * cols + j - 3]; + T x_4 = x[i * cols + j - 4]; + T x_5 = x[i * cols + j - 5]; + T x_6 = x[i * cols + j - 6]; + T x_7 = x[i * cols + j - 7]; + x_row_sum += x_0 * x_0 + x_1 * x_1 + x_2 * x_2 + x_3 * x_3 + x_4 * x_4 + + x_5 * x_5 + x_6 * x_6 + x_7 * x_7; + + T y_grad_0 = y_grad[i * cols + j]; + T y_grad_1 = y_grad[i * cols + j - 1]; + T y_grad_2 = y_grad[i * cols + j - 2]; + T y_grad_3 = y_grad[i * cols + j - 3]; + T y_grad_4 = y_grad[i * cols + j - 4]; + T y_grad_5 = y_grad[i * cols + j - 5]; + T y_grad_6 = y_grad[i * cols + j - 6]; + T y_grad_7 = y_grad[i * cols + j - 7]; + y_grad_row_sum += x_0 * y_grad_0 + x_1 * y_grad_1 + x_2 * y_grad_2 + + x_3 * y_grad_3 + x_4 * y_grad_4 + x_5 * y_grad_5 + + x_6 * y_grad_6 + x_7 * y_grad_7; + } + for (int64 j = remainder; j > 0; j--) { + T x_0 = x[i * cols + j]; + x_row_sum += x_0 * x_0; + + T y_grad_0 = y_grad[i * cols + j]; + y_grad_row_sum += x_0 * y_grad_0; + } + x_row_sum += epsilon; + x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar + y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + for (int64 j = 0; j < cols; ++j) { + x_grad[i * cols + j] = + y_grad[i * cols + j] * x_row_sum - x[i * cols + j] * y_grad_row_sum; + } } + } #ifdef __AVX512F__ - template - void backward_avx512(const float *y_grad, const float *x, float *x_grad, int64 begin_row, int64 end_row, int64 cols) { - printf("AVX512 fused l2 norm grad called.\n"); - int64 avx3_block_num = cols >> 7; // cols / 128 - // handle remainder of 128 - int64 remainder = cols - (avx3_block_num << 7); - // printf("cols: %d, avx3_block_num: %d, remainder %d\n", cols, avx3_block_num, remainder); - for (int64 i = begin_row; i < end_row; ++i) { - T x_row_sum = 0.0; - T y_grad_row_sum = 0.0; - int64 tmp_remainder = remainder; - for (int64 j = 0; j < avx3_block_num; ++j) { - __m512 xs[SUM_BLOCK_SIZE]; - auto x_load = [&](auto idx) { - xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); - xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); - }; - functor::compile_time_for::op(x_load); - __m512 x_block_sum = reduce_sum_block8_ps(xs); - x_row_sum += _mm512_reduce_add_ps(x_block_sum); - - __m512 y_grads[SUM_BLOCK_SIZE]; - auto y_grad_load = [&](auto idx) { - y_grads[idx] = _mm512_loadu_ps(y_grad + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); - xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + 16 * idx); - y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); - }; - functor::compile_time_for::op(y_grad_load); - __m512 y_grad_block_sum = reduce_sum_block8_ps(y_grads); - y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); - } - if (tmp_remainder > 0) { - if (tmp_remainder >= 64) { - __m256 xs[8]; - auto x_load_256 = [&](auto idx) { - xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); - }; - functor::compile_time_for<8>::op(x_load_256); - __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(xs); - x_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - - __m256 y_grads[8]; - auto y_grad_load_256 = [&](auto idx) { - y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - tmp_remainder + 8 * idx); - xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); - }; - functor::compile_time_for<8>::op(y_grad_load_256); - __m256 y_grad_block_sum_remainder = reduce_sum_block8_mm256_ps(y_grads); - y_grad_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(y_grad_block_sum_remainder)); - tmp_remainder -= 64; - } - if (tmp_remainder > 32) { - __m256 xs[4]; - auto x_load_256 = [&](auto idx) { - xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); - }; - functor::compile_time_for<4>::op(x_load_256); - __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(xs); - x_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - - __m256 y_grads[4]; - auto y_grad_load_256 = [&](auto idx) { - y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - tmp_remainder + 8 * idx); - xs[idx] = _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); - }; - functor::compile_time_for<4>::op(y_grad_load_256); - __m256 y_grad_block_sum_remainder = reduce_sum_block4_mm256_ps(y_grads); - y_grad_row_sum += _mm512_reduce_add_ps(_mm512_castps256_ps512(y_grad_block_sum_remainder)); - tmp_remainder -= 32; - } - if (tmp_remainder >= 16) { - __m512 xs = _mm512_loadu_ps(x + cols * i + cols - tmp_remainder); - __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + cols - tmp_remainder); - x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); - y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); - tmp_remainder -= 16; - } - if (tmp_remainder > 0) { - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); - __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - tmp_remainder); - x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); - y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); - } - } - - x_row_sum += epsilon; - x_row_sum = 1.0 / std::sqrt(x_row_sum); - y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); - __m512 x_row_sums = _mm512_set1_ps(x_row_sum); - __m512 y_grad_row_sums = _mm512_set1_ps(y_grad_row_sum); - for (int64 j = 0; j < cols - 15; j += 16) { - __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + j); - __m512 xs = _mm512_loadu_ps(x + cols * i + j); - y_grads = _mm512_mul_ps(y_grads, x_row_sums); - xs = _mm512_mul_ps(xs, y_grad_row_sums); - y_grads = _mm512_sub_ps(y_grads, xs); - _mm512_storeu_ps(x_grad + cols * i + j, y_grads); - } - if (tmp_remainder > 0){ - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - tmp_remainder); - __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); - y_grads = _mm512_mul_ps(y_grads, x_row_sums); - xs = _mm512_mul_ps(xs, y_grad_row_sums); - y_grads = _mm512_sub_ps(y_grads, xs); - _mm512_mask_storeu_ps(x_grad + cols * i + cols - tmp_remainder, mask, y_grads); - } + template + void backward_avx512(const float* y_grad, const float* x, float* x_grad, + int64 begin_row, int64 end_row, int64 cols) { + int64 avx3_block_num = cols >> 7; // cols / 128 + // handle remainder of 128 + int64 remainder = cols - (avx3_block_num << 7); + for (int64 i = begin_row; i < end_row; ++i) { + T x_row_sum = 0.0; + T y_grad_row_sum = 0.0; + int64 tmp_remainder = remainder; + for (int64 j = 0; j < avx3_block_num; ++j) { + __m512 xs[SUM_BLOCK_SIZE]; + auto x_load = [&](auto idx) { + xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + + 16 * idx); + xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for::op(x_load); + __m512 x_block_sum = reduce_sum_block8_ps(xs); + x_row_sum += _mm512_reduce_add_ps(x_block_sum); + + __m512 y_grads[SUM_BLOCK_SIZE]; + auto y_grad_load = [&](auto idx) { + y_grads[idx] = _mm512_loadu_ps(y_grad + cols * i + + 16 * SUM_BLOCK_SIZE * j + 16 * idx); + xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + + 16 * idx); + y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for::op(y_grad_load); + __m512 y_grad_block_sum = reduce_sum_block8_ps(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); + } + if (tmp_remainder > 0) { + if (tmp_remainder >= 64) { + __m256 xs[8]; + auto x_load_256 = [&](auto idx) { + xs[idx] = + _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for<8>::op(x_load_256); + __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(xs); + x_row_sum += + _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + + __m256 y_grads[8]; + auto y_grad_load_256 = [&](auto idx) { + y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - + tmp_remainder + 8 * idx); + xs[idx] = + _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for<8>::op(y_grad_load_256); + __m256 y_grad_block_sum_remainder = + reduce_sum_block8_mm256_ps(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps( + _mm512_castps256_ps512(y_grad_block_sum_remainder)); + tmp_remainder -= 64; + } + if (tmp_remainder > 32) { + __m256 xs[4]; + auto x_load_256 = [&](auto idx) { + xs[idx] = + _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); + }; + functor::compile_time_for<4>::op(x_load_256); + __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(xs); + x_row_sum += + _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); + + __m256 y_grads[4]; + auto y_grad_load_256 = [&](auto idx) { + y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - + tmp_remainder + 8 * idx); + xs[idx] = + _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); + y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); + }; + functor::compile_time_for<4>::op(y_grad_load_256); + __m256 y_grad_block_sum_remainder = + reduce_sum_block4_mm256_ps(y_grads); + y_grad_row_sum += _mm512_reduce_add_ps( + _mm512_castps256_ps512(y_grad_block_sum_remainder)); + tmp_remainder -= 32; } + if (tmp_remainder >= 16) { + __m512 xs = _mm512_loadu_ps(x + cols * i + cols - tmp_remainder); + __m512 y_grads = + _mm512_loadu_ps(y_grad + cols * i + cols - tmp_remainder); + x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); + y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); + tmp_remainder -= 16; + } + if (tmp_remainder > 0) { + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 xs = + _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); + __m512 y_grads = _mm512_maskz_loadu_ps( + mask, y_grad + cols * i + cols - tmp_remainder); + x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); + y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); + } + } + + x_row_sum += epsilon; + x_row_sum = 1.0 / std::sqrt(x_row_sum); + y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + __m512 x_row_sums = _mm512_set1_ps(x_row_sum); + __m512 y_grad_row_sums = _mm512_set1_ps(y_grad_row_sum); + for (int64 j = 0; j < cols - 15; j += 16) { + __m512 y_grads = _mm512_loadu_ps(y_grad + cols * i + j); + __m512 xs = _mm512_loadu_ps(x + cols * i + j); + y_grads = _mm512_mul_ps(y_grads, x_row_sums); + xs = _mm512_mul_ps(xs, y_grad_row_sums); + y_grads = _mm512_sub_ps(y_grads, xs); + _mm512_storeu_ps(x_grad + cols * i + j, y_grads); + } + if (tmp_remainder > 0) { + __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + __m512 y_grads = _mm512_maskz_loadu_ps( + mask, y_grad + cols * i + cols - tmp_remainder); + __m512 xs = + _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); + y_grads = _mm512_mul_ps(y_grads, x_row_sums); + xs = _mm512_mul_ps(xs, y_grad_row_sums); + y_grads = _mm512_sub_ps(y_grads, xs); + _mm512_mask_storeu_ps(x_grad + cols * i + cols - tmp_remainder, mask, + y_grads); + } } + } - inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { - __m512 block_sum = _mm512_add_ps(v[0], v[1]); - block_sum = _mm512_add_ps(block_sum, v[2]); - block_sum = _mm512_add_ps(block_sum, v[3]); - block_sum = _mm512_add_ps(block_sum, v[4]); - block_sum = _mm512_add_ps(block_sum, v[5]); - block_sum = _mm512_add_ps(block_sum, v[6]); - block_sum = _mm512_add_ps(block_sum, v[7]); - return block_sum; - } - inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - block_sum = _mm256_add_ps(block_sum, v[4]); - block_sum = _mm256_add_ps(block_sum, v[5]); - block_sum = _mm256_add_ps(block_sum, v[6]); - block_sum = _mm256_add_ps(block_sum, v[7]); - return block_sum; - } - inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - return block_sum; - } + inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { + __m512 block_sum = _mm512_add_ps(v[0], v[1]); + block_sum = _mm512_add_ps(block_sum, v[2]); + block_sum = _mm512_add_ps(block_sum, v[3]); + block_sum = _mm512_add_ps(block_sum, v[4]); + block_sum = _mm512_add_ps(block_sum, v[5]); + block_sum = _mm512_add_ps(block_sum, v[6]); + block_sum = _mm512_add_ps(block_sum, v[7]); + return block_sum; + } + inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + block_sum = _mm256_add_ps(block_sum, v[4]); + block_sum = _mm256_add_ps(block_sum, v[5]); + block_sum = _mm256_add_ps(block_sum, v[6]); + block_sum = _mm256_add_ps(block_sum, v[7]); + return block_sum; + } + inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { + __m256 block_sum = _mm256_add_ps(v[0], v[1]); + block_sum = _mm256_add_ps(block_sum, v[2]); + block_sum = _mm256_add_ps(block_sum, v[3]); + return block_sum; + } #endif private: - float epsilon; - int32 axis; + float epsilon; + int32 axis; }; REGISTER_KERNEL_BUILDER(Name("FusedL2NormalizeGrad") \ diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc index e121ddbbb53..2c4f5e988ec 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -60,5 +60,43 @@ TEST_F(FusedL2NormalizeOpTest, 2Dims_Float) { //----------------------------------------------------------------------------// // Performance benchmarks // //----------------------------------------------------------------------------// +static Graph* FusedL2Normalize(int rows, int cols) { + Graph* g = new Graph(OpRegistry::Global()); + DataType dtype = DT_FLOAT; + + Tensor in(dtype, TensorShape({rows, cols})); + in.flat().setRandom(); + + Node* input_in = test::graph::Constant(g, in); + auto nodeBuilder = NodeBuilder(g->NewName("n"), "FusedL2Normalize") + .Input(input_in) + .Attr("T", dtype) + .Attr("axis", 0) + .Attr("epsilon", 1e-12); + TF_CHECK_OK(nodeBuilder.Finalize(g, nullptr)); + + return g; +} + +#define BM_FusedL2Norm(ROWS, COLS, NTH) \ + static void BM_FusedL2Norm##_##ROWS##_##COLS##_##NTH##_CPU( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * ROWS * COLS * 3); \ + SessionOptions opts; \ + opts.config.set_intra_op_parallelism_threads(NTH); \ + test::Benchmark("cpu", FusedL2Normalize(ROWS, COLS), &opts).Run(iters); \ + } \ + BENCHMARK(BM_FusedL2Norm##_##ROWS##_##COLS##_##NTH##_CPU); \ + +#define BM_FusedL2Norm_NTH(ROWS, COLS) \ + BM_FusedL2Norm(ROWS, COLS, 1); \ + BM_FusedL2Norm(ROWS, COLS, 4); \ + BM_FusedL2Norm(ROWS, COLS, 8); \ + +BM_FusedL2Norm_NTH(1024, 63); +BM_FusedL2Norm_NTH(1024, 255); +BM_FusedL2Norm_NTH(1024, 511); +BM_FusedL2Norm_NTH(1024, 1023); } } diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index c3928770790..44604f883da 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -45,226 +45,226 @@ from tensorflow.python.platform import test as test_lib -# class ZeroFractionTest(test_lib.TestCase): - -# def _ZeroFraction(self, x): -# assert x.shape -# total_elements = np.prod(x.shape) -# nonzeros = np.count_nonzero(x.flatten()) -# return 1.0 - nonzeros / total_elements - -# @test_util.run_deprecated_v1 -# def testZeroFraction(self): -# x_shape = [5, 17] -# x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) -# y_np = self._ZeroFraction(x_np) - -# x_tf = constant_op.constant(x_np) -# x_tf.set_shape(x_shape) -# y_tf = nn_impl.zero_fraction(x_tf) -# y_tf_np = self.evaluate(y_tf) - -# eps = 1e-8 -# self.assertAllClose(y_tf_np, y_np, eps) - -# @test_util.run_deprecated_v1 -# def testZeroFractionEmpty(self): -# x = np.zeros(0) -# y = self.evaluate(nn_impl.zero_fraction(x)) -# self.assertTrue(np.isnan(y)) - -# @test_util.run_deprecated_v1 -# def testZeroFraction2_27Zeros(self): -# sparsity = nn_impl.zero_fraction( -# array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8)) -# self.assertAllClose(1.0, self.evaluate(sparsity)) - -# @test_util.run_deprecated_v1 -# def testZeroFraction2_27Ones(self): -# sparsity = nn_impl.zero_fraction( -# array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8)) -# self.assertAllClose(0.0, self.evaluate(sparsity)) - -# @test_util.run_deprecated_v1 -# def testUnknownSize(self): -# value = array_ops.placeholder(dtype=dtypes.float32) -# sparsity = nn_impl.zero_fraction(value) -# with self.cached_session() as sess: -# self.assertAllClose( -# 0.25, -# sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})) - - -# class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): - -# def _softmax(self, x): -# assert len(x.shape) == 2 -# m = x.max(1)[:, np.newaxis] -# u = np.exp(x - m) -# z = u.sum(1)[:, np.newaxis] -# return u / z - -# @test_util.run_in_graph_and_eager_modes -# def testSoftmax(self): -# x_shape = [5, 10] -# x_np = np.random.randn(*x_shape).astype(np.float32) -# y_np = self._softmax(x_np) -# x_tf = constant_op.constant(x_np) -# y_tf = nn_ops.softmax_v2(x_tf) -# y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1) -# y_tf_np = self.evaluate(y_tf) -# y_tf_last_dim_np = self.evaluate(y_tf_last_dim) -# eps = 1e-3 -# self.assertAllClose(y_tf_np, y_np, eps) -# self.assertAllClose(y_tf_last_dim_np, y_np, eps) - -# def testSoftmaxAxes(self): -# arr = np.linspace(0., 1, 12).reshape(3, 4) -# x_neg_axis = nn_ops.softmax_v2(arr, axis=-2) -# y_pos_axis = nn_ops.softmax_v2(arr, axis=0) -# z_gt_axis = nn_ops.softmax_v2(arr, axis=0) -# x_neg_axis_tf = self.evaluate(x_neg_axis) -# y_pos_axis_tf = self.evaluate(y_pos_axis) -# z_gt_axis_tf = self.evaluate(z_gt_axis) -# eps = 1e-3 -# self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) -# self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) - -# def testSoftmaxExtendType(self): -# x_shape = [5, 10] -# x_np = np.random.randn(*x_shape).astype(np.float32) - -# x_f32_tf = constant_op.constant(x_np) -# x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16) -# y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf)) -# y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf)) -# expected = math_ops.cast(y_f32_tf, dtypes.bfloat16) -# tol = x_shape[1] * 1e-3 -# self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol) - -# @parameterized.parameters(((5, 10),), ((2, 3, 4),)) -# @test_util.run_deprecated_v1 -# def testGradient(self, x_shape): -# x_np = np.random.randn(*x_shape).astype(np.float64) -# with self.cached_session(): -# x_tf = constant_op.constant(x_np) -# y_tf = nn_ops.softmax_v2(x_tf) -# err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, -# x_shape) -# eps = 2e-8 -# self.assertLess(err, eps) - - -# class LogPoissonLossTest(test_lib.TestCase): - -# def _log_poisson_loss(self, x, z, compute_full_loss=False): -# lpl = np.exp(x) - z * x -# if compute_full_loss: -# stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) -# lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) -# return lpl - -# @test_util.run_in_graph_and_eager_modes -# def testLogPoissonLoss(self): -# x_shape = [5, 10] -# x_np = np.random.randn(*x_shape).astype(np.float32) -# z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32) -# y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) -# y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) -# y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) -# y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) -# y_tf_np = self.evaluate(y_tf) -# y_tf_np_stirling = self.evaluate(y_tf_stirling) -# eps = 1e-3 -# self.assertAllClose(y_tf_np, y_np, eps) -# self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps) - -# @test_util.run_deprecated_v1 -# def testGradient(self): -# x_shape = [5, 10] -# x_np = np.random.randn(*x_shape).astype(np.float64) -# z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64) -# with self.cached_session(): -# x_tf = constant_op.constant(x_np) -# y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False) -# y_tf_stirling = nn_impl.log_poisson_loss( -# z_np, x_tf, compute_full_loss=True) -# err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, -# x_shape) -# err_stirling = gradient_checker.compute_gradient_error( -# x_tf, x_shape, y_tf_stirling, x_shape) -# eps = 1e-6 -# self.assertLess(err, eps) -# self.assertLess(err_stirling, eps) - - -# class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): - -# def _log_softmax(self, x): -# assert len(x.shape) == 2 -# m = x.max(1)[:, np.newaxis] -# u = x - m -# return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) - -# @test_util.run_in_graph_and_eager_modes -# def testLogSoftmax(self): -# x_shape = [5, 10] -# x_np = np.random.randn(*x_shape).astype(np.float32) -# y_np = self._log_softmax(x_np) -# x_tf = constant_op.constant(x_np) -# y_tf = nn_ops.log_softmax_v2(x_tf) -# y_tf_np = self.evaluate(y_tf) -# eps = 1e-3 -# self.assertAllClose(y_tf_np, y_np, eps) - -# def testLogSoftmaxAxes(self): -# arr = np.linspace(0., 1, 12).reshape(3, 4) -# x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2) -# y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0) -# z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0) -# x_neg_axis_tf = self.evaluate(x_neg_axis) -# y_pos_axis_tf = self.evaluate(y_pos_axis) -# z_gt_axis_tf = self.evaluate(z_gt_axis) -# eps = 1e-3 -# self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) -# self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) - -# @parameterized.parameters(((5, 10),), ((2, 3, 4),)) -# @test_util.run_deprecated_v1 -# def testGradient(self, x_shape): -# x_np = np.random.randn(*x_shape).astype(np.float64) -# with self.cached_session(): -# x_tf = constant_op.constant(x_np) -# y_tf = nn_ops.log_softmax_v2(x_tf) -# err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, -# x_shape) -# eps = 1e-7 -# self.assertLess(err, eps) - - -# class L2LossTest(test_lib.TestCase): - -# @test_util.run_in_graph_and_eager_modes -# def testL2Loss(self): -# for dtype in [dtypes.float32, dtypes.float64]: -# x = constant_op.constant( -# [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype) -# l2loss = nn_ops.l2_loss(x) -# value = self.evaluate(l2loss) -# self.assertAllClose(7.0, value) - -# @test_util.run_deprecated_v1 -# def testGradient(self): -# x_shape = [20, 7, 3] -# np.random.seed(1) # Make it reproducible. -# x_val = np.random.random_sample(x_shape).astype(np.float64) -# with self.cached_session(): -# x = constant_op.constant(x_val, name="x") -# output = nn_ops.l2_loss(x) -# err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) -# print("L2Loss gradient err = %g " % err) -# err_tolerance = 1e-10 -# self.assertLess(err, err_tolerance) +class ZeroFractionTest(test_lib.TestCase): + + def _ZeroFraction(self, x): + assert x.shape + total_elements = np.prod(x.shape) + nonzeros = np.count_nonzero(x.flatten()) + return 1.0 - nonzeros / total_elements + + @test_util.run_deprecated_v1 + def testZeroFraction(self): + x_shape = [5, 17] + x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) + y_np = self._ZeroFraction(x_np) + + x_tf = constant_op.constant(x_np) + x_tf.set_shape(x_shape) + y_tf = nn_impl.zero_fraction(x_tf) + y_tf_np = self.evaluate(y_tf) + + eps = 1e-8 + self.assertAllClose(y_tf_np, y_np, eps) + + @test_util.run_deprecated_v1 + def testZeroFractionEmpty(self): + x = np.zeros(0) + y = self.evaluate(nn_impl.zero_fraction(x)) + self.assertTrue(np.isnan(y)) + + @test_util.run_deprecated_v1 + def testZeroFraction2_27Zeros(self): + sparsity = nn_impl.zero_fraction( + array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8)) + self.assertAllClose(1.0, self.evaluate(sparsity)) + + @test_util.run_deprecated_v1 + def testZeroFraction2_27Ones(self): + sparsity = nn_impl.zero_fraction( + array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8)) + self.assertAllClose(0.0, self.evaluate(sparsity)) + + @test_util.run_deprecated_v1 + def testUnknownSize(self): + value = array_ops.placeholder(dtype=dtypes.float32) + sparsity = nn_impl.zero_fraction(value) + with self.cached_session() as sess: + self.assertAllClose( + 0.25, + sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})) + + +class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): + + def _softmax(self, x): + assert len(x.shape) == 2 + m = x.max(1)[:, np.newaxis] + u = np.exp(x - m) + z = u.sum(1)[:, np.newaxis] + return u / z + + @test_util.run_in_graph_and_eager_modes + def testSoftmax(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + y_np = self._softmax(x_np) + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.softmax_v2(x_tf) + y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1) + y_tf_np = self.evaluate(y_tf) + y_tf_last_dim_np = self.evaluate(y_tf_last_dim) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + self.assertAllClose(y_tf_last_dim_np, y_np, eps) + + def testSoftmaxAxes(self): + arr = np.linspace(0., 1, 12).reshape(3, 4) + x_neg_axis = nn_ops.softmax_v2(arr, axis=-2) + y_pos_axis = nn_ops.softmax_v2(arr, axis=0) + z_gt_axis = nn_ops.softmax_v2(arr, axis=0) + x_neg_axis_tf = self.evaluate(x_neg_axis) + y_pos_axis_tf = self.evaluate(y_pos_axis) + z_gt_axis_tf = self.evaluate(z_gt_axis) + eps = 1e-3 + self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) + self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + + def testSoftmaxExtendType(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + + x_f32_tf = constant_op.constant(x_np) + x_bf16_tf = math_ops.cast(x_f32_tf, dtypes.bfloat16) + y_f32_tf = self.evaluate(nn_ops.softmax(x_f32_tf)) + y_bf16_tf = self.evaluate(nn_ops.softmax(x_bf16_tf)) + expected = math_ops.cast(y_f32_tf, dtypes.bfloat16) + tol = x_shape[1] * 1e-3 + self.assertAllClose(y_bf16_tf, expected, rtol=tol, atol=tol) + + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + @test_util.run_deprecated_v1 + def testGradient(self, x_shape): + x_np = np.random.randn(*x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.softmax_v2(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + eps = 2e-8 + self.assertLess(err, eps) + + +class LogPoissonLossTest(test_lib.TestCase): + + def _log_poisson_loss(self, x, z, compute_full_loss=False): + lpl = np.exp(x) - z * x + if compute_full_loss: + stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) + lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) + return lpl + + @test_util.run_in_graph_and_eager_modes + def testLogPoissonLoss(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32) + y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) + y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) + y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) + y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) + y_tf_np = self.evaluate(y_tf) + y_tf_np_stirling = self.evaluate(y_tf_stirling) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps) + + @test_util.run_deprecated_v1 + def testGradient(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float64) + z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False) + y_tf_stirling = nn_impl.log_poisson_loss( + z_np, x_tf, compute_full_loss=True) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + err_stirling = gradient_checker.compute_gradient_error( + x_tf, x_shape, y_tf_stirling, x_shape) + eps = 1e-6 + self.assertLess(err, eps) + self.assertLess(err_stirling, eps) + + +class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): + + def _log_softmax(self, x): + assert len(x.shape) == 2 + m = x.max(1)[:, np.newaxis] + u = x - m + return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) + + @test_util.run_in_graph_and_eager_modes + def testLogSoftmax(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + y_np = self._log_softmax(x_np) + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.log_softmax_v2(x_tf) + y_tf_np = self.evaluate(y_tf) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + + def testLogSoftmaxAxes(self): + arr = np.linspace(0., 1, 12).reshape(3, 4) + x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2) + y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0) + z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0) + x_neg_axis_tf = self.evaluate(x_neg_axis) + y_pos_axis_tf = self.evaluate(y_pos_axis) + z_gt_axis_tf = self.evaluate(z_gt_axis) + eps = 1e-3 + self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) + self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + @test_util.run_deprecated_v1 + def testGradient(self, x_shape): + x_np = np.random.randn(*x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.log_softmax_v2(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + eps = 1e-7 + self.assertLess(err, eps) + + +class L2LossTest(test_lib.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testL2Loss(self): + for dtype in [dtypes.float32, dtypes.float64]: + x = constant_op.constant( + [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype) + l2loss = nn_ops.l2_loss(x) + value = self.evaluate(l2loss) + self.assertAllClose(7.0, value) + + @test_util.run_deprecated_v1 + def testGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) # Make it reproducible. + x_val = np.random.random_sample(x_shape).astype(np.float64) + with self.cached_session(): + x = constant_op.constant(x_val, name="x") + output = nn_ops.l2_loss(x) + err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) + print("L2Loss gradient err = %g " % err) + err_tolerance = 1e-10 + self.assertLess(err, err_tolerance) class L2NormalizeTest(test_lib.TestCase): @@ -279,41 +279,41 @@ def _l2Normalize(self, x, dim): norm = np.apply_along_axis(np.linalg.norm, dim, x) return x / np.expand_dims(norm, dim) - # @test_util.run_in_graph_and_eager_modes - # def testL2Normalize(self): - # x_shape = [20, 7, 3] - # np.random.seed(1) - # x_np = np.random.random_sample(x_shape).astype(np.float32) - # for dim in range(len(x_shape)): - # y_np = self._l2Normalize(x_np, dim) - # x_tf = constant_op.constant(x_np, name="x") - # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) - # self.assertAllClose(y_np, self.evaluate(y_tf)) - - # @test_util.run_in_graph_and_eager_modes - # def testL2NormalizeDimArray(self): - # x_shape = [20, 7, 3] - # np.random.seed(1) - # x_np = np.random.random_sample(x_shape).astype(np.float32) - # dim = [1, 2] - # y_np = self._l2Normalize(x_np, dim) - # x_tf = constant_op.constant(x_np, name="x") - # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) - # self.assertAllClose(y_np, self.evaluate(y_tf)) - - # @test_util.run_deprecated_v1 - # def testL2NormalizeGradient(self): - # x_shape = [20, 7, 3] - # np.random.seed(1) - # x_np = np.random.random_sample(x_shape).astype(np.float64) - # for dim in range(len(x_shape)): - # with self.cached_session(): - # x_tf = constant_op.constant(x_np, name="x") - # y_tf = nn_impl.l2_normalize_v2(x_tf, dim) - # err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, - # x_shape) - # print("L2Normalize gradient err = %g " % err) - # self.assertLess(err, 1e-4) + @test_util.run_in_graph_and_eager_modes + def testL2Normalize(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in range(len(x_shape)): + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_in_graph_and_eager_modes + def testL2NormalizeDimArray(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + dim = [1, 2] + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_deprecated_v1 + def testL2NormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float64) + for dim in range(len(x_shape)): + with self.cached_session(): + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + print("L2Normalize gradient err = %g " % err) + self.assertLess(err, 1e-4) @test_util.run_deprecated_v1 def testFusedL2Normalize(self): @@ -341,1365 +341,1365 @@ def testFusedL2NormalizeGradient(self): self.assertLess(err, 1e-4) -# class DropoutTest(test_lib.TestCase): - -# def testDropout(self): -# # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate -# # that it is producing approximately the right number of ones over a large -# # number of samples, based on the keep probability. -# x_dim = 40 -# y_dim = 30 -# num_iter = 10 -# for keep_prob in [0.1, 0.5, 0.8]: -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# dropout = nn_ops.dropout(t, rate=(1 - keep_prob)) -# final_count = 0 -# self.assertEqual([x_dim, y_dim], dropout.get_shape()) -# for _ in xrange(0, num_iter): -# value = self.evaluate(dropout) -# final_count += np.count_nonzero(value) -# # Verifies that there are only two values: 0 and 1/keep_prob. -# sorted_value = np.unique(np.sort(value)) -# self.assertEqual(0, sorted_value[0]) -# self.assertAllClose(1 / keep_prob, sorted_value[1]) - -# # Check that we are in the 15% error range -# expected_count = x_dim * y_dim * keep_prob * num_iter -# rel_error = math.fabs(final_count - expected_count) / expected_count -# print(rel_error) -# self.assertTrue(rel_error < 0.15) - -# def testShapedDropout(self): -# # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate -# # that it is producing approximately the right number of ones over a large -# # number of samples, based on the keep probability. This time with shaped -# # noise. -# x_dim = 40 * 30 -# y_dim = 3 -# num_iter = 10 -# for keep_prob in [0.1, 0.5, 0.8]: -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) -# self.assertEqual([x_dim, y_dim], dropout.get_shape()) -# final_count = 0 -# for _ in xrange(0, num_iter): -# value = self.evaluate(dropout) -# final_count += np.count_nonzero(value) -# # Verifies that there are only two values: 0 and 1/keep_prob. -# sorted_value = np.unique(np.sort(value)) -# self.assertEqual(0, sorted_value[0]) -# self.assertAllClose(1 / keep_prob, sorted_value[1]) - -# # Check that we are in the 15% error range -# expected_count = x_dim * y_dim * keep_prob * num_iter -# rel_error = math.fabs(final_count - expected_count) / expected_count -# print(rel_error) -# self.assertTrue(rel_error < 0.15) - -# def testShapedDropoutCorrelation(self): -# # Runs a shaped dropout and tests that the correlations are correct. -# x_dim = 40 -# y_dim = 30 -# num_iter = 10 -# for keep_prob in [0.1, 0.5, 0.8]: -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) -# self.assertEqual([x_dim, y_dim], dropout.get_shape()) -# for _ in xrange(0, num_iter): -# value = self.evaluate(dropout) -# # Verifies that each y column as only one type of activation. -# for i in xrange(x_dim): -# sorted_value = np.unique(np.sort(value[i, :])) -# self.assertEqual(sorted_value.size, 1) - -# @test_util.run_deprecated_v1 -# def testDropoutPlaceholderKeepProb(self): -# # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate -# # that it is producing approximately the right number of ones over a large -# # number of samples, based on the keep probability. -# x_dim = 40 -# y_dim = 30 -# num_iter = 10 -# for keep_prob in [0.1, 0.5, 0.8]: -# with self.cached_session(): -# t = constant_op.constant( -# 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# keep_prob_placeholder = array_ops.placeholder(dtypes.float32) -# dropout = nn_ops.dropout(t, keep_prob_placeholder) -# final_count = 0 -# self.assertEqual([x_dim, y_dim], dropout.get_shape()) -# for _ in xrange(0, num_iter): -# value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) -# final_count += np.count_nonzero(value) -# # Verifies that there are only two values: 0 and 1/keep_prob. -# sorted_value = np.unique(np.sort(value)) -# self.assertEqual(0, sorted_value[0]) -# self.assertAllClose(1 / keep_prob, sorted_value[1]) -# # Check that we are in the 15% error range -# expected_count = x_dim * y_dim * keep_prob * num_iter -# rel_error = math.fabs(final_count - expected_count) / expected_count -# print(rel_error) -# self.assertTrue(rel_error < 0.15) - -# @test_util.run_deprecated_v1 -# def testShapedDropoutUnknownShape(self): -# x_dim = 40 -# y_dim = 30 -# keep_prob = 0.5 -# x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# dropout_x = nn_ops.dropout( -# x, -# rate=(1 - keep_prob), -# noise_shape=array_ops.placeholder(dtypes.int32)) -# self.assertEqual(x.get_shape(), dropout_x.get_shape()) - -# def testPartialShapedDropout(self): -# x_dim = 40 * 30 -# y_dim = 3 -# num_iter = 10 -# for keep_prob in [0.1, 0.5, 0.8]: -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# # Set noise_shape=[None, 1] which means [x_dim, 1]. -# dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[None, 1]) -# self.assertEqual([x_dim, y_dim], dropout.get_shape()) -# final_count = 0 -# for _ in xrange(0, num_iter): -# value = self.evaluate(dropout) -# final_count += np.count_nonzero(value) -# # Verifies that there are only two values: 0 and 1/keep_prob. -# sorted_value = np.unique(np.sort(value)) -# self.assertEqual(0, sorted_value[0]) -# self.assertAllClose(1 / keep_prob, sorted_value[1]) - -# # Check that we are in the 15% error range -# expected_count = x_dim * y_dim * keep_prob * num_iter -# rel_error = math.fabs(final_count - expected_count) / expected_count -# print(rel_error) -# self.assertTrue(rel_error < 0.15) - -# @test_util.run_deprecated_v1 -# def testInvalidKeepProb(self): -# x_dim = 40 -# y_dim = 30 -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# with self.assertRaises(ValueError): -# nn_ops.dropout(t, -1.0) -# with self.assertRaises(ValueError): -# nn_ops.dropout(t, 1.1) -# with self.assertRaises(ValueError): -# nn_ops.dropout(t, [0.0, 1.0]) -# with self.assertRaises(ValueError): -# nn_ops.dropout(t, array_ops.placeholder(dtypes.float64)) -# with self.assertRaises(ValueError): -# nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) - -# @test_util.run_deprecated_v1 -# def testInvalidRate(self): -# x_dim = 40 -# y_dim = 30 -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# with self.assertRaises(ValueError): -# nn_ops.dropout_v2(t, -1.0) -# with self.assertRaises(ValueError): -# nn_ops.dropout_v2(t, 1.1) -# with self.assertRaises(ValueError): -# nn_ops.dropout_v2(t, [0.0, 1.0]) - -# def testLargeRate(self): -# x_dim = 40 -# y_dim = 30 -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# _ = nn_ops.dropout_v2(t, 0.9) - -# @test_util.run_deprecated_v1 -# def testShapedDropoutShapeError(self): -# # Runs shaped dropout and verifies an error is thrown on misshapen noise. -# x_dim = 40 -# y_dim = 30 -# keep_prob = 0.5 -# t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) -# with self.assertRaises(ValueError): -# _ = nn_ops.dropout( -# t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim + 10]) -# with self.assertRaises(ValueError): -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim, 5]) -# with self.assertRaises(ValueError): -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim + 3]) -# with self.assertRaises(ValueError): -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim]) -# # test that broadcasting proceeds -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[y_dim]) -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, y_dim]) -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) -# _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, 1]) - -# def testNoDropoutFast(self): -# x = array_ops.zeros((5,)) -# y = nn_ops.dropout(x, rate=0) -# self.assertTrue(x is y) - -# y = nn_ops.dropout_v2(x, rate=0) -# self.assertTrue(x is y) - -# def testDropoutWithIntegerInputs(self): -# x = constant_op.constant([1, 1, 1, 1, 1]) -# with self.assertRaises(ValueError): -# _ = nn_ops.dropout(x, 0.5) - - -# class ComputeSampledLogitsTest(test_lib.TestCase): - -# def setUp(self): -# self._eps = 1e-3 - -# def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels, -# sampled, subtract_log_q): -# """Randomly generates input/output data for a single test case. - -# This function returns numpy constants for use in a test case. - -# Args: -# num_classes: An int. The number of embedding classes in the test case. -# dim: An int. The dimension of the embedding. -# batch_size: An int. The batch size. -# num_true: An int. The number of target classes per training example. -# labels: A list of batch_size * num_true ints. The target classes. -# sampled: A list of indices in [0, num_classes). -# subtract_log_q: A bool corresponding to the parameter in -# _compute_sampled_logits(). - -# Returns: -# weights: Embedding weights to use as test input. It is a numpy array -# of shape [num_classes, dim] -# biases: Embedding biases to use as test input. It is a numpy array -# of shape [num_classes]. -# hidden_acts: Forward activations of the network to use as test input. -# It is a numpy array of shape [batch_size, dim]. -# sampled_vals: A tuple based on `sampled` to use as test input in the -# format returned by a *_candidate_sampler function. -# exp_logits: The output logits expected from _compute_sampled_logits(). -# It is a numpy array of shape [batch_size, num_true + len(sampled)]. -# exp_labels: The output labels expected from _compute_sampled_logits(). -# It is a numpy array of shape [batch_size, num_true + len(sampled)]. -# """ -# weights = np.random.randn(num_classes, dim).astype(np.float32) -# biases = np.random.randn(num_classes).astype(np.float32) -# hidden_acts = np.random.randn(batch_size, dim).astype(np.float32) - -# true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) -# sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) -# sampled_vals = (sampled, true_exp, sampled_exp) - -# sampled_w, sampled_b = weights[sampled], biases[sampled] -# true_w, true_b = weights[labels], biases[labels] - -# true_logits = np.sum( -# hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( -# (batch_size, num_true, dim)), -# axis=2) -# true_b = true_b.reshape((batch_size, num_true)) -# true_logits += true_b -# sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b - -# if subtract_log_q: -# true_logits -= np.log(true_exp) -# sampled_logits -= np.log(sampled_exp[np.newaxis, :]) - -# exp_logits = np.concatenate([true_logits, sampled_logits], axis=1) -# exp_labels = np.hstack((np.ones_like(true_logits) / num_true, -# np.zeros_like(sampled_logits))) - -# return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels - -# def _ShardTestEmbeddings(self, weights, biases, num_shards): -# """Shards the weights and biases returned by _GenerateTestData. - -# Args: -# weights: The weights returned by _GenerateTestData. -# biases: The biases returned by _GenerateTestData. -# num_shards: The number of shards to create. - -# Returns: -# sharded_weights: A list of size `num_shards` containing all the weights. -# sharded_biases: A list of size `num_shards` containing all the biases. -# """ -# with ops.Graph().as_default() as g: -# sharded_weights = variable_scope.get_variable( -# "w", -# partitioner=partitioned_variables.fixed_size_partitioner(num_shards), -# initializer=constant_op.constant(weights)) -# sharded_biases = variable_scope.get_variable( -# "b", -# partitioner=partitioned_variables.fixed_size_partitioner(num_shards), -# initializer=constant_op.constant(biases)) -# with self.session(graph=g) as sess: -# variables.global_variables_initializer().run() -# return self.evaluate([list(sharded_weights), list(sharded_biases)]) - -# def testShapes(self): -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 - -# for num_true in range(1, 5): -# labels = np.random.randint( -# low=0, high=num_classes, size=batch_size * num_true) -# (weights, biases, hidden_acts, sampled_vals, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=num_true, -# labels=labels, -# sampled=[1, 0, 2, 3], -# subtract_log_q=False) -# logits_tensor, labels_tensor = _compute_sampled_logits( -# weights=constant_op.constant(weights), -# biases=constant_op.constant(biases), -# labels=constant_op.constant( -# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=num_true, -# sampled_values=sampled_vals, -# subtract_log_q=False, -# remove_accidental_hits=False, -# partition_strategy="div", -# name="sampled_logits_basic_num_true_%d" % num_true) -# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) -# self.assertEqual(exp_logits.shape, got_logits.shape, self._eps) -# self.assertEqual(exp_labels.shape, got_labels.shape, self._eps) - -# def testBasic(self): -# """Without accidental hit removal or subtract_log_q.""" -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 - -# for num_true in range(1, 5): -# labels = np.random.randint( -# low=0, high=num_classes, size=batch_size * num_true) -# (weights, biases, hidden_acts, sampled_vals, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=num_true, -# labels=labels, -# sampled=[1, 0, 2, 3], -# subtract_log_q=False) -# logits_tensor, labels_tensor = _compute_sampled_logits( -# weights=constant_op.constant(weights), -# biases=constant_op.constant(biases), -# labels=constant_op.constant( -# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=num_true, -# sampled_values=sampled_vals, -# subtract_log_q=False, -# remove_accidental_hits=False, -# partition_strategy="div", -# name="sampled_logits_basic_num_true_%d" % num_true) -# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) -# self.assertAllClose(exp_logits, got_logits, self._eps) -# self.assertAllClose(exp_labels, got_labels, self._eps) - -# def testAccidentalHitRemoval(self): -# """With accidental hit removal, no subtract_log_q.""" -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 -# sampled = [1, 0, 2, 3] - -# for num_true in range(1, 5): -# labels = np.random.randint( -# low=0, high=num_classes, size=batch_size * num_true) -# (weights, biases, hidden_acts, sampled_vals, _, -# _) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=num_true, -# labels=labels, -# sampled=sampled, -# subtract_log_q=False) -# logits_tensor, _ = _compute_sampled_logits( -# weights=constant_op.constant(weights), -# biases=constant_op.constant(biases), -# labels=constant_op.constant( -# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=len(sampled), -# num_classes=num_classes, -# num_true=num_true, -# sampled_values=sampled_vals, -# subtract_log_q=False, -# remove_accidental_hits=True, -# partition_strategy="div", -# name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true) -# # Test that the exponentiated logits of accidental hits are near 0. -# # First we need to find the hits in this random test run: -# labels_reshape = labels.reshape((batch_size, num_true)) -# got_logits = self.evaluate(logits_tensor) -# for row in xrange(batch_size): -# row_labels = labels_reshape[row, :] -# for col in xrange(len(sampled)): -# if sampled[col] in row_labels: -# # We need to add the num_true_test offset into logits_* -# self.assertNear( -# np.exp(got_logits[row, col + num_true]), 0., self._eps) - -# def testSubtractLogQ(self): -# """With subtract_log_q, no accidental hit removal.""" -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 - -# for num_true in range(1, 5): -# labels = np.random.randint( -# low=0, high=num_classes, size=batch_size * num_true) -# (weights, biases, hidden_acts, sampled_vals, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=num_true, -# labels=labels, -# sampled=[1, 0, 2, 3], -# subtract_log_q=True) -# logits_tensor, labels_tensor = _compute_sampled_logits( -# weights=constant_op.constant(weights), -# biases=constant_op.constant(biases), -# labels=constant_op.constant( -# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=num_true, -# sampled_values=sampled_vals, -# subtract_log_q=True, -# remove_accidental_hits=False, -# partition_strategy="div", -# name="sampled_logits_subtract_log_q_num_true_%d" % num_true) -# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) -# self.assertAllClose(exp_logits, got_logits, self._eps) -# self.assertAllClose(exp_labels, got_labels, self._eps) - -# def testSharded(self): -# """With sharded weights and sharded biases.""" -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 - -# for num_true in range(1, 5): -# labels = np.random.randint( -# low=0, high=num_classes, size=batch_size * num_true) -# (weights, biases, hidden_acts, sampled_vals, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=num_true, -# labels=labels, -# sampled=[1, 0, 2, 3], -# subtract_log_q=False) -# weight_shards, bias_shards = self._ShardTestEmbeddings( -# weights, biases, num_shards=3) -# logits_tensor, labels_tensor = _compute_sampled_logits( -# weights=[constant_op.constant(shard) for shard in weight_shards], -# biases=[constant_op.constant(shard) for shard in bias_shards], -# labels=constant_op.constant( -# labels, dtype=dtypes.int64, shape=(batch_size, num_true)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=num_true, -# sampled_values=sampled_vals, -# subtract_log_q=False, -# remove_accidental_hits=False, -# partition_strategy="div", -# name="sampled_logits_sharded_num_true_%d" % num_true) -# got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) -# self.assertAllClose(exp_logits, got_logits, self._eps) -# self.assertAllClose(exp_labels, got_labels, self._eps) - -# def testNCELoss(self): -# # A simple test to verify the numerics. - -# def _SigmoidCrossEntropyWithLogits(logits, targets): -# # logits, targets: float arrays of the same shape. -# assert logits.shape == targets.shape -# pred = 1. / (1. + np.exp(-logits)) -# eps = 0.0001 -# pred = np.minimum(np.maximum(pred, eps), 1 - eps) -# return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) - -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 -# labels = [0, 1, 2] -# (weights, biases, hidden_acts, sampled_vals, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=1, -# labels=labels, -# sampled=[1, 0, 2, 3], -# subtract_log_q=True) -# exp_nce_loss = np.sum( -# _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1) - -# got_nce_loss = nn_impl.nce_loss_v2( -# weights=constant_op.constant(weights), -# biases=constant_op.constant(biases), -# labels=constant_op.constant(labels, shape=(batch_size, 1)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=1, -# sampled_values=sampled_vals) - -# self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) - -# # Test with sharded weights and sharded biases. -# weight_shards, bias_shards = self._ShardTestEmbeddings( -# weights, biases, num_shards=3) -# got_nce_loss = nn_impl.nce_loss_v2( -# weights=[constant_op.constant(shard) for shard in weight_shards], -# biases=[constant_op.constant(shard) for shard in bias_shards], -# labels=constant_op.constant(labels, shape=(batch_size, 1)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=1, -# sampled_values=sampled_vals) - -# self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) - -# def testSampledSoftmaxLoss(self): -# # A simple test to verify the numerics. - -# def _SoftmaxCrossEntropyWithLogits(logits, targets): -# # logits, targets: float arrays of the same shape. -# assert logits.shape == targets.shape -# stable_exp_logits = np.exp( -# logits - np.amax(logits, axis=1, keepdims=True)) -# pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) -# return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) - -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 -# labels = [0, 1, 2] -# (weights, biases, hidden_acts, sampled_vals, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=1, -# labels=labels, -# sampled=[1, 0, 2, 3], -# subtract_log_q=True) -# exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( -# exp_logits, exp_labels) - -# got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( -# weights=constant_op.constant(weights), -# biases=constant_op.constant(biases), -# labels=constant_op.constant(labels, shape=(batch_size, 1)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=1, -# sampled_values=sampled_vals, -# remove_accidental_hits=False) - -# self.assertAllClose(exp_sampled_softmax_loss, -# self.evaluate(got_sampled_softmax_loss), 1e-4) - -# # Test with sharded weights and sharded biases. -# weight_shards, bias_shards = self._ShardTestEmbeddings( -# weights, biases, num_shards=3) -# got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( -# weights=[constant_op.constant(shard) for shard in weight_shards], -# biases=[constant_op.constant(shard) for shard in bias_shards], -# labels=constant_op.constant(labels, shape=(batch_size, 1)), -# inputs=constant_op.constant(hidden_acts), -# num_sampled=4, -# num_classes=num_classes, -# num_true=1, -# sampled_values=sampled_vals, -# remove_accidental_hits=False) - -# self.assertAllClose(exp_sampled_softmax_loss, -# self.evaluate(got_sampled_softmax_loss), 1e-4) - -# def testSampledSoftmaxLossBf16(self): -# # A simple test to verify the numerics for bfloat16. -# def _SoftmaxCrossEntropyWithLogits(logits, targets): -# # logits, targets: float arrays of the same shape. -# assert logits.shape == targets.shape -# stable_exp_logits = np.exp( -# logits - np.amax(logits, axis=1, keepdims=True)) -# pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) -# return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) - -# np.random.seed(0) -# num_classes = 5 -# batch_size = 3 -# labels = [0, 1, 2] -# sampled = [1, 0, 2, 3] -# (weights, biases, hidden_acts, _, exp_logits, -# exp_labels) = self._GenerateTestData( -# num_classes=num_classes, -# dim=10, -# batch_size=batch_size, -# num_true=1, -# labels=labels, -# sampled=sampled, -# subtract_log_q=True) -# exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( -# exp_logits, exp_labels) - -# true_exp_bf16 = np.full([batch_size, 1], -# fill_value=0.5, -# dtype=dtypes.bfloat16.as_numpy_dtype) -# sampled_exp_bf16 = np.full([len(sampled)], -# fill_value=0.5, -# dtype=dtypes.bfloat16.as_numpy_dtype) -# sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) - -# got_sampled_softmax_loss = math_ops.cast( -# nn_impl.sampled_softmax_loss_v2( -# weights=constant_op.constant(weights, dtype=dtypes.bfloat16), -# biases=constant_op.constant(biases, dtype=dtypes.bfloat16), -# labels=constant_op.constant( -# labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), -# inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), -# num_sampled=4, -# num_classes=num_classes, -# num_true=1, -# sampled_values=sampled_vals_bf16, -# remove_accidental_hits=False), dtypes.float32) - -# self.assertAllClose(exp_sampled_softmax_loss, -# self.evaluate(got_sampled_softmax_loss), 1e-1) - - -# class GeluTest(test_lib.TestCase): - -# def test(self): - -# def gelu(x, approximate=False): -# if approximate: -# return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * -# (x + 0.044715 * np.power(x, 3)))) -# else: -# from scipy.stats import norm # pylint: disable=g-import-not-at-top -# return x * norm.cdf(x) - -# np.random.seed(1) # Make it reproducible. -# x = np.random.randn(3, 4).astype(np.float32) -# y = gelu(x) -# z = self.evaluate(nn_ops.gelu(constant_op.constant(x))) -# self.assertAllClose(y, z) - -# y = gelu(x, True) -# z = self.evaluate(nn_ops.gelu(constant_op.constant(x), True)) -# self.assertAllClose(y, z) - - -# class CReluTest(test_lib.TestCase): - -# def test(self): -# np.random.seed(1) # Make it reproducible. -# x = np.random.randn(3, 4).astype(np.float32) -# y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) - -# z = self.evaluate(nn_ops.crelu(constant_op.constant(x))) -# self.assertAllClose(y, z, 1e-4) - - -# class ReluTest(test_lib.TestCase): - -# def test(self): -# np.random.seed(1) # Make it reproducible. -# x = np.random.randn(3, 4).astype(np.float32) -# y = np.maximum(x, 0.0) - -# z = self.evaluate(nn_ops.relu(constant_op.constant(x))) -# self.assertAllEqual(y, z) - -# @test_util.run_deprecated_v1 -# def testNaNs(self): -# # Test that relu(nan) = nan for various sizes. -# for i in range(18): -# x = np.zeros(i) + np.nan -# with self.cached_session(): -# z = nn_ops.relu(constant_op.constant(x)).eval() -# self.assertTrue(np.isnan(z).all()) - - -# class LeakyReluTest(test_lib.TestCase): - -# def testRange(self): -# batch_size = 3 -# height, width = 4, 4 -# np.random.seed(1) # Make it reproducible. -# inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( -# np.float32) -# inputs = constant_op.constant(inputs) - -# outputs = nn_ops.leaky_relu(inputs) -# self.assertEquals(inputs.shape, outputs.shape) - -# inputs, outputs = self.evaluate([inputs, outputs]) - -# self.assertGreaterEqual(outputs.min(), 0.0) -# self.assertLessEqual(outputs.max(), 1.0) -# self.assertAllClose(inputs, outputs) - -# @test_util.run_deprecated_v1 -# def testValues(self): -# for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: -# np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) -# outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) - -# outputs = self.evaluate(outputs) - -# tol = 2e-3 if dtype == np.float16 else 1e-6 -# self.assertAllClose( -# outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) - -# @test_util.run_deprecated_v1 -# def testName(self): -# np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64) -# outputs_with_name_set = nn_ops.leaky_relu( -# constant_op.constant(np_values), -# name='test_relu_op') -# self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0') -# outputs_without_name_set = nn_ops.leaky_relu( -# constant_op.constant(np_values)) -# self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0') - - -# class SwishTest(test_lib.TestCase): - -# @test_util.run_deprecated_v1 -# def testValues(self): -# np_values = np.array( -# [np.linspace(-7.0, 0.0, 100), -# np.linspace(0.0, 7.0, 100)], -# dtype=np.float32) -# tf_values = constant_op.constant(np_values) -# actual_tf_outputs = nn_impl.swish(tf_values) -# expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values) - -# actual_outputs, expected_outputs = self.evaluate( -# [actual_tf_outputs, expected_tf_outputs]) - -# self.assertAllClose(actual_outputs, expected_outputs) - -# @test_util.run_deprecated_v1 -# def testGradients(self): -# shape = [5, 3, 4] -# sigma = 5 -# input_values = np.random.randn(*shape) * sigma -# x_tf = constant_op.constant(input_values) -# y_tf = nn_impl.swish(x_tf) -# with self.cached_session(): -# err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape) -# self.assertLess(err, 1e-4) - - -# class MomentsTest(test_lib.TestCase): - -# def doOutputTest(self, -# input_shape, -# moments_axes, -# tol=1e-4, -# check_gradients=False): -# for mu in [0.0, 1.0, 1e3]: -# for sigma in [1.0, 0.1]: -# for keep_dims in [True, False]: -# input_values = np.random.rand(*input_shape) * sigma + mu -# expected_mean = np.mean( -# input_values, axis=moments_axes, keepdims=keep_dims) -# expected_var = np.var( -# input_values, axis=moments_axes, keepdims=keep_dims) -# with ops.Graph().as_default() as g: -# with self.session(graph=g) as sess: -# inputs = constant_op.constant( -# input_values, shape=input_shape, dtype=dtypes.float32) -# mean, variance = nn_impl.moments_v2( -# inputs, moments_axes, keepdims=keep_dims) - -# if check_gradients: -# err = gradient_checker.compute_gradient_error( -# inputs, input_shape, mean, mean.shape.as_list()) -# self.assertLess(err, 1e-3) -# err = gradient_checker.compute_gradient_error( -# inputs, input_shape, variance, variance.shape.as_list()) -# self.assertLess(err, 1e-3) - -# # Evaluate. -# [mean, variance] = self.evaluate([mean, variance]) -# # Make sure that there are no NaNs -# self.assertFalse(np.isnan(mean).any()) -# self.assertFalse(np.isnan(variance).any()) -# self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) -# self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) - -# def testOutputAndGradient2DInput0(self): -# self.doOutputTest((10, 10), (0,), check_gradients=True) - -# def testOutputAndGradient2DInput01(self): -# self.doOutputTest((10, 10), (0, 1), check_gradients=True) - -# def testOutput2DInput0(self): -# self.doOutputTest((10, 300), (0,)) - -# def testOutput2DInput1(self): -# self.doOutputTest((10, 300), (1,)) - -# def testOutput2DInput01(self): -# self.doOutputTest((10, 300), (0, 1)) - -# def testOutput4DInput0(self): -# self.doOutputTest((10, 10, 10, 30), (0,)) - -# def testOutput4DInput1(self): -# self.doOutputTest((10, 10, 10, 30), (1,)) - -# def testOutput4DInput3(self): -# self.doOutputTest((10, 10, 10, 30), (3,)) - -# def testOutput4DInput012(self): -# self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) - -# def testOutput4DInput123(self): -# self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) - - -# class DataFormatDimMapTest(test_lib.TestCase): - -# def _test(self, x_val, y_val_expected): -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x) - -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def test(self): -# self._test(0, 0) -# self._test(1, 2) -# self._test(2, 3) -# self._test(3, 1) -# self._test(-1, 1) -# self._test(-2, 3) -# self._test(-3, 2) -# self._test(-4, 0) -# self._test([1, 3], [2, 1]) -# self._test([1, 3, -2], [2, 1, 3]) -# self._test([1, -3, -2], [2, 2, 3]) -# self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) - -# def testNHWCtoNCHW(self): -# x_val = [1, -3, -2] -# y_val_expected = [2, 2, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def testNHWCtoHWNC(self): -# x_val = [-4, -3, -2, -1, 0, 1, 2, 3] -# y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def testNHWCtoWHCN(self): -# x_val = [-4, -3, -2, -1, 0, 1, 2, 3] -# y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def testNDHWCtoNCDHW(self): -# x_val = [1, -4, -3, -2] -# y_val_expected = [2, 2, 3, 4] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def testNDHWCtoDHWNC(self): -# x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] -# y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def testDNHWCtoWHDCN(self): -# x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] -# y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# def testArbitraryASCII(self): -# x_val = [-4, -3, -2, -1, 0, 1, 2, 3] -# y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, y_val_expected) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testInvalidLength(self): -# x = [-4, -3, -2, -1, 0, 1, 2, 3] -# with self.assertRaisesRegex(errors.InvalidArgumentError, -# "Source format must be of length 4 or 5"): -# op = nn_ops.data_format_dim_map( -# x, src_format="12345678", dst_format="87654321") -# with test_util.use_gpu(): -# self.evaluate(op) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testDuplicateSrc(self): -# x = [-4, -3, -2, -1, 0, 1, 2, 3] -# with self.assertRaisesRegex( -# errors.InvalidArgumentError, -# "Destination and source format must determine a permutation"): -# op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321") -# with test_util.use_gpu(): -# self.evaluate(op) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testDuplicateDst(self): -# x = [-4, -3, -2, -1, 0, 1, 2, 3] -# with self.assertRaisesRegex( -# errors.InvalidArgumentError, -# "Destination and source format must determine a permutation"): -# op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321") -# with test_util.use_gpu(): -# self.evaluate(op) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testExtraSpecifiers(self): -# x = [-4, -3, -2, -1, 0, 1, 2, 3] -# with self.assertRaisesRegex( -# errors.InvalidArgumentError, -# "Destination and source format must determine a permutation"): -# op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321") -# with test_util.use_gpu(): -# self.evaluate(op) - - -# class DataFormatVectorPermuteTest(test_lib.TestCase): - -# def testNHWCToNCHW(self): -# x_val = [7, 4, 9, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x) -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [7, 3, 4, 9]) - -# def testNCHWToNHWC(self): -# x_val = [7, 4, 9, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [7, 9, 3, 4]) - -# def testNHWCToHWNC(self): -# x_val = [7, 4, 9, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [4, 9, 7, 3]) - -# def testHWNCToNHWC(self): -# x_val = [7, 4, 9, 3] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [9, 7, 4, 3]) - -# def testNHWCToNCHW2D(self): -# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x) -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) - -# def testNHWCToHWNC2D(self): -# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) - -# def testHWNCToNHWC2D(self): -# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) - -# def testNCHWToNHWC2D(self): -# x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] -# x = constant_op.constant(x_val) -# y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") -# with test_util.use_gpu(): -# y_val = self.evaluate(y) -# self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testInvalidLength(self): -# x = [0, 1, 2, 3] -# with self.assertRaisesRegex(errors.InvalidArgumentError, -# "Source format must be of length 4 or 5"): -# op = nn_ops.data_format_vec_permute( -# x, src_format="12345678", dst_format="87654321") -# with test_util.use_gpu(): -# self.evaluate(op) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testDuplicateSrc(self): -# x = [0, 1, 2, 3] -# with self.assertRaisesRegex( -# errors.InvalidArgumentError, -# "Destination and source format must determine a permutation"): -# op = nn_ops.data_format_vec_permute( -# x, src_format="1233", dst_format="4321") -# with test_util.use_gpu(): -# self.evaluate(op) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testDuplicateDst(self): -# x = [0, 1, 2, 3] -# with self.assertRaisesRegex( -# errors.InvalidArgumentError, -# "Destination and source format must determine a permutation"): -# op = nn_ops.data_format_vec_permute( -# x, src_format="1234", dst_format="3321") -# with test_util.use_gpu(): -# self.evaluate(op) - -# @test_util.disable_xla("XLA catches the error and rethrows as different one") -# def testExtraSpecifiers(self): -# x = [0, 1, 2, 3] -# with self.assertRaisesRegex( -# errors.InvalidArgumentError, -# "Destination and source format must determine a permutation"): -# op = nn_ops.data_format_vec_permute( -# x, src_format="1234", dst_format="5321") -# with test_util.use_gpu(): -# self.evaluate(op) - - -# @test_util.run_all_in_graph_and_eager_modes -# class AvgPoolTest(test_lib.TestCase): - -# def test1DTensor(self): -# x = array_ops.ones([3, 6, 5]) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - -# def test1DNumpy(self): -# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 -# # np.ones defaults to using float64 when dtype is not explicitly specified -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = np.ones([3, 6, 5], dtype=dtype) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - -# def test1DNumpyWithGolden(self): -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = np.array([[[3], [6], [5]], -# [[1], [0], [1]]], dtype=dtype) -# ksize = 2 -# strides = 1 -# y = nn_ops.avg_pool1d(x, ksize, strides, "SAME") -# expected_y = np.array([[[4.5], [5.5], [5.0]], -# [[0.5], [0.5], [1.0]]], dtype=dtype) -# self.assertAllEqual(self.evaluate(y), expected_y) - -# def test2DTensor(self): -# x = array_ops.ones([3, 6, 6, 5]) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - -# def test2DNumpy(self): -# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 -# # np.ones defaults to using float64 when dtype is not explicitly specified -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = np.ones([3, 6, 6, 5], dtype=dtype) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - -# def test3DTensor(self): -# if test_lib.is_built_with_rocm(): -# self.skipTest("Pooling with 3D tensors is not supported in ROCm") -# x = array_ops.ones([3, 7, 6, 6, 5]) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - -# def test3DNumpy(self): -# if test_lib.is_built_with_rocm(): -# self.skipTest("Pooling with 3D tensors is not supported in ROCm") -# x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) - - -# @test_util.run_all_in_graph_and_eager_modes -# class MaxPoolTest(test_lib.TestCase): - -# def test1DTensor(self): -# x = array_ops.ones([3, 6, 5]) -# ksize = 2 -# strides = 2 - -# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") - -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) +class DropoutTest(test_lib.TestCase): + + def testDropout(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, rate=(1 - keep_prob)) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropout(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. This time with shaped + # noise. + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropoutCorrelation(self): + # Runs a shaped dropout and tests that the correlations are correct. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + # Verifies that each y column as only one type of activation. + for i in xrange(x_dim): + sorted_value = np.unique(np.sort(value[i, :])) + self.assertEqual(sorted_value.size, 1) + + @test_util.run_deprecated_v1 + def testDropoutPlaceholderKeepProb(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + with self.cached_session(): + t = constant_op.constant( + 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + keep_prob_placeholder = array_ops.placeholder(dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob_placeholder) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + @test_util.run_deprecated_v1 + def testShapedDropoutUnknownShape(self): + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout_x = nn_ops.dropout( + x, + rate=(1 - keep_prob), + noise_shape=array_ops.placeholder(dtypes.int32)) + self.assertEqual(x.get_shape(), dropout_x.get_shape()) + + def testPartialShapedDropout(self): + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + # Set noise_shape=[None, 1] which means [x_dim, 1]. + dropout = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[None, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + @test_util.run_deprecated_v1 + def testInvalidKeepProb(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + nn_ops.dropout(t, -1.0) + with self.assertRaises(ValueError): + nn_ops.dropout(t, 1.1) + with self.assertRaises(ValueError): + nn_ops.dropout(t, [0.0, 1.0]) + with self.assertRaises(ValueError): + nn_ops.dropout(t, array_ops.placeholder(dtypes.float64)) + with self.assertRaises(ValueError): + nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) + + @test_util.run_deprecated_v1 + def testInvalidRate(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, -1.0) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, 1.1) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, [0.0, 1.0]) + + def testLargeRate(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + _ = nn_ops.dropout_v2(t, 0.9) + + @test_util.run_deprecated_v1 + def testShapedDropoutShapeError(self): + # Runs shaped dropout and verifies an error is thrown on misshapen noise. + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + _ = nn_ops.dropout( + t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim + 10]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, y_dim, 5]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim + 3]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim]) + # test that broadcasting proceeds + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[y_dim]) + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, y_dim]) + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[x_dim, 1]) + _ = nn_ops.dropout(t, rate=(1 - keep_prob), noise_shape=[1, 1]) + + def testNoDropoutFast(self): + x = array_ops.zeros((5,)) + y = nn_ops.dropout(x, rate=0) + self.assertTrue(x is y) + + y = nn_ops.dropout_v2(x, rate=0) + self.assertTrue(x is y) + + def testDropoutWithIntegerInputs(self): + x = constant_op.constant([1, 1, 1, 1, 1]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(x, 0.5) + + +class ComputeSampledLogitsTest(test_lib.TestCase): + + def setUp(self): + self._eps = 1e-3 + + def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels, + sampled, subtract_log_q): + """Randomly generates input/output data for a single test case. + + This function returns numpy constants for use in a test case. + + Args: + num_classes: An int. The number of embedding classes in the test case. + dim: An int. The dimension of the embedding. + batch_size: An int. The batch size. + num_true: An int. The number of target classes per training example. + labels: A list of batch_size * num_true ints. The target classes. + sampled: A list of indices in [0, num_classes). + subtract_log_q: A bool corresponding to the parameter in + _compute_sampled_logits(). + + Returns: + weights: Embedding weights to use as test input. It is a numpy array + of shape [num_classes, dim] + biases: Embedding biases to use as test input. It is a numpy array + of shape [num_classes]. + hidden_acts: Forward activations of the network to use as test input. + It is a numpy array of shape [batch_size, dim]. + sampled_vals: A tuple based on `sampled` to use as test input in the + format returned by a *_candidate_sampler function. + exp_logits: The output logits expected from _compute_sampled_logits(). + It is a numpy array of shape [batch_size, num_true + len(sampled)]. + exp_labels: The output labels expected from _compute_sampled_logits(). + It is a numpy array of shape [batch_size, num_true + len(sampled)]. + """ + weights = np.random.randn(num_classes, dim).astype(np.float32) + biases = np.random.randn(num_classes).astype(np.float32) + hidden_acts = np.random.randn(batch_size, dim).astype(np.float32) + + true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) + sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) + sampled_vals = (sampled, true_exp, sampled_exp) + + sampled_w, sampled_b = weights[sampled], biases[sampled] + true_w, true_b = weights[labels], biases[labels] + + true_logits = np.sum( + hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( + (batch_size, num_true, dim)), + axis=2) + true_b = true_b.reshape((batch_size, num_true)) + true_logits += true_b + sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b + + if subtract_log_q: + true_logits -= np.log(true_exp) + sampled_logits -= np.log(sampled_exp[np.newaxis, :]) + + exp_logits = np.concatenate([true_logits, sampled_logits], axis=1) + exp_labels = np.hstack((np.ones_like(true_logits) / num_true, + np.zeros_like(sampled_logits))) + + return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels + + def _ShardTestEmbeddings(self, weights, biases, num_shards): + """Shards the weights and biases returned by _GenerateTestData. + + Args: + weights: The weights returned by _GenerateTestData. + biases: The biases returned by _GenerateTestData. + num_shards: The number of shards to create. + + Returns: + sharded_weights: A list of size `num_shards` containing all the weights. + sharded_biases: A list of size `num_shards` containing all the biases. + """ + with ops.Graph().as_default() as g: + sharded_weights = variable_scope.get_variable( + "w", + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=constant_op.constant(weights)) + sharded_biases = variable_scope.get_variable( + "b", + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=constant_op.constant(biases)) + with self.session(graph=g) as sess: + variables.global_variables_initializer().run() + return self.evaluate([list(sharded_weights), list(sharded_biases)]) + + def testShapes(self): + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_basic_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertEqual(exp_logits.shape, got_logits.shape, self._eps) + self.assertEqual(exp_labels.shape, got_labels.shape, self._eps) + + def testBasic(self): + """Without accidental hit removal or subtract_log_q.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_basic_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testAccidentalHitRemoval(self): + """With accidental hit removal, no subtract_log_q.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + sampled = [1, 0, 2, 3] + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, _, + _) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=sampled, + subtract_log_q=False) + logits_tensor, _ = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=len(sampled), + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=True, + partition_strategy="div", + name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true) + # Test that the exponentiated logits of accidental hits are near 0. + # First we need to find the hits in this random test run: + labels_reshape = labels.reshape((batch_size, num_true)) + got_logits = self.evaluate(logits_tensor) + for row in xrange(batch_size): + row_labels = labels_reshape[row, :] + for col in xrange(len(sampled)): + if sampled[col] in row_labels: + # We need to add the num_true_test offset into logits_* + self.assertNear( + np.exp(got_logits[row, col + num_true]), 0., self._eps) + + def testSubtractLogQ(self): + """With subtract_log_q, no accidental hit removal.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=True, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_subtract_log_q_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testSharded(self): + """With sharded weights and sharded biases.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_sharded_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testNCELoss(self): + # A simple test to verify the numerics. + + def _SigmoidCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + pred = 1. / (1. + np.exp(-logits)) + eps = 0.0001 + pred = np.minimum(np.maximum(pred, eps), 1 - eps) + return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + exp_nce_loss = np.sum( + _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1) + + got_nce_loss = nn_impl.nce_loss_v2( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals) + + self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + + # Test with sharded weights and sharded biases. + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + got_nce_loss = nn_impl.nce_loss_v2( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals) + + self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + + def testSampledSoftmaxLoss(self): + # A simple test to verify the numerics. + + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-4) + + # Test with sharded weights and sharded biases. + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-4) + + def testSampledSoftmaxLossBf16(self): + # A simple test to verify the numerics for bfloat16. + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + sampled = [1, 0, 2, 3] + (weights, biases, hidden_acts, _, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=sampled, + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + true_exp_bf16 = np.full([batch_size, 1], + fill_value=0.5, + dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_exp_bf16 = np.full([len(sampled)], + fill_value=0.5, + dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) + + got_sampled_softmax_loss = math_ops.cast( + nn_impl.sampled_softmax_loss_v2( + weights=constant_op.constant(weights, dtype=dtypes.bfloat16), + biases=constant_op.constant(biases, dtype=dtypes.bfloat16), + labels=constant_op.constant( + labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), + inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals_bf16, + remove_accidental_hits=False), dtypes.float32) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-1) + + +class GeluTest(test_lib.TestCase): + + def test(self): + + def gelu(x, approximate=False): + if approximate: + return 0.5 * x * (1.0 + np.tanh(np.sqrt(2.0 / np.pi) * + (x + 0.044715 * np.power(x, 3)))) + else: + from scipy.stats import norm # pylint: disable=g-import-not-at-top + return x * norm.cdf(x) + + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = gelu(x) + z = self.evaluate(nn_ops.gelu(constant_op.constant(x))) + self.assertAllClose(y, z) + + y = gelu(x, True) + z = self.evaluate(nn_ops.gelu(constant_op.constant(x), True)) + self.assertAllClose(y, z) + + +class CReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) + + z = self.evaluate(nn_ops.crelu(constant_op.constant(x))) + self.assertAllClose(y, z, 1e-4) + + +class ReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.maximum(x, 0.0) + + z = self.evaluate(nn_ops.relu(constant_op.constant(x))) + self.assertAllEqual(y, z) + + @test_util.run_deprecated_v1 + def testNaNs(self): + # Test that relu(nan) = nan for various sizes. + for i in range(18): + x = np.zeros(i) + np.nan + with self.cached_session(): + z = nn_ops.relu(constant_op.constant(x)).eval() + self.assertTrue(np.isnan(z).all()) + + +class LeakyReluTest(test_lib.TestCase): + + def testRange(self): + batch_size = 3 + height, width = 4, 4 + np.random.seed(1) # Make it reproducible. + inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( + np.float32) + inputs = constant_op.constant(inputs) + + outputs = nn_ops.leaky_relu(inputs) + self.assertEquals(inputs.shape, outputs.shape) + + inputs, outputs = self.evaluate([inputs, outputs]) + + self.assertGreaterEqual(outputs.min(), 0.0) + self.assertLessEqual(outputs.max(), 1.0) + self.assertAllClose(inputs, outputs) + + @test_util.run_deprecated_v1 + def testValues(self): + for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: + np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) + outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) + + outputs = self.evaluate(outputs) + + tol = 2e-3 if dtype == np.float16 else 1e-6 + self.assertAllClose( + outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) + + @test_util.run_deprecated_v1 + def testName(self): + np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64) + outputs_with_name_set = nn_ops.leaky_relu( + constant_op.constant(np_values), + name='test_relu_op') + self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0') + outputs_without_name_set = nn_ops.leaky_relu( + constant_op.constant(np_values)) + self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0') + + +class SwishTest(test_lib.TestCase): + + @test_util.run_deprecated_v1 + def testValues(self): + np_values = np.array( + [np.linspace(-7.0, 0.0, 100), + np.linspace(0.0, 7.0, 100)], + dtype=np.float32) + tf_values = constant_op.constant(np_values) + actual_tf_outputs = nn_impl.swish(tf_values) + expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values) + + actual_outputs, expected_outputs = self.evaluate( + [actual_tf_outputs, expected_tf_outputs]) + + self.assertAllClose(actual_outputs, expected_outputs) + + @test_util.run_deprecated_v1 + def testGradients(self): + shape = [5, 3, 4] + sigma = 5 + input_values = np.random.randn(*shape) * sigma + x_tf = constant_op.constant(input_values) + y_tf = nn_impl.swish(x_tf) + with self.cached_session(): + err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape) + self.assertLess(err, 1e-4) + + +class MomentsTest(test_lib.TestCase): + + def doOutputTest(self, + input_shape, + moments_axes, + tol=1e-4, + check_gradients=False): + for mu in [0.0, 1.0, 1e3]: + for sigma in [1.0, 0.1]: + for keep_dims in [True, False]: + input_values = np.random.rand(*input_shape) * sigma + mu + expected_mean = np.mean( + input_values, axis=moments_axes, keepdims=keep_dims) + expected_var = np.var( + input_values, axis=moments_axes, keepdims=keep_dims) + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + inputs = constant_op.constant( + input_values, shape=input_shape, dtype=dtypes.float32) + mean, variance = nn_impl.moments_v2( + inputs, moments_axes, keepdims=keep_dims) + + if check_gradients: + err = gradient_checker.compute_gradient_error( + inputs, input_shape, mean, mean.shape.as_list()) + self.assertLess(err, 1e-3) + err = gradient_checker.compute_gradient_error( + inputs, input_shape, variance, variance.shape.as_list()) + self.assertLess(err, 1e-3) + + # Evaluate. + [mean, variance] = self.evaluate([mean, variance]) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(mean).any()) + self.assertFalse(np.isnan(variance).any()) + self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) + self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) + + def testOutputAndGradient2DInput0(self): + self.doOutputTest((10, 10), (0,), check_gradients=True) + + def testOutputAndGradient2DInput01(self): + self.doOutputTest((10, 10), (0, 1), check_gradients=True) + + def testOutput2DInput0(self): + self.doOutputTest((10, 300), (0,)) + + def testOutput2DInput1(self): + self.doOutputTest((10, 300), (1,)) + + def testOutput2DInput01(self): + self.doOutputTest((10, 300), (0, 1)) + + def testOutput4DInput0(self): + self.doOutputTest((10, 10, 10, 30), (0,)) + + def testOutput4DInput1(self): + self.doOutputTest((10, 10, 10, 30), (1,)) + + def testOutput4DInput3(self): + self.doOutputTest((10, 10, 10, 30), (3,)) + + def testOutput4DInput012(self): + self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) + + def testOutput4DInput123(self): + self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) + + +class DataFormatDimMapTest(test_lib.TestCase): + + def _test(self, x_val, y_val_expected): + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x) + + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def test(self): + self._test(0, 0) + self._test(1, 2) + self._test(2, 3) + self._test(3, 1) + self._test(-1, 1) + self._test(-2, 3) + self._test(-3, 2) + self._test(-4, 0) + self._test([1, 3], [2, 1]) + self._test([1, 3, -2], [2, 1, 3]) + self._test([1, -3, -2], [2, 2, 3]) + self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) + + def testNHWCtoNCHW(self): + x_val = [1, -3, -2] + y_val_expected = [2, 2, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoHWNC(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoWHCN(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNDHWCtoNCDHW(self): + x_val = [1, -4, -3, -2] + y_val_expected = [2, 2, 3, 4] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNDHWCtoDHWNC(self): + x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testDNHWCtoWHDCN(self): + x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testArbitraryASCII(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testInvalidLength(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Source format must be of length 4 or 5"): + op = nn_ops.data_format_dim_map( + x, src_format="12345678", dst_format="87654321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateSrc(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateDst(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testExtraSpecifiers(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321") + with test_util.use_gpu(): + self.evaluate(op) + + +class DataFormatVectorPermuteTest(test_lib.TestCase): + + def testNHWCToNCHW(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testInvalidLength(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Source format must be of length 4 or 5"): + op = nn_ops.data_format_vec_permute( + x, src_format="12345678", dst_format="87654321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateSrc(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_vec_permute( + x, src_format="1233", dst_format="4321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateDst(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_vec_permute( + x, src_format="1234", dst_format="3321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testExtraSpecifiers(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_vec_permute( + x, src_format="1234", dst_format="5321") + with test_util.use_gpu(): + self.evaluate(op) + + +@test_util.run_all_in_graph_and_eager_modes +class AvgPoolTest(test_lib.TestCase): + + def test1DTensor(self): + x = array_ops.ones([3, 6, 5]) + ksize = 2 + strides = 2 + + y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + def test1DNumpy(self): + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + # np.ones defaults to using float64 when dtype is not explicitly specified + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = np.ones([3, 6, 5], dtype=dtype) + ksize = 2 + strides = 2 + + y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.avg_pool1d(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + def test1DNumpyWithGolden(self): + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = np.array([[[3], [6], [5]], + [[1], [0], [1]]], dtype=dtype) + ksize = 2 + strides = 1 + y = nn_ops.avg_pool1d(x, ksize, strides, "SAME") + expected_y = np.array([[[4.5], [5.5], [5.0]], + [[0.5], [0.5], [1.0]]], dtype=dtype) + self.assertAllEqual(self.evaluate(y), expected_y) + + def test2DTensor(self): + x = array_ops.ones([3, 6, 6, 5]) + ksize = 2 + strides = 2 + + y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + def test2DNumpy(self): + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + # np.ones defaults to using float64 when dtype is not explicitly specified + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = np.ones([3, 6, 6, 5], dtype=dtype) + ksize = 2 + strides = 2 + + y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.avg_pool(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + def test3DTensor(self): + if test_lib.is_built_with_rocm(): + self.skipTest("Pooling with 3D tensors is not supported in ROCm") + x = array_ops.ones([3, 7, 6, 6, 5]) + ksize = 2 + strides = 2 + + y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + def test3DNumpy(self): + if test_lib.is_built_with_rocm(): + self.skipTest("Pooling with 3D tensors is not supported in ROCm") + x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) + ksize = 2 + strides = 2 + + y1 = nn_ops.avg_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.avg_pool3d(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + + +@test_util.run_all_in_graph_and_eager_modes +class MaxPoolTest(test_lib.TestCase): + + def test1DTensor(self): + x = array_ops.ones([3, 6, 5]) + ksize = 2 + strides = 2 + + y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") + + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test1DNumpy(self): -# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 -# # np.ones defaults to using float64 when dtype is not explicitly specified -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = np.ones([3, 6, 5], dtype=dtype) -# ksize = 2 -# strides = 2 + def test1DNumpy(self): + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + # np.ones defaults to using float64 when dtype is not explicitly specified + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = np.ones([3, 6, 5], dtype=dtype) + ksize = 2 + strides = 2 -# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") + y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME") -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test1DNumpyWithGolden(self): -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = np.array([[[3], [6], [5]], -# [[1], [0], [1]]], dtype=dtype) -# ksize = 2 -# strides = 1 -# y = nn_ops.max_pool1d(x, ksize, strides, "SAME") -# expected_y = np.array([[[6], [6], [5]], -# [[1], [1], [1]]], dtype=dtype) -# self.assertAllEqual(self.evaluate(y), expected_y) + def test1DNumpyWithGolden(self): + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = np.array([[[3], [6], [5]], + [[1], [0], [1]]], dtype=dtype) + ksize = 2 + strides = 1 + y = nn_ops.max_pool1d(x, ksize, strides, "SAME") + expected_y = np.array([[[6], [6], [5]], + [[1], [1], [1]]], dtype=dtype) + self.assertAllEqual(self.evaluate(y), expected_y) -# def test2DTensor(self): -# x = array_ops.ones([3, 6, 6, 5]) -# ksize = 2 -# strides = 2 + def test2DTensor(self): + x = array_ops.ones([3, 6, 6, 5]) + ksize = 2 + strides = 2 -# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.max_pool(x, ksize, strides, "SAME") + y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.max_pool(x, ksize, strides, "SAME") -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test2DNumpy(self): -# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 -# # np.ones defaults to using float64 when dtype is not explicitly specified -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = np.ones([3, 6, 6, 5], dtype=dtype) -# ksize = 2 -# strides = 2 + def test2DNumpy(self): + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + # np.ones defaults to using float64 when dtype is not explicitly specified + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = np.ones([3, 6, 6, 5], dtype=dtype) + ksize = 2 + strides = 2 -# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.max_pool(x, ksize, strides, "SAME") + y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.max_pool(x, ksize, strides, "SAME") -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test3DTensor(self): -# if test_lib.is_built_with_rocm(): -# self.skipTest("Pooling with 3D tensors is not supported in ROCm") -# x = array_ops.ones([3, 7, 6, 6, 5]) -# ksize = 2 -# strides = 2 + def test3DTensor(self): + if test_lib.is_built_with_rocm(): + self.skipTest("Pooling with 3D tensors is not supported in ROCm") + x = array_ops.ones([3, 7, 6, 6, 5]) + ksize = 2 + strides = 2 -# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") + y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test3DNumpy(self): -# if test_lib.is_built_with_rocm(): -# self.skipTest("Pooling with 3D tensors is not supported in ROCm") -# x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) -# ksize = 2 -# strides = 2 + def test3DNumpy(self): + if test_lib.is_built_with_rocm(): + self.skipTest("Pooling with 3D tensors is not supported in ROCm") + x = np.ones([3, 7, 6, 6, 5], dtype=np.float32) + ksize = 2 + strides = 2 -# y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") -# y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") + y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME") + y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME") -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def testIncorrectSizeInputSmall(self): -# x = array_ops.ones([3, 4]) -# with self.assertRaisesRegex( -# ValueError, "Input tensor must be of rank 3, 4 or 5 but was 2."): -# nn_ops.max_pool_v2(x, 2, 2, "SAME") + def testIncorrectSizeInputSmall(self): + x = array_ops.ones([3, 4]) + with self.assertRaisesRegex( + ValueError, "Input tensor must be of rank 3, 4 or 5 but was 2."): + nn_ops.max_pool_v2(x, 2, 2, "SAME") -# def testIncorrectSizeInput(self): -# x = array_ops.ones([3, 4, 1, 2, 1, 2]) -# with self.assertRaisesRegex( -# ValueError, "Input tensor must be of rank 3, 4 or 5 but was 6."): -# nn_ops.max_pool_v2(x, 2, 2, "SAME") + def testIncorrectSizeInput(self): + x = array_ops.ones([3, 4, 1, 2, 1, 2]) + with self.assertRaisesRegex( + ValueError, "Input tensor must be of rank 3, 4 or 5 but was 6."): + nn_ops.max_pool_v2(x, 2, 2, "SAME") -# @test_util.run_all_in_graph_and_eager_modes -# class ConvolutionTest(test_lib.TestCase): +@test_util.run_all_in_graph_and_eager_modes +class ConvolutionTest(test_lib.TestCase): -# def testUnknownSize(self): -# # explicilty use float32 for ROCm, as MIOpen does not yet support float64 -# # np.ones defaults to using float64 when dtype is not explicitly specified -# dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 -# x = tensor_spec.TensorSpec(None, dtypes.float32, name="x") -# k = np.ones([3, 6, 6, 5], dtype=dtype) + def testUnknownSize(self): + # explicilty use float32 for ROCm, as MIOpen does not yet support float64 + # np.ones defaults to using float64 when dtype is not explicitly specified + dtype = np.float32 if test_lib.is_built_with_rocm() else np.float64 + x = tensor_spec.TensorSpec(None, dtypes.float32, name="x") + k = np.ones([3, 6, 6, 5], dtype=dtype) -# @def_function.function -# def F(value): -# return nn_ops.convolution(value, k, "SAME") + @def_function.function + def F(value): + return nn_ops.convolution(value, k, "SAME") -# F.get_concrete_function(x) + F.get_concrete_function(x) -# class ConvTransposeTest(test_lib.TestCase): +class ConvTransposeTest(test_lib.TestCase): -# def test1D(self): -# t = array_ops.ones([2, 4, 3]) -# v = array_ops.ones([2, 5, 3]) -# strides = 2 + def test1D(self): + t = array_ops.ones([2, 4, 3]) + v = array_ops.ones([2, 5, 3]) + strides = 2 -# y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) -# y2 = nn_ops.conv_transpose(t, v, [2, 8, 5], strides) + y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) + y2 = nn_ops.conv_transpose(t, v, [2, 8, 5], strides) -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test1DTensor(self): -# t = array_ops.ones([2, 4, 3]) -# v = array_ops.ones([2, 5, 3]) -# strides = 2 + def test1DTensor(self): + t = array_ops.ones([2, 4, 3]) + v = array_ops.ones([2, 5, 3]) + strides = 2 -# y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) -# y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 5]), strides) + y1 = nn_ops.conv1d_transpose(t, v, [2, 8, 5], strides) + y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 5]), strides) -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test2D(self): -# t = array_ops.ones([2, 4, 4, 3]) -# v = array_ops.ones([2, 2, 5, 3]) -# strides = 2 + def test2D(self): + t = array_ops.ones([2, 4, 4, 3]) + v = array_ops.ones([2, 2, 5, 3]) + strides = 2 -# y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) -# y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 5], strides) + y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) + y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 5], strides) -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test2DTensor(self): -# t = array_ops.ones([2, 4, 4, 3]) -# v = array_ops.ones([2, 2, 5, 3]) -# strides = 2 + def test2DTensor(self): + t = array_ops.ones([2, 4, 4, 3]) + v = array_ops.ones([2, 2, 5, 3]) + strides = 2 -# y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) -# y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 5]), -# strides) + y1 = nn_ops.conv2d_transpose_v2(t, v, [2, 8, 8, 5], strides) + y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 5]), + strides) -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test3D(self): -# t = array_ops.ones([2, 4, 4, 4, 3]) -# v = array_ops.ones([2, 2, 2, 5, 3]) -# strides = 2 + def test3D(self): + t = array_ops.ones([2, 4, 4, 4, 3]) + v = array_ops.ones([2, 2, 2, 5, 3]) + strides = 2 -# y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) -# y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 8, 5], strides) + y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) + y2 = nn_ops.conv_transpose(t, v, [2, 8, 8, 8, 5], strides) -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def test3DTensor(self): -# t = array_ops.ones([2, 4, 4, 4, 3]) -# v = array_ops.ones([2, 2, 2, 5, 3]) -# strides = 2 + def test3DTensor(self): + t = array_ops.ones([2, 4, 4, 4, 3]) + v = array_ops.ones([2, 2, 2, 5, 3]) + strides = 2 -# y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) -# y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 8, 5]), -# strides) + y1 = nn_ops.conv3d_transpose_v2(t, v, [2, 8, 8, 8, 5], strides) + y2 = nn_ops.conv_transpose(t, v, constant_op.constant([2, 8, 8, 8, 5]), + strides) -# self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) + self.assertAllEqual(self.evaluate(y1), self.evaluate(y2)) -# def testIncorrectSizeInputSmall(self): -# with self.assertRaisesRegex( -# ValueError, "output_shape must be of length 3, 4 or 5 but was 2."): -# nn_ops.conv_transpose(None, 2, [2, 3], "SAME") + def testIncorrectSizeInputSmall(self): + with self.assertRaisesRegex( + ValueError, "output_shape must be of length 3, 4 or 5 but was 2."): + nn_ops.conv_transpose(None, 2, [2, 3], "SAME") -# def testIncorrectSizeInput(self): -# with self.assertRaisesRegex( -# ValueError, "output_shape must be of length 3, 4 or 5 but was 6."): -# nn_ops.conv_transpose(None, 2, [2, 3, 4, 2, 5, 1], "SAME") + def testIncorrectSizeInput(self): + with self.assertRaisesRegex( + ValueError, "output_shape must be of length 3, 4 or 5 but was 6."): + nn_ops.conv_transpose(None, 2, [2, 3, 4, 2, 5, 1], "SAME") -# def testTensorsNoShape(self): -# with self.assertRaisesRegex( -# ValueError, -# "output_shape must be a tensor or sized collection."): -# nn_ops.conv_transpose(None, None, None, None) + def testTensorsNoShape(self): + with self.assertRaisesRegex( + ValueError, + "output_shape must be a tensor or sized collection."): + nn_ops.conv_transpose(None, None, None, None) if __name__ == "__main__": From aa54aa0a29e18625b1cf84c4c5a5b436b1f2e054 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 6 Jun 2022 15:51:46 +0800 Subject: [PATCH 11/17] [Op] Set fused op to default enabled. --- tensorflow/python/ops/nn_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index e8fb43483d2..78ca733cfdb 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -591,7 +591,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None): @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) @deprecated_args(None, "dim is deprecated, use axis instead", "dim") -def l2_normalize(x, axis=None, epsilon=1e-12, do_fusion=False, name=None, dim=None): +def l2_normalize(x, axis=None, epsilon=1e-12, do_fusion=True, name=None, dim=None): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes @@ -619,7 +619,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, do_fusion=False, name=None, dim=No @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) -def l2_normalize_v2(x, axis=None, epsilon=1e-12, do_fusion=False, name=None): +def l2_normalize_v2(x, axis=None, epsilon=1e-12, do_fusion=True, name=None): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes From f1581db99804613dbd2a7549b980bca7bfd8d195 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 8 Jun 2022 16:59:51 +0800 Subject: [PATCH 12/17] [Op] Optimize AVX512 perf. --- .../fused_l2_normalize_grad_op_test.cc | 1 + .../fused_l2_normalize_op.cc | 347 ++++++++---------- .../fused_l2_normalize_op_test.cc | 1 + 3 files changed, 157 insertions(+), 192 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc index b6762b35589..aaa774fa13c 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc @@ -108,6 +108,7 @@ static Graph* FusedL2NormalizeGrad(int rows, int cols) { BM_FusedL2NormGrad(ROWS, COLS, 8); \ BM_FusedL2NormGrad_NTH(1024, 63); +BM_FusedL2NormGrad_NTH(1024, 127); BM_FusedL2NormGrad_NTH(1024, 255); BM_FusedL2NormGrad_NTH(1024, 511); BM_FusedL2NormGrad_NTH(1024, 1023); diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index e9993db57a7..5b57f58d8ee 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -57,19 +57,34 @@ class FusedL2NormalizeOp : public OpKernel { auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); thread::ThreadPool *thread_pool = worker_threads.workers; - thread_pool->ParallelFor(total_unit, unit_cost, - [&input, &output, rows, cols, this](int64 begin_unit, int64 end_unit) { +#ifdef __AVX512F__ + int64 block_num = cols >> 7; + int64 remainder_128 = cols & 0x7F; + int64 remainder_16 = remainder_128 & 0x0F; + int64 remainder_block_num = remainder_128 >> 4; + int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; + + thread_pool->ParallelFor(total_unit, unit_cost, [&input, &output, rows, cols, block_num, remainder_block_num, + remainder_block_num_total, remainder_128, remainder_16, this](int64 begin_unit, int64 end_unit) { auto begin_row = begin_unit * BLOCK_SIZE; auto end_row = end_unit * BLOCK_SIZE; if (end_row > rows) { end_row = rows; } -#ifdef __AVX512F__ - forward_avx512<8>(input, output, begin_row, end_row, cols); + forward_avx512<8>(input, output, begin_row, end_row, cols, block_num, remainder_block_num, + remainder_block_num_total, remainder_128, remainder_16); + }); #else + thread_pool->ParallelFor(total_unit, unit_cost, + [&input, &output, rows, cols, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } forward<8>(input, output, begin_row, end_row, cols); -#endif }); +#endif } private: @@ -111,15 +126,11 @@ class FusedL2NormalizeOp : public OpKernel { #ifdef __AVX512F__ template - void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, - int64 cols) { - int64 avx3_block_num = cols >> 7; // cols / 128 - // handle remainder of 128 - int64 remainder = cols - (avx3_block_num << 7); + void forward_avx512(const T* input, T* output, int64 begin_row, int64 end_row, int64 cols, int64 block_num, int64 remainder_block_num, + int64 remainder_block_num_total, int64 remainder_128, int64 remainder_16) { for (int64 i = begin_row; i < end_row; ++i) { - int64 tmp_remainder = remainder; float row_sum = 0.0; - for (int64 j = 0; j < avx3_block_num; ++j) { + for (int64 j = 0; j < block_num; ++j) { __m512 inputs[SUM_BLOCK_SIZE]; auto load = [&](auto idx) { inputs[idx] = _mm512_loadu_ps(input + cols * i + @@ -127,50 +138,26 @@ class FusedL2NormalizeOp : public OpKernel { inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); }; functor::compile_time_for::op(load); - __m512 block_sum = reduce_sum_block8_ps(inputs); + __m512 block_sum = reduce_sum_block<8>(inputs); row_sum += _mm512_reduce_add_ps(block_sum); } - if (tmp_remainder > 0) { - if (tmp_remainder >= 64) { - __m256 inputs[8]; - auto load_256 = [&](auto idx) { - inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - - tmp_remainder + 8 * idx); - inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); - }; - functor::compile_time_for<8>::op(load_256); - __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(inputs); - row_sum += - _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - tmp_remainder -= 64; - } - if (tmp_remainder > 32) { - __m256 inputs[4]; - auto load_256 = [&](auto idx) { - inputs[idx] = _mm256_loadu_ps(input + cols * i + cols - - tmp_remainder + 8 * idx); - inputs[idx] = _mm256_mul_ps(inputs[idx], inputs[idx]); - }; - functor::compile_time_for<4>::op(load_256); - __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(inputs); - row_sum += - _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - tmp_remainder -= 32; - } - if (tmp_remainder >= 16) { - __m512 inputs = - _mm512_loadu_ps(input + cols * i + cols - tmp_remainder); - inputs = _mm512_mul_ps(inputs, inputs); - row_sum += _mm512_reduce_add_ps(inputs); - tmp_remainder -= 16; + if (remainder_block_num_total) { + __m512 inputs[remainder_block_num_total]; + + for (int64 idx = 0; idx < remainder_block_num; idx++){ + inputs[idx] = _mm512_loadu_ps(input + cols * i + cols - + remainder_128 + 16 * idx); + inputs[idx] = _mm512_mul_ps(inputs[idx], inputs[idx]); } - if (tmp_remainder > 0) { - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 inputs = _mm512_maskz_loadu_ps( - mask, input + cols * i + cols - tmp_remainder); - inputs = _mm512_mul_ps(inputs, inputs); - row_sum += _mm512_reduce_add_ps(inputs); + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + inputs[remainder_block_num] = _mm512_maskz_loadu_ps( + mask, input + cols * i + cols - remainder_16); + inputs[remainder_block_num] = _mm512_mul_ps(inputs[remainder_block_num], inputs[remainder_block_num]); } + + __m512 block_sum = reduce_sum_block_ps(inputs, remainder_block_num_total); + row_sum += _mm512_reduce_add_ps(block_sum); } row_sum += epsilon; @@ -181,12 +168,12 @@ class FusedL2NormalizeOp : public OpKernel { inputs = _mm512_mul_ps(inputs, row_sums); _mm512_storeu_ps(output + cols * i + j, inputs); } - if (tmp_remainder > 0) { - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); __m512 inputs = _mm512_maskz_loadu_ps( - mask, input + cols * i + cols - tmp_remainder); + mask, input + cols * i + cols - remainder_16); inputs = _mm512_mul_ps(inputs, row_sums); - _mm512_mask_storeu_ps(output + cols * i + cols - tmp_remainder, mask, + _mm512_mask_storeu_ps(output + cols * i + cols - remainder_16, mask, inputs); } } @@ -198,31 +185,36 @@ class FusedL2NormalizeOp : public OpKernel { // ... // v7: v7_0, v7_1, ..., v7_15 // sum: v_0, v_1, ..., v_15 - inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { - __m512 block_sum = _mm512_add_ps(v[0], v[1]); - block_sum = _mm512_add_ps(block_sum, v[2]); - block_sum = _mm512_add_ps(block_sum, v[3]); - block_sum = _mm512_add_ps(block_sum, v[4]); - block_sum = _mm512_add_ps(block_sum, v[5]); - block_sum = _mm512_add_ps(block_sum, v[6]); - block_sum = _mm512_add_ps(block_sum, v[7]); + template + inline __m512 reduce_sum_block(const __m512* v) { + __m512 block_sum = _mm512_setzero_ps(); + auto reduce_sum = [&](auto idx) { + block_sum = _mm512_add_ps(block_sum, v[idx]); + }; + functor::compile_time_for::op(reduce_sum); return block_sum; } - inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - block_sum = _mm256_add_ps(block_sum, v[4]); - block_sum = _mm256_add_ps(block_sum, v[5]); - block_sum = _mm256_add_ps(block_sum, v[6]); - block_sum = _mm256_add_ps(block_sum, v[7]); - return block_sum; - } - inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - return block_sum; + + inline __m512 reduce_sum_block_ps(const __m512* v, int64 BLOCK_NUM) { + switch (BLOCK_NUM) + { + case 1: + return reduce_sum_block<1>(v); + case 2: + return reduce_sum_block<2>(v); + case 3: + return reduce_sum_block<3>(v); + case 4: + return reduce_sum_block<4>(v); + case 5: + return reduce_sum_block<5>(v); + case 6: + return reduce_sum_block<6>(v); + case 7: + return reduce_sum_block<7>(v); + case 8: + return reduce_sum_block<8>(v); + } } #endif @@ -276,19 +268,34 @@ class FusedL2NormalizeGradOp : public OpKernel { auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); thread::ThreadPool *thread_pool = worker_threads.workers; - thread_pool->ParallelFor(total_unit, unit_cost, - [&y_grad, &x, &x_grad, rows, cols, this](int64 begin_unit, int64 end_unit) { +#ifdef __AVX512F__ + int64 block_num = cols >> 7; + int64 remainder_128 = cols & 0x7F; + int64 remainder_16 = remainder_128 & 0x0F; + int64 remainder_block_num = remainder_128 >> 4; + int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; + + thread_pool->ParallelFor(total_unit, unit_cost, [&y_grad, &x, &x_grad, rows, cols, block_num, remainder_block_num, + remainder_block_num_total, remainder_128, remainder_16, this](int64 begin_unit, int64 end_unit) { auto begin_row = begin_unit * BLOCK_SIZE; auto end_row = end_unit * BLOCK_SIZE; if (end_row > rows) { end_row = rows; } -#ifdef __AVX512F__ - backward_avx512<8>(y_grad, x, x_grad, begin_row, end_row, cols); + backward_avx512<8>(y_grad, x, x_grad, begin_row, end_row, cols, block_num, + remainder_block_num, remainder_block_num_total, remainder_128, remainder_16); + }); #else + thread_pool->ParallelFor(total_unit, unit_cost, + [&y_grad, &x, &x_grad, rows, cols, this](int64 begin_unit, int64 end_unit) { + auto begin_row = begin_unit * BLOCK_SIZE; + auto end_row = end_unit * BLOCK_SIZE; + if (end_row > rows) { + end_row = rows; + } backward<8>(y_grad, x, x_grad, begin_row, end_row, cols); -#endif }); +#endif } private: @@ -344,18 +351,16 @@ class FusedL2NormalizeGradOp : public OpKernel { } } + #ifdef __AVX512F__ template - void backward_avx512(const float* y_grad, const float* x, float* x_grad, - int64 begin_row, int64 end_row, int64 cols) { - int64 avx3_block_num = cols >> 7; // cols / 128 - // handle remainder of 128 - int64 remainder = cols - (avx3_block_num << 7); + void backward_avx512(const float* y_grad, const float* x, float* x_grad, int64 begin_row, int64 end_row, int64 cols, + int64 block_num, int64 remainder_block_num, int64 remainder_block_num_total, int64 remainder_128, + int64 remainder_16) { for (int64 i = begin_row; i < end_row; ++i) { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; - int64 tmp_remainder = remainder; - for (int64 j = 0; j < avx3_block_num; ++j) { + for (int64 j = 0; j < block_num; ++j) { __m512 xs[SUM_BLOCK_SIZE]; auto x_load = [&](auto idx) { xs[idx] = _mm512_loadu_ps(x + cols * i + 16 * SUM_BLOCK_SIZE * j + @@ -363,7 +368,7 @@ class FusedL2NormalizeGradOp : public OpKernel { xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); }; functor::compile_time_for::op(x_load); - __m512 x_block_sum = reduce_sum_block8_ps(xs); + __m512 x_block_sum = reduce_sum_block<8>(xs); x_row_sum += _mm512_reduce_add_ps(x_block_sum); __m512 y_grads[SUM_BLOCK_SIZE]; @@ -375,81 +380,37 @@ class FusedL2NormalizeGradOp : public OpKernel { y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); }; functor::compile_time_for::op(y_grad_load); - __m512 y_grad_block_sum = reduce_sum_block8_ps(y_grads); + __m512 y_grad_block_sum = reduce_sum_block<8>(y_grads); y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); } - if (tmp_remainder > 0) { - if (tmp_remainder >= 64) { - __m256 xs[8]; - auto x_load_256 = [&](auto idx) { - xs[idx] = - _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); - }; - functor::compile_time_for<8>::op(x_load_256); - __m256 block_sum_remainder = reduce_sum_block8_mm256_ps(xs); - x_row_sum += - _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - - __m256 y_grads[8]; - auto y_grad_load_256 = [&](auto idx) { - y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - - tmp_remainder + 8 * idx); - xs[idx] = - _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); - }; - functor::compile_time_for<8>::op(y_grad_load_256); - __m256 y_grad_block_sum_remainder = - reduce_sum_block8_mm256_ps(y_grads); - y_grad_row_sum += _mm512_reduce_add_ps( - _mm512_castps256_ps512(y_grad_block_sum_remainder)); - tmp_remainder -= 64; + if (remainder_block_num_total) { + __m512 xs[remainder_block_num_total]; + for (int64 idx = 0; idx < remainder_block_num; idx++){ + xs[idx] = _mm512_loadu_ps(x + cols * i + cols - remainder_128 + 16 * idx); + xs[idx] = _mm512_mul_ps(xs[idx], xs[idx]); } - if (tmp_remainder > 32) { - __m256 xs[4]; - auto x_load_256 = [&](auto idx) { - xs[idx] = - _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - xs[idx] = _mm256_mul_ps(xs[idx], xs[idx]); - }; - functor::compile_time_for<4>::op(x_load_256); - __m256 block_sum_remainder = reduce_sum_block4_mm256_ps(xs); - x_row_sum += - _mm512_reduce_add_ps(_mm512_castps256_ps512(block_sum_remainder)); - - __m256 y_grads[4]; - auto y_grad_load_256 = [&](auto idx) { - y_grads[idx] = _mm256_loadu_ps(y_grad + cols * i + cols - - tmp_remainder + 8 * idx); - xs[idx] = - _mm256_loadu_ps(x + cols * i + cols - tmp_remainder + 8 * idx); - y_grads[idx] = _mm256_mul_ps(y_grads[idx], xs[idx]); - }; - functor::compile_time_for<4>::op(y_grad_load_256); - __m256 y_grad_block_sum_remainder = - reduce_sum_block4_mm256_ps(y_grads); - y_grad_row_sum += _mm512_reduce_add_ps( - _mm512_castps256_ps512(y_grad_block_sum_remainder)); - tmp_remainder -= 32; + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + xs[remainder_block_num] = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); + xs[remainder_block_num] = _mm512_mul_ps(xs[remainder_block_num], xs[remainder_block_num]); } - if (tmp_remainder >= 16) { - __m512 xs = _mm512_loadu_ps(x + cols * i + cols - tmp_remainder); - __m512 y_grads = - _mm512_loadu_ps(y_grad + cols * i + cols - tmp_remainder); - x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); - y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); - tmp_remainder -= 16; + __m512 x_block_sum = reduce_sum_block_ps(xs, remainder_block_num_total); + x_row_sum += _mm512_reduce_add_ps(x_block_sum); + + __m512 y_grads[remainder_block_num_total]; + for (int64 idx = 0; idx < remainder_block_num; idx++){ + y_grads[idx] = _mm512_loadu_ps(y_grad + cols * i + cols - remainder_128 + 16 * idx); + xs[idx] = _mm512_loadu_ps(x + cols * i + cols - remainder_128 + 16 * idx); + y_grads[idx] = _mm512_mul_ps(y_grads[idx], xs[idx]); } - if (tmp_remainder > 0) { - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 xs = - _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); - __m512 y_grads = _mm512_maskz_loadu_ps( - mask, y_grad + cols * i + cols - tmp_remainder); - x_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(xs, xs)); - y_grad_row_sum += _mm512_reduce_add_ps(_mm512_mul_ps(y_grads, xs)); + if (remainder_16) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + y_grads[remainder_block_num] = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - remainder_16); + xs[remainder_block_num] = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); + y_grads[remainder_block_num] = _mm512_mul_ps(y_grads[remainder_block_num], xs[remainder_block_num]); } + __m512 y_grad_block_sum = reduce_sum_block_ps(xs, remainder_block_num_total); + y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); } x_row_sum += epsilon; @@ -465,46 +426,48 @@ class FusedL2NormalizeGradOp : public OpKernel { y_grads = _mm512_sub_ps(y_grads, xs); _mm512_storeu_ps(x_grad + cols * i + j, y_grads); } - if (tmp_remainder > 0) { - __mmask16 mask = 0xFFFF >> (16 - tmp_remainder); - __m512 y_grads = _mm512_maskz_loadu_ps( - mask, y_grad + cols * i + cols - tmp_remainder); - __m512 xs = - _mm512_maskz_loadu_ps(mask, x + cols * i + cols - tmp_remainder); + if (remainder_16 > 0) { + __mmask16 mask = 0xFFFF >> (16 - remainder_16); + __m512 y_grads = _mm512_maskz_loadu_ps(mask, y_grad + cols * i + cols - remainder_16); + __m512 xs = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); y_grads = _mm512_mul_ps(y_grads, x_row_sums); xs = _mm512_mul_ps(xs, y_grad_row_sums); y_grads = _mm512_sub_ps(y_grads, xs); - _mm512_mask_storeu_ps(x_grad + cols * i + cols - tmp_remainder, mask, - y_grads); + _mm512_mask_storeu_ps(x_grad + cols * i + cols - remainder_16, mask, y_grads); } } } - inline __m512 reduce_sum_block8_ps(const __m512 (&v)[8]) { - __m512 block_sum = _mm512_add_ps(v[0], v[1]); - block_sum = _mm512_add_ps(block_sum, v[2]); - block_sum = _mm512_add_ps(block_sum, v[3]); - block_sum = _mm512_add_ps(block_sum, v[4]); - block_sum = _mm512_add_ps(block_sum, v[5]); - block_sum = _mm512_add_ps(block_sum, v[6]); - block_sum = _mm512_add_ps(block_sum, v[7]); + template + inline __m512 reduce_sum_block(const __m512* v) { + __m512 block_sum = _mm512_setzero_ps(); + auto reduce_sum = [&](auto idx) { + block_sum = _mm512_add_ps(block_sum, v[idx]); + }; + functor::compile_time_for::op(reduce_sum); return block_sum; } - inline __m256 reduce_sum_block8_mm256_ps(const __m256 (&v)[8]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - block_sum = _mm256_add_ps(block_sum, v[4]); - block_sum = _mm256_add_ps(block_sum, v[5]); - block_sum = _mm256_add_ps(block_sum, v[6]); - block_sum = _mm256_add_ps(block_sum, v[7]); - return block_sum; - } - inline __m256 reduce_sum_block4_mm256_ps(const __m256 (&v)[4]) { - __m256 block_sum = _mm256_add_ps(v[0], v[1]); - block_sum = _mm256_add_ps(block_sum, v[2]); - block_sum = _mm256_add_ps(block_sum, v[3]); - return block_sum; + + inline __m512 reduce_sum_block_ps(const __m512* v, int64 BLOCK_NUM) { + switch (BLOCK_NUM) + { + case 1: + return reduce_sum_block<1>(v); + case 2: + return reduce_sum_block<2>(v); + case 3: + return reduce_sum_block<3>(v); + case 4: + return reduce_sum_block<4>(v); + case 5: + return reduce_sum_block<5>(v); + case 6: + return reduce_sum_block<6>(v); + case 7: + return reduce_sum_block<7>(v); + case 8: + return reduce_sum_block<8>(v); + } } #endif diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc index 2c4f5e988ec..98f25942438 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op_test.cc @@ -95,6 +95,7 @@ static Graph* FusedL2Normalize(int rows, int cols) { BM_FusedL2Norm(ROWS, COLS, 8); \ BM_FusedL2Norm_NTH(1024, 63); +BM_FusedL2Norm_NTH(1024, 127); BM_FusedL2Norm_NTH(1024, 255); BM_FusedL2Norm_NTH(1024, 511); BM_FusedL2Norm_NTH(1024, 1023); From 72ae2f814297b4ba7769466e5d0f1ba39633d7e3 Mon Sep 17 00:00:00 2001 From: marvinYu Date: Fri, 24 Jun 2022 14:14:08 +0800 Subject: [PATCH 13/17] fix api_test issue. --- .../base_api/api_def_FusedL2Normalize.pbtxt | 3 +++ .../base_api/api_def_FusedL2NormalizeGrad.pbtxt | 3 +++ tensorflow/core/ops/fused_l2_normalize_ops.cc | 16 ++++++++-------- 3 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt new file mode 100644 index 00000000000..403e88193c1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedL2Normalize.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "FusedL2Normalize" +} diff --git a/tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt new file mode 100644 index 00000000000..86b652754b3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedL2NormalizeGrad.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "FusedL2NormalizeGrad" +} diff --git a/tensorflow/core/ops/fused_l2_normalize_ops.cc b/tensorflow/core/ops/fused_l2_normalize_ops.cc index 82ca2406056..d2ca1d1414e 100644 --- a/tensorflow/core/ops/fused_l2_normalize_ops.cc +++ b/tensorflow/core/ops/fused_l2_normalize_ops.cc @@ -16,10 +16,10 @@ REGISTER_OP("FusedL2Normalize") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { c->set_output(0, c->input(0)); return Status::OK(); - }) - .Doc(R"doc( -FusedL2Normalize ops. - )doc"); + }); +// .Doc(R"doc( +// FusedL2Normalize ops. +// )doc"); REGISTER_OP("FusedL2NormalizeGrad") .Input("y_grad: T") @@ -31,9 +31,9 @@ REGISTER_OP("FusedL2NormalizeGrad") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) { c->set_output(0, c->input(0)); return Status::OK(); - }) - .Doc(R"doc( -FusedL2NormalizeGrad ops. - )doc"); + }); +// .Doc(R"doc( +// FusedL2NormalizeGrad ops. +// )doc"); } // namespace tensorflow From abd8b8e434258e9fa7c9c19a59b07bbd1ea7ea37 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Fri, 1 Jul 2022 13:43:04 +0800 Subject: [PATCH 14/17] [Op] Add annotations in fused l2n. --- .../fused_l2_normalize_op.cc | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index 5b57f58d8ee..0921318def4 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -58,11 +58,12 @@ class FusedL2NormalizeOp : public OpKernel { thread::ThreadPool *thread_pool = worker_threads.workers; #ifdef __AVX512F__ - int64 block_num = cols >> 7; - int64 remainder_128 = cols & 0x7F; - int64 remainder_16 = remainder_128 & 0x0F; - int64 remainder_block_num = remainder_128 >> 4; - int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; + // Do forward in 128(8*16) block, avx512 handles 16 floats one time + int64 block_num = cols >> 7; // 128-size block nums + int64 remainder_128 = cols & 0x7F; // remainder of 128 + int64 remainder_16 = remainder_128 & 0x0F; // remainder of 16 + int64 remainder_block_num = remainder_128 >> 4; // 16-size block num in 128-remainder + int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; // total 16-size block num in remainder thread_pool->ParallelFor(total_unit, unit_cost, [&input, &output, rows, cols, block_num, remainder_block_num, remainder_block_num_total, remainder_128, remainder_16, this](int64 begin_unit, int64 end_unit) { @@ -99,6 +100,7 @@ class FusedL2NormalizeOp : public OpKernel { int64 remainder = cols % SUM_BLOCK_SIZE; for (int64 i = begin_row; i < end_row; ++i) { T row_sum = 0; + // Sum of squares of the inputs for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { T data_0 = input[i * cols + j]; T data_1 = input[i * cols + j + 1]; @@ -116,8 +118,12 @@ class FusedL2NormalizeOp : public OpKernel { T data_0 = input[i * cols + j]; row_sum += data_0 * data_0; } + + // Square row_sum += epsilon; row_sum = 1.0 / std::sqrt(row_sum); + + // Mul for (int64 j = 0; j < cols; ++j) { output[i * cols + j] = input[i * cols + j] * row_sum; } @@ -130,6 +136,7 @@ class FusedL2NormalizeOp : public OpKernel { int64 remainder_block_num_total, int64 remainder_128, int64 remainder_16) { for (int64 i = begin_row; i < end_row; ++i) { float row_sum = 0.0; + // Sum of squares of the inputs for (int64 j = 0; j < block_num; ++j) { __m512 inputs[SUM_BLOCK_SIZE]; auto load = [&](auto idx) { @@ -160,8 +167,11 @@ class FusedL2NormalizeOp : public OpKernel { row_sum += _mm512_reduce_add_ps(block_sum); } + // Square root row_sum += epsilon; row_sum = 1.0 / std::sqrt(row_sum); + + // Mul & store __m512 row_sums = _mm512_set1_ps(row_sum); for (int64 j = 0; j < cols - 15; j += 16) { __m512 inputs = _mm512_loadu_ps(input + cols * i + j); @@ -268,12 +278,13 @@ class FusedL2NormalizeGradOp : public OpKernel { auto &worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); thread::ThreadPool *thread_pool = worker_threads.workers; -#ifdef __AVX512F__ - int64 block_num = cols >> 7; - int64 remainder_128 = cols & 0x7F; - int64 remainder_16 = remainder_128 & 0x0F; - int64 remainder_block_num = remainder_128 >> 4; - int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; +#ifdef __AVX512F__ + // Do forward in 128(8*16) block, avx512 handles 16 floats one time + int64 block_num = cols >> 7; // 128-size block nums + int64 remainder_128 = cols & 0x7F; // remainder of 128 + int64 remainder_16 = remainder_128 & 0x0F; // remainder of 16 + int64 remainder_block_num = remainder_128 >> 4; // 16-size block num in 128-remainder + int64 remainder_block_num_total = remainder_block_num+ !!remainder_16; // total 16-size block num in remainder thread_pool->ParallelFor(total_unit, unit_cost, [&y_grad, &x, &x_grad, rows, cols, block_num, remainder_block_num, remainder_block_num_total, remainder_128, remainder_16, this](int64 begin_unit, int64 end_unit) { @@ -310,6 +321,7 @@ class FusedL2NormalizeGradOp : public OpKernel { for (int64 i = begin_row; i < end_row; ++i) { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; + // sum of squares of x and sum of y_grad * x for (int64 j = cols - 1; j > remainder; j -= SUM_BLOCK_SIZE) { T x_0 = x[i * cols + j]; T x_1 = x[i * cols + j - 1]; @@ -344,6 +356,7 @@ class FusedL2NormalizeGradOp : public OpKernel { x_row_sum += epsilon; x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + // Calculate x_grad = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) for (int64 j = 0; j < cols; ++j) { x_grad[i * cols + j] = y_grad[i * cols + j] * x_row_sum - x[i * cols + j] * y_grad_row_sum; @@ -360,6 +373,7 @@ class FusedL2NormalizeGradOp : public OpKernel { for (int64 i = begin_row; i < end_row; ++i) { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; + // sum of squares of x and sum of y_grad * x for (int64 j = 0; j < block_num; ++j) { __m512 xs[SUM_BLOCK_SIZE]; auto x_load = [&](auto idx) { @@ -414,8 +428,9 @@ class FusedL2NormalizeGradOp : public OpKernel { } x_row_sum += epsilon; - x_row_sum = 1.0 / std::sqrt(x_row_sum); + x_row_sum = 1.0 / std::sqrt(x_row_sum); // var y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); + // Calculate x_grad = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) __m512 x_row_sums = _mm512_set1_ps(x_row_sum); __m512 y_grad_row_sums = _mm512_set1_ps(y_grad_row_sum); for (int64 j = 0; j < cols - 15; j += 16) { From a6b741230ae246e5c30adbee76a7a9dcd1a58010 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 11 Jul 2022 13:09:39 +0800 Subject: [PATCH 15/17] [Op] Change l2n do_fusion parameter position. --- tensorflow/python/ops/nn_impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 78ca733cfdb..b6f2bc561d5 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -591,7 +591,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None): @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) @deprecated_args(None, "dim is deprecated, use axis instead", "dim") -def l2_normalize(x, axis=None, epsilon=1e-12, do_fusion=True, name=None, dim=None): +def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None, do_fusion=True): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes @@ -615,11 +615,11 @@ def l2_normalize(x, axis=None, epsilon=1e-12, do_fusion=True, name=None, dim=Non A `Tensor` with the same shape as `x`. """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) - return l2_normalize_v2(x, axis, epsilon, do_fusion, name) + return l2_normalize_v2(x, axis, epsilon, name, do_fusion) @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) -def l2_normalize_v2(x, axis=None, epsilon=1e-12, do_fusion=True, name=None): +def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None, do_fusion=True): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes From dd05512c9a65e55e21079c168da6e2d8eadde272 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Mon, 11 Jul 2022 16:51:06 +0800 Subject: [PATCH 16/17] [Op] disable fusion in l2n v1. --- tensorflow/python/ops/nn_impl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index b6f2bc561d5..d3a15134832 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -591,7 +591,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None): @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) @deprecated_args(None, "dim is deprecated, use axis instead", "dim") -def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None, do_fusion=True): +def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): """Normalizes along dimension `axis` using an L2 norm. For a 1-D tensor with `axis = 0`, computes @@ -608,14 +608,13 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None, do_fusion=Tru epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. name: A name for this operation (optional). - do_fusion: Whether fuse op when doing l2 norm on last axis. dim: Deprecated alias for axis. Returns: A `Tensor` with the same shape as `x`. """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) - return l2_normalize_v2(x, axis, epsilon, name, do_fusion) + return l2_normalize_v2(x, axis, epsilon, name, False) @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) @@ -635,7 +634,7 @@ def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None, do_fusion=True): integers. epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. - do_fusion: Whether fuse op when doing l2 norm on last axis. + do_fusion: Whether fuse op when doing l2 norm on last axis, enabled by default. name: A name for this operation (optional). Returns: From dc882bf8817e821cfc7a7963575a2a36014793ab Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Tue, 12 Jul 2022 14:35:56 +0800 Subject: [PATCH 17/17] [Op] Fix nn_test failed. --- .../fused_l2_normalize_grad_op_test.cc | 16 ++++++--- .../fused_l2_normalize_op.cc | 36 +++++++++---------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc index aaa774fa13c..0af5a3c104a 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_grad_op_test.cc @@ -39,14 +39,18 @@ TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) { // y_grad float y_grad_array[1008]; - for (int i = 0; i < sizeof(y_grad_array) / sizeof(float); i++) { + for (int i = 0; i < rows * cols; i++) { y_grad_array[i] = 1.0; } + y_grad_array[251] = 2.0; + y_grad_array[503] = 2.0; + y_grad_array[755] = 2.0; + y_grad_array[1007] = 2.0; AddInputFromArray(TensorShape({rows, cols}), y_grad_array); // x float x_array[1008]; - for (int i = 0; i < sizeof(x_array) / sizeof(float); i++) { + for (int i = 0; i < rows * cols; i++) { x_array[i] = 1.0; } AddInputFromArray(TensorShape({rows, cols}), x_array); @@ -58,9 +62,13 @@ TEST_F(FusedL2NormalizeGradOpTest, 2Dims_Float) { Tensor expected_output(allocator(), DT_FLOAT, TensorShape({rows, cols})); float output_array[1008]; - for (int i = 0; i < sizeof(output_array) / sizeof(float); i++) { - output_array[i] = 0; + for (int i = 0; i < rows * cols; i++) { + output_array[i] = - 1.0 / (252 * std::sqrt(252)); } + output_array[251] = 251.0 / (252 * std::sqrt(252)); + output_array[503] = 251.0 / (252 * std::sqrt(252)); + output_array[755] = 251.0 / (252 * std::sqrt(252)); + output_array[1007] = 251.0 / (252 * std::sqrt(252)); test::FillValues(&expected_output, output_array); test::ExpectTensorNear(expected_output, *GetOutput(0), 1e-6); } diff --git a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc index 0921318def4..f0013d7bfd6 100644 --- a/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc +++ b/tensorflow/core/kernels/fused_l2_normalize/fused_l2_normalize_op.cc @@ -322,31 +322,31 @@ class FusedL2NormalizeGradOp : public OpKernel { T x_row_sum = 0.0; T y_grad_row_sum = 0.0; // sum of squares of x and sum of y_grad * x - for (int64 j = cols - 1; j > remainder; j -= SUM_BLOCK_SIZE) { + for (int64 j = 0; j < cols - remainder; j += SUM_BLOCK_SIZE) { T x_0 = x[i * cols + j]; - T x_1 = x[i * cols + j - 1]; - T x_2 = x[i * cols + j - 2]; - T x_3 = x[i * cols + j - 3]; - T x_4 = x[i * cols + j - 4]; - T x_5 = x[i * cols + j - 5]; - T x_6 = x[i * cols + j - 6]; - T x_7 = x[i * cols + j - 7]; + T x_1 = x[i * cols + j + 1]; + T x_2 = x[i * cols + j + 2]; + T x_3 = x[i * cols + j + 3]; + T x_4 = x[i * cols + j + 4]; + T x_5 = x[i * cols + j + 5]; + T x_6 = x[i * cols + j + 6]; + T x_7 = x[i * cols + j + 7]; x_row_sum += x_0 * x_0 + x_1 * x_1 + x_2 * x_2 + x_3 * x_3 + x_4 * x_4 + x_5 * x_5 + x_6 * x_6 + x_7 * x_7; T y_grad_0 = y_grad[i * cols + j]; - T y_grad_1 = y_grad[i * cols + j - 1]; - T y_grad_2 = y_grad[i * cols + j - 2]; - T y_grad_3 = y_grad[i * cols + j - 3]; - T y_grad_4 = y_grad[i * cols + j - 4]; - T y_grad_5 = y_grad[i * cols + j - 5]; - T y_grad_6 = y_grad[i * cols + j - 6]; - T y_grad_7 = y_grad[i * cols + j - 7]; + T y_grad_1 = y_grad[i * cols + j + 1]; + T y_grad_2 = y_grad[i * cols + j + 2]; + T y_grad_3 = y_grad[i * cols + j + 3]; + T y_grad_4 = y_grad[i * cols + j + 4]; + T y_grad_5 = y_grad[i * cols + j + 5]; + T y_grad_6 = y_grad[i * cols + j + 6]; + T y_grad_7 = y_grad[i * cols + j + 7]; y_grad_row_sum += x_0 * y_grad_0 + x_1 * y_grad_1 + x_2 * y_grad_2 + x_3 * y_grad_3 + x_4 * y_grad_4 + x_5 * y_grad_5 + x_6 * y_grad_6 + x_7 * y_grad_7; } - for (int64 j = remainder; j > 0; j--) { + for (int64 j = cols - remainder; j < cols; j++) { T x_0 = x[i * cols + j]; x_row_sum += x_0 * x_0; @@ -423,12 +423,12 @@ class FusedL2NormalizeGradOp : public OpKernel { xs[remainder_block_num] = _mm512_maskz_loadu_ps(mask, x + cols * i + cols - remainder_16); y_grads[remainder_block_num] = _mm512_mul_ps(y_grads[remainder_block_num], xs[remainder_block_num]); } - __m512 y_grad_block_sum = reduce_sum_block_ps(xs, remainder_block_num_total); + __m512 y_grad_block_sum = reduce_sum_block_ps(y_grads, remainder_block_num_total); y_grad_row_sum += _mm512_reduce_add_ps(y_grad_block_sum); } x_row_sum += epsilon; - x_row_sum = 1.0 / std::sqrt(x_row_sum); // var + x_row_sum = 1.0 / std::sqrt(x_row_sum); // rvar y_grad_row_sum = (y_grad_row_sum * x_row_sum) * (x_row_sum * x_row_sum); // Calculate x_grad = y_grad * rvar - x * ((sum * rvar) * (rvar * rvar)) __m512 x_row_sums = _mm512_set1_ps(x_row_sum);