In [1]:
import math
import torch

## What is the role of the bias term in a neural network?

Let's build a very simple linear layer, that can optionally take a bias term.

In [2]:
def linear(x, w, b=None, debug=False):
    if b is not None:
        a = (x @ w) + b
    else:
        a = (x @ w)
    if debug:
        print(a.shape)
    return a

### Forward pass

If we stack a few instances of this linear layer, the mean of the feature map computed after approximately 44 layers will have a mean and std equal to `nan`. This is called activation explosion.

In [3]:
bs, m = 1, 28 * 28
nh = 50

In [4]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) if i == 0 else torch.randn(nh, nh)
    
    a = linear(a, w)
    print(a.mean(), a.std())
    
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(2.0243) tensor(26.1706)
tensor(-9.6323) tensor(188.7720)
tensor(-575.9617) tensor(1162.1053)
tensor(-1039.7336) tensor(10019.8301)
tensor(-10346.8691) tensor(63094.7305)
tensor(-8526.0703) tensor(455544.7500)
tensor(-1102988.1250) tensor(3361917.5000)
tensor(227481.6875) tensor(25106340.)
tensor(32292486.) tensor(1.8682e+08)
tensor(2.2850e+08) tensor(1.3495e+09)
tensor(-1.2376e+09) tensor(9.3639e+09)
tensor(7.4606e+09) tensor(6.4331e+10)
tensor(-5.6087e+10) tensor(4.3354e+11)
tensor(-5.9011e+11) tensor(2.6942e+12)
tensor(-2.8687e+12) tensor(1.8670e+13)
tensor(1.6036e+13) tensor(1.3257e+14)
tensor(-9.9115e+12) tensor(9.5434e+14)
tensor(-1.4960e+15) tensor(6.2041e+15)
tensor(-4.5607e+15) tensor(3.8465e+16)
tensor(-5.3774e+16) tensor(1.9924e+17)
tensor(9.8150e+16) tensor(1.4602e+18)
tensor(3.1619e+17) tensor(1.0735e+19)
tensor(-1.4530e+19) tensor(6.6974e+19)
tensor(-3.5854e+19) tensor(4.8939e+20)
tensor(1.5195e+20) tensor(3.6402e+21)
tensor(-2.6961e+21) tensor(2.4501e+22)
tensor(4.

Adding a bias term, unfortunately does not help in keeping the variance more numerically stable, and therefore does not seem to help in training deeper networks.

In [5]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) if i == 0 else torch.randn(nh, nh)
    b = torch.randn(nh)

    a = linear(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-4.1558) tensor(22.5135)
tensor(-12.7045) tensor(157.3415)
tensor(116.9325) tensor(1014.1555)
tensor(165.3409) tensor(6892.7690)
tensor(2628.7520) tensor(42421.7539)
tensor(48158.7344) tensor(311417.0938)
tensor(-378727.4375) tensor(2563695.5000)
tensor(3641287.7500) tensor(17689588.)
tensor(-13041480.) tensor(1.3575e+08)
tensor(1.2193e+08) tensor(1.0693e+09)
tensor(-1.2086e+09) tensor(7.3240e+09)
tensor(1.0753e+10) tensor(4.7974e+10)
tensor(-1.1516e+11) tensor(3.2171e+11)
tensor(8.1565e+10) tensor(2.3479e+12)
tensor(2.8885e+12) tensor(1.6541e+13)
tensor(1.0997e+13) tensor(1.2849e+14)
tensor(7.0010e+13) tensor(9.0753e+14)
tensor(-2.8257e+14) tensor(5.6709e+15)
tensor(1.5141e+15) tensor(3.5614e+16)
tensor(4.9815e+15) tensor(2.2148e+17)
tensor(1.4019e+17) tensor(1.4494e+18)
tensor(9.9039e+17) tensor(9.9376e+18)
tensor(-2.6426e+19) tensor(6.0428e+19)
tensor(-3.5399e+18) tensor(4.4832e+20)
tensor(-2.8904e+20) tensor(3.0153e+21)
tensor(3.1508e+21) tensor(2.2223e+22)
tensor(5.3199e+21

Note, however, how the mean and std of the feature maps generated by every layer are still very large. This network, despite being deeper, will not be able to learn much.

### Is a good init sufficient to fix the issue?

Improving the init, completely fixes the activation explosion issue.

In [6]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) / math.sqrt(m) if i == 0 else torch.randn(nh, nh) / math.sqrt(m)  # xavier init
    b = torch.randn(nh)

    a = linear(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.3566) tensor(1.3213)
tensor(-0.1118) tensor(0.9710)
tensor(-0.0071) tensor(1.0783)
tensor(-0.0314) tensor(1.1460)
tensor(0.1210) tensor(0.9311)
tensor(0.1114) tensor(1.1002)
tensor(-0.2043) tensor(1.0311)
tensor(-0.0843) tensor(1.0027)
tensor(-0.0686) tensor(1.1469)
tensor(0.1537) tensor(1.0446)
tensor(-0.1908) tensor(1.0336)
tensor(0.0104) tensor(1.1801)
tensor(0.0464) tensor(1.0580)
tensor(0.3212) tensor(1.1358)
tensor(0.2729) tensor(1.0590)
tensor(-0.1149) tensor(0.8153)
tensor(-0.0500) tensor(0.9828)
tensor(-0.0392) tensor(0.9944)
tensor(0.0926) tensor(0.9233)
tensor(-0.2271) tensor(1.0482)
tensor(-0.2560) tensor(1.0281)
tensor(0.2026) tensor(1.2260)
tensor(0.0484) tensor(1.1781)
tensor(-0.1618) tensor(1.1089)
tensor(0.0450) tensor(1.2104)
tensor(-0.2285) tensor(1.0585)
tensor(-0.1079) tensor(1.0125)
tensor(0.0939) tensor(0.9085)
tensor(0.1066) tensor(0.9092)
tensor(-0.3181) tensor(0.9785)
tensor(0.3183) tensor(0.9863)
tensor(0.2722) tensor(0.8722)
tensor(-0.2136) tensor(0

Let's do an ablation study and try Xavier init without using the bias term.

In [7]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) / math.sqrt(m) if i == 0 else torch.randn(nh, nh) / math.sqrt(m)

    a = linear(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-0.2671) tensor(0.9502)
tensor(-0.0076) tensor(0.2757)
tensor(0.0101) tensor(0.0723)
tensor(0.0038) tensor(0.0186)
tensor(0.0007) tensor(0.0047)
tensor(5.4310e-05) tensor(0.0011)
tensor(1.4506e-05) tensor(0.0003)
tensor(1.5759e-05) tensor(7.7497e-05)
tensor(8.5075e-07) tensor(1.9835e-05)
tensor(-9.8762e-08) tensor(5.0685e-06)
tensor(6.1725e-08) tensor(1.2090e-06)
tensor(4.8311e-08) tensor(2.6407e-07)
tensor(-1.4101e-09) tensor(6.7252e-08)
tensor(-1.8381e-09) tensor(1.4054e-08)
tensor(-4.8925e-10) tensor(3.3937e-09)
tensor(-5.0788e-11) tensor(7.8351e-10)
tensor(1.5469e-11) tensor(1.8582e-10)
tensor(-1.8344e-12) tensor(5.6959e-11)
tensor(1.2314e-12) tensor(1.6890e-11)
tensor(1.0745e-13) tensor(3.9216e-12)
tensor(-1.0743e-14) tensor(8.8020e-13)
tensor(7.0873e-14) tensor(2.2054e-13)
tensor(-3.0059e-15) tensor(6.0513e-14)
tensor(4.1854e-15) tensor(1.4695e-14)
tensor(-1.1109e-15) tensor(3.8098e-15)
tensor(-4.4225e-17) tensor(9.7716e-16)
tensor(4.2663e-17) tensor(2.5967e-16)
tensor(-1.

Removing the bias term shows that just a good init is not enough. The purpose of the bias term is to improve numerical stability in the forward pass. Without a bias term, it is harder to train deep architectures.

In each layer of the forward pass, we are multiplying the previous layer's output activation by a set of weights close to zero. Repeating this step multiple times will lead to a set of output activation with a std so small that it clamps to `nan`.

In [8]:
1 * 0.01, 1 * 0.01 * 0.01, 1 * 0.01 * 0.01 * 0.01

(0.01, 0.0001, 1.0000000000000002e-06)

### TODO

- [ ] Does Xavier init make any assumption about the bias term?
- [ ] Plot activations for each layer
- [ ] Backward pass
- [ ] BatchNorm
- [ ] ReLu
- [ ] Conv layer