1.不含模型参数的自定义层

In [1]:
#基础实现
import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    
    def forward(self, x):
        return x - x.mean()
## 实例化并使用
layer = CenteredLayer()
result = layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))
print(result)

# 将自定义层集成到Sequential中
net = nn.Sequential(
    nn.Linear(8, 128),    # 有参数的线性层
    CenteredLayer()        # 无参数的自定义层
)

# 前向传播
y = net(torch.rand(4, 8))  # 输入: 4个样本，每个8维特征
print(f"输出均值: {y.mean().item()}")  # 应该接近0，因为减去了均值
# 输出: 输出均值: -x.xxxxe-xx (接近0的很小数值)

tensor([-2., -1.,  0.,  1.,  2.])
输出均值: 2.1827872842550278e-09


2.含模型参数的⾃自定义层

In [2]:
import torch
from torch import nn

class MyListDense(nn.Module):
    def __init__(self):
        super(MyListDense, self).__init__()
        # 使用ParameterList管理参数列表
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))
    
    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])  # 矩阵乘法
        return x

net = MyListDense()
print(net)

class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        # 使用ParameterDict管理参数字典
        self.params = nn.ParameterDict({
            'linear1': nn.Parameter(torch.randn(4, 4)),
            'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})  # 新增参数
    
    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])  # 根据选择使用不同的参数

net = MyDictDense()
print(net)

x = torch.ones(1, 4)
print("linear1 输出:", net(x, 'linear1'))
print("linear2 输出:", net(x, 'linear2')) 
print("linear3 输出:", net(x, 'linear3'))

# 在模型序列中使用自定义层
net = nn.Sequential(
    MyDictDense(),
    MyListDense(),
)
print(net)
print("Sequential 输出:", net(x))

MyListDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.float32 of size 4x4]
      (1): Parameter containing: [torch.float32 of size 4x4]
      (2): Parameter containing: [torch.float32 of size 4x4]
      (3): Parameter containing: [torch.float32 of size 4x1]
  )
)
MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)
linear1 输出: tensor([[-1.9645, -1.6549, -2.6726, -3.7488]], grad_fn=<MmBackward0>)
linear2 输出: tensor([[1.6141]], grad_fn=<MmBackward0>)
linear3 输出: tensor([[1.5407, 2.6891]], grad_fn=<MmBackward0>)
Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torc