In [2]:
import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

X = torch.rand(size=[2, 4])
print(net(X))

tensor([[0.2839],
        [0.4507]], grad_fn=<AddmmBackward0>)


In [5]:
# 参数访问
print(net[2].state_dict())
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)

OrderedDict([('weight', tensor([[ 0.2886,  0.0419,  0.1251,  0.2836,  0.2142, -0.0825,  0.0351, -0.3389]])), ('bias', tensor([0.0422]))])
<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([0.0422], requires_grad=True)
tensor([0.0422])


In [8]:
print(*[(name, param.shape) for name, param in net[2].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])

('weight', torch.Size([1, 8])) ('bias', torch.Size([1]))
('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))


In [10]:
print(net.state_dict()['2.bias'].data)

tensor([0.0422])


In [16]:
# 从嵌套块收集参数
def block1():
    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU())

def block2():
    net = nn.Sequential()
    for i in range(4):
#         net._modules[f'block {i}'] = block1()
        net.add_module(f'block {i}', block1())
    
    return net
        
net = nn.Sequential(block2(), nn.Linear(4, 1))
print(net)

Sequential(
  (0): Sequential(
    (block 0): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block 1): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block 2): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
    (block 3): Sequential(
      (0): Linear(in_features=4, out_features=8, bias=True)
      (1): ReLU()
      (2): Linear(in_features=8, out_features=4, bias=True)
      (3): ReLU()
    )
  )
  (1): Linear(in_features=4, out_features=1, bias=True)
)


In [17]:
print(net[0][1][0].bias.data)

tensor([ 0.4291, -0.3281,  0.2791, -0.2385, -0.0843,  0.4811, -0.0657, -0.2932])


In [14]:
# 参数绑定
shared = nn.Linear(8, 8)
# nn.init.constant_(shared.weight, 0.1)
# nn.init.constant_(shared.bias, 0)
shared.weight.data[:] = 0.1
shared.bias.data[:] = 0

net = nn.Sequential(nn.Linear(4, 8),
                    shared,
                    shared,
                    nn.Linear(8, 1))

n_train = 100
true_w = torch.tensor([0.1 * i for i in range(1, 5)])
features = torch.normal(size=(n_train, 4), mean=0, std=0.01)
labels = torch.mm(features, true_w.view(-1, 1))

batch_size = 10
train_data = torch.utils.data.TensorDataset(features, labels)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=batch_size)

loss = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

n_epochs = 10
for epoch in range(n_epochs):
    for X, y in train_iter:
        optimizer.zero_grad()
        y_hat = net(X)
        l = loss(y_hat, y.view(y_hat.shape))
        l.backward()
        optimizer.step()
    
    print(f'epoch {epoch}：')
    for name, param in net.named_parameters():
        print(name, 'data:', param.data)
        print(name, 'grad:', param.grad)

epoch 0：
0.weight data: tensor([[-0.2823,  0.0223, -0.3545, -0.4966],
        [-0.4306, -0.3385, -0.0948,  0.2568],
        [ 0.0421,  0.2541,  0.1574, -0.0056],
        [ 0.3775, -0.2353,  0.1913,  0.4906],
        [-0.4121, -0.4571,  0.1292,  0.3607],
        [-0.4476, -0.2497, -0.1919, -0.1957],
        [ 0.4516, -0.1589, -0.4960, -0.3435],
        [-0.1545,  0.0650, -0.3726,  0.4620]])
0.weight grad: tensor([[ 3.0421e-05,  5.6648e-05, -1.7939e-06, -1.8932e-05],
        [ 3.0447e-05,  5.6697e-05, -1.7955e-06, -1.8948e-05],
        [ 3.0568e-05,  5.6923e-05, -1.8026e-06, -1.9024e-05],
        [ 3.0427e-05,  5.6660e-05, -1.7943e-06, -1.8936e-05],
        [ 3.0276e-05,  5.6379e-05, -1.7854e-06, -1.8842e-05],
        [ 3.0459e-05,  5.6719e-05, -1.7962e-06, -1.8955e-05],
        [ 3.0431e-05,  5.6668e-05, -1.7946e-06, -1.8939e-05],
        [ 3.0533e-05,  5.6858e-05, -1.8006e-06, -1.9002e-05]])
0.bias data: tensor([ 0.0470, -0.0183, -0.3293,  0.0361,  0.4184, -0.0554,  0.0146, -0.2501])
0