In [1]:
import torch
import torch.nn as nn

In [18]:
class RMSNORM(nn.Module):
    def __init__(self,emb_dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        means = x.pow(2).mean(dim = -1, keepdim = True)
        x_normed = x / torch.sqrt(self.eps + means)
        return (x_normed * self.weight).to(dtype=x.dtype)


In [21]:
torch.manual_seed(123)

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNORM(emb_dim=example_batch.shape[-1])

rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)


In [22]:
rms_norm(example_batch)

tensor([[[ 0.8834, -0.4655, -0.7948, -1.5398],
         [ 0.8053,  1.5254, -0.5074, -0.8759],
         [-0.1619, -0.4049,  0.7016, -1.8214]],

        [[ 1.3348, -0.1232,  0.9638, -1.1288],
         [-0.8511, -1.5285, -0.4420,  0.8624],
         [ 0.9398,  0.2082,  1.7241, -0.3180]]], grad_fn=<MulBackward0>)

In [24]:
rmsnorm_pytorch(example_batch)

tensor([[[ 0.8834, -0.4655, -0.7948, -1.5398],
         [ 0.8053,  1.5254, -0.5074, -0.8759],
         [-0.1619, -0.4049,  0.7016, -1.8214]],

        [[ 1.3348, -0.1232,  0.9638, -1.1288],
         [-0.8511, -1.5285, -0.4420,  0.8624],
         [ 0.9398,  0.2082,  1.7241, -0.3180]]], grad_fn=<MulBackward0>)

In [36]:
class Silu(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):    
        value = torch.sigmoid(x) * x 
        return value
        

In [37]:
fn = Silu()
value = fn(example_batch)

In [38]:
value

tensor([[[ 0.1969, -0.0810, -0.1289, -0.2100],
         [ 0.2044,  0.4354, -0.0978, -0.1541],
         [-0.0739, -0.1610,  0.4642, -0.2549]],

        [[ 1.6484, -0.0799,  1.0913, -0.2686],
         [-0.2459, -0.2767, -0.1628,  0.5480],
         [ 0.7297,  0.1228,  1.5791, -0.1407]]])

In [41]:
class FeedForward(nn.Module):
    def __init__(self,cfg): #cfg is the configuration here
        super().__init__()
        self.layer1 = nn.Linear(
            cfg['emb_dim'],
            cfg['hidden_dim'],
            bias=False,
            dtype=cfg["dtype"]
        )
        self.layer2 = nn.Linear(
            cfg['emb_dim'],
            cfg['hidden_dim'],
            bias=False,
            dtype=cfg["dtype"]
        )
        self.layer3 = nn.Linear(
            cfg['hidden_dim'],
            cfg['emb_dim'],
            bias=False,
            dtype=cfg["dtype"]
            
        )
        self.silu = Silu()

    def forward(self,x):
        o1 = self.layer1(x)
        o1 = self.silu(o1)
        o2 = self.layer2(x)
        o1Xo2 = torch.dot(o1 ,o2)
        o3 = self.layer3(o1Xo2)

        return o3 




In [61]:

def precompute_rope_params(head_dim, device,  theta_base = 10_000, context_length = 4096):
    assert head_dim % 2 == 0

    # theta(i) = 10000 power(-2(i-1)d)
    #where i = 1,2,3,4,5,6 .... d/2

    theta_numerator = torch.arange(0 , head_dim , 2).float()

    theta =  1.0 / ((theta_base) ** (theta_numerator)(head_dim)).to(device)

    m = torch.arange(
        context_length,
        device=device
    )

    angles = m[:None] * theta[None,:]  # the dim of angle is (context_lenght , head_dim /2)

    euler_form = torch.polar(
        torch.ones_like(
            angles
        ),
        angles
    )

    return angles


def apply_rotary_embeddings(token_to_be_applied, angle_in_polar_form,device):
      # x is the token we want to apply

      x_complex = torch.view_as_complex(
           token_to_be_applied.float().reshape(
                *token_to_be_applied.shape[:-1],-1,2
           )
      )
      #(Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
      angle_in_polar_form = angle_in_polar_form.unsqueeze(0).unsqueeze(2)

      rotate_input_embeddings = x_complex * angle_in_polar_form
      # the token element wise multiplication with the angle in polar form


      x_out = torch.view_as_real(
           rotate_input_embeddings
      )

      x_out = x_out.shape(
           *token_to_be_applied.shape
      )

      return x_out.type_as(token_to_be_applied).to(device)

