**Vanishing Gradient**
- Gradients gets smaller and smaller during backward pass
- Earlier layers gets small parameter updates
- Model doesn't learn

**Exploding Gradient**
- Gradients get bigger and bigger
- Parameter updates are too large
- Training diverges

**Solution to unstable gradients**
- Proper weights initialization
- Good activations
- Batch normalization


In [2]:
import torch
import torch.nn as nn
import torch.nn.init as init

In [3]:
layer = nn.Linear(8,1)
print(layer.weight)

Parameter containing:
tensor([[ 0.0680,  0.0569, -0.3379, -0.2509, -0.0758, -0.0468, -0.1363, -0.1181]],
       requires_grad=True)


In [4]:
# Weight initialization
# to achieve good initialization we can use He/Kaiming initialization for ReLU
init.kaiming_uniform_(layer.weight)
print(layer.weight)

Parameter containing:
tensor([[ 0.7066, -0.1769, -0.0929,  0.8220, -0.2240,  0.4938,  0.5241, -0.5136]],
       requires_grad=True)


In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9,16)
        self.fc2 = nn.Linear(16,8)
        self.fc3 = nn.Linear(8,1)

        init.kaiming_uniform_(self.fc1.weight)
        init.kaiming_uniform_(self.fc2.weight)
        init.kaiming_uniform_(
            self.fc3.weight,
            nonlinearity="sigmoid",
        )


In [7]:
# Activation functions
"""
ReLU: zero for negative inputs - suffers from dying neuron problem. 

ELU: Non-zero gradients for negative values - helps against dying neurons
Average output around zero - helps against vanishing gradients
"""

# nn.functional.relu()
# nn.functional.elu()

'\nReLU: zero for negative inputs - suffers from dying neuron problem. \n\nELU: Non-zero gradients for negative values - helps against dying neurons\nAverage output around zero - helps against vanishing gradients\n'

In [None]:
"""
Batch normalization:
After a layer
1. Normalize the layer's outputs by:
    a) Subtracting the mean
    b) Dividing by the standard deviation
2. Scale and shift normalized outputs using learned parameters
    a) Faster loss decrease
    b) Helps against unstable gradients

"""
class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9,16)
        self.bn1 = nn.BatchNorm1d(16)

    def forward(self,x):
        x = self.fc1(x)
        x = self.bc1(x)
        x = nn.functional.elu(x)