In [1]:
import torch
import torch.nn as nn

In [2]:
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

tensor([[0.0361],
        [0.0292]], grad_fn=<AddmmBackward>)

In [3]:
net[2].state_dict()

OrderedDict([('weight',
              tensor([[ 0.1857, -0.0757, -0.1531, -0.2865, -0.3501, -0.1025,  0.1997,  0.3516]])),
             ('bias', tensor([-0.0145]))])

In [5]:
net[0].state_dict()

OrderedDict([('weight',
              tensor([[ 0.2901, -0.2606,  0.4850, -0.1477],
                      [-0.2328, -0.1921,  0.1630, -0.1146],
                      [-0.0381,  0.1395,  0.2300,  0.1365],
                      [ 0.4929, -0.2967, -0.2600,  0.3102],
                      [ 0.4331,  0.3494, -0.0419,  0.1940],
                      [-0.1983,  0.3970, -0.2116, -0.4477],
                      [-0.2450, -0.3881, -0.2604, -0.1739],
                      [-0.1197,  0.4936,  0.2643, -0.0748]])),
             ('bias',
              tensor([ 0.2267,  0.2286, -0.3402, -0.2923, -0.4393,  0.1260, -0.4783, -0.3480]))])

In [7]:
net[2].bias.data

tensor([-0.0145])

In [11]:
net[2].weight.grad == None

True

In [12]:
[(name, param.shape) for name, param in net.named_parameters()]

[('0.weight', torch.Size([8, 4])),
 ('0.bias', torch.Size([8])),
 ('2.weight', torch.Size([1, 8])),
 ('2.bias', torch.Size([1]))]

In [17]:
net.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.2901, -0.2606,  0.4850, -0.1477],
                      [-0.2328, -0.1921,  0.1630, -0.1146],
                      [-0.0381,  0.1395,  0.2300,  0.1365],
                      [ 0.4929, -0.2967, -0.2600,  0.3102],
                      [ 0.4331,  0.3494, -0.0419,  0.1940],
                      [-0.1983,  0.3970, -0.2116, -0.4477],
                      [-0.2450, -0.3881, -0.2604, -0.1739],
                      [-0.1197,  0.4936,  0.2643, -0.0748]])),
             ('0.bias',
              tensor([ 0.2267,  0.2286, -0.3402, -0.2923, -0.4393,  0.1260, -0.4783, -0.3480])),
             ('2.weight',
              tensor([[ 0.1857, -0.0757, -0.1531, -0.2865, -0.3501, -0.1025,  0.1997,  0.3516]])),
             ('2.bias', tensor([-0.0145]))])

In [26]:
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
net.apply(init_normal)

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=2, bias=True)
)

In [27]:
net.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.0026,  0.0110, -0.0062,  0.0009],
                      [ 0.0033,  0.0094,  0.0045,  0.0001],
                      [-0.0008, -0.0104, -0.0125,  0.0002],
                      [ 0.0118, -0.0145,  0.0041,  0.0035],
                      [ 0.0064, -0.0013, -0.0065, -0.0061],
                      [-0.0140,  0.0231, -0.0037,  0.0081],
                      [ 0.0004,  0.0038,  0.0092, -0.0028],
                      [-0.0192, -0.0010,  0.0065, -0.0023]])),
             ('0.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0.])),
             ('2.weight',
              tensor([[ 0.0060,  0.0019, -0.0040,  0.0039,  0.0061,  0.0004,  0.0011,  0.0069],
                      [ 0.0150, -0.0031,  0.0206,  0.0118, -0.0058, -0.0021, -0.0091,  0.0030]])),
             ('2.bias', tensor([0., 0.]))])

In [31]:
def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 1)
        nn.init.zeros_(m.bias)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
net.apply(init_constant)     

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=2, bias=True)
)

In [32]:
net.state_dict()

OrderedDict([('0.weight',
              tensor([[1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.],
                      [1., 1., 1., 1.]])),
             ('0.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0.])),
             ('2.weight',
              tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
                      [1., 1., 1., 1., 1., 1., 1., 1.]])),
             ('2.bias', tensor([0., 0.]))])

In [33]:
def xavior(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        
net[0].apply(xavior)
net[2].apply(init_constant)

Linear(in_features=8, out_features=2, bias=True)

In [34]:
net.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.6658,  0.0008, -0.1367,  0.3593],
                      [ 0.0882,  0.2098,  0.4736, -0.6813],
                      [-0.4519, -0.3348, -0.6562, -0.0252],
                      [ 0.2563,  0.5157,  0.1477, -0.5922],
                      [ 0.4333, -0.5923, -0.6029, -0.2108],
                      [-0.3912, -0.6810, -0.0799,  0.5464],
                      [ 0.5599,  0.0175, -0.0358,  0.2486],
                      [ 0.5628, -0.6520,  0.3760,  0.5864]])),
             ('0.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0.])),
             ('2.weight',
              tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
                      [1., 1., 1., 1., 1., 1., 1., 1.]])),
             ('2.bias', tensor([0., 0.]))])

In [44]:
def my_init(m):
    if type(m) == nn.Linear:
        print('Init ', *[(name, param.shape) for name, param in m.named_parameters()][0])
        nn.init.uniform_(m.weight, -10, 10)
        m.weight.data *= m.weight.data.abs() >= 5
        
net.apply(my_init)

Init  weight torch.Size([8, 4])
Init  weight torch.Size([2, 8])


Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=2, bias=True)
)

In [40]:
net.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.0000,  9.6739, -0.0000,  0.0000],
                      [-9.7339, -0.0000, -0.0000, -5.7298],
                      [ 7.2538, -0.0000,  9.6311,  0.0000],
                      [ 9.6488, -7.9599,  7.3452, -8.9995],
                      [ 9.4969, -0.0000, -9.3159, -0.0000],
                      [ 0.0000, -0.0000, -0.0000, -0.0000],
                      [-0.0000,  0.0000,  9.0828, -5.4764],
                      [-0.0000,  6.4324,  9.2892,  0.0000]])),
             ('0.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0.])),
             ('2.weight',
              tensor([[ 6.0596,  6.2055,  5.0768,  6.1540,  6.1925,  8.2721, -9.0119,  6.2420],
                      [ 8.8428, -0.0000, -5.5997, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000]])),
             ('2.bias', tensor([0., 0.]))])