In [2]:
import torch
from torch import nn

net = nn.Sequential(
    nn.Linear(4, 8),
    nn.ReLU(),
    nn.Linear(8, 1),
)

X = torch.rand(size=(2,4))
net(X)


tensor([[0.5600],
        [0.7203]], grad_fn=<AddmmBackward0>)

## 参数访问

In [9]:
print(net[2].state_dict())

OrderedDict([('weight', tensor([[ 0.1112,  0.3345, -0.3205,  0.2873, -0.0705, -0.1037,  0.2488, -0.1382]])), ('bias', tensor([0.2541]))])


In [12]:
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([0.2541], requires_grad=True)
tensor([0.2541])


In [15]:
net[2].weight.grad == None

True

In [20]:
# 一次性访问所有参数
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print()
print(*[(name, param.shape) for name, param in net.named_parameters()])
print()
# 提供了另一种访问网络参数的方式
print(net.state_dict()['2.bias'].data)

('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))

('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))

tensor([0.2541])


In [22]:
# 嵌套块中的收集参数
def block1():
    return nn.Sequential(
        nn.Linear(4, 8), nn.ReLU(),
        nn.Linear(8, 4), nn.ReLU(),
    )

def block2():
    net = nn.Sequential()
    for i in range(4):
        net.add_module(f'block{i}', block1())
    return net

rgnet = nn.Sequential(
    block2(),
    nn.Linear(4, 1)
)
print(rgnet, '\n')
rgnet(X)

Sequential(
  (0): Sequential(
    (block0): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block1): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block2): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block3): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
  )
  (1): Linear(in_features=4, out_features=1, bias=True)
) 



tensor([[0.2633],
        [0.2633]], grad_fn=<AddmmBackward0>)

In [26]:
# 访问嵌套块中的参数
rgnet[0][1][0].bias.data

tensor([-0.3652, -0.0021, -0.2909, -0.0807, -0.1586,  0.1459,  0.0528,  0.1851])

In [30]:
# 内置初始化（正态分布）
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.zeros_(m.bias)

print(net)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
)


(tensor([-0.0080,  0.0112, -0.0010,  0.0011]), tensor(0.))

In [34]:
# 内置初始化（常数）
def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, torch.pi)
        nn.init.zeros_(m.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([3.1416, 3.1416, 3.1416, 3.1416]), tensor(0.))

In [40]:
# 对不同块应用不同的初始化方法
def init_xavier(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight) # xavier是一种初始化方法

def init_constant42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 42)

print(net, '\n')
net[0].apply(init_xavier)
net[2].apply(init_constant42)
print(net[0].weight.data[0], '\n', net[2].weight.data)


Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=1, bias=True)
) 

tensor([ 0.2367, -0.5904, -0.2505,  0.2385]) 
 tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])


In [53]:
# 自定义初始化
def my_init(m):
    if type(m) == nn.Linear:
        print('Init', *[(name, param.shape) for name, param in m.named_parameters()])
        nn.init.uniform_(m.weight, -10, 10) # 均匀初始化
        m.weight.data *= m.weight.data.abs() >= 5 # >=5的数保存 其他的设为0

net.apply(my_init)
net[0].weight

Init ('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))
Init ('weight', torch.Size([1, 8])) ('bias', torch.Size([1]))


Parameter containing:
tensor([[ 9.2807,  8.8804,  0.0000, -0.0000],
        [ 0.0000,  9.6369,  0.0000,  7.8594],
        [ 0.0000, -5.8587,  0.0000, -0.0000],
        [-8.9553, -0.0000,  8.6252, -0.0000],
        [-5.6556,  0.0000, -6.7826, -0.0000],
        [-8.0828, -0.0000, -0.0000, -0.0000],
        [ 8.8788,  7.7309,  0.0000,  0.0000],
        [-0.0000,  0.0000, -6.9658, -6.7726]], requires_grad=True)

In [57]:
# 始终可以直接设置参数
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data

tensor([[42.0000, 12.8804,  4.0000,  4.0000],
        [ 4.0000, 13.6369,  4.0000, 11.8594],
        [ 4.0000, -1.8587,  4.0000,  4.0000],
        [-4.9553,  4.0000, 12.6252,  4.0000],
        [-1.6556,  4.0000, -2.7826,  4.0000],
        [-4.0828,  4.0000,  4.0000,  4.0000],
        [12.8788, 11.7309,  4.0000,  4.0000],
        [ 4.0000,  4.0000, -2.9658, -2.7726]])

In [58]:
# 参数绑定

# 给共享层一个名称，方便引用其参数
shared = nn.Linear(8, 8)
net = nn.Sequential(
    nn.Linear(4, 8), nn.ReLU(),
    shared, nn.ReLU(),
    shared, nn.ReLU(),
    nn.Linear(8, 1)
)

net(X)
# 检查参数是否相同
print(net[2].weight.data == net[4].weight.data)
net[2].weight.data[0, 0] = 100
# 确保两者实际上为同一个对象，而不是初始值相同的两个对象
print(net[2].weight.data == net[4].weight.data)



tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])
