In [19]:
import numpy as np

In [20]:
grad_enabled = False
"""
X: Input data.
gamma: Scale parameter.
beta: Shift parameter.
moving_mean: Moving average of mean.
moving_var: Moving average of variance.
eps: A small value to prevent division by zero.
momentum: Momentum for moving average.
"""
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
    if not grad_enabled:
        # Inference mode, i.e not training. the model uses the moving averages
        # of variance and mean to normalize input X during inference mode.
        X_hat = (X - moving_mean) / np.sqrt(moving_var + eps)
    else:
         
        mean = np.mean(X, axis=(0, 1, 2), keepdims=True) # Mean over N, H, W for each Channel
        var = np.mean((X - mean), axis=(0, 1, 2), keepdims=True)
        
        # Normalization and Moving Average Update
        X_hat = (X - mean) / np.sqrt(var + eps)

        # Update moving averages of mean and variance
        moving_mean = (1.0 - momentum) * moving_mean + mean
        moving_var = (1.0 - momentum) * moving_var * var
    # Apply scale (gamma) and shift (beta)
    Y = gamma * X_hat + beta # Applies scaling and shifting using gamma and beta to the normalized input X_hat
    return Y, moving_mean, moving_var # returns the normalized and scaled input Y and the updated moving averages moving mean and moving variance

In [21]:
np.random.seed(0) # for reproducibility
X = []
N = 3
for _ in range(N):
   X.append(np.random.randn(32, 32, 3))

X = np.array(X)
X.shape


(3, 32, 32, 3)

In [22]:
out = batch_norm(X=X, gamma=np.ones(X.shape), beta=np.zeros(X.shape), moving_mean=np.zeros(X.shape), moving_var=np.ones(X.shape), eps=1e-5, momentum=0.9)
bn_out, moving_mean, moving_var = out
bn_out


array([[[[ 1.76404353,  0.40015521,  0.97873309],
         [ 2.24088199,  1.86754865, -0.97727299],
         [ 0.95008367, -0.15135645, -0.10321834],
         ...,
         [-0.17992394, -1.07074727,  1.05444645],
         [-0.40317493,  1.22243896,  0.20827394],
         [ 0.97663415,  0.35636462,  0.70656964]],

        [[ 0.01049997,  1.78586156,  0.12691146],
         [ 0.40198735,  1.88314128, -1.34775232],
         [-1.27047865,  0.96939186, -1.17311754],
         ...,
         [-2.22339204,  0.62522832, -1.60204965],
         [-1.10437782,  0.05216482, -0.7395593 ],
         [ 1.54300688, -1.29285045,  0.26704953]],

        [[-0.03928262, -1.16808766,  0.52327404],
         [-0.17154547,  0.77178669,  0.82350004],
         [ 2.16322513,  1.33652127, -0.36917999],
         ...,
         [ 2.06448254, -0.1105401 ,  1.02016761],
         [-0.69204639,  1.53636937,  0.28634226],
         [ 0.60884079, -1.04524814,  1.21113923]],

        ...,

        [[-0.74501949,  1.01225545, -1