# Imports

In [1]:
import math
import torch

# Why you need a good init

To understand why initialization is important in a neural net, we'll focus on the basic operation you have there: matrix multiplications. So let's just take a vector `x`, and a matrix `a` initiliazed randomly, then multiply them 100 times (as if we had 100 layers). 

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=1132)

In [2]:
x = torch.randn(512)
a = torch.randn(512, 512)

In [3]:
for i in range(100): 
    x = a @ x
    print(x.mean())

tensor(-1.0124)
tensor(15.4200)
tensor(-119.5077)
tensor(323.9070)
tensor(-249560.4531)
tensor(-7246693.)
tensor(79500952.)
tensor(-2.1179e+09)
tensor(1.7629e+11)
tensor(5.9349e+11)
tensor(1.2413e+13)
tensor(-1.1178e+15)
tensor(-3.0797e+15)
tensor(5.5381e+17)
tensor(-5.4721e+18)
tensor(-2.9697e+20)
tensor(-2.5529e+19)
tensor(2.1651e+23)
tensor(-2.7939e+24)
tensor(-5.2402e+25)
tensor(2.9972e+26)
tensor(5.0566e+27)
tensor(7.9556e+29)
tensor(-1.6590e+31)
tensor(-3.5855e+32)
tensor(7.5000e+32)
tensor(2.2197e+35)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor(nan)
tensor

In [4]:
x.mean(),x.std()

(tensor(nan), tensor(nan))

The problem you'll get with that is activation explosion: very soon, your activations will go to nan. We can even ask the loop to break when that first happens:

In [10]:
x = torch.randn(512)
a = torch.randn(512, 512)

In [11]:
for i in range(100): 
    x = a @ x
    if x.std() != x.std(): break

In [12]:
i

27

It only takes 27 multiplications! On the other hand, if you initialize your activations with a scale that is too low, then you'll get another problem:

In [13]:
x = torch.randn(512)
a = torch.randn(512,512) * 0.01

In [14]:
for i in range(100): 
    x = a @ x
    print(x.mean())

tensor(0.0042)
tensor(0.0011)
tensor(-0.0004)
tensor(1.2775e-05)
tensor(-1.3455e-05)
tensor(-5.1568e-06)
tensor(-1.2435e-06)
tensor(-3.0753e-07)
tensor(9.6334e-08)
tensor(-8.8804e-09)
tensor(-5.6029e-10)
tensor(1.1405e-09)
tensor(3.5639e-11)
tensor(-4.7661e-11)
tensor(2.5637e-11)
tensor(-2.5184e-12)
tensor(-6.3484e-13)
tensor(-1.4026e-13)
tensor(-5.1399e-14)
tensor(4.8537e-15)
tensor(-1.2752e-15)
tensor(2.4888e-16)
tensor(-3.4923e-18)
tensor(1.1579e-17)
tensor(-3.4070e-18)
tensor(7.6530e-19)
tensor(-1.5070e-19)
tensor(1.1509e-19)
tensor(-7.4889e-21)
tensor(-4.3464e-22)
tensor(6.5315e-22)
tensor(-1.0211e-22)
tensor(-1.0815e-23)
tensor(3.8557e-24)
tensor(1.1635e-24)
tensor(-1.5965e-25)
tensor(6.5635e-26)
tensor(-3.0135e-26)
tensor(2.1686e-27)
tensor(-1.8187e-27)
tensor(3.4111e-28)
tensor(-2.5042e-29)
tensor(-9.3455e-30)
tensor(2.7856e-32)
tensor(-6.1475e-31)
tensor(-3.0552e-32)
tensor(3.2340e-32)
tensor(1.0695e-32)
tensor(-2.2935e-33)
tensor(1.1988e-33)
tensor(-1.2915e-34)
tensor(-9.5902

In [15]:
x.mean(),x.std()

(tensor(0.), tensor(0.))

In [18]:
x = torch.randn(512)
a = torch.randn(512,512) * 0.01

In [19]:
for i in range(100): 
    x = a @ x
    if x.std().item() == 0: break

In [20]:
i

69

Here, every activation vanished to 0. So to avoid that problem, people have come with several strategies to initialize their weight matices, such as:
- use a standard deviation that will make sure x and Ax have exactly the same scale
- use an orthogonal matrix to initialize the weight (orthogonal matrices have the special property that they preserve the L2 norm, so x and Ax would have the same sum of squares in that case)
- use [spectral normalization](https://arxiv.org/pdf/1802.05957.pdf) on the matrix A  (the spectral norm of A is the least possible number M such that `torch.norm(A@x) <= M*torch.norm(x)` so dividing A by this M insures you don't overflow. You can still vanish with this)

# The magic number for scaling

Here we will focus on the first one, which is the Xavier initialization. It tells us that we should use a scale equal to `1/math.sqrt(n_in)` where `n_in` is the number of inputs of our matrix.

In [21]:
x = torch.randn(512)
a = torch.randn(512, 512) / math.sqrt(512)

In [22]:
for i in range(100):
    x = a @ x

In [23]:
x.mean(),x.std()

(tensor(-0.0187), tensor(2.0513))

And indeed it works. Note that this magic number isn't very far from the 0.01 we had earlier.

In [24]:
1/ math.sqrt(512)

0.044194173824159216

But where does it come from? It's not that mysterious if you remember the definition of the matrix multiplication. When we do `y = a @ x`, the coefficients of `y` are defined by

$$y_{i} = a_{i,0} x_{0} + a_{i,1} x_{1} + \cdots + a_{i,n-1} x_{n-1} = \sum_{k=0}^{n-1} a_{i,k} x_{k}$$

or in code:
```python
y[i] = sum([c*d for c,d in zip(a[i], x)])
```

Now at the very beginning, our `x` vector has a mean of roughly 0. and a standard deviation of roughly 1. (since we picked it that way).

In [25]:
x = torch.randn(512)
x.mean(), x.std()

(tensor(0.0049), tensor(0.9616))

NB: This is why it's extremely important to normalize your inputs in Deep Learning, the intialization rules have been designed with inputs that have a mean 0. and a standard deviation of 1.

If you need a refresher from your statistics course, the mean is the sum of all the elements divided by the number of elements (a basic average). The standard deviation represents if the data stays close to the mean or on the contrary gets values that are far away. It's computed by the following formula:

$$\sigma = \sqrt{\frac{1}{n}\left[(x_{0}-m)^{2} + (x_{1}-m)^{2} + \cdots + (x_{n-1}-m)^{2}\right]}$$

where m is the mean and $\sigma$ (the greek letter sigma) is the standard deviation. Here we have a mean of 0, so it's just the square root of the mean of x squared.

If we go back to `y = a @ x` and assume that we chose weights for `a` that also have a mean of 0, we can compute the standard deviation of `y` quite easily. Since it's random, and we may fall on bad numbers, we repeat the operation 100 times.

In [18]:
mean, sqr = 0.,0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = a @ x
    mean += y.mean().item()
    sqr  += y.pow(2).mean().item()
mean / 100, sqr / 100

(0.002380470633506775, 511.89030822753904)

Now that looks very close to the dimension of our matrix 512. And that's no coincidence! When you compute y, you sum 512 product of one element of `a` by one element of `x`. So what's the mean and the standard deviation of such a product? We can show mathematically that as long as the elements in `a` and the elements in `x` are independent, the mean is 0 and the std is 1. This can also be seen experimentally:

In [26]:
mean, sqr = 0.,0.
for i in range(10000):
    x = torch.randn(1)
    a = torch.randn(1)
    y = a * x
    mean += y.item()
    sqr  += y.pow(2).item()
mean/10000, sqr/10000

(-0.0045432607773422205, 0.9695331371464219)

Then we sum 512 of those things that have a mean of zero, and a mean of squares of 1, so we get something that has a mean of 0, and mean of square of 512, hence `math.sqrt(512)` being our magic number. If we scale the weights of the matrix `a` and divide them by this `math.sqrt(512)`, it will give us a `y` of scale 1, and repeating the product as many times as we want won't overflow or vanish.

Assuming that $a_i$ and $x_i$ are independent.
$$var(a_ix_i) = E(a_i)^2var(x_i) + E(x_i)^2var(a_i) + var(a_i)var(x_i)$$
Since normalized $x_i$ to have mean of 0 and std of 1, and $a_i$ is chosen to have mean of 0 and std of 1, therefore:
$$var(a_ix_i) = var(a_i)var(x_i) = 1$$

# Adding ReLU in the mix

We can reproduce the previous experiment with a ReLU, to see that this time, the mean shifts and the standard deviation becomes 0.5. This time the magic number will be `math.sqrt(2/512)` to properly scale the weights of the matrix.

In [27]:
mean, sqr = 0., 0.
for i in range(10000):
    x = torch.randn(1)
    a = torch.randn(1)
    y = a * x
    y = 0 if y < 0 else y.item()
    mean += y
    sqr += y**2
mean / 10000, sqr / 10000

(0.3313892057748288, 0.5305700915405362)

We can double check by running the experiment on the whole matrix product.

In [28]:
mean, sqr = 0., 0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = a @ x
    y = y.clamp(min=0)
    mean += y.mean().item()
    sqr += y.pow(2).mean().item()
mean / 100, sqr / 100

(9.100272226333619, 259.10854370117187)

Or that scaling the coefficient with the magic number gives us a scale of 1.

In [29]:
mean, sqr = 0., 0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512) * math.sqrt(2 / 512)
    y = a @ x
    y = y.clamp(min=0)
    mean += y.mean().item()
    sqr  += y.pow(2).mean().item()
mean / 100, sqr / 100

(0.5667782863974571, 1.0099698793888092)

The math behind is a tiny bit more complex, and you can find everything in the [Kaiming](https://arxiv.org/abs/1502.01852) and the [Xavier](http://proceedings.mlr.press/v9/glorot10a.html) paper but this gives the intuition behind those results.