In [2]:
import torch
import torch.nn as nn

In [34]:
x = torch.randn(3,2,5,5)
x

tensor([[[[ 0.1051, -0.6054,  2.2077,  0.5354, -0.0786],
          [ 1.3807, -1.1156,  0.8601, -1.0424, -0.8200],
          [-3.1862, -1.2547,  0.0289,  0.0565,  0.2643],
          [-0.5624, -0.6070, -0.1709, -1.2398, -0.1952],
          [ 1.8700,  0.4523,  0.4213,  0.3290,  1.1165]],

         [[-1.8195,  2.0410,  0.9786,  0.2387,  0.4347],
          [-0.2031,  1.6943,  2.3971,  1.0378, -1.3072],
          [-0.4515,  1.6716, -2.5988, -1.0309,  1.1623],
          [ 0.4303, -0.4173, -0.3568, -0.4327, -1.5390],
          [ 0.8908, -0.4774,  0.5983, -0.4060,  0.4000]]],


        [[[ 1.5962, -0.8092, -0.6239,  1.1259,  0.2702],
          [ 1.8706,  0.4997,  0.0581, -0.5668,  0.7331],
          [-0.3854,  2.4003, -0.2957, -1.2770,  0.5103],
          [-0.2095, -0.5492,  1.1720, -1.4889,  2.2958],
          [-0.3748,  0.8563,  0.8263,  0.5494,  0.0653]],

         [[ 0.0499,  0.9159, -0.3219, -0.1797,  0.4378],
          [-0.5280,  0.7332,  0.3471, -0.9778, -1.8930],
          [ 0.9208,  0.

In [36]:
BN = nn.BatchNorm2d(2)
BN(x)

tensor([[[[ 4.5405e-02, -6.5842e-01,  2.1280e+00,  4.7161e-01, -1.3659e-01],
          [ 1.3089e+00, -1.1638e+00,  7.9322e-01, -1.0912e+00, -8.7098e-01],
          [-3.2148e+00, -1.3016e+00, -3.0113e-02, -2.7728e-03,  2.0310e-01],
          [-6.1581e-01, -6.6001e-01, -2.2805e-01, -1.2868e+00, -2.5212e-01],
          [ 1.7935e+00,  3.8927e-01,  3.5856e-01,  2.6712e-01,  1.0472e+00]],

         [[-1.8457e+00,  2.0053e+00,  9.4556e-01,  2.0742e-01,  4.0299e-01],
          [-2.3329e-01,  1.6595e+00,  2.3605e+00,  1.0046e+00, -1.3346e+00],
          [-4.8102e-01,  1.6369e+00, -2.6231e+00, -1.0590e+00,  1.1288e+00],
          [ 3.9858e-01, -4.4690e-01, -3.8660e-01, -4.6225e-01, -1.5659e+00],
          [ 8.5797e-01, -5.0693e-01,  5.6620e-01, -4.3564e-01,  3.6835e-01]]],


        [[[ 1.5224e+00, -8.6029e-01, -6.7676e-01,  1.0565e+00,  2.0892e-01],
          [ 1.7942e+00,  4.3625e-01, -1.2221e-03, -6.2016e-01,  6.6742e-01],
          [-4.4046e-01,  2.3188e+00, -3.5164e-01, -1.3236e+00,  4.4673

In [37]:

mean = x.mean(dim=(0, 2, 3), keepdim=True)
mean

tensor([[[[0.0593]],

         [[0.0307]]]])

In [38]:
import torch
from torch import nn
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下，直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况，计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况，计算通道维上（axis=1）的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下，用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data


In [None]:
class BatchNorm(nn.Module):
    # num_features：完全连接层的输出数量或卷积层的输出通道数。
    # num_dims：2表示完全连接层，4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上，将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y
