# Kaiming Initialzation

## # What we what?

$$
Var(Y_t) = Var(Y_{t-1})
$$

where $t$ is the number of layers and $Y_t$ is the output of layer $t$.

## # How to do?

We assume that the shape of our nn: `l0 -> l1 -> l2 -> ...`.  
So we have: 

$$
Y_t = W_t X_t + b_t
$$

because $b_t$ is a constant, we have: 

$$
Var(Y_t) = Var(W_t X_t + b_t) = Var(W_t X_t)
$$

Here we let $w_t, x_t, y_t$ represent the numbers in $W_t, X_t, Y_t$. Now it comes:  

$$
Var(y_t) = Var(\sum_{i=1}^{l_t} w_t^i x_t^i) = l_t Var(w_t x_t)
$$

if we assume that $\mathbb{E}(w_t) = 0$, we have $\mathbb{E}(w_t x_t) = 0$(beacuse $W_t$ & $X_t$ are iid.). 

then we can use formula of variance: 

$$
Var(y_t) = l_t Var(w_t x_t) = l_t [\mathbb{E}(w_t^2 x_t^2) - \mathbb{E}^2(w_t x_t)] = l_t \mathbb{E}(w_t^2 x_t^2)
$$

and we can know that $w_t$ & $x_t$ are iid, so $w_t^2$ & $x_t^2$ are iid.

we get: 

$$
Var(y_t) = l_t \mathbb{E}(w_t^2) \mathbb{E}(x_t^2) = l_t Var(w_t) \mathbb{E}(x_t^2)
$$

now, our task is to compute $\mathbb{E}(x_t^2)$ by $Var(y_t)$.

we now that:

$$
x_t = f(y_{t-1})
$$

$f(\cdot)$ here is activation function. We assume $f(\cdot)$ is ReLu here. Then we assume that $\mathbb{E}(y_{t-1}) = 0$. Now we can compute:  

$$
\mathbb{E}(x_t^2) = \int_{-\infty}^{\infty} ReLu^2(y_{t-1}) p(y_{t-1}) dy_{t-1} = \int_{0}^{\infty} y_{t-1}^2 p(y_{t-1}) dy_{t-1}
$$

and we assume that $p(x)$ is even, such as $\mathcal{N}$ & $\mathcal{D}$, we can get: 

$$
\mathbb{E}(x_t^2) = {1\over 2} \int_{-\infty}^{\infty} (y_{t-1}-0)^2 p(y_{t-1}) dy_{t-1} = {1\over 2} Var(y_{t-1})
\Rightarrow Var(y_t) = l_t Var(w_t) {1\over 2} Var(y_{t-1})
$$

so to make our assumption: "$\mathbb{E}(y_{t-1}) = 0$" fit to all, we should make: 

$$
Var(y_t) = Var(y_{t-1}) = \dots = Var(x_{0})
$$

so, here we go: 

$$
l_t Var(w_t) {1\over 2} = 1 \Rightarrow Var(w_t) = {2 \over l_t}
$$

and if we use Leaky-Relu, the conclusion will be: 

$$
Var(w_t) = {2 \over (1 + a^2) l_t}
$$

do not forget we assume that: 
- Distribution of $y_t$ is even, and $x_{0}$ zero-mean.
- $w_{t}$ are zero-mean.  
- activate function is ReLU or Leaky-ReLU

so we always make:

$$
w_t \sim \mathcal{N}(0, {2 \over l_t})
$$

# Reference 

[1] [Delving Deep into Rectifiers (He et al., 2015)](https://arxiv.org/abs/1502.01852)

In [1]:
import jax.numpy as jnp
import numpy as np
from jax import random, grad, vmap, jit

key = random.PRNGKey(0)

np.set_printoptions(suppress=True, formatter={'float_kind': '{:8.2f}'.format})

# number of layers
num_layer = 30

In [2]:
ws = [random.normal(key, (3, 3)) * jnp.sqrt(2 / 3) for _ in range(num_layer)]
bs = [jnp.zeros(1) for _ in range(num_layer)]

X = random.normal(key, (1000, 3))
res = X
for w, b in zip(ws, bs):
    res = res @ w + b
    res = jnp.maximum(0, res)

print(f'original : \n{X}')

original : 
[[    1.62     2.03    -0.43]
 [   -0.08     0.18    -0.97]
 [   -0.50     0.49     0.66]
 ...
 [    0.22    -0.35    -0.29]
 [    0.48     0.46    -0.48]
 [   -0.05    -0.21    -0.42]]


In [3]:
print(f'use kaiming: \n{res}')

use kaiming: 
[[  839.02  1280.97     0.00]
 [  114.80   175.27     0.00]
 [    0.00     0.00     0.00]
 ...
 [  169.26   258.41     0.00]
 [  313.43   478.53     0.00]
 [   46.43    70.89     0.00]]


In [4]:
ws = [random.normal(key, (3, 3)) for _ in range(num_layer)]
res = X
for w, b in zip(ws, bs):
    res = res @ w + b
    res = jnp.maximum(0, res)
    
print(f'not use kaiming: \n{res}')

not use kaiming: 
[[367400.50 560927.81     0.00]
 [50271.13 76751.32     0.00]
 [    0.00     0.00     0.00]
 ...
 [74116.68 113157.45     0.00]
 [137249.80 209545.78     0.00]
 [20330.89 31040.14     0.00]]
