# Parameter Initialization

Now that we know how to access the parameters,
let's look at how to initialize them properly.
We discussed the need for proper initialization in :numref:`sec_numerical_stability`.
The deep learning framework provides default random initializations to its layers.
However, we often want to initialize our weights
according to various other protocols. The framework provides most commonly
used protocols, and also allows to create a custom initializer.


By default, PyTorch initializes weight and bias matrices
uniformly by drawing from a range that is computed according to the input and output dimension.
PyTorch's `nn.init` module provides a variety
of preset initialization methods.


In [None]:
import jax
from jax import numpy as jnp, random, grad, vmap, jit
import flax.linen as nn

net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = random.uniform(random.PRNGKey(0), (2, 4))
output, params = net.init_with_output(random.PRNGKey(1), X)
output.shape

## [**Built-in Initialization**]

Let's begin by calling on built-in initializers.
The code below initializes all weight parameters
as Gaussian random variables
with standard deviation 0.01, while bias parameters cleared to zero.


In [2]:
from functools import partial

DenseInit = partial(nn.Dense,
                    kernel_init=nn.initializers.normal(0.01),
                    bias_init=nn.initializers.zeros)

net = nn.Sequential([DenseInit(8), nn.relu, DenseInit(1)])
output, params = net.init_with_output(random.PRNGKey(2), X)
params['params']['layers_0']['kernel'][0], params['params']['layers_0']['bias'][0]

(DeviceArray([-6.9597606e-03, -7.1280561e-03, -5.9801121e-05,
               1.4234671e-02,  5.1514809e-03,  2.1050077e-02,
              -4.5213206e-03, -4.1506551e-03], dtype=float32),
 DeviceArray(0., dtype=float32))

We can also initialize all the parameters
to a given constant value (say, 1).


In [3]:
DenseConstant = partial(nn.Dense,
                        kernel_init=nn.initializers.constant(1),
                        bias_init=nn.initializers.zeros)

net = nn.Sequential([DenseConstant(8), nn.relu, DenseConstant(1)])
output, params = net.init_with_output(random.PRNGKey(3), X)
params['params']['layers_0']['kernel'][0], params['params']['layers_0']['bias'][0]

(DeviceArray([1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32),
 DeviceArray(0., dtype=float32))

[**We can also apply different initializers for certain blocks.**]
For example, below we initialize the first layer
with the Xavier initializer
and initialize the second layer
to a constant value of 42.


In [4]:
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform()),
                     nn.relu,
                     nn.Dense(8, kernel_init=nn.initializers.constant(42))])

params = net.init(random.PRNGKey(4), X)

print(params['params']['layers_0']['kernel'][0])
print(params['params']['layers_2']['kernel'][0])

[ 0.6330256   0.31407022 -0.6485304   0.6649149  -0.2651401   0.05750887
  0.3933842   0.5046952 ]
[42. 42. 42. 42. 42. 42. 42. 42.]


### [**Custom Initialization**]

Sometimes, the initialization methods we need
are not provided by the deep learning framework.
In the example below, we define an initializer
for any weight parameter $w$ using the following strange distribution:

$$
\begin{aligned}
    w \sim \begin{cases}
        U(5, 10) & \text{ with probability } \frac{1}{4} \\
            0    & \text{ with probability } \frac{1}{2} \\
        U(-10, -5) & \text{ with probability } \frac{1}{4}
    \end{cases}
\end{aligned}
$$


Again, we implement a `my_init` function to apply to `net`.


In [5]:
def my_init(key, shape, dtype):
    print('Init', shape, dtype)
    data = random.uniform(key, shape, minval=-10, maxval=10)
    factor = (jnp.abs(data) >= 5).astype(dtype)
    return data * factor

net = nn.Sequential([nn.Dense(4, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(random.PRNGKey(5), X)
params['params']['layers_0']['kernel'][:2]

Init (4, 4) <class 'jax.numpy.float32'>


DeviceArray([[-0.       , -0.       , -9.521742 , -7.9924726],
             [ 8.898151 , -0.       , -0.       ,  9.311113 ]],            dtype=float32)

Note that we always have the option
of setting parameters directly.


In [6]:
from flax.core import unfreeze, freeze

params = unfreeze(params)
kernel = params['params']['layers_0']['kernel']
kernel = kernel.at[:].add(1)
kernel = kernel.at[0, 0].set(42)

params['params']['layers_0']['kernel'] = kernel
params = freeze(params)

params['params']['layers_0']['kernel']

DeviceArray([[42.       ,  1.       , -8.521742 , -6.9924726],
             [ 9.898151 ,  1.       ,  1.       , 10.311113 ],
             [ 1.       ,  1.       ,  1.       ,  1.       ],
             [-5.4702964,  1.       , 10.669337 ,  8.386413 ]],            dtype=float32)

## Summary

We can initialize parameters using built-in and custom initializers.

## Exercises

Look up the online documentation for more built-in initializers.


[Discussions](https://discuss.d2l.ai/t/8090)
