## [Definition](https://superkogito.github.io/blog/2020/04/30/rms_normalization.html)

In [76]:
import torch
import torch.nn as nn

class RMSNorm_(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
        Root Mean Square Layer Normalization
        :param d: model size,
        :param p: partial size, default [0., 1.]
        :param eps: epsilon value, default 1e-8
        :param bias: is bias needed.     
        """
        super(RMSNorm_, self).__init__()
        
        self.d = d
        self.p = p
        self.eps = eps
        self.bias = bias
        
        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)
            
        if bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)
    
    def forward(self, x: torch.Tensor):
        # ignore p when it is not useful
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            dx = self.d
        else:
            partial_size = int(self.d*self.p)
            partial_x, _ = x.split([partial_size, self.d - partial_size], dim=-1)
            
            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            dx = partial_size
        
        assert dx > 0, "The partial d_x should be positive"
        rms_x = norm_x*dx**(-1./2)
        x_normed = x/(rms_x + self.eps)
        if self.bias:
            return self.scale*x_normed + self.offset 
        return self.scale*x_normed

In [26]:
"""
Official Implementation
"""
class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
            Root Mean Square Layer Normalization
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super(RMSNorm, self).__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size
        assert d_x > 0, "The partial d_x should be positive"
        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset

        return self.scale * x_normed

In [75]:
import random

for i in range(10):
    d = random.randint(10, 20)
    p = random.random()
    norm = RMSNorm(d, p)
    norm_ = RMSNorm_(d, p)
    x = torch.randn(d)
    x_1, x_2 = norm_(x), norm(x)
    assert torch.equal(x_1, x_2)

In [28]:
x = torch.Tensor(list(range(8)))
norm(x)
# x

tensor([0.0000, 0.5345, 1.0690, 1.6036, 2.1381, 2.6726, 3.2071, 3.7417],
       grad_fn=<MulBackward0>)