In [None]:
import torch
import torch.nn as nn


class MyBatchNorm1d(nn.Module):
    def __init__(self, num_features):
        super(MyBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        # 注册缓冲区
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # 计算批次均值和方差
            batch_mean = x.mean(dim=0)
            batch_var = x.var(dim=0, unbiased=False)
            # 更新运行时均值和方差
            self.running_mean = 0.9 * self.running_mean + 0.1 * batch_mean
            self.running_var = 0.9 * self.running_var + 0.1 * batch_var
            # 归一化
            x_hat = (x - batch_mean) / torch.sqrt(batch_var + 1e-5)
        else:
            # 使用运行时均值和方差
            x_hat = (x - self.running_mean) / torch.sqrt(self.running_var + 1e-5)
        return self.weight * x_hat + self.bias

# 使用自定义的批归一化层
model = nn.Sequential(
    nn.Linear(20, 50),
    MyBatchNorm1d(50),
    nn.ReLU(),
    nn.Linear(50, 2)
)

print(model)
print(model.state_dict())


Sequential(
  (0): Linear(in_features=20, out_features=50, bias=True)
  (1): MyBatchNorm1d()
  (2): ReLU()
  (3): Linear(in_features=50, out_features=2, bias=True)
)
OrderedDict({'0.weight': tensor([[ 1.0231e-01,  8.3599e-02,  1.1854e-01,  1.4539e-01, -1.2421e-01,
          1.5562e-01,  1.6308e-01,  2.0126e-01, -4.9486e-02, -1.4712e-01,
          1.4997e-01, -2.2033e-02, -6.8833e-02, -1.3629e-01,  3.6222e-02,
          7.9490e-02, -7.1304e-02,  1.9562e-01,  1.8270e-01,  1.9370e-01],
        [ 1.6584e-02,  1.7515e-01,  1.0991e-01, -6.7647e-02, -6.3368e-02,
         -2.1535e-01,  3.6940e-04,  1.9492e-01, -3.6339e-02, -6.1725e-02,
          2.1460e-01, -1.7542e-01, -7.4580e-02, -1.4578e-01, -2.0213e-01,
         -1.7447e-01, -1.7259e-01, -2.9973e-02, -2.0611e-01, -1.5593e-01],
        [-4.1025e-02, -2.0821e-01, -4.1230e-02,  2.0132e-01, -1.3415e-02,
          3.9165e-02, -5.1288e-02, -1.5958e-01,  4.5779e-02, -1.8148e-01,
         -3.4378e-02, -1.0283e-02,  5.4560e-02,  1.9194e-02,  1.244

AttributeError: 'Sequential' object has no attribute 'running_mean'