Weight Initialization in Neural Networks: A Journey From the Basics to Kaiming
https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79

For a quick-and-dirty example that illustrates this, let’s pretend that we have a vector x that contains some network inputs. It’s standard practice when training neural networks to ensure that our inputs’ values are scaled such that they fall inside such a normal distribution with a mean of 0 and a standard deviation of 1.

In [2]:
import torch

In [19]:
x = torch.randn(512)

In [20]:
x[:10]

tensor([ 0.7684,  0.5141,  1.1430,  1.1872,  0.3943, -0.8984,  0.3526,  0.5783,
        -0.8831,  1.0569])

Let’s also pretend that we have a simple 100-layer network with no activations , and that each layer has a matrix a that contains the layer’s weights. In order to complete a single forward pass we’ll have to perform a matrix multiplication between layer inputs and weights at each of the hundred layers, which will make for a grand total of 100 consecutive matrix multiplications.

It turns out that initializing the values of layer weights from the same standard normal distribution to which we scaled our inputs is never a good idea. To see why, we can simulate a forward pass through our hypothetical network.

In [21]:
for i in range(100):
    a = torch.randn(512, 512)
    x = a @ x
x.mean(), x.std()

(tensor(nan), tensor(nan))

Whoa! Somewhere during those 100 multiplications, the layer outputs got so big that even the computer wasn’t able to recognize their standard deviation and mean as numbers. We can actually see exactly how long that took to happen.

In [14]:
x = torch.randn(512,)
for i in range(100):
    a = torch.randn(512, 512)
    x = a @ x
    if torch.isnan(x.std()):break
i

28

The activation outputs exploded within 29 of our network’s layers. We clearly initialized our weights to be too large.

Unfortunately, we also have to worry about preventing layer outputs from vanishing. To see what happens when we initialize network weights to be too small — we’ll scale our weight values such that, while they still fall inside a normal distribution with a mean of 0, they have a standard deviation of 0.01.

In [18]:
x = torch.randn(512)
for i in range(100):
    a = torch.randn(512, 512) * 0.01
    x = a @ x
x.mean(), x.std()

(tensor(0.), tensor(0.))

During the course of the above hypothetical forward pass, the activation outputs completely vanished.

To sum it up, if weights are initialized too large, the network won’t learn well. The same happens when weights are initialized too small.

We can demonstrate that at a given layer, the matrix product of our inputs x and weight matrix a that we initialized from a standard normal distribution will, on average, have a standard deviation very close to the square root of the number of input connections, which in our example is √512.

In [17]:
import math

mean, var = 0., 0.
for i in range(10000):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = a @ x
    mean += y.mean().item()
    var += y.pow(2).mean().item()
mean/10000, math.sqrt(var/10000)

(0.0027618063405156134, 22.639051156329657)

In [22]:
math.sqrt(512)

22.627416997969522

Now let’s re-run our quick-and-dirty 100-layer network. As before, we first choose layer weights at random from standard normal distribution inside [-1,1], but this time we scale those weights by 1/√n, where n is the number of network input connections at a layer, which is 512 in our example.

In [30]:
x = torch.randn(512)
for i in range(100):
    a = torch.randn(512, 512) / math.sqrt(512) # or * math.sqrt(1./512)
    x = a @ x
x.mean(), x.std()

(tensor(0.0527), tensor(0.7191))

For the sake of simplicity, activation functions were omitted. However, we’d never do this in real life. 

In [31]:
def tanh(x): return torch.tanh(x)

In [32]:
x = torch.randn(512)
for i in range(100):
    a = torch.randn(512, 512) / math.sqrt(512) # or * math.sqrt(1./512)
    x = tanh(a @ x)
x.mean(), x.std()

(tensor(-0.0034), tensor(0.0731))

The standard deviation of activation outputs of the 100th layer is down to about 0.06. This is definitely on the small side, but at least activations haven’t totally vanished!

When Xavier Glorot and Yoshua Bengio published their landmark paper titled Understanding the difficulty of training deep feedforward neural networks, the “commonly used heuristic” to which they compared their experiments was that of initializing weights from a uniform distribution in [-1,1] and then scaling by 1/√n.

It turns out this “standard” approach doesn’t actually work that well.

In [35]:
x = torch.randn(512)
for i in range(100):
    a = torch.Tensor(512, 512).uniform_(-1,1) / math.sqrt(512) # or * math.sqrt(1./512)
    x = tanh(a @ x)
x.mean(), x.std()

(tensor(2.0694e-26), tensor(1.3261e-24))

randn -> Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution).

uniform -> Generates uniformly distributed random samples from the half-open interval

Xavier initialization

In [38]:
def xavier(m, h):
    return torch.Tensor(m, h).uniform_(-1, 1) * math.sqrt(6./(m+h))

In [39]:
x = torch.randn(512)
for i in range(100):
    a = xavier(512,512)
    x = tanh(a @ x)
x.mean(), x.std()

(tensor(0.0032), tensor(0.0753))

In our experimental network, Xavier initialization performs pretty identical to the home-grown method that we derived earlier, where we sampled values from a random normal distribution and scaled by the square root of number of incoming network connections, n.

But what if we’re using ReLU activation functions? Would it still make sense to want to scale random initial weight values in the same way?



In [40]:
def relu(x): return x.clamp_min_(0.)

In [41]:
import math

mean, var = 0., 0.
for i in range(10000):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = relu(a @ x)
    mean += y.mean().item()
    var += y.pow(2).mean().item()
mean/10000, math.sqrt(var/10000)

(9.029190416193009, 16.01297031356908)

It turns out that when using a ReLU activation, a single layer will, on average have standard deviation that’s very close to the square root of the number of input connections, divided by the square root of two, or √512/√2 in our example.

In [44]:
math.sqrt(512/2)

16.0

In [46]:
import math

mean, var = 0., 0.
for i in range(10000):
    x = torch.randn(512)
    a = torch.randn(512, 512) * math.sqrt(2/512)
    y = relu(a @ x)
    mean += y.mean().item()
    var += y.pow(2).mean().item()
mean/10000, math.sqrt(var/10000)

(0.5642064209282398, 1.0003509295779327)

In their 2015 paper, He et. al. demonstrated that deep networks (e.g. a 22-layer CNN) would converge much earlier if the following input weight initialization strategy is employed:

Create a tensor with the dimensions appropriate for a weight matrix at a given layer, and populate it with numbers randomly chosen from a standard normal distribution.

Multiply each randomly chosen number by √2/√n where n is the number of incoming connections coming into a given layer from the previous layer’s output (also known as the “fan-in”).

Bias tensors are initialized to zero.

In [54]:
def kaiming(m, h):
    return torch.randn(m, h) * math.sqrt(2./m)

In [55]:
x = torch.randn(512)
for i in range(100):
    a = kaiming(512,512)
    x = relu(a @ x)
x.mean(), x.std()

(tensor(0.3497), tensor(0.5191))

As a final comparison, here’s what would happen if we were to use Xavier initialization, instead.

In [56]:
x = torch.randn(512)
for i in range(100):
    a = xavier(512,512)
    x = relu(a @ x)
x.mean(), x.std()

(tensor(2.3099e-16), tensor(3.3960e-16))

Ouch! When using Xavier to initialize weights, activation outputs have almost completely vanished by the 100th layer!

The moral of the story for us is that any network we train from scratch, especially for computer vision applications, will almost certainly contain ReLU activation functions and be several layers deep. In such cases, Kaiming should be our go-to weight init strategy.