# Batch Normalization in Deep Neural Networks

Training deep neural networks can be unstable due to **internal covariate shift** (changing input distributions for each layer during training).

**Batch Normalization (BatchNorm)** helps by normalizing the activations of each layer, leading to:
- Faster convergence
- Better stability
- Reduced overfitting

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

print("TensorFlow version:", tf.__version__)

## Load and preprocess dataset

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

print("Training data:", x_train.shape)
print("Test data:", x_test.shape)

## Model without Batch Normalization

In [None]:
model_no_bn = Sequential([
    Flatten(input_shape=(28,28)),
    Dense(256, activation='relu'),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

model_no_bn.compile(optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

history_no_bn = model_no_bn.fit(x_train, y_train, epochs=5, batch_size=32,
                                validation_data=(x_test, y_test), verbose=0)

## Model with Batch Normalization

In [None]:
model_bn = Sequential([
    Flatten(input_shape=(28,28)),
    Dense(256, activation='relu'),
    BatchNormalization(),
    Dense(128, activation='relu'),
    BatchNormalization(),
    Dense(10, activation='softmax')
])

model_bn.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['accuracy'])

history_bn = model_bn.fit(x_train, y_train, epochs=5, batch_size=32,
                          validation_data=(x_test, y_test), verbose=0)

## Compare Validation Accuracy

In [None]:
plt.plot(history_no_bn.history['val_accuracy'], label='Without BatchNorm')
plt.plot(history_bn.history['val_accuracy'], label='With BatchNorm')
plt.title('Effect of Batch Normalization')
plt.xlabel('Epochs')
plt.ylabel('Validation Accuracy')
plt.legend()
plt.show()

## Key Takeaways
- BatchNorm stabilizes and speeds up training.
- Helps reduce overfitting by adding slight regularization.
- Often allows use of higher learning rates.
- Widely used in CNNs, RNNs, and Transformers.