In [2]:
import torch
import torch.nn as nn

### 注意不管是LayerNorm还是RMSNorm，求均值、方差等计算方式时，都是针对最后一维度。就相当于有batch_size*seq_len个样本，每个样本要保持均值为0方差为1。

In [13]:
# 实现LayerNorm
class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init_()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps

        # 创建可学习的缩放参数gamma和偏移参数beta
        # nn.Parameter 会将它们注册为模型的参数，这样在训练时可以被优化器更新
        self.gamma = nn.Parameter(torch.ones(self.normalized_shape))
        self.beta = nn.Parameter(torch.zeros(self.normalized_shape))

    def forward(self, x):
        # x.shape = [batch_size, seq_len, embedding_dim]
        dims = tuple(range(x.dim() - len(self.normalized_shape), x.dim()))
        print("dims:", dims)
        # 计算均值和方差
        mean = x.mean(dims, keepdim=True)
        var = x.var(dims, keepdim=True, unbiased=False)
        # 归一化
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        # 缩放和偏移
        output = self.gamma * x_normalized + self.beta

        return output

In [18]:
# 实现RMSNorm
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.dim = dim
        self.eps = eps
        # 创建可学习的缩放参数gamma
        self.gamma = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        # 实现x / sqrt( (1/n) * sum(x_i^2) + eps )
        # torch.rsqrt()计算1/sqrt()
        rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x * rms
    
    def forward(self, x):
        output = self._norm(x)
        # 缩放gamma
        return output * self.gamma

In [19]:
x = torch.randn(2, 3, 4)
embedding_dim = x.shape[-1]

layer_norm = nn.LayerNorm(embedding_dim)
output = layer_norm(x)

print(output.shape)

rms_norm = RMSNorm(embedding_dim)
output = rms_norm(x)

print(output.shape)

torch.Size([2, 3, 4])
torch.Size([2, 3, 4])
