In [1]:
import math

In [2]:
import torch

## What is the role of the bias term in a neural network?

Let's build a very simple linear layer, that can optionally take a bias term.

In [3]:
def lin(x, w, b=None, debug=False):
    if debug:
        print(x.shape, w.shape, b.shape)
    if b is not None:
        return (x @ w) + b
    else:
        return x @ w

If we stack a few instances of this linear layer, the mean of the feature map computed after just two layers will have a mean very far from zero and inf/nan std. This is called activation explosion.

In [25]:
a = torch.randn(512,512)  # input image

for i in range(100): 
    w = torch.randn(512)  # init weights for layer i
    
    a = lin(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-0.6111) tensor(22.8972)
tensor(-719.8424) tensor(nan)
Numerical instability at layer 1


The bias term helps to keep the variance more numerically stable, thus allowing us to train deeper networks.

In [26]:
a = torch.randn(512,512)  # input image


for i in range(100): 
    w = torch.randn(512)
    b = torch.zeros(512)

    a = lin(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.3053) tensor(21.7885)
tensor(78.3094) tensor(0.)
tensor(-587.9160) tensor(0.)
tensor(-20112.5020) tensor(0.)
tensor(51898.5625) tensor(0.)
tensor(-123345.1250) tensor(0.)
tensor(-1790769.7500) tensor(0.3754)
tensor(-35605408.) tensor(0.)
tensor(-96858832.) tensor(8.0078)
tensor(-2.3924e+09) tensor(512.5007)
tensor(6.5086e+09) tensor(0.)
tensor(2.3188e+09) tensor(0.)
tensor(-5.6051e+10) tensor(16400.0234)
tensor(9.9710e+11) tensor(65600.0938)
tensor(-5.0624e+11) tensor(0.)
tensor(1.3642e+13) tensor(0.)
tensor(3.6873e+14) tensor(0.)
tensor(-1.9596e+15) tensor(4.0305e+08)
tensor(-1.4032e+16) tensor(0.)
tensor(-1.7219e+16) tensor(1.0748e+09)
tensor(3.8877e+17) tensor(3.4393e+10)
tensor(-6.3708e+18) tensor(0.)
tensor(-2.6871e+20) tensor(0.)
tensor(-5.3975e+20) tensor(0.)
tensor(-5.6857e+21) tensor(0.)
tensor(-1.7724e+23) tensor(3.6064e+16)
tensor(6.9813e+24) tensor(5.7702e+17)
tensor(-1.6895e+25) tensor(3.4621e+18)
tensor(-5.8620e+26) tensor(1.4772e+20)
tensor(-4.6679e+26) tensor(0

Note, however, how the mean and std of the feature maps generate by every layer are still very large. This network, despite being deeper, will not be able to learn much.

## Is good init sufficient to fix the issue?

Improving the init further helps delaying the activation explosion issues.

In [12]:
a = torch.randn(512,512)  # input image

for i in range(100): 
    # init params for every layer
    w = torch.randn(512) / math.sqrt(512)  # xavier init
    b = torch.zeros(512)

    a = lin(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.0038) tensor(1.0265)
tensor(0.4211) tensor(1.4916e-07)
tensor(0.6980) tensor(0.)
tensor(0.0290) tensor(1.1187e-08)
tensor(-0.0189) tensor(1.8645e-09)
tensor(0.0026) tensor(6.9918e-10)
tensor(-0.0021) tensor(6.9918e-10)
tensor(-0.0001) tensor(0.)
tensor(-8.8191e-05) tensor(7.2831e-12)
tensor(8.8952e-06) tensor(2.7312e-12)
tensor(7.2592e-06) tensor(4.5519e-13)
tensor(2.9686e-06) tensor(2.2760e-13)
tensor(4.9048e-06) tensor(0.)
tensor(8.9878e-06) tensor(0.)
tensor(7.3327e-07) tensor(0.)
tensor(-4.6926e-07) tensor(8.5349e-14)
tensor(-1.0211e-06) tensor(0.)
tensor(4.9381e-07) tensor(0.)
tensor(1.6887e-07) tensor(0.)
tensor(3.1682e-08) tensor(3.5562e-15)
tensor(-2.0901e-08) tensor(1.7781e-15)
tensor(-6.5610e-09) tensor(1.7781e-15)
tensor(-5.0126e-09) tensor(0.)
tensor(-1.1444e-08) tensor(8.8905e-16)
tensor(2.0539e-08) tensor(3.5562e-15)
tensor(4.3722e-09) tensor(0.)
tensor(-7.8328e-11) tensor(0.)
tensor(-1.5745e-11) tensor(0.)
tensor(-7.1695e-12) tensor(1.3023e-18)
tensor(3.5638e-12

Let's do an ablation study and try Xavier init without using the bias term.

In [27]:
a = torch.randn(512,512)  # input image

for i in range(100): 
    w = torch.randn(512) / math.sqrt(512)  # xavier init

    a = lin(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-0.0595) tensor(1.0453)
tensor(-0.3266) tensor(nan)
Numerical instability at layer 1


Removing the bias term shows that just a good init is not enough. The purpose of the bias term is to improve numerical stability in the forward pass. Without a bias term, it is harder to train deep architectures.