forked from wiibrew/DeepLearningCourseCodes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_norm.py
62 lines (50 loc) · 1.85 KB
/
batch_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Batch Normalization for TensorFlow.
Parag K. Mital, Jan 2016.
"""
import tensorflow as tf
def batch_norm(x, phase_train, scope='bn', affine=True):
"""
Batch normalization on convolutional maps.
from: https://stackoverflow.com/questions/33949786/how-could-i-
use-batch-normalization-in-tensorflow
Only modified to infer shape from input tensor x.
Parameters
----------
x
Tensor, 4D BHWD input maps
phase_train
boolean tf.Variable, true indicates training phase
scope
string, variable scope
affine
whether to affine-transform outputs
Return
------
normed
batch-normalized maps
"""
with tf.variable_scope(scope):
shape = x.get_shape().as_list()
beta = tf.Variable(tf.constant(0.0, shape=[shape[-1]]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[shape[-1]]),
name='gamma', trainable=affine)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.9)
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
def mean_var_with_update():
"""Summary
Returns
-------
name : TYPE
Description
"""
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema_mean, ema_var))
normed = tf.nn.batch_norm_with_global_normalization(
x, mean, var, beta, gamma, 1e-3, affine)
return normed