## Further Reading

##### Further information on batch normalization and how they work: https://towardsdatascience.com/batch-norm-explained-visually-how-it-works-and-why-neural-networks-need-it-b18919692739

In [None]:
import tensorflow as tf
import numpy as np

## Experimentation with tf.nn.batch_normalization

##### Input tensor can be of any shape, for context - let's assume it's the output from a conv2d -> global_avg_pooling layer

In [None]:
x = np.random.randn(1, 6, 6, 2)
x = tf.Variable(x, dtype=tf.float32)

##### Get mean and variance from tensor

In [None]:
mean_x, std_x = tf.nn.moments(x, axes=2, keepdims=True)
mean_x, std_x

##### Initialize offset and scale from same shape as input tensor, these are trainable and would be adjusted by backpropagation of the model

In [None]:
offset = tf.Variable(tf.random.normal(x.shape, stddev=0.1),
                    trainable=True,
                    dtype=tf.float32
                   )
offset

In [None]:
scale = tf.Variable(tf.random.normal(x.shape, stddev=0.1),
                   trainable=True,
                   dtype=tf.float32
                   )
scale

##### input the above parameters to tf batch normalization layer

In [None]:
batch_layer = tf.nn.batch_normalization(x=x, 
                                        mean=mean_x, 
                                        variance=std_x, 
                                        offset=offset, 
                                        scale=scale,
                                        variance_epsilon=1e-12
                                       )
batch_layer

## Build Custom Batch Normalizaion Layer

### Implement class layer

In [None]:
class BatchNormalization(tf.Module):
    
    def __init__(self, 
                 name: str = None):
        super(BatchNormalization, self).__init__(name)
        
        self.offset: Optional[tf.Tensor] = None
        self.scale: Optional[tf.Tensor] = None
            
        self.is_built: bool = False
            
    
    def __call__(self, x_in):
        
        if not self.is_built: # initialize the weights, NOTE: offset and scale are added to the batch normalization layer and are trained to be optimized by backprop
                        
            self.offset = tf.Variable(
                tf.random.normal(x.shape, stddev=0.1),
                trainable=True,
                dtype=tf.float32,
                name="offset_layer_weights"
            )
            self.scale = tf.Variable(
                tf.random.normal(x.shape, stddev=0.1),
                trainable=True,
                dtype=tf.float32,
                name="scale_layer_weights"    
            )
            
            self.is_built = True
            
        mean_x, std_x = tf.nn.moments(x_in, axes=2, keepdims=True) # calculate mean and std of input tensor
        
        return tf.nn.batch_normalization(x=x_in, 
                                        mean=mean_x, 
                                        variance=std_x, 
                                        offset=self.offset, 
                                        scale=self.scale,
                                        variance_epsilon=1e-12
                                       )

In [None]:
batch_normalizatio_layer = BatchNormalization()

In [None]:
batch_normalizatio_layer(x_in=x)