In [None]:
import torch , math
import torch.nn as nn
import torch.nn.functional as F

# ffn
class FFN(nn.Module):
    def __init__(self , hid_dim , embed , dropout : float):
        super().__init__()        # we project in high dims
        #we will use manual instead of linear projection , we will be using transpose to match architecture of eco-system
        self.ln1 = nn.Parameter(torch.randn(hid_dim , embed)) #(hid_dim , embed)
        self.ln1B = nn.Parameter(torch.randn(hid_dim))

        self.ln2 = nn.Parameter(torch.randn(embed , hid_dim))
        self.ln2B = nn.Parameter(torch.randn(embed))

        nn.init.xavier_uniform_(self.ln1)
        nn.init.zeros_(self.ln1B)

        nn.init.xavier_uniform_(self.ln2)
        nn.init.zeros_(self.ln2B)

        self.dropout = nn.Dropout(dropout)
        self.act = nn.ReLU()

    def forward(self , x):
        # x.shape = [batch , seq , embed]
        hidden = self.act(x @ self.ln1.T + self.ln1B)# [batch , seq , hid]
        hidden = self.dropout(hidden)

        output = hidden @ self.ln2.T + self.ln2B

        return output
    
class LayerNormalizaton(nn.Module):
    def __init__(self , features , eps = 1e-4):
        super().__init__()
        self.alpha = nn.Parameter(torch.randn(features))
        self.bias = nn.Parameter(torch.randn(features))
        self.eps = eps

    def forward(self , x):
        mean = x.mean(dim = -1 , keepdim = True)
        std = x.std(dim = -1 , keepdim = True)

        return (x - mean) / (std + self.eps) * self.alpha + self.bias
    
class Residual(nn.Module):
    def __init__(self , features , dropout : float):
        super().__init__()
        self.layer1 = LayerNormalizaton(features)
        self.dropout = nn.Dropout(dropout)

    def forward(self , x , sublayer):
        return x + self.dropout(sublayer(self.layer1(x)))
    

class Embedding(nn.Module):
    def __init__(self , embed , vocab):
        super().__init__()
        self.embed = nn.Embedding(vocab , embed)
        self.embed_dim = embed
    
    def forward(self , x):
        output = self.embed(x) * math.sqrt(self.embed_dim)
        return output

class PositionalEmbedding(nn.Module):
    def __init__(self , embed , seq_len , dropout : float):
        super().__init__()
        self.embed = embed
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        pos = torch.arange(0 , seq_len , dtype = torch.float32).unsqueeze(1)
        term = torch.exp(torch.arange(0 , embed , 2).float() * -math.log(10000) / embed)
        pe = torch.zeros(seq_len , embed)

        pe[: , 0::2] = torch.sin(pos * term)
        pe[: , 1::2] = torch.cos(pos * term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe' , pe)

    def forward(self , x):
        return self.dropout(x + (self.pe[: , :x.shape[1] , :].requires_grad_(False)))
    
class MulticlassAttention(nn.Module):
    def __init__(self , embed , num_head , dropout : float):
        super().__init__()
        self.embed = embed
        self.num_head = num_head
        assert (embed % num_head == 0) , "dusra try kr" 

        self.dk = embed // num_head

        self.q = nn.Parameter(torch.randn(embed , embed))
        self.q_bias = nn.Parameter(torch.randn(embed))

        self.k = nn.Parameter(torch.randn(embed , embed))
        self.k_bias = nn.Parameter(torch.randn(embed))

        self.v = nn.Parameter(torch.randn(embed , embed))
        self.v_bias = nn.Parameter(torch.randn(embed))

        self.o = nn.Parameter(torch.randn(embed , embed))
        self.o_bias = nn.Parameter(torch.randn(embed))

        self.dropout = nn.Dropout(dropout)

        for name in [self.q , self.k , self.v , self.o]:
            nn.init.xavier_uniform_(name)
        for name in [self.q_bias , self.k_bias , self.v_bias , self.o_bias]:
            nn.init.zeros_(name)
    
    @staticmethod
    def attention(q , k , v , mask , dropout , pastlayer):
        dk = q.size(-1)

        if pastlayer is not None:
            k_ , v_ = pastlayer
            k = torch.cat([k_ , k] , dim = -2)
            v = torch.cat([v_ , v] , dim = -2)
        
        present = (k , v)
        
        scores = (q @ k.transpose(-2 , -1)) / math.sqrt(dk)
        if mask is not None:
            print(f"  original mask shape: {mask.shape}")
            
            current_kv_len = k.size(-2)
            current_q_len = q.size(-2)
            
            if pastlayer is not None:
                causal_mask = torch.tril(torch.ones(current_kv_len, current_kv_len, device=scores.device))

                causal_mask = causal_mask[-current_q_len:, :].unsqueeze(0).unsqueeze(0)
                if mask.size(-1) >= current_kv_len:
                    
                    padding_mask = mask[..., -current_kv_len:]
                else:
                    padding_mask = torch.ones_like(mask[..., :current_kv_len])
                causal_mask = causal_mask.expand(mask.size(0), -1, -1, -1)
                
                # Combine masks
                mask = causal_mask * padding_mask
                print(f"  combined mask shape: {mask.shape}")
                
            else:
                if mask.size(-1) != current_kv_len:
                    mask = mask[..., :current_kv_len]
                
                if mask.dim() == 3:
                    mask = mask.unsqueeze(1)
            
            scores = scores.masked_fill(mask == 0, -1e9)
        
        max_ = torch.max(scores , dim = -1 , keepdim = True)[0]
        sc = torch.exp(scores - max_)
        scores = sc / torch.sum(sc , dim = -1 , keepdim = True)

        if dropout is not None:
            scores = dropout(scores)
        
        return (scores @ v), present
        

    def forward(self , q , k , v , mask = None , pastlayer = None): #batch , seq , embed
        q = q @ self.q + self.q_bias
        k = k @ self.k + self.k_bias
        v = v @ self.v + self.v_bias

        batch = q.size(0)
        seq = q.size(1)
        k_seq = k.size(1)

        query = q.view(batch , seq , self.num_head , self.dk).permute(0 , 2 , 1 , 3)
        key = k.view(batch , k_seq , self.num_head , self.dk).permute(0 , 2 , 1 , 3)
        value = v.view(batch , k_seq , self.num_head , self.dk).permute(0 , 2 , 1 , 3)

        attn , present = MulticlassAttention.attention(query , key , value , mask , self.dropout , pastlayer)

        attn = attn.permute(0, 2, 1, 3).contiguous().view(batch , seq , self.embed)

        return attn @ self.o + self.o_bias , present


class EncoderBlock(nn.Module):
    def __init__(self , attn , ffn , feat , dropout : float):
        super().__init__()
        self.attn = attn
        self.ffn = ffn
        self.res1 = Residual(feat , dropout)
        self.res2 = Residual(feat , dropout)

    def forward(self , x , mask):
        selfattn , _ = self.attn(x , x , x , mask , pastlayer = None)
        x = self.res1(x , lambda _: selfattn)
        x = self.res2(x , lambda a: self.ffn(a))
        return x
    
class Encoder(nn.Module):
    def __init__(self , feat , layers):
        super().__init__()
        self.norm = LayerNormalizaton(feat)
        self.layers = layers
    
    def forward(self , x , mask):
        for layer in self.layers:
            x = layer(x , mask)

        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self , selfattn , crossattn , ffn , feat , dropout : float):
        super().__init__()
        self.selfattn = selfattn
        self.crossattn = crossattn
        self.ffn = ffn
        self.res1 = Residual(feat , dropout)
        self.res2 = Residual(feat , dropout)
        self.res3 = Residual(feat , dropout)
    
    def forward(self , x , encoderout , selfmask , crossmask , pastlayer):
        attnpresent , _ = pastlayer if pastlayer else (None , None)

        selfattn , selfpresent = self.selfattn(x , x , x , selfmask , attnpresent)
        crossattn , _ = self.crossattn(x , encoderout , encoderout , crossmask , pastlayer = None)

        x = self.res1(x , lambda _ : selfattn)
        x = self.res2(x , lambda _ : crossattn)
        x = self.res3(x , lambda x : self.ffn(x))

        return x , (selfpresent , None)
    
class Decoder(nn.Module):
    def __init__(self , feat , layers):
        super().__init__()
        self.norm = LayerNormalizaton(feat)
        self.layers = layers
    
    def forward(self , x , enc , tgt_mask , src_mask , pastvalues):
        new = []

        for i , layer in enumerate(self.layers):
            past = pastvalues[i] if pastvalues else None
            x , layerpast = layer(x , enc , tgt_mask , src_mask , past)

            new.append(layerpast)
        
        return self.norm(x) , new

class Projection(nn.Module):
    def __init__(self , embed , vocab):
        super().__init__()
        self.linear = nn.Parameter(torch.randn(vocab , embed))
        self.bias = nn.Parameter(torch.randn(vocab))

        nn.init.xavier_uniform_(self.linear)
        nn.init.zeros_(self.bias)

    def forward(self , x):
        out = x @ self.linear.T + self.bias
        return out
    
class Transformer(nn.Module):
    def __init__(self , encoder : Encoder , decoder : Decoder , src_emb , tgt_emb , src_pos , tgt_pos , proj_layer):
        super().__init__()
        self.encoder_ = encoder
        self.decoder_ = decoder
        self.src_emb = src_emb
        self.src_pos = src_pos
        self.tgt_emb = tgt_emb
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer 
    

    def encode(self, src, src_mask):
        src = self.src_emb(src)
        src = self.src_pos(src)
        return self.encoder_(src, src_mask)
    
    def decode(self, tgt, encoder_out, tgt_mask, src_mask, pastvalues=None):
        tgt = self.tgt_emb(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder_(tgt, encoder_out, tgt_mask, src_mask, pastvalues)
    
    def projection(self, x):
        return self.proj_layer(x)

def build(src_vocab , tgt_vocab , src_seq , tgt_seq , embed = 512 , num_head = 8 , num_layer = 6 , dropout =  float(0.1) , hid_dim = 2048):
    src_embed = Embedding(embed = embed , vocab = src_vocab)
    tgt_embed = Embedding(embed = embed , vocab = tgt_vocab)

    src_pos = PositionalEmbedding(embed = embed , seq_len = src_seq , dropout = dropout)
    tgt_pos = PositionalEmbedding(embed = embed , seq_len = tgt_seq , dropout = dropout)

    enc_blocks = []

    for _ in range(num_layer):
        enc_attn = MulticlassAttention(embed , num_head , dropout )
        ffn = FFN(hid_dim , embed , dropout)
        enc_block = EncoderBlock(attn = enc_attn , ffn = ffn , feat = embed , dropout = dropout)
        enc_blocks.append(enc_block)

    dec_blocks = []

    for _ in range(num_layer):
        dec_selfAttn = MulticlassAttention(embed , num_head , dropout)
        dec_crossAttn = MulticlassAttention(embed , num_head , dropout)
        ffn = FFN(hid_dim , embed , dropout)
        dec_block = DecoderBlock(selfattn = dec_selfAttn , crossattn = dec_crossAttn , ffn = ffn , feat = embed , dropout = dropout)
        dec_blocks.append(dec_block)

    encoder = Encoder(feat = embed , layers = nn.ModuleList(enc_blocks))
    decoder = Decoder(feat = embed , layers = nn.ModuleList(dec_blocks))

    proj_layer = Projection(embed = embed , vocab = tgt_vocab)

    transformer = Transformer(encoder = encoder , decoder = decoder , src_emb = src_embed , tgt_emb = tgt_embed , src_pos = src_pos , tgt_pos = tgt_pos , proj_layer = proj_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer
        
transformer = build(src_vocab = 50 , tgt_vocab = 50 , src_seq = 8 , tgt_seq = 8 )

In [None]:
import math

def kaiming_initializer(self):
    for name, param in self.named_parameters():
        if 'weight' in name.lower() or 'wq' in name or 'wk' in name or 'wv' in name or 'wo' in name:
            if param.dim() > 1:  
                fan_in = param.size(1)  
                std = math.sqrt(2.0 / fan_in)
                param.data.normal_(0, std)
            else: 
                pass
        elif 'bias' in name.lower() or name.endswith('b'):
            if param.dim() == 1:  
                param.data.fill_(0.0)

def xavier_initializer(self):
    for name, param in self.named_parameters():
        if 'weight' in name.lower() or 'wq' in name or 'wk' in name or 'wv' in name or 'wo' in name:
            if param.dim() > 1:  
                fan_in = param.size(1)
                fan_out = param.size(0)
                std = math.sqrt(2.0 / (fan_in + fan_out))
                param.data.normal_(0, std)
            else:
                pass
        elif 'bias' in name.lower() or 'qb' in name or 'kb' in name or 'vb' in name:  
            if param.dim() == 1:
                param.data.fill_(0.0)

all tests are generated by AI to test 

In [2]:
import torch
import torch.nn as nn
import math

# Assuming your model is already built and initialized
batch = 1
seq_len = 5

# Create the model first
model = build(
    src_vocab=30,
    tgt_vocab=30,
    src_seq=seq_len,
    tgt_seq=seq_len,
    embed=512,
    num_head=8,
    num_layer=6,
    dropout=0.1,
    hid_dim=2048
)

# Create necessary masks
src_mask = torch.ones(batch, 1, 1, seq_len)  # Encoder mask (no padding)
tgt_mask_full = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)

print("="*60)
print("STEP 1: PREFILL (Processing the prompt)")
print("="*60)

# --- STEP 1: PREFILL (The Prompt) ---
input_5 = torch.randint(0, 30, (batch, seq_len))
print(f"Input shape: {input_5.shape}")

enc_out = model.encode(input_5, src_mask)
print(f"Encoder output shape: {enc_out.shape}")

output_5, cache_5 = model.decode(
    input_5, 
    enc_out, 
    tgt_mask_full, 
    src_mask, 
    pastvalues=None
)

print(f"Step 1 Output shape: {output_5.shape}")
print(f"Cache size: {cache_5[0][0][0].shape[2]} tokens")  # Should be 5

print("\n" + "="*60)
print("STEP 2: GENERATION (The 'Aha!' Moment)")
print("="*60)

# --- STEP 2: GENERATION ---
# Take the last token from output_5 OR a random token
logits = model.projection(output_5[:, -1:, :]) # Shape: (1, 1, 30)

# 2. Now pick the token from the 30 possibilities
input_1 = logits.argmax(dim=-1) # Shape: (1, 1)  # Use model's prediction
# OR for testing: input_1 = torch.randint(0, 30, (batch, 1))

print(f"Step 2 Input shape: {input_1.shape}")

# Mask for single token - allows attending to ALL past tokens (5 previous + current = 6)
# Shape: [batch, 1, 1, current_total_length]
current_total_length = cache_5[0][0][0].shape[2] + 1  # 5 + 1 = 6
tgt_mask_inference = torch.ones(batch, 1, 1, current_total_length)
print(f"Inference mask shape: {tgt_mask_inference.shape}")

# Decode ONLY the new token with cache
output_1, cache_6 = model.decode(
    input_1,                # Single token [1, 1]
    enc_out,                # Same encoder output
    tgt_mask_inference,      # Mask for attending to all 6 tokens
    src_mask,                # Same source mask
    pastvalues=cache_5       # Cache from step 1
)

print(f"\n--- TEST RESULTS ---")
print(f"Step 2 Output Shape: {output_1.shape}")     # Should be (1, 1, 512)
print(f"Cache before: {cache_5[0][0][0].shape[2]} tokens")  # 5
print(f"Cache after: {cache_6[0][0][0].shape[2]} tokens")   # Should be 6

# Verify cache grew correctly
if cache_6[0][0][0].shape[2] == cache_5[0][0][0].shape[2] + 1:
    print("\n✅ SUCCESS: Cache grew from 5 to 6 tokens!")
else:
    print("\n❌ FAILURE: Cache didn't grow correctly")


STEP 1: PREFILL (Processing the prompt)
Input shape: torch.Size([1, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
Encoder output shape: torch.Size([1, 5, 512])
  original mask shape: torch.Size([1, 1, 5, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 5, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 5, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 5, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 5, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 5, 5])
  original mask shape: torch.Size([1, 1, 1, 5])
Step 1 Output shap

In [3]:
# Starting prompt (already processed in Step 1)
current_input = input_1 
current_cache = cache_6
generated_tokens = []

for _ in range(10):
    # 1. Get the current total length for the mask
    total_len = current_cache[0][0][0].shape[2] + 1
    mask = torch.ones(1, 1, 1, total_len)
    
    # 2. Forward pass
    out, next_cache = model.decode(current_input, enc_out, mask, src_mask, current_cache)
    
    # 3. Project to vocab and pick next token
    logits = model.projection(out)
    next_token = logits.argmax(dim=-1)
    
    # 4. Update for next iteration
    generated_tokens.append(next_token.item())
    current_input = next_token
    current_cache = next_cache

print(f"Generated Token IDs: {generated_tokens}")

  original mask shape: torch.Size([1, 1, 1, 7])
  combined mask shape: torch.Size([1, 1, 1, 7])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 7])
  combined mask shape: torch.Size([1, 1, 1, 7])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 7])
  combined mask shape: torch.Size([1, 1, 1, 7])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 7])
  combined mask shape: torch.Size([1, 1, 1, 7])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 7])
  combined mask shape: torch.Size([1, 1, 1, 7])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 7])
  combined mask shape: torch.Size([1, 1, 1, 7])
  original mask shape: torch.Size([1, 1, 1, 5])
  original mask shape: torch.Size([1, 1, 1, 8])
  combined mask shape: torch.Size([1, 1, 1, 8])
  original mask shape: torch.Size([1, 1,

In [129]:
import math

def kaiming_initializer(self):
    for name, param in self.named_parameters():
        if 'weight' in name.lower() or 'wq' in name or 'wk' in name or 'wv' in name or 'wo' in name:
            if param.dim() > 1:  
                fan_in = param.size(1)  
                std = math.sqrt(2.0 / fan_in)
                param.data.normal_(0, std)
            else: 
                pass
        elif 'bias' in name.lower() or name.endswith('b'):
            if param.dim() == 1:  
                param.data.fill_(0.0)

def xavier_initializer(self):
    for name, param in self.named_parameters():
        if 'weight' in name.lower() or 'wq' in name or 'wk' in name or 'wv' in name or 'wo' in name:
            if param.dim() > 1:  
                fan_in = param.size(1)
                fan_out = param.size(0)
                std = math.sqrt(2.0 / (fan_in + fan_out))
                param.data.normal_(0, std)
            else:
                pass
        elif 'bias' in name.lower() or 'qb' in name or 'kb' in name or 'vb' in name:  
            if param.dim() == 1:
                param.data.fill_(0.0)