# 归一化层（Normalization）

归一化层能有效稳定训练过程，加速模型收敛。

## Layer Normalization (LN)

- 核心思想：在单个样本内部进行归一化。对于一个样本，会计算其所有特征维度的均值和方差，然后用它来归一化该样本。
- 公式：$\text{LN}(x) = \gamma \odot \frac{x-\mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$
- 应用：由 `normalized_shape` 参数决定
  - NLP 模型：输入是 `(batch_size, seq_len, embedding_dim)`，`normalized_shape` 是 `(embedding_dim, )`，在最后一个维度上做归一化
  - 图像模型：输入是 `(batch_size, channels, height, width)`，`normalized_shape` 可能设置成 `(channels, height, width)`，此时在所有特征维度上做归一化

In [None]:
# Layer Normalization
import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        '''
        Args:
            normalized_shape: 需要归一化的维度
            eps: 防止除零
            elementwise_affine: 是否使用可学习的平移和缩放参数
        '''
        super().__init__()
        self.normalized_shape = normalized_shape
        self.elementwise_affine = elementwise_affine
        self.eps = eps

        if self.elementwise_affine:
            self.gamma = nn.Parameter(torch.ones(self.normalized_shape)) # 可学习的缩放参数，初始为1
            self.beta = nn.Parameter(torch.zeros(self.normalized_shape)) # 可学习的平移参数，初始为0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 确认张量最后维度与 normalized_shape 是对应的
        assert x.shape[-len(self.normalized_shape):] == self.normalized_shape

        # 获取需要归一化的维度索引
        # e.g. NLP (batch, seq_len, embedding_dim), normalized shape: (embedding_dim, ), dims = (-1, 0)
        dims = tuple(range(-len(self.normalized_shape), 0))

        mean = x.mean(dim=dims, keepdim=True)
        var = x.var(dim=dims, keepdim=True, unbiased=False) # unbiased=False 代表计算整体方差（分母是 N，不是 N-1）

        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        if self.elementwise_affine:
            return self.gamma * x_norm + self.beta
        
        return x_norm

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [3]:
# Use Layer Normalization
batch_size, seq_len, d_model = 2, 4, 8

x_nlp = torch.randn(batch_size, seq_len, d_model)

ln_nlp = LayerNorm(normalized_shape=(d_model, ))

output_nlp = ln_nlp(x_nlp)

print("=== NLP input ===")
print("input shape: ", x_nlp.shape)
print("normalized shape: ", ln_nlp.normalized_shape)
print("output shape: ", output_nlp.shape)
print("output: ", output_nlp)

batch_size, channels, height, width = 2, 3, 4, 4

x_vision = torch.randn(batch_size, channels, height, width)

# 对最后三个维度做归一化（所有特征）(channels, height, width)

ln_vision = LayerNorm(normalized_shape=(channels, height, width))

output_vision = ln_vision(x_vision)

print("=== Vision input ===")
print("input shape: ", x_vision.shape)
print("normalized shape: ", ln_vision.normalized_shape)
print("output shape: ", output_vision.shape)
print("output: ", output_vision)

=== NLP input ===
input shape:  torch.Size([2, 4, 8])
normalized shape:  (8,)
output shape:  torch.Size([2, 4, 8])
output:  tensor([[[-1.4625, -0.6196,  1.5962,  0.6424, -0.2388, -0.7670, -0.4443,
           1.2937],
         [-0.4792,  1.0489,  0.9811,  0.4844, -0.5933, -1.5518, -1.1085,
           1.2183],
         [-0.1809, -1.2149, -0.4875,  0.9943,  0.2714, -1.6061,  1.3757,
           0.8480],
         [-1.0473,  0.7146,  2.0784, -0.3647,  0.1784, -0.1274, -0.0579,
          -1.3741]],

        [[ 1.3167,  1.1966,  0.6393, -0.6592, -0.2899, -1.1827,  0.4865,
          -1.5072],
         [ 0.4824,  1.1537,  1.1923, -0.8763, -0.6078,  0.8967, -1.6507,
          -0.5904],
         [-0.6461, -0.0729, -0.3627, -0.8746,  1.2042,  1.1846, -1.5826,
           1.1500],
         [-1.1200,  1.8385,  0.1356,  0.0567, -1.0411,  0.8595,  0.4275,
          -1.1568]]], grad_fn=<AddBackward0>)
=== Vision input ===
input shape:  torch.Size([2, 3, 4, 4])
normalized shape:  (3, 4, 4)
output shape:  

## RMSNorm (Root Mean Square Layer Normalization)

- 核心思想：是 Layer Normalization 的简化版，被 Llama 等模型采用。移除了均值中心化操作（减去均值 $\mu$），只进行方差的缩放，计算量更小
- 计算方法：对于给定的输入 X，X 是一个 $n \times d$ 的矩阵，$n$ 是批量大小，$d$是特征维度，RMSNorm 的计算可以表示为
    1. 计算每个样本的特征平方的均方根：
        $$
        \mu = \frac{1}{d}\sum_{i = 1}^d x_i^2
        $$
    2. 接着计算均方根的倒数，加上一个小常数防止除零：
        $$
        \text{RMS} = \sqrt{\frac{1}{\mu + \sigma}}
        $$
    3. 使用得到的 RMS 值对输入进行归一化，并乘可学习的权重参数 $\omega$：
        $$
        Y = X * \text{RMS} * \omega
        $$

In [None]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # 1. 计算平方的均值
        mean_of_square = x.pow(2).mean(-1, keepdim=True)

        # 2. 计算均方根的倒数
        rrms = torch.rsqrt(mean_of_square + self.eps)

        return rrms * x
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self._norm(x.float()).type_as(x) # 中间结果使用 float32 保证精度，然后转回原始类型
        return output * self.weight