# Weight Initialization

## Variance Propagation in the Forward Pass
Given a layer $y = Wx$ followed by an activation (non-linear function), what variance should the entries of $W$ have so that activations don’t blow up or vanish as we stack layers?

We model one neuron in a fully connected layer as

$$y = \sum_{i=1}^{n} w_i x_i$$

where
* $n = \text{fan\_in}$,
* $x_i$ are i.i.d., $\mathbb{E}[x_i] = 0, \operatorname{Var}(x_i) = v_x$,
* $w_i$ are i.i.d., $\mathbb{E}[w_i] = 0, \operatorname{Var}(w_i) = v_w$,
* $w_i$ independent of $x_i$.

By independence and zero mean assumptions,

$$
\operatorname{Var}(y)
= \operatorname{Var}\Big(\sum_{i=1}^n w_i x_i\Big)
= \sum_{i=1}^n \operatorname{Var}(w_i x_i)
= \sum_{i=1}^n \operatorname{Var}(w_i)\operatorname{Var}(x_i)
= n \, v_w \, v_x.
$$

So

$$
\operatorname{Var}(y) = \text{fan\_in} \cdot v_w \cdot v_x
$$

If we want $\operatorname{Var}(y) = v_x$ we need $v_w =\frac{1}{\text{fan\_in}}$.


If we draw uniformly $w_i\sim \mathcal{U}(-b,b)$, then 

$$v_w=\operatorname{Var}(w_i)=\int_{-b}^bx^2 \frac{1}{2b}dx=\frac{b^2}{3}.$$

Then $$b=\sqrt{3v_w}$$

## Variance Propagation in the Backward Pass (Gradients)

Given a layer $y = W x$, how should the variance of the weight entries $W$ be chosen so that the gradients $\frac{\partial L}{\partial x}$ do not explode or vanish during backpropagation through layers?


We model the output of a fully connected layer as

$$y_j = \sum_{i=1}^{n} w_{j,i} x_i, \quad j=1,2,\dots,m$$

where
* $n = \text{fan\_in}$,
* $m = \text{fan\_out}$,

Denote by $L$ the loss function of the neural network, we assume:
* $\frac{\partial L}{\partial y_j}$ are i.i.d., $\mathbb{E}[\frac{\partial L}{\partial y_j}] = 0, \operatorname{Var}(\frac{\partial L}{\partial y_j}) = v_{\partial y}$,
* $w_{j,i}$ are i.i.d., $\mathbb{E}[w_{j,i}] = 0, \operatorname{Var}(w_{j,i}) = v_w$,
* $w_{j,i}$ independent of $\frac{\partial L}{\partial y_j}$.

**Note:**  $\frac{\partial L}{\partial y}$ are not independent because $\frac{\partial L}{\partial y}$ depends on $w$ through higher layers. Nevertheless, we approximate them as independent to make the variance analysis tractable. This is standard in both the Glorot & Bengio (2010) and He et al. (2015) derivations.

Then, consider gradient backpropagation through this layer for the variable $x_\ell$:
$$
\begin{align*}
\frac{\partial L}{\partial x_\ell} & =\sum_{j=1}^{m} \frac{\partial L}{\partial y_j}\frac{\partial y_j}{\partial x_\ell}=\sum_{j=1}^{m} w_{j,\ell} \frac{\partial L}{\partial y_j}.
\end{align*}
$$

By independence and zero mean assumptions,
$$
\begin{align*}
\operatorname{Var}\left(\frac{\partial L}{\partial x_\ell}\right)
&= \operatorname{Var}\left(\sum_{j=1}^{m} w_{j,\ell} \frac{\partial L}{\partial y_j}\right),\\
&= \sum_{j=1}^{m} \operatorname{Var}\left(w_{j,\ell} \frac{\partial L}{\partial y_j}\right),\\
&= \sum_{j=1}^{m} \operatorname{Var}\left(w_{j,\ell}\right)\operatorname{Var}\left( \frac{\partial L}{\partial y_j}\right),\\
&= m \, v_w \, v_{\partial y}.
\end{align*}
$$

So

$$
\operatorname{Var}\left(\frac{\partial L}{\partial x_\ell}\right) = \text{fan\_out} \cdot v_w \cdot v_{\partial y}
$$

So, if we want $\operatorname{Var}\left(\frac{\partial L}{\partial x_\ell}\right) = v_{\partial y}$ we need $v_w =\frac{1}{\text{fan\_out}}$.

## Xavier/Glorot Initialization

It’s impossible to perfectly satisfy both forward and backward pass values for $v_w$. So Xavier/Glorot Initialization makes compromise by averaging both
$$v_w=\frac{2}{\text{fan\_in}+\text{fan\_out}}$$

If we draw uniformly, from the prevous derivation, $w_{j,i}\sim \mathcal{U}(-b,b)$, then 
$$b=\sqrt{3v_w}=\sqrt{\frac{6}{\text{fan\_in}+\text{fan\_out}}}$$

If we draw normally $w_{j,i} \sim \mathcal{N}\left(0,\frac{2}{\text{fan\_in}+\text{fan\_out}}\right)$.

## Kaiming (He)

For convinience, we define the gain such that $b=\text{gain}\sqrt{\frac{3}{\text{fan\_in}}}$

## LeakyReLU

We assume:
* $y$ is symetric: $P(y \in A)=P(y \in -A)$, for any borel real set $A \in \mathbb{R}$, 
and
$$
\begin{align*}
-A&:=\{-x: x \in A\}.
\end{align*}
$$
* $p(y = 0) = 0$.


Let $z = \operatorname{LeakyReLU}(y)$. We want 
$$\operatorname{Var}(z) = \operatorname{Var}(\operatorname{LeakyReLU}(y))= \operatorname{Var}(y\mathbf{1}_{\{y>0\}}+\alpha y\mathbf{1}_{\{y\leq 0\}})   $$

$$
\begin{align*}
\operatorname{Var}(z)&= \mathbb{E}[z^2]- \mathbb{E}[z]^2
\end{align*}
$$

From the assumptions on $y$, we obtain
$$
\begin{align*}
\mathbb{E}[y^2\mathbf{1}_{\{y>0\}}]&=\int_{(0,\infty)} \tau^2dp_y(\tau),\\
&=\frac{1}{2}\left(\int_{(0,\infty)} \tau^2dp_y(\tau)+\int_{(0,\infty)} \tau^2dp_y(-\tau)\right),\\
&=\frac{1}{2}\left(\int_{(0,\infty)} \tau^2dp_y(\tau)+\int_{(-\infty,0)} (-\tau)^2dp_y(\tau)\right),\\
&=\frac{1}{2}\left(\int_{-\infty}^\infty \tau^2dp_y(\tau)\right),\\
&=\frac{1}{2}\mathbb{E}[y^2],\\
&=\frac{1}{2}\operatorname{Var}(y).
\end{align*}
$$

Notice that from the third equality we also have 
$$\mathbb{E}[y^2\mathbf{1}_{\{y>0\}}]=\mathbb{E}[y^2\mathbf{1}_{\{y\leq0\}}]=\frac{1}{2}\operatorname{Var}(y)$$

Then we can deduce
$$
\begin{align*}
\mathbb{E}[z^2]&=\mathbb{E}[y^2\mathbf{1}_{\{y>0\}}+\alpha^2 y^2\mathbf{1}_{\{y\leq 0\}}],\\
&=\mathbb{E}[y^2\mathbf{1}_{\{y>0\}}]+\alpha^2\mathbb{E}[ y^2\mathbf{1}_{\{y\leq 0\}}],\\
&=\frac{1}{2}\operatorname{Var}(y)+\alpha^2\frac{1}{2}\operatorname{Var}(y),\\
&=\frac{1+\alpha^2}{2}\operatorname{Var}(y)
\end{align*}
$$


We simply ignore the second term $\mathbb{E}[z]^2$, in practice this terms is sually neglectable. Theoretically espeaking we have very weak justifications, for example in the $\text{ReLU}$ case $\alpha=0$ we have:
$$
\begin{align*}
(\mathbb{E}[z])^2 &= \big(\mathbb{E}[y \mathbf{1}_{\{y>0\}}]\big)^2,\\
&\le \mathbb{E}[y^2] \cdot \mathbb{E}[\mathbf{1}^2_{\{y>0\}}],\\
&= \mathbb{E}[y^2] \cdot \mathbb{E}[\mathbf{1}_{\{y>0\}}],\\
&= \frac{1}{2} \operatorname{Var}(y) \cdot \frac{1}{2}\left(1 + p(y =0)\right),\\
&= \frac{1}{4} \operatorname{Var}(y),
\end{align*}
$$
where the inequality directly comes from Cauchy–Schwarz inequality and step 4 comes from the symetry of $y$ as follow
$$
\begin{align*}
\mathbb{E}[\mathbf{1}_{\{y>0\}}]&=\int_0^\infty dp_y(\tau)=p(y \in [0,\infty)),\\
&=\frac{1}{2}\left(p(y \in (-\infty,0]) + p(y \in [0,\infty))\right),\\
&=\frac{1}{2}\left(1 + p(y =0)\right).
\end{align*}
$$

Then we have the inequality
$$
\frac14 \operatorname{Var}(Y) \;\le\; \operatorname{Var}(\text{ReLU}(Y)) \;\le\; \frac12 \operatorname{Var}(Y)
$$

The Kaiming initialization takes the approximation 
$$\operatorname{Var}[\max(0,y)]\approx \frac12\operatorname{Var}(y)$$

So we conclude

$$
\begin{align*}
\operatorname{Var}[z]\approx \frac{1+\alpha^2}{2}\operatorname{Var}(y) =\frac{1+\alpha^2}{2}\text{fan\_in} \cdot v_w \cdot v_x
\end{align*}
$$

Then, to obtain $\operatorname{Var}[z]\approx v_x$ we need 
$v_w=\frac{2}{(1+\alpha^2)\text{fan\_in}}$. Then the value of the bound $b$ is

$$b=\sqrt{3v_w}=\sqrt{\frac{6}{(1+\alpha^2)\text{fan\_in}}}$$

**Note:** In this case we have $\text{gain} = \sqrt{\frac{2}{1+\alpha^2}}$.

### Full Gain table

| Nonlinearity / Init | Gain / Formula | Origin | Type | Notes |
|----------------------|----------------|---------|------|--------|
| **relu** | √2 | He et al. (2015) | ✅ Theoretical | Derived for rectifiers; compensates for halved variance. |
| **leaky_relu(a)** | √(2 / (1 + a²)) | He et al. (2015), generalized | ✅ Theoretical | Generalization of He for nonzero negative slope. |
| **tanh** | 5/3 | Torch/PyTorch convention | ⚠️ Heuristic | Empirically tuned for stable gradients; no closed form. |
| **sigmoid** | 1.0 | Glorot & Bengio (2010) | ⚠️ Heuristic | Safe default; avoids saturation by keeping small weights. |
| **linear / conv\*** | 1.0 | Identity mapping | ✅ Trivial | No nonlinearity → no scaling required. |
| **elu(α)** | ≈√2 | Clevert et al. (2015) | ⚠️ Heuristic | Negative branch saturates; ReLU-like behavior empirically fits √2. |
| **selu** | LeCun Normal (σ² = 1 / fan_in) | Klambauer et al. (2017) | ✅ Theoretical | λ, α constants chosen for mean=0, var=1 fixed point; self-normalizing. |
| **gelu** | ≈√2 | Hendrycks & Gimpel (2016) | ⚠️ Empirical | Variance ≈ 0.425 → gain ≈ 1.54 ≈ √2; smooth ReLU-like. |
| **xavier / glorot** | σ² = 2 / (fan_in + fan_out) | Glorot & Bengio (2010) | ✅ Theoretical | Balances forward & backward variance for tanh/sigmoid. |

In [63]:
import torch
import torch.nn as nn
import math

## Code

In [64]:
tensor = torch.tensor([[[1,2,3,4],[4,5,6,2]]],dtype=torch.float32)
print(f'{tensor.dim()=}')
print(f'{tensor.size()=}')
print(f'{tensor.size(1)=}')
print(f'{tensor.numel()=}')

tensor.dim()=3
tensor.size()=torch.Size([1, 2, 4])
tensor.size(1)=2
tensor.numel()=8


In [65]:
def calculate_fan_in_and_fan_out(tensor: torch.Tensor):
    dimensions = tensor.dim()
    if dimensions < 2:
        raise ValueError("Fan in/out can not be computed for tensor with fewer than 2 dimensions")

    num_input_fmaps = tensor.size(1)
    num_output_fmaps = tensor.size(0)
    receptive_field_size = 1
    if tensor.dim() > 2:
        # convolutional weights e.g. [out_channels, in_channels, kH, kW, ...]
        receptive_field_size = tensor[0][0].numel()
    fan_in = num_input_fmaps * receptive_field_size
    fan_out = num_output_fmaps * receptive_field_size
    return fan_in, fan_out

def calculate_gain(nonlinearity: str, a: float = 0.0):
    nonlinearity = nonlinearity.lower()
    match nonlinearity:
        case 'sigmoid'|'linear' | 'conv1d' | 'conv2d' | 'conv3d':
            return 1.0
        case 'tanh':
            return 5.0/3  # example value
        case 'relu':
            return math.sqrt(2.0)
        case 'leaky_relu':
            return math.sqrt(2.0 / (1 + a*a))
        case _:
            raise ValueError(f"Unsupported nonlinearity {nonlinearity}")

In [66]:
def xavier_uniform_(tensor: torch.Tensor,
                    gain: float = 1.0):
    """
    Fills the input `tensor` with values drawn from U(-bound, bound)
    according to Xavier/Glorot uniform initialization.

    Reference:
    Glorot & Bengio (2010), "Understanding the difficulty of training deep feedforward neural networks".
    """

    # 1) Compute fan_in and fan_out
    fan_in, fan_out = calculate_fan_in_and_fan_out(tensor)

    # 2) Compute standard deviation and uniform bound
    bound = gain * math.sqrt(6/(fan_in + fan_out))

    # 3) Fill tensor
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

In [67]:
def xavier_normal_(tensor: torch.Tensor,
                   gain: float = 1.0):
    """
    Fills the input `tensor` with values drawn from N(0, std^2)
    according to Xavier/Glorot normal initialization.

    Reference:
    Glorot & Bengio (2010), "Understanding the difficulty of training deep feedforward neural networks".
    """

    # 1) Compute fan_in and fan_out
    fan_in, fan_out = calculate_fan_in_and_fan_out(tensor)

    # 2) Compute standard deviation
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))

    # 3) Fill tensor
    with torch.no_grad():
        return tensor.normal_(0.0, std)

In [68]:
def kaiming_uniform_(tensor: torch.Tensor,
                     a: float = 0.0,
                     mode: str = 'fan_in',
                     nonlinearity: str = 'leaky_relu'):
    """
    Fills the input `tensor` with values drawn from U(-bound, bound)
    according to Kaiming/He initialization.
    """

    # 1) Compute fan_in and fan_out
    fan_in, fan_out = calculate_fan_in_and_fan_out(tensor)

    # 2) Choose correct fan
    if mode == 'fan_in':
        fan = fan_in
    elif mode == 'fan_out':
        fan = fan_out
    else:
        raise ValueError("mode should be 'fan_in' or 'fan_out'")

    # 3) Compute gain and bound
    gain = calculate_gain(nonlinearity, a)  
    bound = math.sqrt(3.0/fan) * gain

    # 5) Fill tensor
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

## Testing

In [72]:
gain = 3.0

#### Xavier Uniform

In [79]:
tensor = torch.empty(3, 5)
nn_tensor = torch.empty(3, 5)

torch.manual_seed(0)
nn.init.xavier_uniform_(tensor=tensor, gain=gain)
print(f'After init (nn): {nn_tensor}')

torch.manual_seed(0)
xavier_uniform_(tensor=tensor, gain=gain)
print(f'After init (custom): {tensor}')

After init (nn): tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
After init (custom): tensor([[-0.0195,  1.3937, -2.1383, -1.9120, -1.0007],
        [ 0.6967, -0.0515,  2.0600, -0.2306,  0.6875],
        [-0.7852, -0.5107, -2.4821, -1.7207, -1.0710]])


#### Xavier Normal

In [77]:
tensor = torch.empty(3, 5)
nn_tensor = torch.empty(3, 5)

torch.manual_seed(0)
nn.init.xavier_normal_(tensor=nn_tensor, gain=gain)
print(f'After init (nn): {nn_tensor}')

torch.manual_seed(0)
xavier_normal_(tensor=tensor, gain=gain)
print(f'After init (custom): {tensor}')

After init (nn): tensor([[ 2.3115, -0.4401, -3.2682,  0.8526, -1.6268],
        [-2.0979,  0.6050,  1.2570, -1.0789, -0.6050],
        [-0.8950,  0.2731, -1.2850,  1.6509, -1.6068]])
After init (custom): tensor([[ 2.3115, -0.4401, -3.2682,  0.8526, -1.6268],
        [-2.0979,  0.6050,  1.2570, -1.0789, -0.6050],
        [-0.8950,  0.2731, -1.2850,  1.6509, -1.6068]])


#### Kaiming

In [71]:
nonlinearity='sigmoid'
a=3

In [78]:
tensor = torch.empty(3, 5)
nn_tensor = torch.empty(3, 5)

torch.manual_seed(0)
nn.init.kaiming_uniform_(tensor=nn_tensor, a=a, mode='fan_out', nonlinearity=nonlinearity)
print(f'After init (nn): {nn_tensor}')

torch.manual_seed(0)
kaiming_uniform_(tensor=tensor, a=a, mode='fan_out', nonlinearity=nonlinearity)
print(f'After init (custom): {tensor}')


After init (nn): tensor([[-0.0075,  0.5364, -0.8230, -0.7359, -0.3852],
        [ 0.2682, -0.0198,  0.7929, -0.0887,  0.2646],
        [-0.3022, -0.1966, -0.9553, -0.6623, -0.4122]])
After init (custom): tensor([[-0.0075,  0.5364, -0.8230, -0.7359, -0.3852],
        [ 0.2682, -0.0198,  0.7929, -0.0887,  0.2646],
        [-0.3022, -0.1966, -0.9553, -0.6623, -0.4122]])
