In [1]:
# 自定义层
# 深度学习的一个魅力在于神经网络中各式各样的层，本节介绍如何使用module来自定义层，从而可以被重复调用

import torch
from torch import nn


In [4]:
class CenteredLayer(nn.Module):
  def __init__(self, **kwargs):
    super(CenteredLayer,self).__init__(**kwargs)

  def forward(self, x):
    return x - x.mean()


In [6]:
layer = CenteredLayer()
layer (torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))


tensor([-2., -1.,  0.,  1.,  2.])

In [7]:
net = nn.Sequential(nn.Linear(8,128),CenteredLayer())

y = net(torch.rand(4,8))
y.mean().item()

-9.313225746154785e-10

In [8]:
# Parameter 类是Tensor子类，如果一个Tensor是Parame，那么它会自动被添加到模型的参数列表中
# ParameterList 接受一个Parameter实例的列表作为输入然后得到一个参数列表
# 也可以使用append和extend在列表后面新增参数

class MyDense(nn.Module):
  def __init__(self):
    super(MyDense,self).__init__()
    self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])
    self.params.append(nn.Parameter(torch.randn(4,1)))

  def forward(self, x):
    for i in range(len(self.params)):
      x = torch.mm(x, self.params[i])
    return x


In [9]:
net = MyDense()
print(net)

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.float32 of size 4x4]
      (1): Parameter containing: [torch.float32 of size 4x4]
      (2): Parameter containing: [torch.float32 of size 4x4]
      (3): Parameter containing: [torch.float32 of size 4x1]
  )
)


In [10]:
# 而ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典，可以按照字典的规则使用
class MyDictDense(nn.Module):
  def __init__(self):
    super(MyDictDense, self).__init__()
    self.params = nn.ParameterDict(
      {
        "linear1": nn.Parameter(torch.randn(4, 4)),
        "linear2": nn.Parameter(torch.randn(4, 2)),
      }
    )
    self.params.update({"linear3": nn.Parameter(torch.randn(4, 1))})
  
  def forward(self,x,choice = "linear1"):
    return torch.mm(x,self.params[choice])



In [11]:
net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x2]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [12]:
x = torch.ones(1,4)
print(net(x,"linear1"))
print(net(x,"linear2"))
print(net(x,"linear3"))

tensor([[-0.1947,  2.7465, -0.8175, -2.5497]], grad_fn=<MmBackward0>)
tensor([[ 1.1543, -0.7281]], grad_fn=<MmBackward0>)
tensor([[1.8090]], grad_fn=<MmBackward0>)


In [14]:
net = nn.Sequential(
  MyDictDense(),
  MyDense(),
)
print(net)
print(net(x))


Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x2]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.float32 of size 4x4]
        (1): Parameter containing: [torch.float32 of size 4x4]
        (2): Parameter containing: [torch.float32 of size 4x4]
        (3): Parameter containing: [torch.float32 of size 4x1]
    )
  )
)
tensor([[3.1968]], grad_fn=<MmBackward0>)
