In [1]:
import torch

# Why you need a good init
vector x and matrix a, initialized randomly, then multiply them 100 times (100 layers)

In [2]:
x = torch.randn(512)
a = torch.randn(512, 512)

In [3]:
for i in range(100): x = a@x

In [5]:
x.mean(), x.std()
# activation explosion

(tensor(nan), tensor(nan))

In [6]:
x = torch.randn(512)
a = torch.randn(512, 512)

In [7]:
for i in range(100):
    x = a@x
    if x.std() != x.std(): break

In [8]:
i

28

In [9]:
# only takes 27 multuplications

In [11]:
# try lower scale
x = torch.randn(512)
a = torch.randn(512, 512) * 0.01

In [13]:
for i in range(100): x = a@x
x.mean(), x.std()

(tensor(0.), tensor(0.))

every activation vanished to 0.

Serveral strategies to initialize weight matrices, such as:
- standard deviation that make sure x and Ax have exactly the same scale
- use an orthogonal matrix to initialize the weight (orthorgonal matrices have the special property that they preserve L2 norm, so x and Ax would have the same sum of squares)
- spectral normalization on matrix A, least possible number M such that torch.norm(A@x) <= M*toch.norm(x), so dividing A by M insures don't overflow, (but can vanish)

# The first method is Xavier initalization
use a scale equal to 1 / math.sqrt(n_in)
n_in = number of input


In [14]:
import math

In [18]:
x = torch.randn(512)
a = torch.randn(512, 512)/math.sqrt(512)

In [19]:
for i in range(100): x = a@x

In [21]:
# it works
x.mean(), x.std()

(tensor(-0.0474), tensor(1.4116))

In [22]:
1 / math.sqrt(512)

0.044194173824159216

In [23]:
x = torch.randn(512)
x.mean(), x.std()

(tensor(-0.0072), tensor(1.0000))

In [43]:
# if weights of A have a mean of 0, 
# can compute the standard deviation of y easily

In [35]:
mean, sqr = 0., 0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = a@x # matrix multiplication
    mean += y.mean().item()
    sqr  += y.pow(2).mean().item()
mean/100, sqr/100

(0.0738837055489421, 517.2750936889648)

Very close to to 512. as long as the elements in a and x are independent, the mean is 0 and std is 1. Can also be seen experimentally:

In [41]:
y.mean()#.item()

tensor(0.7541)

In [42]:
mean, sqr = 0., 0.
for i in range(10000):
    x = torch.randn(1)
    a = torch.randn(1)
    y = a*x # multiplying
    mean += y.item()
    sqr += y.pow(2).item()
mean/10000, sqr/10000

(-0.0060097617130701565, 0.9454599582348421)

# Adding ReLU

In [44]:
mean, sqr = 0., 0.
for i in range(10000):
    x = torch.randn(1)
    a = torch.randn(1)
    y = a*x
    y = 0 if y < 0 else y.item()
    mean += y
    sqr += y **2
mean/10000, sqr/10000

(0.316620453225357, 0.5099332958430325)

# Check on matrix product

In [45]:
mean, sqr = 0., 0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = a @ x
    y = y.clamp(min=0)
    mean += y.mean().item()
    sqr += y.pow(2).mean().item()
mean/100, sqr/100

(9.074674410820007, 258.8311618041992)

In [46]:
mean,sqr = 0.,0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512) * math.sqrt(2/512)
    y = a @ x
    y = y.clamp(min=0)
    mean += y.mean().item()
    sqr  += y.pow(2).mean().item()
mean/100,sqr/100

(0.5642086517810821, 0.9965019994974136)