In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from torchtune.modules import RMSNorm

In [42]:
@dataclass
class ModelArgs:
    #Hyperparameters

    block_size = 128
    batch_size = 16
    embeddings_dims = 256
    attn_dropout = 0.1
    no_of_heads = 32 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    max_lr = 3e-4
    no_of_decoder_layers = 32 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.1
    beta_1 = 0.9
    beta_2 = 0.95
    clip = 1.0
    device = 'cpu'
    no_kv_heads = 2

In [43]:
class RMENorm(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):
        self.rmsnorm_layer = RMSNorm(dim=embeddings_dims)
        
        super().__init__()
        
    def forward(x, self):
        
        x = self.rmsnorm_layer(x)
        return x
        

In [107]:
import numpy as np
class RotaryEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size
    ):
        super().__init__()
        
        self.embeddings_dims = embeddings_dims
        self.block_size = block_size
        self.matrix = torch.zeros((self.block_size, self.embeddings_dims, self.embeddings_dims), device=ModelArgs.device, requires_grad=False)
        self.theta = 0  
        self.init_matrix(self.block_size)
        
        print("MATRXO: ", self.matrix)
        
    def init_matrix(self, seq_len):
            for pos in range(seq_len):
                for j in range(1, self.embeddings_dims // 2):
                    self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
                    self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
                    self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
                    self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
                    self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
                    
        
    def forward(self, x):
        # B,T,C = x.shape
        print("MATRIX:",x)
        if(x > self.block_size):
            return self.init_matrix(x)
            
        else:
            return self.init_matrix(self.block_size)


In [108]:
class RotaryAttentionHead(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        no_of_heads: int = ModelArgs.no_of_heads,
        attn_dropout: int = ModelArgs.attn_dropout
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        
    def forward(self,x):
        
        batch, block_size, embeddings_dims = x.shape
        query = self.query(x)
        key = self.key(x)
        values = self.value(x)
        matrix = self.rotary_matrix(block_size)
        
        masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
        rotary_query = matrix @ torch.transpose(query, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        rotary_key = matrix @ torch.transpose(key, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        weights = rotary_query @ (torch.transpose(rotary_key, dim0=-2, dim1=-1))
        weights_masked = weights.masked_fill(masked == 0, float('-inf'))
        scaled_weights = weights_masked / (torch.sqrt(self.head_size))
        scaled_weights = F.softmax(scaled_weights, dim=-1)
        value = scaled_weights @ values
        out = self.dropout(value)
        return out

In [109]:
class MQA(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        no_of_kv_heads: int = ModelArgs.no_of_heads,
        no_of_heads: int = ModelArgs.no_of_heads
    ):
        super().__init__()
        
        self.no_of_kv_heads = no_of_kv_heads
        self.no_of_q_heads = no_of_heads // no_of_kv_heads
        self.head_size = embeddings_dims // self.no_of_q_heads
        self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims)
        # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=ModelArgs.device, bias=False)
        
        
        
    def scaled_dot_product(self, q, k, v, block_size):
            
            masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
            matrix = self.rotary_matrix(block_size)
            
            rotary_query = matrix @ torch.transpose(q, dim0=1, dim1=0)
            rotary_key = matrix @ torch.transpose(k, dim0=1, dim1=0)
            weights = rotary_query @ (torch.transpose(rotary_key, dim0=-2, dim1=-1))
            weights_masked = weights.masked_fill(masked == 0, float('-inf'))
            scaled_weights = weights_masked / (k.shape[-1] ** -0.5)
            scaled_weights = F.softmax(scaled_weights, dim=-1)
            value = scaled_weights @ v
            return value
    
    def forward(self,x):
        print("MQA: ", x.shape)
        batch, block_size, embeddings_dims = x.shape
        multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False) for _ in range(self.no_of_q_heads)])
        # query = self.query(x)
            
        key = self.key(x)
        values = self.value(x)
        # matrix = self.rotary_matrix(block_size)
        # self.mqa = nn.ModuleList([
           
        # ])
        multi_query_concat = torch.cat([self.scaled_dot_product(query(x), key, values, block_size=block_size) for query in multi_query], dim=-1)
        # linear_layer_query = self.linear_layer(multi_query_concat)
        # masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
        # rotary_query = matrix @ torch.transpose(query, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # rotary_key = matrix @ torch.transpose(key, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # print(multi_query_concat.shape)
        # print(key.shape)
        # print(linear_layer_query.shape)
        # weights = linear_layer_query @ (torch.transpose(key, dim0=-2, dim1=-1))
        # weights_masked = weights.masked_fill(masked == 0, float('-inf'))
        # scaled_weights = weights_masked / (key.shape[-1] ** -0.5)
        # scaled_weights = F.softmax(scaled_weights, dim=-1)
        # value = scaled_weights @ values
        
        linear_layer= self.linear_layer(multi_query_concat)
        out = self.dropout(linear_layer)
        return out

In [110]:
class GQA(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims,
        block_size: int = ModelArgs.block_size,
        no_of_q_heads: int = ModelArgs.no_of_heads,
        no_of_kv_heads: int = ModelArgs.no_kv_heads
    ):
        super().__init__()
        
        self.head_size = embeddings_dims // no_of_q_heads
        # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.no_of_kv_heads = no_of_kv_heads
        self.no_of_q_heads = no_of_q_heads
        self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=ModelArgs.device, bias=False)
        self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
        self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_kv_heads, out_features=self.head_size, device=ModelArgs.device, bias=False)
        
    # def scaled_dot_product(self, q, k, v, block_size):
            
    #         masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
    #         weights = q @ (torch.transpose(k, dim0=-2, dim1=-1))
    #         weights_masked = weights.masked_fill(masked == 0, float('-inf'))
    #         scaled_weights = weights_masked / (k.shape[-1] ** -0.5)
    #         scaled_weights = F.softmax(scaled_weights, dim=-1)
    #         value = scaled_weights @ v
    #         return value
        
        
    def forward(self,x):
        
        batch, block_size, embeddings_dims = x.shape
        mqa = nn.ModuleList([MQA(embeddings_dims=embeddings_dims, block_size=block_size) for _ in range(self.no_of_kv_heads)])
        # query = self.query(x)
        # key = self.key(x)
        # values = self.value(x)
        # matrix = self.rotary_matrix(block_size)
        grouped_query_concat = torch.cat([group(x) for group in mqa], dim=-1)
        # linear_layer_query = self.linear_layer(multi_query_concat)
        # masked = torch.tril(torch.ones((block_size, block_size), device=ModelArgs.device, requires_grad=False))
        # rotary_query = matrix @ torch.transpose(query, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # rotary_key = matrix @ torch.transpose(key, dim0=1, dim1=0) # (B,T,C ) @ (B,T,C,C) -> (B,C,T)
        # print(multi_query_concat.shape)
        # print(key.shape)
        # print(linear_layer_query.shape)
        print(grouped_query_concat.shape)
        linear_layer= self.linear_layer(grouped_query_concat)
        out = self.dropout(linear_layer)
        return out

In [111]:

random_data = torch.randn((ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims))
mqa = GQA()
# input_data = torch.tensor()
res = mqa(random_data)
res.shape

MATRXO:  tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -1.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -1.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.5403,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.8415,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.5403,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.8415]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,

TypeError: unsupported operand type(s) for @: 'NoneType' and 'Tensor'

In [18]:
masked = torch.tril(torch.ones((ModelArgs.block_size, ModelArgs.block_size), device=ModelArgs.device, requires_grad=False))
masked.shape

torch.Size([512, 512])

In [27]:
class KVCache:
    def __init__(
        self,
        embeddings_dims: int =  ModelArgs.embeddings_dims,
        block_size: int  = ModelArgs.block_size,
        no_of_decoder_layers: int =ModelArgs.no_of_decoder_layers
    ):
        super().__init__()
        self.head_size = embeddings_dims / no_of_decoder_layers
        self.k_cache = torch.ones((block_size, embeddings_dims, self.head_size), device=ModelArgs.device, requires_grad=False)
        self.v_cache = torch.ones((block_size, embeddings_dims, self.head_size), device=ModelArgs.device, requires_grad=False)
        self.block_size = block_size,
        self.embeddings_dims = embeddings_dims
    def update(
        self,
        k: torch.tensor,
        v: torch.tensor
    ):
        self.k_cache[:self.block_size, :self.block_size] = k
        self.v_cache = v
        
    def get(self):
        

SyntaxError: incomplete input (1877103268.py, line 23)