In [1]:
import math

In [2]:
import torch

## What is the role of the bias term in a neural network?

Let's build a very simple linear layer, that can optionally take a bias term.

In [3]:
def lin(x, w, b=None, debug=False):
    if debug:
        print(x.shape, w.shape, b.shape)
    if b is not None:
        return (x @ w) + b
    else:
        return x @ w

If we stack a few instances of this linear layer, the mean of the feature map computed after just two layers will have a mean very far from zero and inf/nan std. This is called activation explosion.

In [4]:
a = torch.randn(512,512)  # input image

for i in range(100): 
    w = torch.randn(512)  # init weights for layer i
    
    a = lin(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(-0.0630) tensor(21.4497)
tensor(-169.5533) tensor(nan)
Numerical instability at layer 1


The bias term helps to keep the variance more numerically stable, thus allowing us to train deeper networks.

In [5]:
a = torch.randn(512,512)  # input image


for i in range(100): 
    w = torch.randn(512)
    b = torch.zeros(512)

    a = lin(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(1.1425) tensor(22.2082)
tensor(7.0174) tensor(0.)
tensor(-67.3624) tensor(2.2911e-05)
tensor(110.9664) tensor(2.2911e-05)
tensor(1554.6863) tensor(0.)
tensor(-21929.1367) tensor(0.)
tensor(130204.3828) tensor(0.0235)
tensor(-811657.5625) tensor(0.0626)
tensor(21114270.) tensor(4.0039)
tensor(2.1871e+08) tensor(16.0156)
tensor(6.2728e+09) tensor(512.5007)
tensor(-1.4847e+11) tensor(0.)
tensor(1.2078e+11) tensor(0.)
tensor(1.7717e+12) tensor(524800.7500)
tensor(-2.8269e+13) tensor(8396812.)
tensor(-5.0848e+14) tensor(2.0152e+08)
tensor(-1.9954e+15) tensor(4.0305e+08)
tensor(1.2028e+16) tensor(0.)
tensor(3.7500e+17) tensor(0.)
tensor(2.9286e+18) tensor(2.7515e+11)
tensor(3.0962e+19) tensor(8.8047e+12)
tensor(4.1758e+20) tensor(3.5219e+13)
tensor(8.0624e+21) tensor(5.6350e+14)
tensor(-7.9024e+22) tensor(0.)
tensor(1.3301e+24) tensor(0.)
tensor(-3.1645e+24) tensor(8.6554e+17)
tensor(-7.6284e+25) tensor(2.7697e+19)
tensor(5.1303e+26) tensor(1.4772e+20)
tensor(-1.5459e+28) tensor(3.545

Note, however, how the mean and std of the feature maps generate by every layer are still very large. This network, despite being deeper, will not be able to learn much.

## Is good init sufficient to fix the issue?

Improving the init further helps delaying the activation explosion issues.

In [6]:
a = torch.randn(512,512)  # input image

for i in range(100): 
    # init params for every layer
    w = torch.randn(512) / math.sqrt(512)  # xavier init
    b = torch.zeros(512)

    a = lin(a, w, b)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.0601) tensor(0.9640)
tensor(1.2635) tensor(0.)
tensor(-1.6608) tensor(2.3865e-07)
tensor(-1.7729) tensor(1.1933e-07)
tensor(4.8575) tensor(9.5461e-07)
tensor(1.2564) tensor(3.5798e-07)
tensor(-1.3141) tensor(1.1933e-07)
tensor(-0.8528) tensor(1.7899e-07)
tensor(-0.5903) tensor(1.7899e-07)
tensor(-0.3502) tensor(8.9494e-08)
tensor(0.1527) tensor(0.)
tensor(0.1190) tensor(2.2374e-08)
tensor(-0.0290) tensor(0.)
tensor(-0.0472) tensor(1.1187e-08)
tensor(-0.0199) tensor(1.8645e-09)
tensor(-0.0005) tensor(0.)
tensor(0.0005) tensor(8.7397e-11)
tensor(-0.0006) tensor(1.7479e-10)
tensor(0.0002) tensor(0.)
tensor(-0.0002) tensor(1.4566e-11)
tensor(0.0002) tensor(1.4566e-11)
tensor(0.0002) tensor(8.7397e-11)
tensor(0.0006) tensor(1.7479e-10)
tensor(5.2908e-05) tensor(0.)
tensor(-6.2868e-05) tensor(2.1849e-11)
tensor(2.4735e-05) tensor(1.8208e-12)
tensor(-6.6639e-06) tensor(1.8208e-12)
tensor(-6.3730e-06) tensor(4.5519e-13)
tensor(-9.1494e-06) tensor(0.)
tensor(-4.9633e-06) tensor(4.5519e

Let's do an ablation study and try Xavier init without using the bias term.

In [7]:
a = torch.randn(512,512)  # input image

for i in range(100): 
    w = torch.randn(512) / math.sqrt(512)  # xavier init

    a = lin(a, w)
    print(a.mean(), a.std())
    if torch.isinf(a.std()) or torch.isnan(a.std()):
        print(f"Numerical instability at layer {i}")
        break

tensor(0.0239) tensor(0.9825)
tensor(-1.1993) tensor(nan)
Numerical instability at layer 1


Removing the bias term shows that just a good init is not enough. The purpose of the bias term is to improve numerical stability in the forward pass. Without a bias term, it is harder to train deep architectures.