In [1]:
import math
import torch

## What is the role of the bias term in a neural network?

Let's build a very simple linear layer, that can optionally take a bias term.

In [2]:
def linear(x, w, b=None, debug=False):
    if b is not None:
        a = (x @ w) + b
    else:
        a = (x @ w)
    if debug:
        print(a.shape)
    return a

### Forward pass

If we stack a few instances of this linear layer, the mean of the feature map computed after approximately 44 layers will have a mean and std equal to `nan`. This is called activation explosion.

In [3]:
bs, m = 1, 28 * 28
nh = 50

In [4]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) if i == 0 else torch.randn(nh, nh)
    
    a = linear(a, w)
    print(a.mean(), a.std())
    
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-0.6582) tensor(25.3606)
tensor(11.4425) tensor(156.2503)
tensor(185.5521) tensor(1141.8094)
tensor(1349.2837) tensor(6914.5078)
tensor(6187.2715) tensor(48314.0352)
tensor(620.2056) tensor(279270.5000)
tensor(330917.7500) tensor(1975426.7500)
tensor(1520116.5000) tensor(15572713.)
tensor(27431668.) tensor(1.3274e+08)
tensor(-1.8001e+08) tensor(1.0131e+09)
tensor(-33159700.) tensor(5.9802e+09)
tensor(7.1075e+08) tensor(4.1234e+10)
tensor(5.2392e+10) tensor(3.0525e+11)
tensor(-1.2567e+10) tensor(2.3668e+12)
tensor(-2.0088e+12) tensor(1.4824e+13)
tensor(-2.2484e+13) tensor(1.2909e+14)
tensor(-4.0457e+13) tensor(8.5373e+14)
tensor(-1.2641e+15) tensor(6.6426e+15)
tensor(3.6104e+14) tensor(4.7668e+16)
tensor(-1.0594e+16) tensor(3.7612e+17)
tensor(9.7311e+15) tensor(2.6672e+18)
tensor(-9.0080e+17) tensor(1.8061e+19)
tensor(2.2738e+19) tensor(1.1207e+20)
tensor(5.2459e+19) tensor(8.0268e+20)
tensor(-2.9992e+20) tensor(6.5404e+21)
tensor(6.4730e+21) tensor(3.8716e+22)
tensor(-4.9422e+22

Adding a bias term, unfortunately does not help in keeping the variance more numerically stable, and therefore does not seem to help in training deeper networks.

In [5]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) if i == 0 else torch.randn(nh, nh)
    b = torch.randn(nh)

    a = linear(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-4.5790) tensor(25.1622)
tensor(22.9265) tensor(167.6310)
tensor(-4.9820) tensor(1242.3936)
tensor(1681.5902) tensor(8887.5879)
tensor(-5290.4561) tensor(85838.1641)
tensor(117315.8594) tensor(716983.6250)
tensor(-460480.7500) tensor(5826601.5000)
tensor(665238.7500) tensor(38150776.)
tensor(-7813630.5000) tensor(2.8585e+08)
tensor(-1.9672e+08) tensor(2.0505e+09)
tensor(-2.6880e+09) tensor(1.5786e+10)
tensor(2.4715e+10) tensor(1.0402e+11)
tensor(8.2434e+10) tensor(8.1714e+11)
tensor(-1.0596e+12) tensor(5.2401e+12)
tensor(-5.7354e+12) tensor(3.2944e+13)
tensor(8.1868e+13) tensor(2.6975e+14)
tensor(-2.1753e+14) tensor(2.1331e+15)
tensor(-1.5150e+15) tensor(1.4056e+16)
tensor(2.8143e+16) tensor(1.0459e+17)
tensor(-1.8545e+16) tensor(7.7766e+17)
tensor(-1.6651e+17) tensor(5.5947e+18)
tensor(-4.1283e+18) tensor(4.2637e+19)
tensor(2.0435e+19) tensor(2.8666e+20)
tensor(-4.3687e+20) tensor(1.9931e+21)
tensor(8.7775e+20) tensor(1.3401e+22)
tensor(-4.0719e+21) tensor(8.9700e+22)
tensor(6.

Note, however, how the mean and std of the feature maps generated by every layer are still very large. This network, despite being deeper, will not be able to learn much.

### Is a good init sufficient to fix the issue?

Improving the init, completely fixes the activation explosion issue.

In [6]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) / math.sqrt(m) if i == 0 else torch.randn(nh, nh) / math.sqrt(m)  # xavier init
    b = torch.randn(nh)

    a = linear(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-0.0405) tensor(1.4605)
tensor(-0.1080) tensor(1.0736)
tensor(-0.2777) tensor(1.1102)
tensor(-0.0034) tensor(1.0211)
tensor(-0.1301) tensor(1.1686)
tensor(-0.0411) tensor(1.0003)
tensor(-0.0158) tensor(0.9847)
tensor(-0.0138) tensor(0.9824)
tensor(0.0689) tensor(1.0759)
tensor(-0.0387) tensor(1.0492)
tensor(-0.1385) tensor(1.0854)
tensor(-0.0245) tensor(1.0515)
tensor(-0.0124) tensor(1.1625)
tensor(0.1718) tensor(1.1672)
tensor(0.0484) tensor(0.8471)
tensor(-0.0323) tensor(0.8489)
tensor(-0.1236) tensor(0.9071)
tensor(0.0276) tensor(0.9205)
tensor(0.0013) tensor(1.0657)
tensor(0.1014) tensor(1.0310)
tensor(-0.2713) tensor(1.0612)
tensor(0.0652) tensor(0.9682)
tensor(0.1583) tensor(0.8461)
tensor(0.1868) tensor(0.9837)
tensor(-0.0121) tensor(0.9486)
tensor(0.2276) tensor(1.0550)
tensor(0.1102) tensor(1.0480)
tensor(0.0410) tensor(1.0528)
tensor(-0.1705) tensor(0.9793)
tensor(-0.2110) tensor(1.0383)
tensor(-0.0044) tensor(1.0173)
tensor(0.0357) tensor(0.9955)
tensor(-0.0852) tenso

Let's do an ablation study and try Xavier init without using the bias term.

In [7]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) / math.sqrt(m) if i == 0 else torch.randn(nh, nh) / math.sqrt(m)

    a = linear(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.0287) tensor(0.9324)
tensor(-0.0291) tensor(0.2501)
tensor(0.0012) tensor(0.0751)
tensor(-9.2146e-05) tensor(0.0215)
tensor(-0.0005) tensor(0.0053)
tensor(0.0002) tensor(0.0015)
tensor(0.0001) tensor(0.0004)
tensor(-1.4655e-05) tensor(9.0919e-05)
tensor(7.6926e-08) tensor(2.1277e-05)
tensor(5.4413e-07) tensor(5.6696e-06)
tensor(-2.5233e-08) tensor(1.5630e-06)
tensor(-4.9312e-08) tensor(3.6550e-07)
tensor(-1.0862e-08) tensor(9.0942e-08)
tensor(4.8076e-09) tensor(2.1733e-08)
tensor(6.8163e-10) tensor(5.7634e-09)
tensor(1.2298e-10) tensor(1.3088e-09)
tensor(-4.0085e-11) tensor(3.2561e-10)
tensor(-9.4507e-12) tensor(1.0271e-10)
tensor(4.5394e-12) tensor(2.3110e-11)
tensor(-6.9884e-13) tensor(6.0443e-12)
tensor(2.3810e-14) tensor(1.6117e-12)
tensor(-8.4391e-14) tensor(3.6230e-13)
tensor(8.3668e-15) tensor(1.0140e-13)
tensor(2.4415e-15) tensor(2.9122e-14)
tensor(-1.3079e-15) tensor(7.3537e-15)
tensor(-2.1614e-16) tensor(1.8758e-15)
tensor(1.6323e-17) tensor(4.8513e-16)
tensor(-3.009

Removing the bias term shows that just a good init is not enough. The purpose of the bias term is to improve numerical stability in the forward pass. Without a bias term, it is harder to train deep architectures.

In each layer of the forward pass, we are multiplying the previous layer's output activation by a set of weights close to zero. Repeating this step multiple times will lead to a set of output activation with a std so small that it clamps to `nan`.

In [8]:
1 * 0.01, 1 * 0.01 * 0.01, 1 * 0.01 * 0.01 * 0.01

(0.01, 0.0001, 1.0000000000000002e-06)

### TODO

- [ ] Does Xavier init make any assumption about the bias term?
- [ ] Plot activations for each layer
- [ ] Backward pass
- [ ] BatchNorm
- [ ] ReLu
- [ ] Conv layer