# 1. 不含模型参数的自定义层

In [5]:
import torch
from torch import nn

In [6]:
class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    def forward(self, x):
        return x - x.mean()

In [7]:
layer = CenteredLayer()
x = torch.tensor([1,2,3,4,5], dtype = torch.float)
layer(x)

tensor([-2., -1.,  0.,  1.,  2.])

In [8]:
net = nn.Sequential(
    nn.Linear(8, 128),
    CenteredLayer()
)

y = net(torch.rand(4,8))
y.mean().item()

1.1641532182693481e-10

# 2. 含模型参数的自定义层
如果一个tensor是parameter类，则会被自动添加进模型参数列表。所以自定义含模型参数的层时，就可以将参数定义为parameter实例。还可以使用ParameterList, ParameterDict

In [28]:
# ParameterList
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for _ in range(3)])
        self.params.append(nn.Parameter(torch.randn(4,1)))
        
    def forward(self, x):
        for param in self.params:
            x = torch.mm(x, param)
#         for i in range(len(self.params)):
#             x = torch.mm(x, self.params[i])
        return x

In [29]:
net = MyDense()
net

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)

In [30]:
x = torch.rand(4,4)
net(x)

tensor([[-2.5360],
        [-0.8346],
        [-1.7269],
        [-1.4515]], grad_fn=<MmBackward>)

In [32]:
# ParameterDict实现
class MydictDense(nn.Module):
    def __init__(self):
        super(MydictDense, self).__init__()
        self.params = nn.ParameterDict({
            'linear1': nn.Parameter(torch.randn(4,4)),
            'linear2': nn.Parameter(torch.randn(4,1))
        })
        self.params['linear3'] = nn.Parameter(torch.randn(4,2))
    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MydictDense()
print(net)

MydictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [34]:
x = torch.ones(1,4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

tensor([[ 0.0505,  3.6799,  1.3257, -0.7924]], grad_fn=<MmBackward>)
tensor([[1.0167]], grad_fn=<MmBackward>)
tensor([[ 0.0413, -2.7925]], grad_fn=<MmBackward>)
