In [None]:
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not torch.is_grad_enabled():
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            mean = X.mean(dim = 0)
            var = ((X - mean) ** 2).mean(dim = 0)
        else:
            mean = X.mean(dim = (0, 2, 3))
            var = ((X - mean) ** 2).mean(dim = (0, 2, 3))
        X_hat = (X - mean) / torch.sqrt(var + eps)
        moving_mean = momentnum * mean + (1.0 - momentum) * moving_mean
        moving_var = momentnum * var + (1.0 - momentum) * moving_var
    Y = gamma * X_hat + beta
    return Y, moving_mean.data, moving_var.data

In [None]:
class BatchNorm(nn.Module):
    # num_features:全连接层输出数量或卷积层输出的通道数
    # num_dims：'2'为全连接，'4'为卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.ones(shape)
        self.moving_var = torch.zeros(shape)
    
    def forward(self, X):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        Y, self.gamma, self.beta = batch_norm(X, self.gamma, self.beta, self.moving_mean, self.moving_var,eps = 1e-5, momentum=0.9)
        return Y

In [None]:
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),
                    nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(6, 16,
                              kernel_size=5), BatchNorm(16, num_dims=4),
                    nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Flatten(), nn.Linear(16 * 4 * 4, 120),
                    BatchNorm(120, num_dims=2), nn.Sigmoid(),
                    nn.Linear(120, 84), BatchNorm(84, num_dims=2),
                    nn.Sigmoid(), nn.Linear(84, 10))