In [1]:
import torch
from torch import nn
from d2l import torch as d2l

def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
    '''

    :param X: 输入
    :param gamma:可学习的参数
    :param beta: 可学习的参数
    :param moving_mean: 全局的均值
    :param moving_var: 全局的方差
    :param eps:
    :param momentum:通常取0.9
    :return:
    '''
    # 推理的时候用全局的均值和方差
    # 训练的时候用批量的均值和方差
    if not torch.is_grad_enabled():
        # 不算梯度表示在做inference（推理，推论）
        X_hat = (X-moving_mean) / torch.sqrt(moving_var + eps)
    else:
        # 做训练
        # 如果有两个维度（样本数，特征），表示在全连接层
        # 如果有四个维度（样本数，通道，高，宽），表示在卷积层
        assert len(X.shape) in (2,4)
        if len(X.shape) == 2:
            # 实际上是求所有样本中每一个特征的均值和方差
            # 对每一列求均值
            mean = X.mean(dim = 0)
            var = ((X - mean) ** 2).mean(dim = 0)
        else:
            # 0:批量大小 1:通道数 2:高 3:宽
            # 求出来是(1,n,1,1)的维度
            mean = X.mean(dim = (0,2,3),kerrpdim = True)
            var = ((X - mean)**2).mean(dim = (0,2,3),keepdim = True)
        X_hat = (X - mean)/torch.sqrt(var+eps)
        # 更新全局的均值和方差
        moving_mean = momentum*moving_mean + (1.0-momentum)*mean
        moving_var - momentum*moving_var + (1.0-momentum)*var
    Y = gamma*X_hat+beta
    return Y,moving_mean.data,moving_var.data

In [2]:
class BatchNorm(nn.Module):
    def __init__(self,num_features,num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1,num_features)
        else:
            shape = (1,num_features,1,1)

        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self,X):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)

        Y,self.moving_mean,self.moving_var = batch_norm(
            X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps = 1e-5,momentum=0.9
        )
        return Y

In [3]:
# 将BN运用于LeNet
net = nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),
    nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),
    nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Flatten(),nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),
    nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),
    nn.Linear(84,10)
)

In [None]:
lr,num_epochs,batch_size = 1.0,10,256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

In [None]:
net[1].gamma.reshape((-1,)),net[1].beta.reshape((-1,)),