# 自定义层

In [1]:
import torch
import torch.nn as nn

## 不含模型参数的自定义层

In [2]:
class ReduceMeanLayer(nn.Module):
    def __init__(self, **kwargs):
        super(ReduceMeanLayer, self).__init__(**kwargs)
        
    def forward(self, x):
        
        return x - x.mean()      

In [3]:
net = nn.Sequential(nn.Linear(4, 2), 
                    ReduceMeanLayer())
print(net)

Sequential(
  (0): Linear(in_features=4, out_features=2, bias=True)
  (1): ReduceMeanLayer()
)


In [4]:
x = torch.ones(4, 4)
y = net(x)
print(y.mean().item())

0.0


## 包含模型参数的自定义层

### 使用 ParameterList

In [18]:
class CustomListDense(nn.Module):
    def __init__(self, n):
        super(CustomListDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(n_in, n_out)) for n_in, n_out in zip(n[:-1], n[1:])]) # weiths
        self.params.append(nn.Parameter(torch.zeros(n[-1])))# bias
    
    def forward(self, x):
        for l in range(len(self.params) - 1):
            x = torch.mm(x, self.params[l])
        x = x + self.params[-1] # add bias
        
        return x     

In [20]:
net = CustomListDense([10, 32, 10])
print(net)

CustomListDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 10x32]
      (1): Parameter containing: [torch.FloatTensor of size 32x2]
      (2): Parameter containing: [torch.FloatTensor of size 2]
  )
)


### 使用 ParameterDict

In [23]:
class CustomDictDense(nn.Module):
    def __init__(self, n):
        super(CustomDictDense, self).__init__()
        self.params1 = nn.ParameterDict({
            'weight': nn.Parameter(torch.randn(n[0], n[1])),
            'bias': nn.Parameter(torch.zeros(n[1]))
        })
        self.params2 = nn.ParameterDict({
            'weight': nn.Parameter(torch.randn(n[1], n[2])),
            'bias': nn.Parameter(torch.zeros(n[2]))
        })
    
    def forward(self, x):
        x = torch.mm(x, self.params1['weight']) +  self.params1['bias']
        x = torch.mm(x, self.params2['weight']) +  self.params2['bias']
        return x

In [24]:
net = CustomDictDense([10, 32, 2])
print(net)

CustomDictDense(
  (params1): ParameterDict(
      (bias): Parameter containing: [torch.FloatTensor of size 32]
      (weight): Parameter containing: [torch.FloatTensor of size 10x32]
  )
  (params2): ParameterDict(
      (bias): Parameter containing: [torch.FloatTensor of size 2]
      (weight): Parameter containing: [torch.FloatTensor of size 32x2]
  )
)


组合自定义层

In [25]:
net = nn.Sequential(
    CustomListDense([10, 32, 10]),
    CustomDictDense([10, 16, 2])
)
print(net)

Sequential(
  (0): CustomListDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 10x32]
        (1): Parameter containing: [torch.FloatTensor of size 32x10]
        (2): Parameter containing: [torch.FloatTensor of size 10]
    )
  )
  (1): CustomDictDense(
    (params1): ParameterDict(
        (bias): Parameter containing: [torch.FloatTensor of size 16]
        (weight): Parameter containing: [torch.FloatTensor of size 10x16]
    )
    (params2): ParameterDict(
        (bias): Parameter containing: [torch.FloatTensor of size 2]
        (weight): Parameter containing: [torch.FloatTensor of size 16x2]
    )
  )
)


In [26]:
x = torch.randn(20, 10)
y = net(x)
print(y)

tensor([[-593.8731,  540.7744],
        [  64.8773,  -10.4726],
        [ -58.9383,  -10.2243],
        [-369.0108,  273.4879],
        [ -92.3642, -105.4754],
        [ 559.9776,    8.4848],
        [ 377.1218,  -41.2701],
        [ -45.4641,  490.8425],
        [ 330.3158, -199.8192],
        [ 224.3984, -465.1309],
        [ 842.5359, -429.8053],
        [-468.3011,  354.9525],
        [ 824.1006, -463.4549],
        [-125.5716,   11.0276],
        [-332.8697, -244.5576],
        [ 198.2209,  -23.7020],
        [-300.0633,  -20.0615],
        [ -36.0695, -113.7194],
        [-266.0639,  354.3293],
        [-167.6408,  586.8952]], grad_fn=<AddBackward0>)
