In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

np.random.seed(123)
torch.random.manual_seed(123)

In [None]:
n = 10
p = 1000
x = torch.randn(n, p)
fcs = [nn.Linear(p, p, bias=False) for i in range(5)]
print(fcs)

# 固定方差初始化

In [None]:
sigma = 0.02
# sigma = 0.1
nn.init.normal_(fcs[0].weight, mean=0.0, std=sigma)
a1 = torch.tanh(fcs[0](x))
plt.hist(a1.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
nn.init.normal_(fcs[1].weight, mean=0.0, std=sigma)
a2 = torch.tanh(fcs[1](a1))
plt.hist(a2.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
nn.init.normal_(fcs[2].weight, mean=0.0, std=sigma)
a3 = torch.tanh(fcs[2](a2))
plt.hist(a3.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
nn.init.normal_(fcs[3].weight, mean=0.0, std=sigma)
a4 = torch.tanh(fcs[3](a3))
plt.hist(a4.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
nn.init.normal_(fcs[4].weight, mean=0.0, std=sigma)
a5 = torch.tanh(fcs[4](a4))
plt.hist(a5.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

# Xavier 初始化

In [None]:
np.random.seed(123)
torch.random.manual_seed(123)

in_dim = fcs[0].weight.shape[1]
nn.init.normal_(fcs[0].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a1 = torch.tanh(fcs[0](x))
plt.hist(a1.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
in_dim = fcs[1].weight.shape[1]
nn.init.normal_(fcs[1].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a2 = torch.tanh(fcs[1](a1))
plt.hist(a2.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
in_dim = fcs[2].weight.shape[1]
nn.init.normal_(fcs[2].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a3 = torch.tanh(fcs[2](a2))
plt.hist(a3.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
in_dim = fcs[3].weight.shape[1]
nn.init.normal_(fcs[3].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a4 = torch.tanh(fcs[3](a3))
plt.hist(a4.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

In [None]:
in_dim = fcs[4].weight.shape[1]
nn.init.normal_(fcs[4].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a5 = torch.tanh(fcs[4](a4))
plt.hist(a5.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 1)

# 以 ReLU 为激活函数的 Xavier 初始化

In [None]:
np.random.seed(123)
torch.random.manual_seed(123)

in_dim = fcs[0].weight.shape[1]
nn.init.normal_(fcs[0].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a1 = torch.relu(fcs[0](x))
plt.hist(a1.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[1].weight.shape[1]
nn.init.normal_(fcs[1].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a2 = torch.relu(fcs[1](a1))
plt.hist(a2.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[2].weight.shape[1]
nn.init.normal_(fcs[2].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a3 = torch.relu(fcs[2](a2))
plt.hist(a3.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[3].weight.shape[1]
nn.init.normal_(fcs[3].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a4 = torch.relu(fcs[3](a3))
plt.hist(a4.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[4].weight.shape[1]
nn.init.normal_(fcs[4].weight, mean=0.0, std=1.0 / math.sqrt(in_dim))
a5 = torch.relu(fcs[4](a4))
plt.hist(a5.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

# Kaiming 初始化

In [None]:
np.random.seed(123)
torch.random.manual_seed(123)

in_dim = fcs[0].weight.shape[1]
nn.init.normal_(fcs[0].weight, mean=0.0, std=2.0 / math.sqrt(in_dim))
a1 = torch.relu(fcs[0](x))
plt.hist(a1.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[1].weight.shape[1]
nn.init.normal_(fcs[1].weight, mean=0.0, std=2.0 / math.sqrt(in_dim))
a2 = torch.relu(fcs[1](a1))
plt.hist(a2.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[2].weight.shape[1]
nn.init.normal_(fcs[2].weight, mean=0.0, std=2.0 / math.sqrt(in_dim))
a3 = torch.relu(fcs[2](a2))
plt.hist(a3.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[3].weight.shape[1]
nn.init.normal_(fcs[3].weight, mean=0.0, std=2.0 / math.sqrt(in_dim))
a4 = torch.relu(fcs[3](a3))
plt.hist(a4.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)

In [None]:
in_dim = fcs[4].weight.shape[1]
nn.init.normal_(fcs[4].weight, mean=0.0, std=2.0 / math.sqrt(in_dim))
a5 = torch.relu(fcs[4](a4))
plt.hist(a5.detach().numpy().reshape(-1), bins=50, density=True)
plt.xlim(-1, 10)