In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer

In [None]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps):
    # Compute reciprocal of square root of the moving variance elementwise
    inv = tf.cast(tf.math.sqrt(moving_var + eps), X.dtype)
    # Scale and shift | y = gamma*x - beta
    Y = gamma * ((X - moving_mean) / inv) + beta
    return Y

In [None]:
class BatchNorm(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
    
    def build(self, input_shape):
        weight_shape = [input_shape[-1],]

        self.gamma = self.add_weight(
            name = 'gamma',
            shape = weight_shape,
            initializer = tf.initializer.ones, 
            trainable = True
        )
        self.beta = self.add_weight(
            name = 'beta',
            shape = weight_shape,
            initializer = tf.initializer.zeros, 
            trainable = True
        )
        self.moving_mean = self.add_weight(
            name='moving_mean',
            shape=weight_shape, 
            initializer=tf.initializers.zeros,
            trainable=False
        )
        
        self.moving_variance = self.add_weight(
            name='moving_variance',
            shape=weight_shape, 
            initializer=tf.initializers.ones,
            trainable=False
        )
        
        super(BatchNorm, self).build(input_shape)

    def assign_moving_average(self, variable, value):
        momentum = 0.1
        delta = (1.0 - momentum) * variable + momentum * value
        return variable.assign(delta)

    @tf.function
    def call(self, inputs, training):
        if training:
            axes = list(range(len(inputs.shape) - 1))
            batch_mean = tf.reduce_mean(inputs, axes, keepdims=True)
            batch_variance = tf.reduce_mean(tf.math.squared_difference(
                inputs, tf.stop_gradient(batch_mean)), axes, keepdims=True)
            batch_mean = tf.squeeze(batch_mean, axes)
            batch_variance = tf.squeeze(batch_variance, axes)
            mean_update = self.assign_moving_average(
                self.moving_mean, batch_mean)
            variance_update = self.assign_moving_average(
                self.moving_variance, batch_variance)
            self.add_update(mean_update)
            self.add_update(variance_update)
            mean, variance = batch_mean, batch_variance
        else:
            mean, variance = self.moving_mean, self.moving_variance
        output = batch_norm(inputs, moving_mean=mean, moving_var=variance,
            beta=self.beta, gamma=self.gamma, eps=1e-5)
        return output
    
     

bn_layer = BatchNorm()
x = bn_layer(inputs)
--> x = BatchNorm()(inputs)

inputs = tf.keras.Input(shape=(64,))       # data: 64 features
h = tf.keras.layers.Dense(128)(inputs)     # fully connected layer output
x = BatchNorm()(h)   

Input shape: (batch_size, n_features)