In [2]:
import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

X = torch.rand(size=[2, 4])
print(net(X))

tensor([[0.2839],
        [0.4507]], grad_fn=<AddmmBackward0>)


In [5]:
# 参数访问
print(net[2].state_dict())
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

OrderedDict([('weight', tensor([[ 0.2886,  0.0419,  0.1251,  0.2836,  0.2142, -0.0825,  0.0351, -0.3389]])), ('bias', tensor([0.0422]))])
<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([0.0422], requires_grad=True)
tensor([0.0422])


In [8]:
print(*[(name, param.shape) for name, param in net[2].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])

('weight', torch.Size([1, 8])) ('bias', torch.Size([1]))
('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))


In [10]:
print(net.state_dict()['2.bias'].data)

tensor([0.0422])


In [16]:
# 从嵌套块收集参数
def block1():
    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())

def block2():
    net = nn.Sequential()
    for i in range(4):
#         net._modules[f'block {i}'] = block1()
        net.add_module(f'block {i}', block1())
    
    return net
        
net = nn.Sequential(block2(), nn.Linear(4, 1))
print(net)

Sequential(
  (0): Sequential(
    (block 0): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block 1): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block 2): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block 3): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
  )
  (1): Linear(in_features=4, out_features=1, bias=True)
)


In [17]:
print(net[0][1][0].bias.data)

tensor([ 0.4291, -0.3281,  0.2791, -0.2385, -0.0843,  0.4811, -0.0657, -0.2932])


In [12]:
# 参数绑定
shared = nn.Linear(8, 8)
nn.init.constant_(shared.weight, 0.1)
nn.init.constant_(shared.bias, 0)

net = nn.Sequential(nn.Linear(4, 8),
                    shared,
                    shared,
                    nn.Linear(8, 1))

n_train = 100
true_w = torch.tensor([0.1 * i for i in range(1, 5)])
features = torch.normal(size=(n_train, 4), mean=0, std=0.01)
labels = torch.mm(features, true_w.view(-1, 1))

batch_size = 10
train_data = torch.utils.data.TensorDataset(features, labels)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=batch_size)

loss = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

n_epochs = 10
for epoch in range(n_epochs):
    for X, y in train_iter:
        optimizer.zero_grad()
        y_hat = net(X)
        l = loss(y_hat, y.view(y_hat.shape))
        l.backward()
        optimizer.step()
    
    print(f'epoch {epoch}：')
    for name, param in net.named_parameters():
        print(name, 'data:', param.data)
        print(name, 'grad:', param.grad)

epoch 0：
0.weight data: tensor([[-0.3811, -0.0793,  0.2477, -0.4919],
        [-0.4294, -0.0051,  0.0947, -0.1152],
        [-0.2401, -0.0871, -0.3302,  0.1950],
        [-0.4377, -0.4623, -0.1647,  0.1332],
        [-0.1423,  0.0505, -0.1142, -0.2791],
        [-0.4233, -0.2119,  0.4646,  0.3153],
        [ 0.0065, -0.1471,  0.0376,  0.1088],
        [-0.1480, -0.3733, -0.2106, -0.3045]])
0.weight grad: tensor([[-9.7476e-05, -5.8327e-05,  5.2640e-05,  1.0381e-05],
        [-9.6206e-05, -5.7567e-05,  5.1954e-05,  1.0246e-05],
        [-9.6702e-05, -5.7864e-05,  5.2221e-05,  1.0298e-05],
        [-9.6558e-05, -5.7778e-05,  5.2144e-05,  1.0283e-05],
        [-9.5922e-05, -5.7398e-05,  5.1800e-05,  1.0215e-05],
        [-9.6056e-05, -5.7477e-05,  5.1873e-05,  1.0230e-05],
        [-9.6103e-05, -5.7505e-05,  5.1898e-05,  1.0235e-05],
        [-9.6042e-05, -5.7469e-05,  5.1865e-05,  1.0228e-05]])
0.bias data: tensor([ 0.4666, -0.1522,  0.0950,  0.0188, -0.2894, -0.2218, -0.2006, -0.2326])
0