# Understanding the Disharmony between Dropout and Batch Normalization

Why two of the most commonly used techniques perform worse when applied together?

* Dropout: 
    * Used to avoid overfitting 
    * Simple to implement
    * Widely adopted
* Batch Normalization:
    * Enables faster training / higher learning rates
    * Reduce the dependency over careful initialization
    * Widely adopted


As it turns out, applying dropout before batch normalization leads to a "variance shift" phenomenon, which is the key to why these two techniques should be combined with care. 
This variance shift is due to the different behavior that dropout exhibits between training and testing phases. The main intuition suggests that Batch norm learns some statistics during training that are not kept during testing.

Main reference: https://arxiv.org/abs/1801.05134

**The notebook will cover the following points:**
 1- Dropout
 2- BatchNorm
 3- Combined Dropout + BatchNorm and calculation of the variance shift 



In [0]:
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


In [0]:
"""
p is the dropout keep probability, which is equivalent to 1-dropout_rate
For dropout, there are two ways we can handle scaling of the activations. 
During test the activation can be scaled down by a factor of p

Or (Inverse dropout) we can scale the activations up during training by a factor
of p and keep the testing activation the same
"""

n = 5 # the dimension of the features (k in the paper)
dropout_rate = 0.2
inputs = keras.Input(shape=(n,))
# x = keras.layers.Lambda(lambda x: x)(inputs)
outputs = keras.layers.Dropout(rate=dropout_rate)(inputs)
dropout_model = keras.Model(inputs, outputs)
dropout_model.compile(loss='mse', optimizer='adam')
print(dropout_model.summary())
print(1/(1-dropout_rate))

Model: "model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         (None, 5)                 0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 5)                 0         
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
None
1.25


In [0]:
x_test = np.ones((1,n))
pred = dropout_model.predict(x_test)
print(pred)

[[1. 1. 1. 1. 1.]]




---
---
# **Batch Normalization case**




In [0]:
inputs = keras.Input(shape=(n,))
outputs = keras.layers.BatchNormalization()(inputs)
bn_model = keras.Model(inputs, outputs)
bn_model.compile(loss='mse', optimizer='adam')
bn_model.summary()


Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 5)                 0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 5)                 20        
Total params: 20
Trainable params: 10
Non-trainable params: 10
_________________________________________________________________


In [0]:
bn_layer = bn_model.layers[1]
print(bn_layer.get_config()['name'])
print("[gamma(scale), beta(shift), running mean , running variance]\n")
gamma, beta, mean, var = bn_layer.get_weights()
print("gamma (scale) ", gamma)
print("beta (shift) ", beta)
print("moving mean ", mean)
print("moving_variance ", var)
# x_bn = x*gamma + beta
# with the same input samp  le we train beta parameter to match the running mean

batch_normalization_1
[gamma(scale), beta(shift), running mean , running variance]

gamma (scale)  [1. 1. 1. 1. 1.]
beta (shift)  [0. 0. 0. 0. 0.]
moving mean  [0. 0. 0. 0. 0.]
moving_variance  [1. 1. 1. 1. 1.]


In [0]:
mu = 5
sigma = 2
x_train = np.random.normal(mu,sigma,size=(10000,n))
print("mean: ", np.mean(x_train,axis=0))
print("std: ", np.std(x_train, axis=0))

mean:  [5.02197837 5.0263782  5.04208676 5.00473362 5.00503994]
std:  [2.0145153  1.99571036 1.99048178 1.98553206 1.98840974]


In [0]:
history = bn_model.fit(x_train,x_train, epochs=50, verbose=0)





In [0]:
gamma, beta, mean, var = bn_layer.get_weights()
print("gamma (scale) ", gamma)
print("beta (shift) ", beta)
print("moving mean ", mean)
print("moving_variance ", var)

gamma (scale)  [1.9716144 1.9501512 1.945204  1.9310975 1.9399271]
beta (shift)  [5.0193267 5.023732  5.041845  5.010377  5.0020723]
moving mean  [4.999528  5.0187097 5.0573626 5.0396338 4.978292 ]
moving_variance  [4.174094  3.9469488 4.035455  3.889659  3.873004 ]


In [0]:
pred = bn_model.predict(np.ones((1,n)))
print(pred)

[[1.1601241 1.0794395 1.1135061 1.0554851 1.0810254]]




---



---
# **Dropout + Batch Normalization**


In [0]:
inputs = keras.Input(shape=(n,))
x = keras.layers.Dropout(dropout_rate)(inputs) # <== To have the discrepancy between train and test
outputs = keras.layers.BatchNormalization()(x)
model = keras.Model(inputs, outputs=outputs)
model.compile(loss='mse', optimizer='adam')
model.summary()

Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         (None, 5)                 0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 5)                 0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 5)                 20        
Total params: 20
Trainable params: 10
Non-trainable params: 10
_________________________________________________________________


In [0]:
bn_layer = model.layers[2]
gamma, beta, mean, var = bn_layer.get_weights()
print("gamma (scale) ", gamma)
print("beta (shift) ", beta)
print("moving mean ", mean)
print("moving_variance ", var)

gamma (scale)  [1. 1. 1. 1. 1.]
beta (shift)  [0. 0. 0. 0. 0.]
moving mean  [0. 0. 0. 0. 0.]
moving_variance  [1. 1. 1. 1. 1.]


In [0]:
history = model.fit(x_train, x_train, epochs=100, verbose=0)

In [0]:
pred = model.predict(np.ones((1,n)))
print(pred)

[[3.541892  3.6039495 3.6400888 3.6278784 3.6181204]]


In [0]:
bn_layer = model.layers[2]
gamma, beta, mean, var = bn_layer.get_weights()
print("gamma (scale) ", gamma)
print("beta (shift) ", beta)
print("moving mean ", mean)
print("moving_variance ", var)

gamma (scale)  [1.208375  1.1724666 1.1604689 1.1533899 1.1569933]
beta (shift)  [5.029295  5.030495  5.0373297 5.0059257 5.0011787]
moving mean  [5.126892  5.070994  5.0210896 5.0042624 5.0066295]
moving_variance  [11.2396755 11.194161  11.15252   11.231312  11.233118 ]


The equation for the variance is: 

$c = E[x_k]$

$v = Var[x_k]$

$Var^{Train}(x_k) = \frac{1}{p}(c^2 + v) - c^2$



In [0]:
c = np.mean(x_train, axis=0)
v = np.var(x_train, axis=0)
moving_var_train = 1/(1-dropout_rate) * (c**2 + v) - c**2 
print(moving_var_train)

[11.37790653 11.29469424 11.30818186 11.18976161 11.2048228 ]


In [0]:
"""
An interesting experiment is to try to undo this variance shift since BN already
learning the correct mean and we can access the information about the dropout
retain rate.
However, this is only limited to the case of Dropout -> BN and not 
Dropout -> Conv ->  BN. 
Additionally, it turns out that BN learns a different scale parameter gamma that
needs to be corrected as well. Although the suggested correction already reduces
the shift 
"""

restored_var = (mean**2+var)*(1-dropout_rate) - mean**2
print(restored_var)
bn_layer.set_weights([gamma, beta, mean, restored_var])

[3.7347355 3.812334  3.8797493 3.9765224 3.9732265]


In [0]:
pred = model.predict(np.ones((1,n)))
print(pred)

[[2.4491937 2.5862288 2.6685781 2.690172  2.6758535]]


Guidelines & Recommendations:

1- In modern CNNs batch norma and dropout are not recommended to be combined due to the variance shift phenomenon

2- The severity of the variance shift depends on the dropout rate and feature dimensions

3- Apply dropout after the last BN layer

4- Adjusting the moving mean and variance by passing the training data during test is not enough

5- New form of Dropout UDrop can help

* Variance shift can lead to numerical disturbances that will be amplified and lead to misclassification


