In [None]:
import torch
def casual_mask(seq , pad_idx = 0):
    mask = (seq == pad_idx).unsqueeze(0).unsqueeze(0)

    #or

    mask = torch.tril((seq.size(1) , seq.size(1)))

    return mask

In [None]:
import torch , math
import torch.nn as nn
import torch.nn.functional as F

# ffn
class FFN(nn.Module):
    def __init__(self , hid_dim , embed , dropout : float):
        super().__init__()        # we project in high dims
        #we will use manual instead of linear projection , we will be using transpose to match architecture of eco-system
        self.ln1 = nn.Parameter(torch.randn(hid_dim , embed)) #(hid_dim , embed)
        self.ln1B = nn.Parameter(torch.randn(hid_dim))

        self.ln2 = nn.Parameter(torch.randn(embed , hid_dim))
        self.ln2B = nn.Parameter(torch.randn(embed))

        nn.init.xavier_uniform_(self.ln1)
        nn.init.zeros_(self.ln1B)

        nn.init.xavier_uniform_(self.ln2)
        nn.init.zeros_(self.ln2B)

        self.dropout = nn.Dropout(dropout)
        self.act = nn.ReLU()

    def forward(self , x):
        # x.shape = [batch , seq , embed]
        hidden = self.act(x @ self.ln1.T + self.ln1B)# [batch , seq , hid]
        hidden = self.dropout(hidden)

        output = hidden @ self.ln2.T + self.ln2B

        return output
    
class LayerNormalizaton(nn.Module):
    def __init__(self , features , eps = 1e-4):
        super().__init__()
        self.alpha = nn.Parameter(torch.randn(features))
        self.bias = nn.Parameter(torch.randn(features))
        self.eps = eps

    def forward(self , x):
        mean = x.mean(dim = -1 , keepdim = True)
        std = x.std(dim = -1 , keepdim = True)

        return (x - mean) / (std + self.eps) * self.alpha + self.bias
    
class Residual(nn.Module):
    def __init__(self , features , dropout : float):
        super().__init__()
        self.layer1 = LayerNormalizaton(features)
        self.dropout = nn.Dropout(dropout)

    def forward(self , x , sublayer):
        return x + self.dropout(sublayer(self.layer1(x)))
    

class Embedding(nn.Module):
    def __init__(self , embed , vocab):
        super().__init__()
        self.embed = nn.Embedding(vocab , embed)
        self.embed_dim = embed
    
    def forward(self , x):
        output = self.embed(x) * math.sqrt(self.embed_dim)
        return output

class PositionalEmbedding(nn.Module):
    def __init__(self , embed , seq_len , dropout : float):
        super().__init__()
        self.embed = embed
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        pos = torch.arange(0 , seq_len , dtype = torch.float32).unsqueeze(1)
        term = torch.exp(torch.arange(0 , embed , 2).float() * -math.log(10000) / embed)
        pe = torch.zeros(seq_len , embed)

        pe[: , 0::2] = torch.sin(pos * term)
        pe[: , 1::2] = torch.cos(pos * term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe' , pe)

    def forward(self , x):
        return self.dropout(x + (self.pe[: , :x.shape[1] , :].requires_grad_(False)))
    
class MulticlassAttention(nn.Module):
    def __init__(self , embed , num_head , dropout : float):
        super().__init__()
        self.embed = embed
        self.num_head = num_head
        assert (embed % num_head == 0) , "dusra try kr" 

        self.dk = embed // num_head

        self.q = nn.Parameter(torch.randn(embed , embed))
        self.q_bias = nn.Parameter(torch.randn(embed))

        self.k = nn.Parameter(torch.randn(embed , embed))
        self.k_bias = nn.Parameter(torch.randn(embed))

        self.v = nn.Parameter(torch.randn(embed , embed))
        self.v_bias = nn.Parameter(torch.randn(embed))

        self.o = nn.Parameter(torch.randn(embed , embed))
        self.o_bias = nn.Parameter(torch.randn(embed))

        self.dropout = nn.Dropout(dropout)

        for name in [self.q , self.k , self.v , self.o]:
            nn.init.xavier_uniform_(name)
        for name in [self.q_bias , self.k_bias , self.v_bias , self.o_bias]:
            nn.init.zeros_(name)
    
    @staticmethod
    def attention(q , k , v , mask , dropout , pastlayer):
        dk = q.size(-1)

        if pastlayer is not None:
            k_ , v_ = pastlayer
            k = torch.cat([k_ , k] , dim = -2)
            v = torch.cat([v_ , v] , dim = -2)
        
        present = (k , v)
        
        scores = (q @ k.transpose(-2 , -1)) / math.sqrt(dk)
        if mask is not None:
            
            current_kv_len = k.size(-2)
            current_q_len = q.size(-2)
            
            if pastlayer is not None:
                causal_mask = torch.tril(torch.ones(current_kv_len, current_kv_len, device=scores.device))

                causal_mask = causal_mask[-current_q_len:, :].unsqueeze(0).unsqueeze(0)
                if mask.size(-1) >= current_kv_len:
                    
                    padding_mask = mask[..., -current_kv_len:]
                else:
                    padding_mask = torch.ones_like(mask[..., :current_kv_len])
                causal_mask = causal_mask.expand(mask.size(0), -1, -1, -1)
                
                # Combine masks
                mask = causal_mask * padding_mask
                print(f"  combined mask shape: {mask.shape}")
                
            else:
                if mask.size(-1) != current_kv_len:
                    mask = mask[..., :current_kv_len]
                
                if mask.dim() == 3:
                    mask = mask.unsqueeze(1)
            
            scores = scores.masked_fill(mask == 0, -1e9)
        
        max_ = torch.max(scores , dim = -1 , keepdim = True)[0]
        sc = torch.exp(scores - max_)
        scores = sc / torch.sum(sc , dim = -1 , keepdim = True)

        if dropout is not None:
            scores = dropout(scores)
        
        return (scores @ v), present
        

    def forward(self , q , k , v , mask = None , pastlayer = None): #batch , seq , embed
        q = q @ self.q + self.q_bias
        k = k @ self.k + self.k_bias
        v = v @ self.v + self.v_bias

        batch = q.size(0)
        seq = q.size(1)
        k_seq = k.size(1)

        query = q.view(batch , seq , self.num_head , self.dk).permute(0 , 2 , 1 , 3)
        key = k.view(batch , k_seq , self.num_head , self.dk).permute(0 , 2 , 1 , 3)
        value = v.view(batch , k_seq , self.num_head , self.dk).permute(0 , 2 , 1 , 3)

        attn , present = MulticlassAttention.attention(query , key , value , mask , self.dropout , pastlayer)

        attn = attn.permute(0, 2, 1, 3).contiguous().view(batch , seq , self.embed)

        return attn @ self.o + self.o_bias , present


class EncoderBlock(nn.Module):
    def __init__(self , attn , ffn , feat , dropout : float):
        super().__init__()
        self.attn = attn
        self.ffn = ffn
        self.res1 = Residual(feat , dropout)
        self.res2 = Residual(feat , dropout)

    def forward(self , x , mask):
        selfattn , _ = self.attn(x , x , x , mask , pastlayer = None)
        x = self.res1(x , lambda _: selfattn)
        x = self.res2(x , lambda a: self.ffn(a))
        return x
    
class Encoder(nn.Module):
    def __init__(self , feat , layers):
        super().__init__()
        self.norm = LayerNormalizaton(feat)
        self.layers = layers
    
    def forward(self , x , mask):
        for layer in self.layers:
            x = layer(x , mask)

        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self , selfattn , crossattn , ffn , feat , dropout : float):
        super().__init__()
        self.selfattn = selfattn
        self.crossattn = crossattn
        self.ffn = ffn
        self.res1 = Residual(feat , dropout)
        self.res2 = Residual(feat , dropout)
        self.res3 = Residual(feat , dropout)
    
    def forward(self , x , encoderout , selfmask , crossmask , pastlayer):
        attnpresent , _ = pastlayer if pastlayer else (None , None)

        selfattn , selfpresent = self.selfattn(x , x , x , selfmask , attnpresent)
        crossattn , _ = self.crossattn(x , encoderout , encoderout , crossmask , pastlayer = None)

        x = self.res1(x , lambda _ : selfattn)
        x = self.res2(x , lambda _ : crossattn)
        x = self.res3(x , lambda x : self.ffn(x))

        return x , (selfpresent , None)
    
class Decoder(nn.Module):
    def __init__(self , feat , layers):
        super().__init__()
        self.norm = LayerNormalizaton(feat)
        self.layers = layers
    
    def forward(self , x , enc , tgt_mask , src_mask , pastvalues):
        new = []

        for i , layer in enumerate(self.layers):
            past = pastvalues[i] if pastvalues else None
            x , layerpast = layer(x , enc , tgt_mask , src_mask , past)

            new.append(layerpast)
        
        return self.norm(x) , new

class Projection(nn.Module):
    def __init__(self , embed , vocab):
        super().__init__()
        self.linear = nn.Parameter(torch.randn(vocab , embed))
        self.bias = nn.Parameter(torch.randn(vocab))

        nn.init.xavier_uniform_(self.linear)
        nn.init.zeros_(self.bias)

    def forward(self , x):
        out = x @ self.linear.T + self.bias
        return out
    
class Transformer(nn.Module):
    def __init__(self , encoder : Encoder , decoder : Decoder , src_emb , tgt_emb , src_pos , tgt_pos , proj_layer):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
        self.src_emb = src_emb
        self.src_pos = src_pos
        self.tgt_emb = tgt_emb
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer 
    

    def encode(self, src, src_mask):
        src = self.src_emb(src)
        src = self.src_pos(src)
        return self.encoder_(src, src_mask)
    
    def decode(self, tgt, encoder_out, tgt_mask, src_mask, pastvalues=None):
        tgt = self.tgt_emb(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder_(tgt, encoder_out, tgt_mask, src_mask, pastvalues)
    
    def projection(self, x):
        return self.proj_layer(x)

def build(src_vocab , tgt_vocab , src_seq , tgt_seq , embed = 512 , num_head = 8 , num_layer = 6 , dropout =  float(0.1) , hid_dim = 2048):
    src_embed = Embedding(embed = embed , vocab = src_vocab)
    tgt_embed = Embedding(embed = embed , vocab = tgt_vocab)

    src_pos = PositionalEmbedding(embed = embed , seq_len = src_seq , dropout = dropout)
    tgt_pos = PositionalEmbedding(embed = embed , seq_len = tgt_seq , dropout = dropout)

    enc_blocks = []

    for _ in range(num_layer):
        enc_attn = MulticlassAttention(embed , num_head , dropout )
        ffn = FFN(hid_dim , embed , dropout)
        enc_block = EncoderBlock(attn = enc_attn , ffn = ffn , feat = embed , dropout = dropout)
        enc_blocks.append(enc_block)

    dec_blocks = []

    for _ in range(num_layer):
        dec_selfAttn = MulticlassAttention(embed , num_head , dropout)
        dec_crossAttn = MulticlassAttention(embed , num_head , dropout)
        ffn = FFN(hid_dim , embed , dropout)
        dec_block = DecoderBlock(selfattn = dec_selfAttn , crossattn = dec_crossAttn , ffn = ffn , feat = embed , dropout = dropout)
        dec_blocks.append(dec_block)

    encoder = Encoder(feat = embed , layers = nn.ModuleList(enc_blocks))
    decoder = Decoder(feat = embed , layers = nn.ModuleList(dec_blocks))

    proj_layer = Projection(embed = embed , vocab = tgt_vocab)

    transformer = Transformer(encoder = encoder , decoder = decoder , src_emb = src_embed , tgt_emb = tgt_embed , src_pos = src_pos , tgt_pos = tgt_pos , proj_layer = proj_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer
        
transformer = build(src_vocab = 50 , tgt_vocab = 50 , src_seq = 8 , tgt_seq = 8 )

In [None]:
import torch
import torch.nn as nn
import math


class FFN(nn.Module):
    def __init__(self , embed_dim , hid_dim , dropout):
        super().__init__()
        self.linear1 = nn.Parameter(torch.randn(hid_dim , embed_dim)) #(hid , embed)
        self.l1bias = nn.Parameter(torch.randn(hid_dim))

        self.dropout = nn.Dropout(dropout)

        self.linear2 = nn.Parameter(torch.randn(embed_dim , hid_dim)) #(embed , hid)
        self.l2bias = nn.Parameter(torch.randn(embed_dim))

        nn.init.xavier_uniform_(self.linear1)
        nn.init.xavier_uniform_(self.linear2)
        nn.init.zeros_(self.l1bias)
        nn.init.zeros_(self.l2bias)

        self.act = nn.ReLU()
    
    def forward(self , x): #x -> (batch , seq , embed) 
        # batch , seq , hid
        hidden = self.act(x @ self.linear1.T + self.l1bias) #(batch , seq , embed) @ (embed , hid) -> (batch , seq , hid)
        hidden = self.dropout(hidden)

        output = hidden @ self.linear2.T + self.l2bias #(batch , seq , hid) @ 
        return  output 
    # we project to higher dims to have more features to be learned. 

class LayerNormalization(nn.Module):
    def __init__(self , features , eps = 10 **-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self , x):
        mean = x.mean(dim = -1 , keepdim = True)
        std  = x.std(dim = -1 , keepdim  = True)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias
    
class Residual(nn.Module):
    def __init__(self , features , dropout):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.dropout = nn.Dropout(dropout)

    def forward(self , x , sublayer):
        return x + self.dropout(sublayer(self.norm(x))) # sublayer is layer we skip while doing connection


class Embedding(nn.Module):
    def __init__(self , embed_dim , vocab_size):
        super().__init__()
        self.embed_dim = embed_dim 
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size , embed_dim) # row , columns   

    def forward(self , x):
        out = self.embed(x) * math.sqrt(self.embed_dim)
        return  out

class PositionalEmbedding(nn.Module):
    def __init__(self , embed_dim , seq_len , dropout : float):
        super().__init__()
        self.embed_dim = embed_dim
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        position = torch.arange(0 , seq_len , dtype = torch.float).unsqueeze(1) 
        pe = torch.zeros(seq_len, embed_dim)

        # (sin(position / 10000 ** (2i / embed-dim))) lets call 10000 ** (2i / embed-dim)) term
        term = torch.exp(torch.arange(0 , embed_dim , 2).float() * -math.log(10000.0) / embed_dim)
        pe[: , 0::2] = torch.sin(position * term)
        pe[: , 1::2] = torch.cos(position * term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self , z):
        # z -> (batch , seq_len , embed_dim)
        return self.dropout(z + (self.pe[: , :z.shape[1] , :]).requires_grad_(False))

def RoPE(x, sin, cos, start_pos=0):
    seq = x.size(-2)
    head_dim = x.size(-1)
    num_head_x = x.size(1)  # This could be 8 (query) or 1 (key)
    
    embed_dim = sin.size(-1)
    total_heads = embed_dim // head_dim  # This should be 8
    
    sin_sliced = sin[start_pos : start_pos + seq]  # [seq, embed_dim]
    cos_sliced = cos[start_pos : start_pos + seq]  # [seq, embed_dim]
    
    sin_sliced = sin_sliced.view(seq, total_heads, head_dim)
    cos_sliced = cos_sliced.view(seq, total_heads, head_dim)
    
    if total_heads != num_head_x:
        sin_sliced = sin_sliced[:, :num_head_x, :]
        cos_sliced = cos_sliced[:, :num_head_x, :]
    
    sin_sliced = sin_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    cos_sliced = cos_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    
   
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    
    sin_even = sin_sliced[..., 0::2]
    sin_odd = sin_sliced[..., 1::2]
    cos_even = cos_sliced[..., 0::2]
    cos_odd = cos_sliced[..., 1::2]
    
    rotated = torch.empty_like(x)
    rotated[..., 0::2] = x_even * cos_even - x_odd * sin_even
    rotated[..., 1::2] = x_even * sin_odd + x_odd * cos_odd
    
    return rotated

def RoPE_Embed(dim , seq):
    
    term = 1.0 / (10000 ** (torch.arange(0 , dim , 2).float() / dim))
    seq_ = torch.arange(seq).float()

    emb = torch.outer(seq_ , term)
    emb = torch.cat([emb , emb] , dim = -1)

    return emb.sin() , emb.cos()

class MulticlassAttention(nn.Module):
    def __init__(self , embed , num_head , dropout , max_seq_len = 5000):
        super().__init__()
        self.embed_dim = embed
        self.num_head = num_head

        assert embed % num_head == 0, "d_model must be divisible by num_heads"

        self.dim = embed // num_head
        self.dropout = nn.Dropout(dropout)


        self.wq = nn.Parameter(torch.randn(embed , embed))
        self.qb = nn.Parameter(torch.randn(embed))

        self.wk = nn.Parameter(torch.randn(embed , self.dim))
        self.kb = nn.Parameter(torch.randn(self.dim))

        self.wv = nn.Parameter(torch.randn(embed , self.dim))
        self.vb = nn.Parameter(torch.randn(self.dim))



        # self.wq = nn.Linear(embed_dim , embed_dim , bias = True)
        # self.wv = nn.Linear(embed_dim , embed_dim , bias = True)
        # self.wk = nn.Linear(embed_dim , embed_dim , bias = True)

        
        self.wo = nn.Parameter(torch.randn(embed , embed))
        # self.wo = nn.Linear(embed_dim , embed_dim , bias = False)

        for param in [self.wq, self.wk, self.wv, self.wo]:
            nn.init.xavier_uniform_(param)
        for bias in [self.qb, self.kb, self.vb]:
            nn.init.zeros_(bias)

        self.register_buffer('sin', None)
        self.register_buffer('cos', None)
        self.max_seq_len = max_seq_len

    def _precompute_rope(self, seq_len):
        if self.sin is None or self.sin.size(0) < seq_len:
            sin, cos = RoPE_Embed(self.embed_dim, seq_len)
            self.register_buffer('sin', sin)
            self.register_buffer('cos', cos)

    @staticmethod
    def attention(query, key, value, mask, dropout, pastlayer=None):
        d_k = query.shape[-1]

        if pastlayer is not None:
            past_k, past_v = pastlayer
            # Past first, then current
            key = torch.cat([past_k, key], dim=-2)
            value = torch.cat([past_v, value], dim=-2)

        present = (key, value)

        # Compute attention scores
        #(batch , num_head , seq , d_k(dim)) @ (batch , 1 , d_k(dim) , seq) 
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        print(f"  scores shape: {scores.shape}")

        if mask is not None:
            print(f"  original mask shape: {mask.shape}")
            
            current_kv_len = key.size(-2)
            current_q_len = query.size(-2)
            
            # CRITICAL FIX: Create proper mask for current sequence length
            if pastlayer is not None:
                # Create causal mask for the current total sequence length
                causal_mask = torch.tril(torch.ones(current_kv_len, current_kv_len, device=scores.device))

                causal_mask = causal_mask[-current_q_len:, :].unsqueeze(0).unsqueeze(0)
                if mask.size(-1) >= current_kv_len:
                    
                    padding_mask = mask[..., -current_kv_len:]
                else:
                    padding_mask = torch.ones_like(mask[..., :current_kv_len])
                causal_mask = causal_mask.expand(mask.size(0), -1, -1, -1)
                
                # Combine masks
                mask = causal_mask * padding_mask
                print(f"  combined mask shape: {mask.shape}")
                
            else:
                if mask.size(-1) != current_kv_len:
                    mask = mask[..., :current_kv_len]
                
                if mask.dim() == 3:
                    mask = mask.unsqueeze(1)
            
            scores = scores.masked_fill(mask == 0, -1e9)

        max_ = torch.max(scores , dim = -1 , keepdim = True)[0]
        sc = torch.exp(scores - max_)
        scores = sc / torch.sum(sc , dim = -1 , keepdim = True)
        # scores = scores.softmax(dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        return (scores @ value), present
    
    def forward(self , q , k , v , mask = None , pastlayer = None , start_pos = 0):
        batch = q.size(0)
        

        query = q @ self.wq + self.qb # (batch , seq_len , embed_dim)
        key = k @ self.wk + self.kb
        value = v @ self.wv + self.vb

        
        # (batch , seq_len , d_model(embed_dim)) -> (batch , seq , num_head , d_k(dim))
        # transpose (batch , seq , num_head , d_k(dim)) -> (batch , num_head , seq , d_k(dim))
        # by transposing it can work parallelly like take a batch , take head then give the seq (words)
        query = query.reshape(batch , q.size(1) , self.num_head , self.dim).permute(0 , 2 , 1 , 3)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)

        # key = key.reshape(batch , k.size(1) , 1 , self.dim).permute(0 , 2 , 1 , 3)
        # value = value.reshape(batch , v.size(1) , 1 , self.dim).permute(0 , 2 , 1 , 3)

        total_seq = max(q.size(1) + start_pos, 
                       k.size(1) + (pastlayer[0].size(2) if pastlayer else 0))
        self._precompute_rope(total_seq)
        query = RoPE(query, self.sin, self.cos, start_pos)
        key = RoPE(key, self.sin, self.cos, 0)

        out , pastlayer_ = MulticlassAttention.attention(query , key , value , mask , self.dropout , pastlayer)
        # (batch , num_head , seq , d_k) -> # (batch , seq , num_head , d_k) -> (batch , seq , d_model)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch , q.size(1) , self.embed_dim)
   
        return out @ self.wo , pastlayer_# (batch , seq , d_model) 


class EncoderBlock(nn.Module):
    def __init__(self , attention , ffn , feat , dropout):
        super().__init__()
        self.attention = attention
        self.ffn = ffn
        self.residual1 = Residual(feat , dropout)
        self.residual2 = Residual(feat , dropout)

        
    def forward(self , x , mask):
        
        selfattn , _ = self.attention(x , x , x , mask)
        x = self.residual1(x , lambda _: selfattn)
        x = self.residual2(x , lambda x: self.ffn(x))
        return x

class Encoder(nn.Module):
    def __init__(self , features , layers):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.layer = layers

    def forward(self , x , mask):
        for layer in self.layer:
            x = layer(x , mask)
        
        return self.norm(x)
    


class DecoderBlock(nn.Module):
    def __init__(self , selfAttention , crossAttention , ffn , feat , dropout):
        super().__init__()
        self.attention = selfAttention
        self.ffn = ffn
        self.cross = crossAttention
        self.residual1 = Residual(feat , dropout)
        self.residual2 = Residual(feat , dropout)
        self.residual3 = Residual(feat , dropout)

    def forward(self , x , enoderOutput , selfmask , crossmask , pastlayergroup):
        selfpast , _ = pastlayergroup if pastlayergroup is not None else (None , None)
        
        selfattn , selfpresent = self.attention(x , x , x , selfmask , pastlayer = selfpast)
        crossattn , crosspresent = self.cross(x , enoderOutput , enoderOutput , crossmask , pastlayer = None)

        x = self.residual1(x , lambda _: selfattn)
        x = self.residual2(x , lambda _: crossattn)
        x = self.residual3(x , lambda x: self.ffn(x))
        return x , (selfpresent , None)
    
class Decoder(nn.Module):
    def __init__(self , features , layers):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self , x , enc_out , tgt_mask , src_mask , pastvalues):
        new = []
        for i , layer in enumerate(self.layers):
            past = pastvalues[i] if pastvalues else None
            x , layergrp = layer(x , enc_out , tgt_mask , src_mask , pastlayergroup = past)
            new.append(layergrp)

        return self.norm(x) , new

class Projection(nn.Module):
    def __init__(self , embed_dim , vocab_size):
        super().__init__()
        self.project = nn.Parameter(torch.randn(vocab_size , embed_dim))
        self.bias = nn.Parameter(torch.randn(vocab_size))
        nn.init.xavier_uniform_(self.project)
        nn.init.zeros_(self.bias)
        # self.project = nn.Linear(embed_dim , vocab_size)

    def forward(self , x): #(batch , seq , embed)
        return x @ self.project.T + self.bias

class Transformer(nn.Module):
    def __init__(self , encoder : Encoder , decoder : Decoder , src_emb , tgt_emb , src_pos , tgt_pos , proj_layer):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
        self.src_emb = src_emb
        self.src_pos = src_pos
        self.tgt_emb = tgt_emb
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer 
    

    def encode(self, src, src_mask):
        src = self.src_emb(src)
        src = self.src_pos(src)
        return self.encoder_(src, src_mask)
    
    def decode(self, tgt, encoder_out, tgt_mask, src_mask, pastvalues=None):
        tgt = self.tgt_emb(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder_(tgt, encoder_out, tgt_mask, src_mask, pastvalues)
    
    def projection(self, x):
        return self.proj_layer(x)

def build(src_vocab , tgt_vocab , src_seq , tgt_seq , embed_dim = 512 , num_head = 8 , num_layer = 6 , dropout =  float(0.1) , hid_dim = 2048):
    src_embed = Embedding(embed_dim = embed_dim , vocab_size = src_vocab)
    tgt_embed = Embedding(embed_dim = embed_dim , vocab_size = tgt_vocab)

    src_pos = PositionalEmbedding(embed_dim = embed_dim , seq_len = src_seq , dropout = dropout)
    tgt_pos = PositionalEmbedding(embed_dim = embed_dim , seq_len = tgt_seq , dropout = dropout)

    enc_blocks = []

    for _ in range(num_layer):
        enc_attn = MulticlassAttention(embed_dim , num_head , dropout )
        ffn = FFN(embed_dim , hid_dim , dropout)
        enc_block = EncoderBlock(attention = enc_attn , ffn = ffn , feat = embed_dim , dropout = dropout)
        enc_blocks.append(enc_block)

    dec_blocks = []

    for _ in range(num_layer):
        dec_selfAttn = MulticlassAttention(embed_dim , num_head , dropout )
        dec_crossAttn = MulticlassAttention(embed_dim , num_head , dropout)
        ffn = FFN(embed_dim , hid_dim , dropout)
        dec_block = DecoderBlock(selfAttention = dec_selfAttn , crossAttention = dec_crossAttn , ffn = ffn , feat = embed_dim , dropout = dropout)
        dec_blocks.append(dec_block)

    encoder = Encoder(features = embed_dim , layers = nn.ModuleList(enc_blocks))
    decoder = Decoder(features = embed_dim , layers = nn.ModuleList(dec_blocks))

    proj_layer = Projection(embed_dim = embed_dim , vocab_size = tgt_vocab)

    transformer = Transformer(encoder = encoder , decoder = decoder , src_emb = src_embed , tgt_emb = tgt_embed , src_pos = src_pos , tgt_pos = tgt_pos , proj_layer = proj_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer
        
transformer = build(src_seq = 8 , tgt_seq = 8 , src_vocab = 50 , tgt_vocab = 50)

all tests are generated by AI to test 

In [129]:
import math

def kaiming_initializer(self):
    for name, param in self.named_parameters():
        if 'weight' in name.lower() or 'wq' in name or 'wk' in name or 'wv' in name or 'wo' in name:
            if param.dim() > 1:  
                fan_in = param.size(1)  
                std = math.sqrt(2.0 / fan_in)
                param.data.normal_(0, std)
            else: 
                pass
        elif 'bias' in name.lower() or name.endswith('b'):
            if param.dim() == 1:  
                param.data.fill_(0.0)

def xavier_initializer(self):
    for name, param in self.named_parameters():
        if 'weight' in name.lower() or 'wq' in name or 'wk' in name or 'wv' in name or 'wo' in name:
            if param.dim() > 1:  
                fan_in = param.size(1)
                fan_out = param.size(0)
                std = math.sqrt(2.0 / (fan_in + fan_out))
                param.data.normal_(0, std)
            else:
                pass
        elif 'bias' in name.lower() or 'qb' in name or 'kb' in name or 'vb' in name:  
            if param.dim() == 1:
                param.data.fill_(0.0)

In [46]:
import torch
import torch.nn as nn
import math


class FFN(nn.Module):
    def __init__(self , embed_dim , hid_dim , dropout):
        super().__init__()
        self.linear1 = nn.Parameter(torch.randn(hid_dim , embed_dim)) #(hid , embed)
        self.l1bias = nn.Parameter(torch.randn(hid_dim))

        self.dropout = nn.Dropout(dropout)

        self.linear2 = nn.Parameter(torch.randn(embed_dim , hid_dim)) #(embed , hid)
        self.l2bias = nn.Parameter(torch.randn(embed_dim))

        nn.init.xavier_uniform_(self.linear1)
        nn.init.xavier_uniform_(self.linear2)
        nn.init.zeros_(self.l1bias)
        nn.init.zeros_(self.l2bias)

        self.act = nn.ReLU()
    
    def forward(self , x): #x -> (batch , seq , embed) 
        # batch , seq , hid
        hidden = self.act(x @ self.linear1.T + self.l1bias) #(batch , seq , embed) @ (embed , hid) -> (batch , seq , hid)
        hidden = self.dropout(hidden)

        output = hidden @ self.linear2.T + self.l2bias #(batch , seq , hid) @ 
        return  output 
    # we project to higher dims to have more features to be learned. 

class LayerNormalization(nn.Module):
    def __init__(self , features , eps = 10 **-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))

    def forward(self , x):
        mean = x.mean(dim = -1 , keepdim = True)
        std  = x.std(dim = -1 , keepdim  = True)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias
    
class Residual(nn.Module):
    def __init__(self , features , dropout):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.dropout = nn.Dropout(dropout)

    def forward(self , x , sublayer):
        return x + self.dropout(sublayer(self.norm(x))) # sublayer is layer we skip while doing connection


class Embedding(nn.Module):
    def __init__(self , embed_dim , vocab_size):
        super().__init__()
        self.embed_dim = embed_dim 
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size , embed_dim) # row , columns   

    def forward(self , x):
        out = self.embed(x) * math.sqrt(self.embed_dim)
        return  out

class PositionalEmbedding(nn.Module):
    def __init__(self , embed_dim , seq_len , dropout : float):
        super().__init__()
        self.embed_dim = embed_dim
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        position = torch.arange(0 , seq_len , dtype = torch.float).unsqueeze(1) 
        pe = torch.zeros(seq_len, embed_dim)

        # (sin(position / 10000 ** (2i / embed-dim))) lets call 10000 ** (2i / embed-dim)) term
        term = torch.exp(torch.arange(0 , embed_dim , 2).float() * -math.log(10000.0) / embed_dim)
        pe[: , 0::2] = torch.sin(position * term)
        pe[: , 1::2] = torch.cos(position * term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self , z):
        # z -> (batch , seq_len , embed_dim)
        return self.dropout(z + (self.pe[: , :z.shape[1] , :]).requires_grad_(False))

def RoPE(x, sin, cos, start_pos=0):
    seq = x.size(-2)
    head_dim = x.size(-1)
    num_head_x = x.size(1)  # This could be 8 (query) or 1 (key)
    
    embed_dim = sin.size(-1)
    total_heads = embed_dim // head_dim  # This should be 8
    
    sin_sliced = sin[start_pos : start_pos + seq]  # [seq, embed_dim]
    cos_sliced = cos[start_pos : start_pos + seq]  # [seq, embed_dim]
    
    sin_sliced = sin_sliced.view(seq, total_heads, head_dim)
    cos_sliced = cos_sliced.view(seq, total_heads, head_dim)
    
    if total_heads != num_head_x:
        sin_sliced = sin_sliced[:, :num_head_x, :]
        cos_sliced = cos_sliced[:, :num_head_x, :]
    
    sin_sliced = sin_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    cos_sliced = cos_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    
   
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    
    sin_even = sin_sliced[..., 0::2]
    sin_odd = sin_sliced[..., 1::2]
    cos_even = cos_sliced[..., 0::2]
    cos_odd = cos_sliced[..., 1::2]
    
    rotated = torch.empty_like(x)
    rotated[..., 0::2] = x_even * cos_even - x_odd * sin_even
    rotated[..., 1::2] = x_even * sin_odd + x_odd * cos_odd
    
    return rotated

def RoPE_Embed(dim , seq):
    
    term = 1.0 / (10000 ** (torch.arange(0 , dim , 2).float() / dim))
    seq_ = torch.arange(seq).float()

    emb = torch.outer(seq_ , term)
    emb = torch.cat([emb , emb] , dim = -1)

    return emb.sin() , emb.cos()

class MulticlassAttention(nn.Module):
    def __init__(self , embed , num_head , dropout , max_seq_len = 5000):
        super().__init__()
        self.embed_dim = embed
        self.num_head = num_head

        assert embed % num_head == 0, "d_model must be divisible by num_heads"

        self.dim = embed // num_head
        self.dropout = nn.Dropout(dropout)


        self.wq = nn.Parameter(torch.randn(embed , embed))
        self.qb = nn.Parameter(torch.randn(embed))

        self.wk = nn.Parameter(torch.randn(embed , self.dim))
        self.kb = nn.Parameter(torch.randn(self.dim))

        self.wv = nn.Parameter(torch.randn(embed , self.dim))
        self.vb = nn.Parameter(torch.randn(self.dim))



        # self.wq = nn.Linear(embed_dim , embed_dim , bias = True)
        # self.wv = nn.Linear(embed_dim , embed_dim , bias = True)
        # self.wk = nn.Linear(embed_dim , embed_dim , bias = True)

        
        self.wo = nn.Parameter(torch.randn(embed , embed))
        # self.wo = nn.Linear(embed_dim , embed_dim , bias = False)

        for param in [self.wq, self.wk, self.wv, self.wo]:
            nn.init.xavier_uniform_(param)
        for bias in [self.qb, self.kb, self.vb]:
            nn.init.zeros_(bias)

        self.register_buffer('sin', None)
        self.register_buffer('cos', None)
        self.max_seq_len = max_seq_len

    def _precompute_rope(self, seq_len):
        if self.sin is None or self.sin.size(0) < seq_len:
            sin, cos = RoPE_Embed(self.embed_dim, seq_len)
            self.register_buffer('sin', sin)
            self.register_buffer('cos', cos)

    @staticmethod
    def attention(query, key, value, mask, dropout, pastlayer=None):
        d_k = query.shape[-1]

        if pastlayer is not None:
            past_k, past_v = pastlayer
            # Past first, then current
            key = torch.cat([past_k, key], dim=-2)
            value = torch.cat([past_v, value], dim=-2)

        present = (key, value)

        # Compute attention scores
        #(batch , num_head , seq , d_k(dim)) @ (batch , 1 , d_k(dim) , seq) 
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        print(f"  scores shape: {scores.shape}")

        if mask is not None:
            print(f"  original mask shape: {mask.shape}")
            
            current_kv_len = key.size(-2)
            current_q_len = query.size(-2)
            
            # CRITICAL FIX: Create proper mask for current sequence length
            if pastlayer is not None:
                # Create causal mask for the current total sequence length
                causal_mask = torch.tril(torch.ones(current_kv_len, current_kv_len, device=scores.device))

                causal_mask = causal_mask[-current_q_len:, :].unsqueeze(0).unsqueeze(0)
                if mask.size(-1) >= current_kv_len:
                    
                    padding_mask = mask[..., -current_kv_len:]
                else:
                    padding_mask = torch.ones_like(mask[..., :current_kv_len])
                causal_mask = causal_mask.expand(mask.size(0), -1, -1, -1)
                
                # Combine masks
                mask = causal_mask * padding_mask
                print(f"  combined mask shape: {mask.shape}")
                
            else:
                if mask.size(-1) != current_kv_len:
                    mask = mask[..., :current_kv_len]
                
                if mask.dim() == 3:
                    mask = mask.unsqueeze(1)
            
            scores = scores.masked_fill(mask == 0, -1e9)

        max_ = torch.max(scores , dim = -1 , keepdim = True)[0]
        sc = torch.exp(scores - max_)
        scores = sc / torch.sum(sc , dim = -1 , keepdim = True)
        # scores = scores.softmax(dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        return (scores @ value), present
    
    def forward(self , q , k , v , mask = None , pastlayer = None , start_pos = 0):
        batch = q.size(0)
        

        query = q @ self.wq + self.qb # (batch , seq_len , embed_dim)
        key = k @ self.wk + self.kb
        value = v @ self.wv + self.vb

        
        # (batch , seq_len , d_model(embed_dim)) -> (batch , seq , num_head , d_k(dim))
        # transpose (batch , seq , num_head , d_k(dim)) -> (batch , num_head , seq , d_k(dim))
        # by transposing it can work parallelly like take a batch , take head then give the seq (words)
        query = query.reshape(batch , q.size(1) , self.num_head , self.dim).permute(0 , 2 , 1 , 3)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)

        # key = key.reshape(batch , k.size(1) , 1 , self.dim).permute(0 , 2 , 1 , 3)
        # value = value.reshape(batch , v.size(1) , 1 , self.dim).permute(0 , 2 , 1 , 3)

        total_seq = max(q.size(1) + start_pos, 
                       k.size(1) + (pastlayer[0].size(2) if pastlayer else 0))
        self._precompute_rope(total_seq)
        query = RoPE(query, self.sin, self.cos, start_pos)
        key = RoPE(key, self.sin, self.cos, 0)

        out , pastlayer_ = MulticlassAttention.attention(query , key , value , mask , self.dropout , pastlayer)
        # (batch , num_head , seq , d_k) -> # (batch , seq , num_head , d_k) -> (batch , seq , d_model)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch , q.size(1) , self.embed_dim)
   
        return out @ self.wo , pastlayer_# (batch , seq , d_model) 


class EncoderBlock(nn.Module):
    def __init__(self , attention , ffn , feat , dropout):
        super().__init__()
        self.attention = attention
        self.ffn = ffn
        self.residual1 = Residual(feat , dropout)
        self.residual2 = Residual(feat , dropout)

        
    def forward(self , x , mask):
        
        selfattn , _ = self.attention(x , x , x , mask)
        x = self.residual1(x , lambda _: selfattn)
        x = self.residual2(x , lambda x: self.ffn(x))
        return x

class Encoder(nn.Module):
    def __init__(self , features , layers):
        super().__init__()
        self.norm = LayerNormalization(features)
        self.layer = layers

    def forward(self , x , mask):
        for layer in self.layer:
            x = layer(x , mask)
        
        return self.norm(x)
    


class DecoderBlock(nn.Module):
    def __init__(self , selfAttention , crossAttention , ffn , feat , dropout):
        super().__init__()
        self.attention = selfAttention
        self.ffn = ffn
        self.cross = crossAttention
        self.residual1 = Residual(feat , dropout)
        self.residual2 = Residual(feat , dropout)
        self.residual3 = Residual(feat , dropout)

    def forward(self , x , enoderOutput , selfmask , crossmask , pastlayergroup):
        selfpast , _ = pastlayergroup if pastlayergroup is not None else (None , None)
        
        selfattn , selfpresent = self.attention(x , x , x , selfmask , pastlayer = selfpast)
        crossattn , crosspresent = self.cross(x , enoderOutput , enoderOutput , crossmask , pastlayer = None)

        x = self.residual1(x , lambda _: selfattn)
        x = self.residual2(x , lambda _: crossattn)
        x = self.residual3(x , lambda x: self.ffn(x))
        return x , (selfpresent , None)
    
class Decoder(nn.Module):
    def __init__(self , features , layers):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self , x , enc_out , tgt_mask , src_mask , pastvalues):
        new = []
        for i , layer in enumerate(self.layers):
            past = pastvalues[i] if pastvalues else None
            x , layergrp = layer(x , enc_out , tgt_mask , src_mask , pastlayergroup = past)
            new.append(layergrp)

        return self.norm(x) , new

class Projection(nn.Module):
    def __init__(self , embed_dim , vocab_size):
        super().__init__()
        self.project = nn.Parameter(torch.randn(vocab_size , embed_dim))
        self.bias = nn.Parameter(torch.randn(vocab_size))
        nn.init.xavier_uniform_(self.project)
        nn.init.zeros_(self.bias)
        # self.project = nn.Linear(embed_dim , vocab_size)

    def forward(self , x): #(batch , seq , embed)
        return x @ self.project.T + self.bias

class Transformer(nn.Module):
    def __init__(self , encoder : Encoder , decoder : Decoder , src_emb , tgt_emb , src_pos , tgt_pos , proj_layer):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
        self.src_emb = src_emb
        self.src_pos = src_pos
        self.tgt_emb = tgt_emb
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer 
    

    def encode(self, src, src_mask):
        src = self.src_emb(src)
        src = self.src_pos(src)
        return self.encoder_(src, src_mask)
    
    def decode(self, tgt, encoder_out, tgt_mask, src_mask, pastvalues=None):
        tgt = self.tgt_emb(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder_(tgt, encoder_out, tgt_mask, src_mask, pastvalues)
    
    def projection(self, x):
        return self.proj_layer(x)

def build(src_vocab , tgt_vocab , src_seq , tgt_seq , embed_dim = 512 , num_head = 8 , num_layer = 6 , dropout =  float(0.1) , hid_dim = 2048):
    src_embed = Embedding(embed_dim = embed_dim , vocab_size = src_vocab)
    tgt_embed = Embedding(embed_dim = embed_dim , vocab_size = tgt_vocab)

    src_pos = PositionalEmbedding(embed_dim = embed_dim , seq_len = src_seq , dropout = dropout)
    tgt_pos = PositionalEmbedding(embed_dim = embed_dim , seq_len = tgt_seq , dropout = dropout)

    enc_blocks = []

    for _ in range(num_layer):
        enc_attn = MulticlassAttention(embed_dim , num_head , dropout )
        ffn = FFN(embed_dim , hid_dim , dropout)
        enc_block = EncoderBlock(attention = enc_attn , ffn = ffn , feat = embed_dim , dropout = dropout)
        enc_blocks.append(enc_block)

    dec_blocks = []

    for _ in range(num_layer):
        dec_selfAttn = MulticlassAttention(embed_dim , num_head , dropout )
        dec_crossAttn = MulticlassAttention(embed_dim , num_head , dropout)
        ffn = FFN(embed_dim , hid_dim , dropout)
        dec_block = DecoderBlock(selfAttention = dec_selfAttn , crossAttention = dec_crossAttn , ffn = ffn , feat = embed_dim , dropout = dropout)
        dec_blocks.append(dec_block)

    encoder = Encoder(features = embed_dim , layers = nn.ModuleList(enc_blocks))
    decoder = Decoder(features = embed_dim , layers = nn.ModuleList(dec_blocks))

    proj_layer = Projection(embed_dim = embed_dim , vocab_size = tgt_vocab)

    transformer = Transformer(encoder = encoder , decoder = decoder , src_emb = src_embed , tgt_emb = tgt_embed , src_pos = src_pos , tgt_pos = tgt_pos , proj_layer = proj_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer
        
transformer = build(src_seq = 8 , tgt_seq = 8 , src_vocab = 50 , tgt_vocab = 50)

In [50]:
def RoPE(x, sin, cos, start_pos=0):
    seq = x.size(-2)
    head_dim = x.size(-1)
    num_head_x = x.size(1)  # This could be 8 (query) or 1 (key)
    
    embed_dim = sin.size(-1)
    total_heads = embed_dim // head_dim  # This should be 8
    
    sin_sliced = sin[start_pos : start_pos + seq]  # [seq, embed_dim]
    cos_sliced = cos[start_pos : start_pos + seq]  # [seq, embed_dim]
    
    sin_sliced = sin_sliced.view(seq, total_heads, head_dim)
    cos_sliced = cos_sliced.view(seq, total_heads, head_dim)
    
    if total_heads != num_head_x:
        sin_sliced = sin_sliced[:, :num_head_x, :]
        cos_sliced = cos_sliced[:, :num_head_x, :]
    
    sin_sliced = sin_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    cos_sliced = cos_sliced.permute(1, 0, 2).unsqueeze(0)  # [1, num_head_x, seq, head_dim]
    
   
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    
    sin_even = sin_sliced[..., 0::2]
    sin_odd = sin_sliced[..., 1::2]
    cos_even = cos_sliced[..., 0::2]
    cos_odd = cos_sliced[..., 1::2]
    
    rotated = torch.empty_like(x)
    rotated[..., 0::2] = x_even * cos_even - x_odd * sin_even
    rotated[..., 1::2] = x_even * sin_odd + x_odd * cos_odd
    
    return rotated

def RoPE_Embed(dim, seq):
    term = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
    
    seq_ = torch.arange(seq).float()
    
    emb = torch.outer(seq_, term)
    
    emb = torch.cat([emb, emb], dim=-1)
    
    return emb.sin(), emb.cos()