# 动手学深度学习 章节7.5 批量规范化 (batch normalization) 代码实现

In [4]:
import torch
from torch import nn

定义batch normalization层

In [5]:
class BatchNorm(nn.Module):
	def __init__(self, num_features, input_dim, **kwargs) -> None:
		super().__init__(**kwargs)

		if input_dim == 2:
			shape = (1, num_features)
		else:
			shape = (1, num_features, 1, 1)
		self.moving_mean = torch.zeros(shape)
		self.moving_var = torch.zeros(shape)

		# 对每一个特征维度重新赋予可学习的scale和bias参数
		self.gamma = nn.Parameter(torch.ones(shape))
		self.beta = nn.Parameter(torch.zeros(shape))
	

	def forward(self, X):
		if self.moving_mean.device != X.device:
			self.moving_mean = self.moving_mean.to(X.device)
			self.moving_var = self.moving_var.to(X.device)
		
		Y, self.moving_mean, self.moving_var = self.batch_norm(
			X, self.gamma, self.beta, self.moving_mean, self.moving_var,
			eps=1e-5, momentum=0.9
		)

		return Y

	
	@staticmethod
	def batch_norm(X:torch.Tensor, gamma, beta, moving_mean, moving_var, eps, momentum):
		"""
		Applies batch normalization to the input tensor.
		Args:
			X (torch.Tensor): The input tensor of shape
			gamma (torch.Tensor): The learnable scale parameter of shape
			beta (torch.Tensor): The learnable shift parameter of shape
			moving_mean (torch.Tensor): 整体数据集的均值
			moving_var (torch.Tensor): 整体数据集的方差
			eps (float): A small value added to the denominator for numerical stability.
			momentum (float): The momentum for computing the moving average of the mean and variance.
		Returns:
			torch.Tensor: The normalized tensor of the same shape as X.
		"""
		
		if not torch.is_grad_enabled():
			# 使用移动平均的均值和方差对样本进行归一化
			X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
		else:
			assert len(X.shape) in (2, 4)
			# 全连接层, X是2维
			if len(X.shape) == 2:
				mean = X.mean(dim=0)
				var = ((X - mean)**2).mean(dim=0)
			# 卷积层, X是4维
			else:
				mean = X.mean(dim=[0,2,3], keepdim=True)
				var = ((X - mean)**2).mean(dim=[0,2,3], keepdim=True)

			X_hat = (X - mean) / torch.sqrt(var + eps)

			# 使用EMA更新移动平均的均值和方差
			moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
			moving_var = momentum * moving_var + (1.0 - momentum) * var

		# 拉伸和偏移 (重新赋予均值和方差)
		Y = gamma * X_hat + beta

		return Y, moving_mean, moving_var