# 自定义层

In [1]:
import torch
import torch.nn as nn

## 不含模型参数的自定义层

In [2]:
class ReduceMeanLayer(nn.Module):
    def __init__(self, **kwargs):
        super(ReduceMeanLayer, self).__init__(**kwargs)
        
    def forward(self, x):
        
        return x - x.mean()      

In [3]:
net = nn.Sequential(nn.Linear(4, 2), 
                    ReduceMeanLayer())
print(net)

Sequential(
  (0): Linear(in_features=4, out_features=2, bias=True)
  (1): ReduceMeanLayer()
)


In [4]:
x = torch.ones(4, 4)
y = net(x)
print(y.mean().item())

0.0


## 包含模型参数的自定义层

### 使用 ParameterList

In [18]:
class CustomListDense(nn.Module):
    def __init__(self, n):
        super(CustomListDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(n_in, n_out)) for n_in, n_out in zip(n[:-1], n[1:])]) # weiths
        self.params.append(nn.Parameter(torch.zeros(n[-1])))# bias
    
    def forward(self, x):
        for l in range(len(self.params) - 1):
            x = torch.mm(x, self.params[l])
        x = x + self.params[-1] # add bias
        
        return x     

In [20]:
net = CustomListDense([10, 32, 10])
print(net)

CustomListDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 10x32]
      (1): Parameter containing: [torch.FloatTensor of size 32x2]
      (2): Parameter containing: [torch.FloatTensor of size 2]
  )
)


### 使用 ParameterDict

In [23]:
class CustomDictDense(nn.Module):
    def __init__(self, n):
        super(CustomDictDense, self).__init__()
        self.params1 = nn.ParameterDict({
            'weight': nn.Parameter(torch.randn(n[0], n[1])),
            'bias': nn.Parameter(torch.zeros(n[1]))
        })
        self.params2 = nn.ParameterDict({
            'weight': nn.Parameter(torch.randn(n[1], n[2])),
            'bias': nn.Parameter(torch.zeros(n[2]))
        })
    
    def forward(self, x):
        x = torch.mm(x, self.params1['weight']) +  self.params1['bias']
        x = torch.mm(x, self.params2['weight']) +  self.params2['bias']
        return x

In [24]:
net = CustomDictDense([10, 32, 2])
print(net)

CustomDictDense(
  (params1): ParameterDict(
      (bias): Parameter containing: [torch.FloatTensor of size 32]
      (weight): Parameter containing: [torch.FloatTensor of size 10x32]
  )
  (params2): ParameterDict(
      (bias): Parameter containing: [torch.FloatTensor of size 2]
      (weight): Parameter containing: [torch.FloatTensor of size 32x2]
  )
)


**组合自定义层**

In [27]:
net = nn.Sequential(
    CustomListDense([10, 32, 10]),
    CustomDictDense([10, 16, 2])
)
print(net)

Sequential(
  (0): CustomListDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 10x32]
        (1): Parameter containing: [torch.FloatTensor of size 32x10]
        (2): Parameter containing: [torch.FloatTensor of size 10]
    )
  )
  (1): CustomDictDense(
    (params1): ParameterDict(
        (bias): Parameter containing: [torch.FloatTensor of size 16]
        (weight): Parameter containing: [torch.FloatTensor of size 10x16]
    )
    (params2): ParameterDict(
        (bias): Parameter containing: [torch.FloatTensor of size 2]
        (weight): Parameter containing: [torch.FloatTensor of size 16x2]
    )
  )
)


In [28]:
x = torch.randn(20, 10)
y = net(x)
print(y)

tensor([[-112.1467,   18.4507],
        [  76.1977,  200.5879],
        [ 142.6300, -108.2325],
        [ 150.3695, -208.1327],
        [  88.2661,  102.7652],
        [ -26.9573,  -38.4968],
        [ 119.7113,   83.5226],
        [-141.3333, -191.5485],
        [ 168.4373,   96.2810],
        [ -84.7361,  317.9265],
        [ -97.1944,   79.0032],
        [  55.8315,  217.4391],
        [  -8.5072,  208.8832],
        [ 226.3479,   41.3147],
        [ 167.2700, -135.3372],
        [  41.5516, -103.4281],
        [ -19.0259, -141.9671],
        [ -30.7916,  -64.6191],
        [-115.7281,  -82.0340],
        [  77.6595, -217.5599]], grad_fn=<AddBackward0>)
