# 4.4 自定义层
## 4.4.1 不含模型参数的自定义层

In [35]:
import torch
from torch import nn

print(torch.__version__)

1.3.1


In [36]:
class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    def forward(self, x):
        return x - x.mean()

In [37]:
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

tensor([-2., -1.,  0.,  1.,  2.])

In [38]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [39]:
x = torch.rand(4, 8)
print(x)
y = net(x)
y.mean().item()

tensor([[0.0210, 0.4608, 0.9278, 0.4115, 0.8493, 0.6683, 0.7941, 0.1787],
        [0.0072, 0.2039, 0.7495, 0.2236, 0.8541, 0.7214, 0.8545, 0.2787],
        [0.2536, 0.4536, 0.6467, 0.4538, 0.1826, 0.0869, 0.2458, 0.2922],
        [0.2975, 0.1326, 0.5745, 0.6827, 0.4935, 0.6262, 0.4081, 0.3360]])


0.0

## 4.4.2 含模型参数的自定义层

In [40]:
class MyListDense(nn.Module):
    def __init__(self):
        super(MyListDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))

    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
net = MyListDense()
print(net)

MyListDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [41]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [42]:
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

tensor([[ 0.3321, -0.4752, -1.8057, -1.7255]], grad_fn=<MmBackward>)
tensor([[4.1323]], grad_fn=<MmBackward>)
tensor([[ 0.2826, -4.1887]], grad_fn=<MmBackward>)


In [43]:
net = nn.Sequential(
    MyDictDense(),
    MyListDense(),
)
print(net)
print(net(x))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyListDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[8.8466]], grad_fn=<MmBackward>)
