In [91]:
import torch
from torch import nn

net = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 1),
)

X = torch.rand(size=(2,4))
net(X)


tensor([[0.0418],
        [0.0148]], grad_fn=<AddmmBackward0>)

In [92]:
# 参数访问
print(net[2].state_dict())

OrderedDict([('weight', tensor([[-0.1899,  0.2770, -0.1330,  0.2035,  0.1784,  0.0103,  0.3378,  0.3189]])), ('bias', tensor([-0.3332]))])


In [93]:
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([-0.3332], requires_grad=True)
tensor([-0.3332])


In [94]:
net[2].weight.grad == None

True

In [95]:
# 一次性访问所有参数
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print()
print(*[(name, param.shape) for name, param in net.named_parameters()])
print()
# 提供了另一种访问网络参数的方式
print(net.state_dict()['2.bias'].data)

('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))

('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))

tensor([-0.3332])


In [96]:
# 嵌套块中的收集参数
def block1():
    return nn.Sequential(
        nn.Linear(4, 8), nn.ReLU(),
        nn.Linear(8, 4), nn.ReLU(),
    )

def block2():
    net = nn.Sequential()
    for i in range(4):
        net.add_module(f'block{i}', block1())
    return net

rgnet = nn.Sequential(
    block2(),
    nn.Linear(4, 1)
)
print(rgnet, '\n')
rgnet(X)

Sequential(
  (0): Sequential(
    (block0): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block1): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block2): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block3): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
  )
  (1): Linear(in_features=4, out_features=1, bias=True)
) 



tensor([[-0.2962],
        [-0.2962]], grad_fn=<AddmmBackward0>)

In [97]:
# 访问嵌套块中的参数
rgnet[0][1][0].bias.data

tensor([-0.0097, -0.0301,  0.0405, -0.3921,  0.2059,  0.4871, -0.4064, -0.3059])

In [98]:
# 内置初始化（正态分布）
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)

print(net)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)


(tensor([ 0.0058,  0.0119, -0.0056, -0.0117]), tensor(0.))

In [99]:
# 内置初始化（常数）
def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, torch.pi)
        nn.init.zeros_(m.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([3.1416, 3.1416, 3.1416, 3.1416]), tensor(0.))

In [100]:
# 对不同块应用不同的初始化方法
def init_xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight) # xavier是一种初始化方法

def init_constant42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 42)

print(net, '\n')
net[0].apply(init_xavier)
net[2].apply(init_constant42)
print(net[0].weight.data[0], '\n', net[2].weight.data)


Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
) 

tensor([-0.1512, -0.5121,  0.2876, -0.4021]) 
 tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])


In [101]:
# 自定义初始化
def my_init(m):
    if type(m) == nn.Linear:
        print('Init', *[(name, param.shape) for name, param in m.named_parameters()][0])
        nn.init.uniform_(m.weight, -10, 10) # 均匀初始化
        m.weight.data *= m.weight.data.abs() >= 5 # >=5的数保存 其他的设为0

net.apply(my_init)
net[0].weight

Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])


Parameter containing:
tensor([[ 0.0000, -0.0000,  0.0000, -0.0000],
        [-9.6676, -6.5232,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -8.6212],
        [ 0.0000, -0.0000, -7.1565,  0.0000],
        [-0.0000, -8.1832, -0.0000,  6.8759],
        [ 0.0000,  0.0000, -8.6269, -0.0000],
        [-5.6632, -0.0000, -0.0000,  0.0000],
        [-9.4955,  0.0000,  5.5236, -0.0000]], requires_grad=True)

In [102]:
# 始终可以直接设置参数
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data

tensor([[42.0000,  1.0000,  1.0000,  1.0000],
        [-8.6676, -5.5232,  1.0000,  1.0000],
        [ 1.0000,  1.0000,  1.0000, -7.6212],
        [ 1.0000,  1.0000, -6.1565,  1.0000],
        [ 1.0000, -7.1832,  1.0000,  7.8759],
        [ 1.0000,  1.0000, -7.6269,  1.0000],
        [-4.6632,  1.0000,  1.0000,  1.0000],
        [-8.4955,  1.0000,  6.5236,  1.0000]])

In [104]:
# 参数绑定

# 给共享层一个名称，方便引用其参数
shared = nn.Linear(8, 8)
net = nn.Sequential(
    nn.Linear(4, 8), nn.ReLU(),
    shared, nn.ReLU(),
    shared, nn.ReLU(),
    nn.Linear(8, 1)
)

net(X)
# 检查参数是否相同
print(net[2].weight.data == net[4].weight.data)
net[2].weight.data[0, 0] = 100
# 确保两者实际上为同一个对象，而不是初始值相同的两个对象
print(net[2].weight.data == net[4].weight.data)



tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])
