In [5]:
import torch
import torch.nn as nn

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        # 计算隐含状态的均方根
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 将隐含状态除以其均方根后重新缩放
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

rms_norm = LlamaRMSNorm(8)
x = torch.rand((4, 8))
print(x)
print(rms_norm(x))

tensor([[0.5853, 0.7699, 0.8091, 0.0658, 0.1708, 0.4693, 0.7407, 0.7925],
        [0.3276, 0.5998, 0.0996, 0.6390, 0.9192, 0.4677, 0.3422, 0.4757],
        [0.9624, 0.5708, 0.6921, 0.0715, 0.7837, 0.4021, 0.3529, 0.2157],
        [0.4720, 0.0607, 0.3729, 0.3018, 0.2765, 0.0527, 0.1086, 0.9097]])
tensor([[0.9526, 1.2530, 1.3167, 0.1071, 0.2780, 0.7637, 1.2054, 1.2898],
        [0.6120, 1.1204, 0.1861, 1.1938, 1.7172, 0.8736, 0.6392, 0.8887],
        [1.6619, 0.9856, 1.1950, 0.1235, 1.3532, 0.6944, 0.6093, 0.3725],
        [1.1383, 0.1464, 0.8994, 0.7279, 0.6670, 0.1272, 0.2619, 2.1942]],
       grad_fn=<MulBackward0>)
