forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 91
/
fused_batch_norm_op.h
158 lines (132 loc) · 6.84 KB
/
fused_batch_norm_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
#define TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace functor {
// FusedBatchNormEx op supports side inputs and activations:
// (1) batch_norm + activation
// (2) batch norm + side input + activation
enum class FusedBatchNormActivationMode { kIdentity, kRelu };
string ToString(FusedBatchNormActivationMode activation_mode);
Status ParseActivationMode(OpKernelConstruction* context,
FusedBatchNormActivationMode* activation_mode);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// There is a behavior difference between cuDNN v4 and v5 with regard to the
// scaling factor for function cudnnBatchNormalizationForwardInference.
// This function corrects the scaling factor if cuDNN v4 is used, so that
// this behavior inconsistency is hidden from TensorFlow users.
// Details: in cuDNN v4, y = bnScale * (x - mean) * variance + bnBias;
// in v5, y = bnScale * (x - mean) / sqrt(variance + epsilon) + bnBias
// The template is instantiated with T as float in batch_norm_ops.cu.cc; for
// other types, the instantiation needs to be added accordingly.
template <class T>
struct VarianceToInvVariance {
void operator()(const Eigen::GpuDevice& d, const T* variance, double epsilon,
int channels, T* inv_variance);
};
// This function converts the inverted variance of the cuDNN forward training
// output to variance for TensorFlow to calculate the running variance.
// The template is instantiated with T as float in batch_norm_ops.cu.cc; for
// other types, the instantiation needs to be added accordingly.
template <class T>
struct InvVarianceToVariance {
void operator()(const Eigen::GpuDevice& d, double epsilon, int sample_size,
int channels, T* variance);
};
// This function sets a GPU tensor to NaNs.
template <class T>
struct SetNanFunctor {
void operator()(const Eigen::GpuDevice& d, typename TTypes<T>::Flat out);
};
// This is a functor to launch custom CUDA kernel for FusedBatchNorm with side
// input and activation when 'is_training=False'. In training we rely on cuDNN.
template <typename Device, typename T, typename U>
struct FusedBatchNormInferenceFunctor {
void operator()(OpKernelContext* context, TensorFormat tensor_format,
typename TTypes<T, 4>::ConstTensor in,
typename TTypes<U>::ConstVec scale,
typename TTypes<U>::ConstVec offset,
typename TTypes<U>::ConstVec estimated_mean,
typename TTypes<U>::ConstVec estimated_variance,
typename TTypes<T, 4>::ConstTensor side_input, U epsilon,
FusedBatchNormActivationMode activation_mode,
typename TTypes<T, 4>::Tensor out);
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Functor used by FusedBatchNormGradOp to do the computations when
// is_training=False. Both CPU and GPU will use this functor.
template <typename Device, typename T, typename U>
struct FusedBatchNormFreezeGrad {
void operator()(const Device& d, const Tensor& y_backprop_input,
const Tensor& x_input, const Tensor& scale_input,
const Tensor& pop_mean_input,
const Tensor& pop_variance_input, U epsilon,
Tensor* x_backprop_output, Tensor* scale_backprop_output,
Tensor* offset_backprop_output,
typename TTypes<U>::Vec scratch1,
typename TTypes<U>::Vec scratch2) {
typename TTypes<T, 4>::ConstTensor y_backprop(
y_backprop_input.tensor<T, 4>());
typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
const int depth = pop_mean.dimension(0);
const int rest_size = input.size() / depth;
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
Eigen::array<int, 1> reduction_axis{0};
Eigen::array<int, 2> rest_by_one({rest_size, 1});
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
rest_by_one.set(0, rest_size);
#endif
// offset_backprop = sum(y_backprop)
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
auto y_backprop_rest_by_depth =
y_backprop.reshape(rest_by_depth).template cast<U>();
auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
// scratch1 = rsqrt(pop_var + epsilon)
scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
// scratch2 = sum(y_backprop * (x - mean))
scratch2.device(d) =
(y_backprop_rest_by_depth *
(input_rest_by_depth -
pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
.sum(reduction_axis);
x_backprop.reshape(rest_by_depth).device(d) =
(y_backprop_rest_by_depth * ((scratch1 * scale)
.eval()
.reshape(one_by_depth)
.broadcast(rest_by_one)))
.template cast<T>();
scale_backprop.device(d) = scratch2 * scratch1;
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_