In [2]:
import torch
from torch import nn

In [3]:
x = torch.randn(4, 32, 512) # B, T, n_embd
x.shape

torch.Size([4, 32, 512])

In [4]:
r = nn.RMSNorm(x.shape[-1], elementwise_affine=False)
r

RMSNorm((512,), eps=None, elementwise_affine=False)

In [5]:
torch.allclose(r(x), (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True)))) 

True

In [6]:
%timeit r(x)

60.8 µs ± 5.44 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
%timeit (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True)))

43.8 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
class TestRMSNorm(nn.Module):
    def __init__(self, eps=None):
        super().__init__()
        self.eps = eps
    
    def forward(self, x):
        if self.eps is None:
            return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True))
        else:
            return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

In [9]:
rt = TestRMSNorm(1e-6)

In [10]:
torch.allclose(rt(x), r(x))

True

In [11]:
torch.allclose(rt(x), (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True))))

True

In [12]:
%timeit rt(x)

59.5 µs ± 3.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
from model import Reference_RMSNorm

In [16]:
t = Reference_RMSNorm(x.shape[-1])

In [17]:
torch.allclose(t(x), (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True))))

True

In [18]:
%timeit t(x)

70.6 µs ± 4.24 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Probably making very unfair timing comparisons here. Calling overhead clearly adds to the time taken. Better to use RMSNorm shipped by pytorch.