In [10]:
import torch
import torch.nn as nn
import torch.functional as F
import numpy as np
import random
import os,sys
from torchsummary import summary

In [4]:
base_dir = os.path.dirname(os.path.abspath('grad_vanish_explod.ipynb'))
work_dir = os.path.dirname(base_dir)
sys.path.append(work_dir)

In [8]:
from tools.common_tools import set_seed
set_seed(3407)

In [30]:
class MLP(nn.Module):

    def __init__(self,neural_num,layers) -> None:
        super(MLP,self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num,neural_num) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self,x):
        for i,linear in enumerate(self.linears):
            x = linear(x)
            x = torch.relu(x)

            print('layers:{} , std:{}'.format(i,x.std()))
            if torch.isnan(x.std()):
                print('std is nan in layers:{}'.format(i))
                break

        return x
    
    def initialize(self):
        for m in self.modules():
            if isinstance(m,nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

In [31]:
layer_nums = 10
neural_nums = 20
batch_size = 16
net = MLP(neural_num=neural_nums,layers=layer_nums)

In [36]:
summary(net,input_size=(10,20),batch_size=batch_size)

layers:0 , std:0.18307073414325714
layers:1 , std:0.0909644365310669
layers:2 , std:0.0916116014122963
layers:3 , std:0.06463372707366943
layers:4 , std:0.07213423401117325
layers:5 , std:0.09453681111335754
layers:6 , std:0.0675988718867302
layers:7 , std:0.07824891805648804
layers:8 , std:0.05637740343809128
layers:9 , std:0.09160798788070679
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [16, 10, 20]             420
            Linear-2               [16, 10, 20]             420
            Linear-3               [16, 10, 20]             420
            Linear-4               [16, 10, 20]             420
            Linear-5               [16, 10, 20]             420
            Linear-6               [16, 10, 20]             420
            Linear-7               [16, 10, 20]             420
            Linear-8               [16, 10, 20]             420
            

In [19]:
net.initialize()

In [20]:
summary(net,input_size=(10,20),batch_size=batch_size)

layers:0 , std:0.4215196669101715
layers:1 , std:0.4067298173904419
layers:2 , std:0.3231349289417267
layers:3 , std:0.27762874960899353
layers:4 , std:0.22679096460342407
layers:5 , std:0.24510356783866882
layers:6 , std:0.20481672883033752
layers:7 , std:0.2281114161014557
layers:8 , std:0.28698694705963135
layers:9 , std:0.3048616945743561
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [16, 10, 20]             420
            Linear-2               [16, 10, 20]             420
            Linear-3               [16, 10, 20]             420
            Linear-4               [16, 10, 20]             420
            Linear-5               [16, 10, 20]             420
            Linear-6               [16, 10, 20]             420
            Linear-7               [16, 10, 20]             420
            Linear-8               [16, 10, 20]             420
            Li

In [29]:
x = torch.randn(1000)
out = torch.tanh(x)
gain = x.std() / out.std()
print('gain:{}'.format(gain))

tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in pytorch:',tanh_gain)

gain:1.5884541273117065
tanh_gain in pytorch: 1.6666666666666667
