# Initialisation (`torch.nn.init`)

https://pytorch.org/docs/stable/nn.init.html

In [1]:
import torch
from torch import nn

In [3]:
conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
linear = nn.Linear(in_features=10, out_features=1)

In [4]:
isinstance(conv,nn.Conv2d) # 判断conv是否是nn.Conv2d类型

True

In [5]:

isinstance(linear,nn.Conv2d) # 判断linear是否是nn.Conv2d类型

False

In [8]:
# 查看随机初始化的conv参数
conv.weight.data

tensor([[[[-0.1601,  0.0511, -0.3157],
          [ 0.1375, -0.3138, -0.1144],
          [-0.2575, -0.2136,  0.0251]]],


        [[[-0.2027, -0.3327,  0.3181],
          [-0.0728,  0.0657,  0.2611],
          [ 0.3107, -0.1458, -0.1220]]],


        [[[ 0.0473,  0.1609,  0.0596],
          [ 0.1150,  0.2460, -0.2638],
          [ 0.0539, -0.0785, -0.2971]]]])

In [9]:
linear.weight.data

tensor([[ 0.0568,  0.2411, -0.2641,  0.1228, -0.1983, -0.0790,  0.2232,  0.0200,
          0.1867,  0.1824]])

In [10]:
nn.init.kaiming_normal_(conv.weight.data)
conv.weight.data

tensor([[[[ 0.3782, -0.5009,  0.4688],
          [ 0.1973,  0.0392, -0.5790],
          [ 0.4397, -0.1050,  0.0361]]],


        [[[ 1.0495,  0.0351,  0.0784],
          [-1.0659, -0.5772,  0.0593],
          [-0.7584,  0.0585,  0.7883]]],


        [[[-0.4651,  0.0076, -0.3955],
          [-0.3594, -0.1236, -0.1103],
          [-0.0549,  0.2011,  0.3440]]]])

In [11]:
nn.init.constant_(linear.weight.data, 0.3)
linear.weight.data

tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000,
         0.3000]])

## 封装初始化函数

In [None]:
def initialise_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, mean=0.1)
            if m.bias is not None:
                nn.init.zeros_(m.bias.data)
        elif isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, val=0.3)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1) 		 
            m.bias.data.zeros_()	