In [None]:
import torch
import torch.nn as nn
import math


class FFN(nn.Module):
    def __init__(self , embed_dim , hid_dim , dropout):
        super().__init__()
        self.linear1 = nn.Parameter(torch.randn(hid_dim , embed_dim)) #(hid , embed)
        self.l1bias = nn.Parameter(torch.randn(hid_dim))

        self.dropout = nn.Dropout(dropout)

        self.linear2 = nn.Parameter(torch.randn(embed_dim , hid_dim)) #(embed , hid)
        self.l2bias = nn.Parameter(torch.randn(embed_dim))

        nn.init.xavier_uniform_(self.linear1)
        nn.init.xavier_uniform_(self.linear2)
        nn.init.zeros_(self.l1bias)
        nn.init.zeros_(self.l2bias)

        self.act = nn.ReLU()
    
    def forward(self , x): #x -> (batch , seq , embed) 
        # batch , seq , hid
        hidden = self.act(x @ self.linear1.T + self.l1bias) #(batch , seq , embed) @ (embed , hid) -> (batch , seq , hid)
        hidden = self.dropout(hidden)

        output = hidden @ self.linear2.T + self.l2bias #(batch , seq , hid) @ 
        return  output 
    # we project to higher dims to have more features to be learned. 

class LayerNormalization(nn.Module):
    def __init__(self , features , eps = 10 **-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self , x):
        mean = x.mean(dim = -1 , keepdim = True)
        std  = x.std(dim = -1 , keepdim  = True)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias
    
class Residual(nn.Module):
    def __init__(self , features , dropout):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.dropout = nn.Dropout(dropout)

    def forward(self , x , sublayer):
        return x + self.dropout(sublayer(self.norm(x))) # sublayer is layer we skip while doing connection


class Embedding(nn.Module):
    def __init__(self , embed_dim , vocab_size):
        super().__init__()
        self.embed_dim = embed_dim 
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size , embed_dim) # row , columns   

    def forward(self , x):
        out = self.embed(x) * math.sqrt(self.embed_dim)
        return  out

class PositionalEmbedding(nn.Module):
    def __init__(self , embed_dim , seq_len , dropout : float):
        super().__init__()
        self.embed_dim = embed_dim
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        position = torch.arange(0 , seq_len , dtype = torch.float).unsqueeze(1) 
        pe = torch.zeros(seq_len, embed_dim)

        # (sin(position / 10000 ** (2i / embed-dim))) lets call 10000 ** (2i / embed-dim)) term
        term = torch.exp(torch.arange(0 , embed_dim , 2).float() * -math.log(10000.0) / embed_dim)
        pe[: , 0::2] = torch.sin(position * term)
        pe[: , 1::2] = torch.cos(position * term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self , z):
        # z -> (batch , seq_len , embed_dim)
        return self.dropout(z + (self.pe[: , :z.shape[1] , :]).requires_grad_(False))

def RoPE(x, sin, cos, start_pos=0):
    seq = x.size(-2)
    head_dim = x.size(-1)
    num_head_x = x.size(1)  # This could be 8 (query) or 1 (key)
    
    embed_dim = sin.size(-1)
    total_heads = embed_dim // head_dim  # This should be 8
    
    sin_sliced = sin[start_pos : start_pos + seq]  # [seq, embed_dim]
    cos_sliced = cos[start_pos : start_pos + seq]  # [seq, embed_dim]
    
    sin_sliced = sin_sliced.view(seq, total_heads, head_dim)
    cos_sliced = cos_sliced.view(seq, total_heads, head_dim)
    
    if total_heads != num_head_x:
        sin_sliced = sin_sliced[:, :num_head_x, :]
        cos_sliced = cos_sliced[:, :num_head_x, :]
    
    sin_sliced = sin_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    cos_sliced = cos_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    
   
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    
    sin_even = sin_sliced[..., 0::2]
    sin_odd = sin_sliced[..., 1::2]
    cos_even = cos_sliced[..., 0::2]
    cos_odd = cos_sliced[..., 1::2]
    
    rotated = torch.empty_like(x)
    rotated[..., 0::2] = x_even * cos_even - x_odd * sin_even
    rotated[..., 1::2] = x_even * sin_odd + x_odd * cos_odd
    
    return rotated

def RoPE_Embed(dim , seq):
    
    term = 1.0 / (10000 ** (torch.arange(0 , dim , 2).float() / dim))
    seq_ = torch.arange(seq).float()

    emb = torch.outer(seq_ , term)
    emb = torch.cat([emb , emb] , dim = -1)

    return emb.sin() , emb.cos()

import torch.nn as nn 
import torch , math

def Rope_embed(dim , seq):

    term = 1.0 / (10000 ** (torch.arange(0 , dim , 2).float() / dim))
    seq_ = torch.arange(seq).float()
    emb = torch.outer(seq_ , term)
    # emb = torch.cat([emb , emb] , dim = -1)
    emb = torch.stack([emb, emb], dim=-1).flatten(-2)

    return emb.sin() , emb.cos()


def RoPE(emb, sin, cos, start_pos=0):
    # emb shape: (Batch, Heads, Seq, Head_Dim)
    seq_len = emb.size(-2)
    head_dim = emb.size(-1)
    

    s = sin[start_pos : start_pos + seq_len, :head_dim]
    c = cos[start_pos : start_pos + seq_len, :head_dim]
    
    s = s.unsqueeze(0).unsqueeze(0)
    c = c.unsqueeze(0).unsqueeze(0)
    
    x1 = emb[..., 0::2]
    x2 = emb[..., 1::2]
    
    res = torch.empty_like(emb)
    res[..., 0::2] = x1 * c[..., 0::2] - x2 * s[..., 0::2]
    res[..., 1::2] = x1 * s[..., 1::2] + x2 * c[..., 1::2]
    return res

def RoPE_backward(grad_output, sin, cos, start_pos=0):
    
    seq_len = grad_output.size(-2)
    head_dim = grad_output.size(-1)

    s = sin[start_pos : start_pos + seq_len, :head_dim].unsqueeze(0).unsqueeze(0)
    c = cos[start_pos : start_pos + seq_len, :head_dim].unsqueeze(0).unsqueeze(0)

    gy1 = grad_output[..., 0::2]
    gy2 = grad_output[..., 1::2]

    grad_input = torch.empty_like(grad_output)
    
    grad_input[..., 0::2] =  gy1 * c[..., 0::2] + gy2 * s[..., 1::2]
    grad_input[..., 1::2] = -gy1 * s[..., 0::2] + gy2 * c[..., 1::2]
    return grad_input

class MultiHeadAttention(nn.Module):
    def __init__(self , embed_dim , numhead , start_pos = 0):
        super().__init__()
        self.start = start_pos
        self.embed_dim = embed_dim
        self.num_head = numhead

        assert (embed_dim % numhead == 0) , "try a different number"

        self.d_k = embed_dim // numhead

        self.wq = nn.Parameter(torch.randn(embed_dim , embed_dim))
        self.wk = nn.Parameter(torch.randn(embed_dim , embed_dim))
        self.wv = nn.Parameter(torch.randn(embed_dim , embed_dim))

        self.wo = nn.Parameter(torch.randn(embed_dim , embed_dim))

        self.qb = nn.Parameter(torch.randn(embed_dim))
        self.kb = nn.Parameter(torch.randn(embed_dim))
        self.vb = nn.Parameter(torch.randn(embed_dim))

        for name in [self.wq , self.wk , self.wv]:
            nn.init.kaiming_uniform_(name)

        for name in [self.qb , self.kb , self.vb]:
            nn.init.zeros_(name)

        self.register_buffer('sin', None)
        self.register_buffer('cos', None)
    
    def precompute_emb(self , seq):
        if self.sin is None or self.sin.size(0) <= seq:
            sin , cos = Rope_embed(dim = self.embed_dim , seq = seq)
            self.register_buffer('sin' , sin)
            self.register_buffer('cos' , cos)

    def forward(self , q , k , v , mask = None , pastlayer = None):

        # shapes of all q , k , v = (batch , seq , embed)
        self.batch = q.size(0)
        self.seq = q.size(1)

        self.query = q
        self.key = k
        self.value = v

        q = q @ self.wq + self.qb
        k = k @ self.wk + self.kb
        v = v @ self.wv + self.vb

        # if pastlayer is not None:
        #     key_ , value_ = pastlayer
        #     k = torch.cat([key_ , k])
        #     v = torch.cat([value_ , v])
        
        
        present = (k , v)

    
        q_split = q.view(self.batch, self.seq, self.num_head, self.d_k)
        k_split = k.view(self.batch, self.seq, self.num_head, self.d_k)
        v_split = v.view(self.batch, self.seq, self.num_head, self.d_k)

        #  Permute to bring num_heads to the second dimension
        # shape: (batch, num_heads, seq, d_k)
        self.query_reshape = q_split.transpose(1, 2)
        self.key_reshape   = k_split.transpose(1, 2)
        self.value_reshaped = v_split.transpose(1, 2)
        

        total_seq = max(q.size(1) + self.start , k.size(1) , (pastlayer[0].size(2) if pastlayer else 0))
        self.precompute_emb(total_seq)
        self.query_reshaped = RoPE(self.query_reshape , self.sin , self.cos , self.start)
        self.key_reshaped = RoPE(self.key_reshape , self.sin , self.cos , 0)

        self.score_bf_sf = torch.matmul(self.query_reshaped , self.key_reshaped.transpose(-2 , -1)) / math.sqrt(self.d_k) 


        if mask is not None:
            current_k = self.key_reshaped.shape[-2]
            current_q = self.query_reshaped.shape[-2]
            if pastlayer is not None:
                casual_mask = torch.tril(torch.ones(current_k , current_k))
                casual_mask_shaped = casual_mask[-current_q , :].unsqueeze(0).unsqueeze(0)

                if mask.size(0) >= current_k:
                    padding_mask = mask[... , -current_k:]
                else:
                    padding_mask = torch.ones_like(mask[..., :current_k])
                
                casual_mask_shaped = casual_mask_shaped.expand(mask.size(0) , -1 , -1 , -1)
                combinedmask = casual_mask_shaped * padding_mask

            else:
                if mask.size(0) != current_k:
                    mask = mask[... , -current_k:]
                elif mask.dim() == 3:
                    mask = mask.unsqueeze(0)
                
                combinedmask = mask
            self.combined_mask = combinedmask
            self.score_bf_sf = self.score_bf_sf.masked_fill_(combinedmask == 0 , -1e9)

        max_scores = torch.max(self.score_bf_sf , dim = -1 , keepdim = True)[0]
        exp_scores = torch.exp(self.score_bf_sf - max_scores)
        sum_scores = torch.sum(exp_scores , dim = -1 , keepdim = True)

        self.scores_sf = exp_scores / sum_scores

        out = self.scores_sf @ self.value_reshaped
        # out shape (batch , num_head , seq , d_k)

        self.out = out.permute(0 , 2 , 1 , 3).contiguous().view(self.batch , self.seq , -1) #(batch , seq , embed)
    

        self.context_merged = self.out
        return self.context_merged @ self.wo , present

    
    def backward(self , dout):
        batch , seq , embed = dout.shape

    #     #(shape of dout = batch , seq , embed_dim) , (shape of self.out = batch , seq , embed)
        self.dwo = torch.einsum("bsi,bsj->ij", self.context_merged, dout)   #(batch , embed , embed)
    #     # dwo = dwo.sum(dim = 0) #(to get real weight projection across all the batches -> embed - embed)

    #     # we need to extract the weights from the current output because we will backprop on current condition
        d_context_merged = dout @ self.wo.transpose(-2 , -1) # this is the weight of current (you can say actual attn weights with softmax)
    #     # it has a shape of (batch , seq , embed)


    #     # now we need it to be in 4D (as our 'out' is in 4d when we computed across values)
        d_context = d_context_merged.view(batch , seq , self.num_head , self.d_k).permute(0 , 2 , 1 , 3)
    #     #d_context shape = batch , numhead , seq , d_k

        
    #     # now we can go for value and actualy softmax scores

    #     # self.score_sf shape = (batch , numhead , seq , seq)
        dv = self.scores_sf.transpose(-2, -1) @ d_context

    #     #dv shape (batch , numhead , seq , d_k)

    #     # we will use real reshaped value
        d_score = torch.matmul(d_context, self.value_reshaped.transpose(-2, -1))
        
        if hasattr(self, 'combined_mask') and self.combined_mask is not None:
            d_score = d_score.masked_fill(self.combined_mask == 0, 0.0)


        # now we have to backtrack to get non softmax scores 
        """
             so here we have to be careful , like the derivative is S(1 - S) 'i = j' and -SS ' i != j'  
             it can be written as diag(S) - SS as it follow this 
             and we differentiate wrt to i and j 
             we'd get Sj * (dSj - submission(Si * dSj))

             where dSj is the change in scores , submission
         """

        sum_dp = torch.sum(d_score * self.scores_sf , dim = -1 , keepdim = True)
        before_sf_score = (self.scores_sf * (d_score - sum_dp)) / math.sqrt(self.d_k)

        


        #shape for before_sf_score would be (batch , numhead , seq , seq)

        # now we have the non softmax scores now we have to backprop to get self.wq and all other and actual query and key , value proj 
        # to get their we need to get all dq , dk , dv , for this we need actual query and all 

        dq = torch.matmul(before_sf_score, self.key_reshaped)
        dk = torch.matmul(
            before_sf_score.transpose(-2, -1),
            self.query_reshaped
         )


      # dq_unrotated = RoPE(dq , -self.sin , self.cos , self.start)
       # dk_unrotated = RoPE(dk , -self.sin , self.cos , 0)

        dq_unrotated = RoPE_backward(dq , self.sin , self.cos , self.start)
        dk_unrotated = RoPE_backward(dk , self.sin , self.cos , 0)
       # dq_unrotated = dq
        # dk_unrotated = dk

        q_reshaped = dq_unrotated.transpose(2 , 1).contiguous().view(batch , seq , -1)
        k_reshaped = dk_unrotated.transpose(2 , 1).contiguous().view(batch , seq , -1)
        v_reshaped = dv.transpose(2 , 1).contiguous().view(batch , seq , -1)


        # here is little confusing because query = q @ self.wq , do we need to be careful while using the real and reshaped grad query 
        # and another thing we have to make (embed , embed) so we gonna flat it 
        real_q = self.query.view(-1 , embed)
        grad_q = q_reshaped.view(-1 , embed)
        self.dwq = (real_q.T @ grad_q)
        self.dqb = q_reshaped.sum(dim = (0 , 1))

        # same for all other 
        real_k = self.key.view(-1 , embed)
        grad_k = k_reshaped.view(-1 , embed)
        self.dwk = (real_k.T @ grad_k)
        self.dkb = k_reshaped.sum(dim = (0 , 1))

        real_v = self.value.view(-1 , embed)
        grad_v = v_reshaped.view(-1 , embed)
        self.dwv = (real_v.T @ grad_v)
        self.dvb = v_reshaped.sum(dim = (0 , 1))

        dx_q = q_reshaped @ self.wq.T
        dx_k = k_reshaped @ self.wk.T
        dx_v = v_reshaped @ self.wv.T
        return dx_q + dx_k + dx_v


class EncoderBlock(nn.Module):
    def __init__(self , attention , ffn , feat , dropout):
        super().__init__()
        self.attention = attention
        self.ffn = ffn
        self.residual1 = Residual(feat , dropout)
        self.residual2 = Residual(feat , dropout)

        
    def forward(self , x , mask):
        
        selfattn , _ = self.attention(x , x , x , mask)
        x = self.residual1(x , lambda _: selfattn)
        x = self.residual2(x , lambda x: self.ffn(x))
        return x

class Encoder(nn.Module):
    def __init__(self , features , layers):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.layer = layers

    def forward(self , x , mask):
        for layer in self.layer:
            x = layer(x , mask)
        
        return self.norm(x)
    


class DecoderBlock(nn.Module):
    def __init__(self , selfAttention , crossAttention , ffn , feat , dropout):
        super().__init__()
        self.attention = selfAttention
        self.ffn = ffn
        self.cross = crossAttention
        self.residual1 = Residual(feat , dropout)
        self.residual2 = Residual(feat , dropout)
        self.residual3 = Residual(feat , dropout)

    def forward(self , x , enoderOutput , selfmask , crossmask , pastlayergroup):
        selfpast , _ = pastlayergroup if pastlayergroup is not None else (None , None)
        
        selfattn , selfpresent = self.attention(x , x , x , selfmask , pastlayer = selfpast)
        crossattn , crosspresent = self.cross(x , enoderOutput , enoderOutput , crossmask , pastlayer = None)

        x = self.residual1(x , lambda _: selfattn)
        x = self.residual2(x , lambda _: crossattn)
        x = self.residual3(x , lambda x: self.ffn(x))
        return x , (selfpresent , None)
    
class Decoder(nn.Module):
    def __init__(self , features , layers):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self , x , enc_out , tgt_mask , src_mask , pastvalues):
        new = []
        for i , layer in enumerate(self.layers):
            past = pastvalues[i] if pastvalues else None
            x , layergrp = layer(x , enc_out , tgt_mask , src_mask , pastlayergroup = past)
            new.append(layergrp)

        return self.norm(x) , new

class Projection(nn.Module):
    def __init__(self , embed_dim , vocab_size):
        super().__init__()
        self.project = nn.Parameter(torch.randn(vocab_size , embed_dim))
        self.bias = nn.Parameter(torch.randn(vocab_size))
        nn.init.xavier_uniform_(self.project)
        nn.init.zeros_(self.bias)
        # self.project = nn.Linear(embed_dim , vocab_size)

    def forward(self , x): #(batch , seq , embed)
        return x @ self.project.T + self.bias

class Transformer(nn.Module):
    def __init__(self , encoder : Encoder , decoder : Decoder , src_emb , tgt_emb , src_pos , tgt_pos , proj_layer):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
        self.src_emb = src_emb
        self.src_pos = src_pos
        self.tgt_emb = tgt_emb
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer 
    

    def encode(self, src, src_mask):
        src = self.src_emb(src)
        src = self.src_pos(src)
        return self.encoder_(src, src_mask)
    
    def decode(self, tgt, encoder_out, tgt_mask, src_mask, pastvalues=None):
        tgt = self.tgt_emb(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder_(tgt, encoder_out, tgt_mask, src_mask, pastvalues)
    
    def projection(self, x):
        return self.proj_layer(x)

def build(src_vocab , tgt_vocab , src_seq , tgt_seq , embed_dim = 512 , num_head = 8 , num_layer = 6 , dropout =  float(0.1) , hid_dim = 2048):
    src_embed = Embedding(embed_dim = embed_dim , vocab_size = src_vocab)
    tgt_embed = Embedding(embed_dim = embed_dim , vocab_size = tgt_vocab)

    src_pos = PositionalEmbedding(embed_dim = embed_dim , seq_len = src_seq , dropout = dropout)
    tgt_pos = PositionalEmbedding(embed_dim = embed_dim , seq_len = tgt_seq , dropout = dropout)

    enc_blocks = []

    for _ in range(num_layer):
        enc_attn = MultiHeadAttention(embed_dim , num_head , dropout )
        ffn = FFN(embed_dim , hid_dim , dropout)
        enc_block = EncoderBlock(attention = enc_attn , ffn = ffn , feat = embed_dim , dropout = dropout)
        enc_blocks.append(enc_block)

    dec_blocks = []

    for _ in range(num_layer):
        dec_selfAttn = MultiHeadAttention(embed_dim , num_head , dropout )
        dec_crossAttn = MultiHeadAttention(embed_dim , num_head , dropout)
        ffn = FFN(embed_dim , hid_dim , dropout)
        dec_block = DecoderBlock(selfAttention = dec_selfAttn , crossAttention = dec_crossAttn , ffn = ffn , feat = embed_dim , dropout = dropout)
        dec_blocks.append(dec_block)

    encoder = Encoder(features = embed_dim , layers = nn.ModuleList(enc_blocks))
    decoder = Decoder(features = embed_dim , layers = nn.ModuleList(dec_blocks))

    proj_layer = Projection(embed_dim = embed_dim , vocab_size = tgt_vocab)

    transformer = Transformer(encoder = encoder , decoder = decoder , src_emb = src_embed , tgt_emb = tgt_embed , src_pos = src_pos , tgt_pos = tgt_pos , proj_layer = proj_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer
        
transformer = build(src_seq = 8 , tgt_seq = 8 , src_vocab = 50 , tgt_vocab = 50)