In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch.nn.functional as F
from torch import nn
from torch.nn import init
from torchvision import transforms
from torchvision.datasets import MNIST

### Where does the default parameter `a=math.sqrt(5)` for `kaiming_uniform_` come from?

The default init strategy for Conv2d layers in PyTorch is Kaiming uniform.

In [3]:
nn.modules.conv._ConvNd.reset_parameters??

[0;31mSignature:[0m [0mnn[0m[0;34m.[0m[0mmodules[0m[0;34m.[0m[0mconv[0m[0;34m.[0m[0m_ConvNd[0m[0;34m.[0m[0mreset_parameters[0m[0;34m([0m[0mself[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;32mdef[0m [0mreset_parameters[0m[0;34m([0m[0mself[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;31m# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with[0m[0;34m[0m
[0;34m[0m        [0;31m# uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)[0m[0;34m[0m
[0;34m[0m        [0;31m# For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573[0m[0;34m[0m
[0;34m[0m        [0minit[0m[0;34m.[0m[0mkaiming_uniform_[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mweight[0m[0;34m,[0m [0ma[0m[0;34m=[0m[0mmath[0m[0;34m.[0m[0msqrt[0m[0;34m([0m

In [4]:
nn.init.kaiming_uniform_??

[0;31mSignature:[0m
[0mnn[0m[0;34m.[0m[0minit[0m[0;34m.[0m[0mkaiming_uniform_[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0ma[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmode[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'fan_in'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnonlinearity[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'leaky_relu'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mkaiming_uniform_[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mTensor[0m[0;34m,[0m [0ma[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m [0mmode[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'fan_in'[0m[0;34m,[0m [0mnonlinearity[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'leaky_relu'[0m[0;34m[0

Interestingly, the default value assumed for the negative slope of the rectifier unit seems arbitrary. Where does `a=math.sqrt(5)` come from? Does it work well in practice? Let's find out...

First, let's prepare a dataset we can use to test our assumptions. In this case, we are going to use the popular MNIST dataset.

In [5]:
dataset = MNIST('/workspace/data/', download=True, transform=transforms.ToTensor())
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: /workspace/data/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [6]:
n_train = 50_000
n_valid = dataset.data.shape[0] - n_train
x_train, y_train = dataset.data[:n_train, :, :].view(n_train, -1) / 255, dataset.targets[:n_train]
x_valid, y_valid = dataset.data[n_train:, :, :].view(n_valid, -1) / 255, dataset.targets[n_train:]

In [7]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

(tensor(0.1310), tensor(0.3085))

In [8]:
def normalize(x, m, s):
    return (x - m) / s

In [9]:
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

Since we are going to use an architecture that use convolutional layers, we are going to format the data using a `b,c,w,h` format. We know that images in the MNIST dataset are grayscale image of size 28x28. Thus, the training set should have a shape of `(50_000, 1, 28, 28)`.

In [10]:
x_train = x_train.view(-1, 1, 28, 28)
x_valid = x_valid.view(-1, 1, 28, 28)
x_train.shape, x_valid.shape

(torch.Size([50000, 1, 28, 28]), torch.Size([10000, 1, 28, 28]))

In [11]:
n, *_ = x_train.shape
c = y_train.max() + 1
nh = 32
n, c

(50000, tensor(10))

Let's create a `Conv2d` layer, and inspect its statistics.

In [12]:
l1 = nn.Conv2d(1, nh, 5)

In [13]:
l1.weight.shape, l1.bias.shape

(torch.Size([32, 1, 5, 5]), torch.Size([32]))

The `Conv2d` layer we just created has 32 single-channel 5x5 filters, and we have one bias unit for every filter.

In [14]:
def stats(x):
    return x.mean(), x.std()

In [15]:
stats(l1.weight), stats(l1.bias)

((tensor(-0.0039, grad_fn=<MeanBackward0>),
  tensor(0.1187, grad_fn=<StdBackward0>)),
 (tensor(-0.0105, grad_fn=<MeanBackward0>),
  tensor(0.1206, grad_fn=<StdBackward0>)))

The mean and std of the randomly initialized weights  and biases are centered at zero with a std around 0.11.

What about the output activation? Let's try to pass the first 100 images of the validation set through the `Conv2d` layer...

In [16]:
x = x_valid[:100]
x.shape

torch.Size([100, 1, 28, 28])

In [17]:
t = l1(x)

In [18]:
stats(t)

(tensor(-0.0201, grad_fn=<MeanBackward0>),
 tensor(0.6267, grad_fn=<StdBackward0>))

The mean is close to zero, but the std is not close to 1. This is not ideal.

### Kaiming normal

Let's check if Kaiming normal works better. If you remember, Kaiming normal is designed to be used after a ReLU or Leaky ReLU layer.

In [19]:
init.kaiming_normal_??

[0;31mSignature:[0m
[0minit[0m[0;34m.[0m[0mkaiming_normal_[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0ma[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmode[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'fan_in'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnonlinearity[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'leaky_relu'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mkaiming_normal_[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtensor[0m[0;34m:[0m [0mTensor[0m[0;34m,[0m [0ma[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m [0mmode[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'fan_in'[0m[0;34m,[0m [0mnonlinearity[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'leaky_relu'[0m[0;34m[0m
[0;34m[0m[0;34m)[0

What is `a`? The docstring says "the negative slope of the rectifier used after this layer". In other words, it is the slope of the activation function for every value below 0. For a linear function, such slope is 1. For a ReLU, the value is 0. For a Leaky ReLU, the value is generally 0.01.

![what is a](../imgs/what_is_a.png)

In this case, we are not using a non-linearity, thus a negative slope of `a=1` is appropriate.

In [20]:
init.kaiming_normal_(l1.weight, a=1.)
stats(l1(x))

(tensor(0.0120, grad_fn=<MeanBackward0>),
 tensor(1.0885, grad_fn=<StdBackward0>))

The variance of the output activation is now closer to 1. This is better than what we would have by using the default init method.

One thing to notice is that when we tried Kaiming uniform, we used a value of `a=0`. Such value assumes a `ReLU` activation. Let's try to fix this mistake and set `a=1` also for Kaiming uniform. The resulting statistics for the output activation are better, although not as good as the one generated by Kaiming normal.

In [21]:
init.kaiming_uniform_(l1.weight, a=1.)
stats(l1(x))

(tensor(0.0069, grad_fn=<MeanBackward0>),
 tensor(1.0962, grad_fn=<StdBackward0>))

In [22]:
def f1(x, a=0):
    return F.leaky_relu(l1(x), a)

In [23]:
init.kaiming_normal_(l1.weight, a=0)
stats(f1(x))

(tensor(0.5376, grad_fn=<MeanBackward0>),
 tensor(1.1020, grad_fn=<StdBackward0>))

The first thing to notice is that the mean is closer to 0.5 instead of 0. This is because we are using a ReLU activation, which convert every negative value to 0. Because of that, the mean will be closer to 0.5.

In [24]:
l1 = nn.Conv2d(1, nh, 5)
stats(f1(x))

(tensor(0.2195, grad_fn=<MeanBackward0>),
 tensor(0.3717, grad_fn=<StdBackward0>))

Remember that in the default init for every `Conv2d` layer, we are using `a=math.sqrt(5)`. As you can see, this value is highly suboptimal, as it leads to a mean far from 0.5 ― remember, we are using a ReLU so the mean is shifted towards 0.5 ― and the std is far from 1.

In [25]:
l1 = nn.Conv2d(1, nh, 5)
init.kaiming_uniform_(l1.weight, a=0)
stats(f1(x))

(tensor(0.4529, grad_fn=<MeanBackward0>),
 tensor(0.8546, grad_fn=<StdBackward0>))

Also in this case, the Kaiming uniform init performs similarly to Kaiming normal when initialized with the correct value of `a`.