In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from torchtune.modules import RMSNorm



In [6]:
@dataclass
class ModelArgs:
    #Hyperparameters

    block_size = 512
    batch_size = 64
    embeddings_dims = 2048
    attn_dropout = 0.1
    no_of_heads = 32 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    max_lr = 3e-4
    no_of_decoder_layers = 32 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.1
    beta_1 = 0.9
    beta_2 = 0.95
    clip = 1.0
    device = 'cuda'

In [5]:
class RMENorm(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        self.rmsnorm_layer = RMSNorm(dim=embeddings_dims)
        
    def forward(x, self):
        
        x = self.rmsnorm_layer(x)
        return x
        

In [8]:
class RotaryEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.batch_size
    ):
        
        self.matrix = torch.ones((block_size, embeddings_dims, embeddings_dims), device=ModelArgs.device, requires_grad=False)
        self.theta = 0   
        for pos in range(block_size):
            for j in range(1, embeddings_dims // 2):
                self.theta = 10000 ** (-2*(pos-1) / embeddings_dims)
                self.matrix[pos, 2*j + 1, 2*j + 1] = torch.cos((pos* self.theta))
                self.matrix[pos, 2*j + 1, j + 1] = -torch.sin((pos* self.theta))
                self.matrix[pos, 2*j , 2*j ] = -torch.cos((pos* self.theta))
                self.matrix[pos, 2*j + 1, 2*j + 1] = torch.sin((pos* self.theta))
      
    def forward(self, x):
        return self.matrix(x)

In [9]:
class RotaryAttentionHead(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        no_of_heads: int = ModelArgs.no_of_heads
    ):
        self.head_size = embeddings_dims / no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims)
        
    def forward(self,x):
        query = self.query(x)
        key = self.key(x)
        values = self.value(x)
        
        batch, block_size, embeddings_dims = x.shape
        masked = torch.tril((block_size, block_size))
        rotary_query = self.rotary_matrix(query)
        rotary_key = self.rotary_matrix(key)
        weights = rotary_query @ (torch.transpose(rotary_key, dim0=-2, dim1=-1))
        weights_masked = torch.masked_fill(weights, masked == 0, float('-inf'))
        scaled_weights = weights_masked / (torch.sqrt(self.head_size))
        scaled_weights = F.softmax(scaled_weights, dim=-1)
        out = scaled_weights @ values
        
        return out
    
    
    
    
    
    
    
    