In [1]:
import math

In [2]:
import torch

## What is the role of the bias term in a neural network?

Let's build a very simple linear layer, that can optionally take a bias term.

In [15]:
def lin(x, w, b=None, debug=False):
    if b is not None:
        a = (x @ w) + b
    else:
        a = (x @ w)
    if debug:
        print(a.shape)
    return a

### Forward pass

If we stack a few instances of this linear layer, the mean of the feature map computed after just two layers will have a mean very far from zero and inf/nan std. This is called activation explosion.

In [16]:
bs, m = 1, 28*28
nh = 50

In [17]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) if i == 0 else torch.randn(nh, nh)
    
    a = lin(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-1.6022) tensor(32.2359)
tensor(25.5023) tensor(194.1806)
tensor(133.4860) tensor(1383.7087)
tensor(1224.4739) tensor(8681.0977)
tensor(-17732.4043) tensor(67681.3516)
tensor(429.2725) tensor(521368.6875)
tensor(-34658.0195) tensor(3407623.)
tensor(8298728.5000) tensor(27529012.)
tensor(15308513.) tensor(1.9155e+08)
tensor(-4.4268e+08) tensor(1.5578e+09)
tensor(-6.2623e+08) tensor(9.4120e+09)
tensor(-2.0260e+09) tensor(6.2145e+10)
tensor(5.8645e+10) tensor(3.9950e+11)
tensor(-1.0380e+11) tensor(2.5730e+12)
tensor(2.1449e+12) tensor(1.7724e+13)
tensor(-3.0679e+12) tensor(1.0827e+14)
tensor(7.2502e+13) tensor(7.6339e+14)
tensor(4.9434e+13) tensor(5.3466e+15)
tensor(-2.7428e+15) tensor(3.8353e+16)
tensor(2.5686e+16) tensor(2.3944e+17)
tensor(7.3178e+16) tensor(1.8415e+18)
tensor(4.2566e+16) tensor(1.2829e+19)
tensor(-8.2446e+17) tensor(8.3547e+19)
tensor(1.3131e+19) tensor(6.9450e+20)
tensor(-4.0995e+19) tensor(5.1958e+21)
tensor(-5.8406e+21) tensor(3.6376e+22)
tensor(-2.1548e+22) 

The bias term helps to keep the variance more numerically stable, thus allowing us to train deeper networks.

In [6]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) if i == 0 else torch.randn(nh, nh)
    b = torch.randn(nh)

    a = lin(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.2948) tensor(25.2224)
tensor(11.3351) tensor(157.6096)
tensor(39.7886) tensor(888.9370)
tensor(305.9420) tensor(5561.0532)
tensor(-4947.6919) tensor(42374.2031)
tensor(-45642.4648) tensor(231599.9531)
tensor(-269654.5938) tensor(1517741.1250)
tensor(-1216053.5000) tensor(10389182.)
tensor(-696933.7500) tensor(77055200.)
tensor(-41839628.) tensor(5.1609e+08)
tensor(2.3685e+08) tensor(3.5630e+09)
tensor(5.2460e+09) tensor(2.5518e+10)
tensor(2.0887e+10) tensor(1.6205e+11)
tensor(1.2260e+11) tensor(1.0795e+12)
tensor(9.1355e+11) tensor(5.9297e+12)
tensor(-6.1938e+11) tensor(4.4245e+13)
tensor(7.6832e+12) tensor(3.0600e+14)
tensor(2.9639e+14) tensor(2.1948e+15)
tensor(-9.9300e+14) tensor(1.7176e+16)
tensor(-1.4316e+16) tensor(1.1445e+17)
tensor(1.2421e+16) tensor(8.1405e+17)
tensor(9.5092e+17) tensor(5.7035e+18)
tensor(-3.8658e+16) tensor(3.5691e+19)
tensor(3.0521e+19) tensor(2.1759e+20)
tensor(5.3957e+20) tensor(1.4464e+21)
tensor(1.4418e+21) tensor(1.1056e+22)
tensor(6.8730e+21) 

Note, however, how the mean and std of the feature maps generate by every layer are still very large. This network, despite being deeper, will not be able to learn much.

### Is a good init sufficient to fix the issue?

Improving the init further helps delaying the activation explosion issues.

In [11]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) / math.sqrt(m) if i == 0 else torch.randn(nh, nh) / math.sqrt(m)  # xavier init
    b = torch.randn(nh)

    a = lin(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.0711) tensor(1.4957)
tensor(-0.0511) tensor(1.0095)
tensor(-0.4422) tensor(1.1059)
tensor(0.1277) tensor(1.0596)
tensor(-0.3985) tensor(1.3124)
tensor(0.1857) tensor(1.0094)
tensor(-0.0760) tensor(0.9465)
tensor(0.0490) tensor(1.1492)
tensor(0.1703) tensor(0.9531)
tensor(0.0941) tensor(0.9850)
tensor(0.2780) tensor(1.1059)
tensor(0.3568) tensor(1.1376)
tensor(-0.1717) tensor(1.1406)
tensor(0.1412) tensor(1.0808)
tensor(-0.0173) tensor(0.9722)
tensor(-0.0411) tensor(1.0436)
tensor(-0.1327) tensor(1.0183)
tensor(-0.0854) tensor(0.9669)
tensor(0.0919) tensor(1.0189)
tensor(0.0479) tensor(1.0539)
tensor(0.0948) tensor(0.9043)
tensor(-0.0322) tensor(0.8762)
tensor(-0.0547) tensor(1.0981)
tensor(0.0137) tensor(1.1735)
tensor(0.0523) tensor(0.9133)
tensor(-0.0798) tensor(0.9836)
tensor(-0.0092) tensor(1.0789)
tensor(0.0777) tensor(0.9722)
tensor(-0.2081) tensor(1.0408)
tensor(0.2621) tensor(1.0460)
tensor(-0.0367) tensor(1.1459)
tensor(0.1246) tensor(0.8488)
tensor(-0.0796) tensor(0.

Let's do an ablation study and try Xavier init without using the bias term.

In [8]:
a = torch.randn(bs, m)  # input image

for i in range(100): 
    w = torch.randn(m, nh) / math.sqrt(m) if i == 0 else torch.randn(nh, nh) / math.sqrt(m)

    a = lin(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.1339) tensor(0.9245)
tensor(0.0319) tensor(0.2444)
tensor(-0.0045) tensor(0.0622)
tensor(0.0019) tensor(0.0140)
tensor(-0.0004) tensor(0.0032)
tensor(-4.7454e-05) tensor(0.0008)
tensor(2.6425e-06) tensor(0.0002)
tensor(1.3865e-06) tensor(5.3755e-05)
tensor(3.1186e-06) tensor(1.4076e-05)
tensor(-3.4377e-08) tensor(3.8901e-06)
tensor(-9.4254e-08) tensor(9.7954e-07)
tensor(1.7289e-08) tensor(2.7845e-07)
tensor(-4.1952e-09) tensor(6.8295e-08)
tensor(-8.0264e-10) tensor(1.6526e-08)
tensor(2.3847e-10) tensor(4.0860e-09)
tensor(-1.3845e-10) tensor(9.3302e-10)
tensor(1.5291e-11) tensor(2.4523e-10)
tensor(-7.7738e-12) tensor(6.5599e-11)
tensor(3.3116e-12) tensor(1.6463e-11)
tensor(7.6094e-13) tensor(3.7808e-12)
tensor(-2.5715e-14) tensor(8.1276e-13)
tensor(-4.0869e-14) tensor(2.1212e-13)
tensor(-5.1578e-16) tensor(6.2611e-14)
tensor(4.0379e-15) tensor(1.6626e-14)
tensor(9.6956e-16) tensor(3.9063e-15)
tensor(-1.5877e-16) tensor(9.1845e-16)
tensor(1.3143e-17) tensor(2.4905e-16)
tensor(-6

Removing the bias term shows that just a good init is not enough. The purpose of the bias term is to improve numerical stability in the forward pass. Without a bias term, it is harder to train deep architectures.

In each layer of the forward pass, we are multiplying the previous layer's output activation by a set of weights close to zero. Repeating this step multiple times will lead to a set of output activation with a std so small that it clamps to `nan`.

In [9]:
1 * 0.01, 1 * 0.01 * 0.01, 1 * 0.01 * 0.01 * 0.01

(0.01, 0.0001, 1.0000000000000002e-06)

### TODO

- [ ] Does Xavier init make any assumption about the bias term?
- [ ] Plot activations for each layer
- [ ] Backward pass
- [ ] Conv layer
- [ ] BatchNorm