In [2]:
import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append('..')
import d2lzh_pytorch as d2l

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [39]:
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not is_training: # 如果是预测模式，直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2: # 应用在全连接层， 计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else: # 应用在卷积层中，计算通道维上（axis=1）的均值和方差
            # 需要保持X的形状以便后面可以广播运算
            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        # 训练模式下使用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
            
    Y = gamma * X_hat + beta # 拉伸和偏移
        
    return Y, moving_mean, moving_var
            

In [40]:
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数， 分别初始化为0和1
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 不参与求梯度和迭代的变量，在内存上初始化为0
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
    def forward(self, X):
        # 如果X不在内存上，将moving_mean和moving_var复制到X所在的显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta,
                                                          self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9)
        return Y

#### 使用批量归一化的LeNet

In [41]:
net = nn.Sequential(
    nn.Conv2d(1, 6, 5),
    BatchNorm(6, num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2,2),
    nn.Conv2d(6, 16, 5),
    BatchNorm(16, num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    d2l.FlattenLayer(),
    nn.Linear(16*4*4, 120),
    BatchNorm(120, num_dims=2),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    BatchNorm(84, num_dims=2),
    nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [42]:
batch_size = 256
data_dir = './Datasets/FashionMNIST'
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root = data_dir)

In [43]:
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on  cuda
epoch 1, loss 0.9671, train_acc 0.795, test acc 0.831, 12.0 sec
epoch 2, loss 0.2255, train_acc 0.864, test acc 0.800, 11.5 sec
epoch 3, loss 0.1222, train_acc 0.878, test acc 0.866, 11.5 sec
epoch 4, loss 0.0823, train_acc 0.887, test acc 0.843, 11.6 sec
epoch 5, loss 0.0622, train_acc 0.891, test acc 0.861, 11.6 sec


#### 使用pytorch模块定义的批量归一化类

In [44]:
net = nn.Sequential(
    nn.Conv2d(1, 6, 5),
    nn.BatchNorm2d(6),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(6, 16, 5),
    nn.BatchNorm2d(16),
    nn.Sigmoid(),
    nn.MaxPool2d(2, 2),
    d2l.FlattenLayer(),
    nn.Linear(16*4*4, 120),
    nn.BatchNorm1d(120),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    nn.BatchNorm1d(84),
    nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [45]:
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on  cuda
epoch 1, loss 1.3945, train_acc 0.765, test acc 0.770, 7.9 sec
epoch 2, loss 0.3079, train_acc 0.854, test acc 0.851, 7.5 sec
epoch 3, loss 0.1398, train_acc 0.873, test acc 0.842, 7.5 sec
epoch 4, loss 0.0892, train_acc 0.882, test acc 0.817, 7.5 sec
epoch 5, loss 0.0656, train_acc 0.888, test acc 0.859, 7.5 sec
