In [5]:
import numpy as np

In [6]:
# def batch_norm_2d(x, gamma, beta, eps):
#     N, C, H, W = x.shape

#     # Compute mean and variance
#     mu = np.mean(x, axis=(0, 2, 3)).reshape(1, C, 1, 1)
#     var = np.var(x, axis=(0, 2, 3)).reshape(1, C, 1, 1)

#     # Normalize and scale
#     x_norm = (x - mu) / np.sqrt(var + eps)
#     out = gamma.reshape(1, C, 1, 1) * x_norm + beta.reshape(1, C, 1, 1)

#     # Cache values needed for backward pass
#     cache = (x, x_norm, mu, var, gamma, beta, eps)

#     return out, cache


def np_batch_norm_2d_backward(dout, x, gamma, eps=1e-5):
    # Compute the batch size
    N, C, H, W = x.shape
    
    # compute the mean and var
    mean = np.mean(x, axis=(0, 2, 3), keepdims=True)
    var = np.mean((x - mean)**2, axis=(0, 2, 3), keepdims=True)
    
    # Compute the standard deviation and inverse of the standard deviation
    std = np.sqrt(var + eps)
    istd = 1.0 / std
    
    # Compute the normalized input
    x_norm = (x - mean) / std
    
    # Compute the gradients with respect to gamma and beta
    dgamma = np.sum(dout * x_norm, axis=(0, 2, 3), keepdims=True)
    dbeta = np.sum(dout, axis=(0, 2, 3), keepdims=True)
    
    # Compute the gradient with respect to the input
    # gamma is of shape (1, C, 1, 1), we should reshape it to C (scalar)
    dx_norm = dout * gamma
    dvar = np.sum(dx_norm * (x - mean) * (-0.5) * istd**3, axis=(0, 2, 3), keepdims=True)
    dmean = np.sum(dx_norm * (-istd), axis=(0, 2, 3), keepdims=True) + dvar * np.mean(-2.0 * (x - mean), axis=(0, 2, 3), keepdims=True)
    dx = dx_norm * istd + dvar * 2.0 * (x - mean) / (N * H * W) + dmean / (N * H * W)
    
    return dx, dgamma, dbeta



In [8]:
N, C, H, W = 2, 3, 4, 5
X_shape = (N, C, H, W)
X = np.random.randn(*X_shape)
out = np.random.randn(*X_shape)
gamma = np.random.randn(1, C, 1, 1)
eps = 1e-5

dx, dgamma, dbeta = np_batch_norm_2d_backward(out, X, gamma, eps)

In [10]:
dgamma.shape

(1, 3, 1, 1)