In [78]:
# layer and block
import torch
from torch import nn
from torch.nn import functional as F

In [79]:
net = nn.Sequential(
    nn.Linear(20, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

In [80]:
x = torch.rand(2, 20)
net(x)

tensor([[ 0.0188,  0.1511, -0.1886, -0.1540, -0.0092, -0.1905,  0.0692,  0.0419,
         -0.1051,  0.1184],
        [ 0.0919,  0.0533, -0.2360, -0.2985,  0.0369, -0.2856, -0.0738, -0.0828,
         -0.2076,  0.1451]], grad_fn=<AddmmBackward0>)

In [81]:
class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, x):
        return self.out(F.relu(self.hidden(x)))

In [82]:
x = torch.rand(2, 20)
mlp = MLP()
net(x), mlp(x), mlp.forward(x)

(tensor([[ 0.0926,  0.0705, -0.2660, -0.1389, -0.0780, -0.2122, -0.0457,  0.1269,
          -0.0186,  0.1293],
         [-0.0849,  0.1036, -0.1554, -0.1827, -0.0031, -0.1593, -0.0359,  0.0300,
           0.0347,  0.1464]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.1987, -0.0509,  0.1977,  0.2436,  0.1256,  0.0195, -0.1403, -0.3043,
           0.3677,  0.0285],
         [ 0.0323, -0.0304,  0.0330,  0.1555,  0.2481,  0.1104,  0.0269, -0.2569,
           0.1228, -0.0268]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.1987, -0.0509,  0.1977,  0.2436,  0.1256,  0.0195, -0.1403, -0.3043,
           0.3677,  0.0285],
         [ 0.0323, -0.0304,  0.0330,  0.1555,  0.2481,  0.1104,  0.0269, -0.2569,
           0.1228, -0.0268]], grad_fn=<AddmmBackward0>))

In [83]:
class MySequential(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        for block in args:
            self._modules[block] = block
    
    def forward(self, x):
        for block in self._modules.values():
            x = block(x)
        return x
net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 20))
net(x)

tensor([[ 0.1254, -0.0653, -0.0624,  0.1660,  0.1033, -0.0889,  0.2056,  0.0043,
          0.0832,  0.0389, -0.1497,  0.2734, -0.0480,  0.0406, -0.0152,  0.0106,
          0.1540, -0.1328, -0.1839, -0.1363],
        [ 0.0539,  0.0008,  0.0719,  0.0299,  0.0498, -0.0855,  0.1297,  0.0440,
         -0.0139,  0.0371, -0.1556,  0.1745, -0.2012,  0.0545, -0.0793,  0.0204,
          0.0206,  0.0042, -0.1819, -0.0934]], grad_fn=<AddmmBackward0>)

In [84]:
class NestMLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )
        self.linear = nn.Linear(32, 16)
    
    def forward(self, x):
        return self.linear(self.net(x))

chimera = nn.Sequential(
    NestMLP(),
    nn.Linear(16, 20),
    # 可以很灵活的嵌套使用
)
chimera(x)

tensor([[ 0.0257,  0.0657,  0.1901, -0.1300,  0.1691, -0.0320, -0.0717, -0.1572,
         -0.1928, -0.1286,  0.1095,  0.0991,  0.2782,  0.1584, -0.1364,  0.2207,
         -0.0403, -0.0764, -0.1334,  0.3239],
        [ 0.0130,  0.0713,  0.2061, -0.1274,  0.1729, -0.0142, -0.0856, -0.1506,
         -0.2004, -0.1227,  0.1151,  0.0928,  0.2610,  0.1746, -0.1339,  0.2353,
         -0.0464, -0.0812, -0.1155,  0.2954]], grad_fn=<AddmmBackward0>)

In [85]:
# Parameter Management
import torch
from torch import nn

net = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 1)
)
x = torch.rand(size=(2, 4))
net(x)

tensor([[-0.2825],
        [-0.2501]], grad_fn=<AddmmBackward0>)

In [86]:
net[0].state_dict()

OrderedDict([('weight',
              tensor([[-0.3477, -0.3291, -0.1199,  0.1273],
                      [-0.4685, -0.1462, -0.3261, -0.4817],
                      [ 0.2077, -0.3973,  0.0481, -0.4569],
                      [ 0.2406, -0.2041,  0.2241,  0.4448],
                      [-0.1539,  0.3636, -0.1811,  0.2011],
                      [-0.1051,  0.1996, -0.2217,  0.1891],
                      [-0.1465,  0.4027,  0.4486, -0.2097],
                      [-0.4343,  0.1832,  0.1637, -0.0583]])),
             ('bias',
              tensor([-0.0094, -0.2667, -0.4531, -0.0297, -0.4095, -0.3470, -0.1216,  0.1620]))])

In [87]:
net[1].state_dict()

OrderedDict()

In [88]:
net[2].state_dict()

OrderedDict([('weight',
              tensor([[-0.1475,  0.2821, -0.1007, -0.0862,  0.0621,  0.1009,  0.0137,  0.0990]])),
             ('bias', tensor([-0.2662]))])

In [89]:
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)
print(net[2].bias.grad)

<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([-0.2662], requires_grad=True)
tensor([-0.2662])
None


In [90]:
print([(name, param.shape) for name, param in net[0].named_parameters()])
print([(name, param.shape) for name, param in net.named_parameters()])

[('weight', torch.Size([8, 4])), ('bias', torch.Size([8]))]
[('0.weight', torch.Size([8, 4])), ('0.bias', torch.Size([8])), ('2.weight', torch.Size([1, 8])), ('2.bias', torch.Size([1]))]


In [91]:
net.state_dict

<bound method Module.state_dict of Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)>

In [92]:
net.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.3477, -0.3291, -0.1199,  0.1273],
                      [-0.4685, -0.1462, -0.3261, -0.4817],
                      [ 0.2077, -0.3973,  0.0481, -0.4569],
                      [ 0.2406, -0.2041,  0.2241,  0.4448],
                      [-0.1539,  0.3636, -0.1811,  0.2011],
                      [-0.1051,  0.1996, -0.2217,  0.1891],
                      [-0.1465,  0.4027,  0.4486, -0.2097],
                      [-0.4343,  0.1832,  0.1637, -0.0583]])),
             ('0.bias',
              tensor([-0.0094, -0.2667, -0.4531, -0.0297, -0.4095, -0.3470, -0.1216,  0.1620])),
             ('2.weight',
              tensor([[-0.1475,  0.2821, -0.1007, -0.0862,  0.0621,  0.1009,  0.0137,  0.0990]])),
             ('2.bias', tensor([-0.2662]))])

In [93]:
net.state_dict()['2.weight']

tensor([[-0.1475,  0.2821, -0.1007, -0.0862,  0.0621,  0.1009,  0.0137,  0.0990]])

In [96]:
def block1():
    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())

def block2():
    net = nn.Sequential()
    for i in range(4):
        net.add_module(f'module {i}', block1())
    return net


In [100]:
rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(x)

tensor([[-0.3039],
        [-0.3039]], grad_fn=<AddmmBackward0>)

In [101]:
print(rgnet)

Sequential(
  (0): Sequential(
    (module 0): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (module 1): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (module 2): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (module 3): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
  )
  (1): Linear(in_features=4, out_features=1, bias=True)
)


In [105]:
print(rgnet[0])
print(rgnet[0][0])
print(rgnet[0][0][0])

Sequential(
  (module 0): Sequential(
    (0): Linear(in_features=4, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=4, bias=True)
    (3): ReLU()
  )
  (module 1): Sequential(
    (0): Linear(in_features=4, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=4, bias=True)
    (3): ReLU()
  )
  (module 2): Sequential(
    (0): Linear(in_features=4, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=4, bias=True)
    (3): ReLU()
  )
  (module 3): Sequential(
    (0): Linear(in_features=4, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=4, bias=True)
    (3): ReLU()
  )
)
Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=4, bias=True)
  (3): ReLU()
)
Linear(in_features=4, out_features=8, bias=True)


In [108]:
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)

net.apply(init_normal)
net[0].weight.data, net[0].bias.data

(tensor([[-0.0064,  0.0070, -0.0032,  0.0026],
         [-0.0087, -0.0071,  0.0022, -0.0138],
         [-0.0129,  0.0016, -0.0145, -0.0087],
         [ 0.0100, -0.0131,  0.0007,  0.0015],
         [ 0.0224, -0.0061, -0.0101, -0.0188],
         [ 0.0062,  0.0058, -0.0091, -0.0037],
         [-0.0017, -0.0142,  0.0033,  0.0101],
         [ 0.0047,  0.0309,  0.0094,  0.0003]]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0.]))

In [112]:
def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 2)

net.apply(init_constant)
net[0].weight.data, net[0].bias.data

(tensor([[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]),
 tensor([2., 2., 2., 2., 2., 2., 2., 2.]))

In [113]:
def xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)

def init_42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.bias, 42)

In [115]:
net[0].apply(xavier)
net[2].apply(init_42)
print(net[0].weight.data)
print(net[2].bias.data)


tensor([[-0.6578, -0.1700,  0.0170,  0.6516],
        [-0.6848,  0.5857,  0.3264, -0.0347],
        [-0.6460,  0.5526,  0.6986, -0.4984],
        [ 0.0716, -0.5996,  0.3527, -0.5074],
        [ 0.0692, -0.2321, -0.3740, -0.5290],
        [ 0.4265, -0.5074,  0.4755,  0.3182],
        [-0.0331, -0.0484,  0.1610, -0.6345],
        [ 0.4860,  0.5008, -0.2605, -0.1428]])
tensor([42.])


In [127]:
def my_init(m):
    if type(m) == nn.Linear:
        print(
            'Init', 
            *[(name, param.shape) for name, param in m.named_parameters()][0]
        )
        nn.init.uniform_(m.weight, -10, 10)
        # m.weight.data = torch.where(m.weight.data >= 5, m.weight.data, 0)
        print(torch.where(m.weight.data >= 5, m.weight.data, 0))
        m.weight.data *= m.weight.data.abs() >= 5 
        print(m.weight.data)

net.apply(my_init)
# net[0].weight

Init weight torch.Size([8, 4])
tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 7.2166, 0.0000],
        [0.0000, 0.0000, 0.0000, 5.0048],
        [0.0000, 0.0000, 9.5269, 0.0000],
        [0.0000, 6.7699, 0.0000, 0.0000],
        [0.0000, 7.0299, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 9.1215, 6.1191]])
tensor([[-0.0000, -0.0000,  0.0000, -9.4322],
        [ 0.0000, -8.6026,  7.2166,  0.0000],
        [-0.0000, -0.0000, -9.0070,  5.0048],
        [-0.0000,  0.0000,  9.5269, -0.0000],
        [ 0.0000,  6.7699, -0.0000,  0.0000],
        [ 0.0000,  7.0299, -7.2648,  0.0000],
        [-0.0000, -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  9.1215,  6.1191]])
Init weight torch.Size([1, 8])
tensor([[0.0000, 5.3375, 5.3902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
tensor([[-7.3344,  5.3375,  5.3902, -0.0000,  0.0000, -5.3590,  0.0000, -9.5614]])


Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)

In [128]:
# 参数绑定
share_layer = nn.Linear(8, 8)
net = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    share_layer,
    nn.ReLU(),
    share_layer,
    nn.ReLU(),
    nn.Linear(8, 1)
)
net(x)

tensor([[-0.4059],
        [-0.4110]], grad_fn=<AddmmBackward0>)

In [132]:
print(net[2].weight.data == net[4].weight.data)
net[2].weight.data[0, 0] = 100
net[2].weight.data == net[4].weight.data

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])


tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])