# RMSNorm implementation

In [4]:
# imports

import torch 
from torch import nn
import numpy as np
import math
from typing import Optional, Tuple, Dict, List

In [22]:
# basic test
x = torch.randn(1, 2)
print(x.shape)
x = torch.sum(torch.pow(x,2), dim =0)
print(x.shape)

torch.Size([1, 2])
torch.Size([2])


In [None]:
class RMSNorm(nn.Module):
    def __init__(self, out_dim, eps=1e-6):
        super().__init__()
        self.eps = eps # eps for numerical stability
        self.g = nn.Parameter(torch.ones(out_dim))  # Learnable scale parameter

    def forward(self, x):
        rms = torch.pow(x,2).mean(dim = -1, keepdim = True) # mean of x^2
        a_hat = (torch.sqrt(x + self.eps)/ rms) 
        return a_hat * self.g

In [34]:
def test_rmsnorm():

    n = 2
    x = torch.randn(1, n)
    rmsnorm = RMSNorm(out_dim = n)
    a_hat = rmsnorm(x)

    assert a_hat.shape == (1, n)
    print(a_hat.shape)
    print("RMSNorm Shape test passed")
test_rmsnorm()

torch.Size([1, 2])
RMSNorm Shape test passed
