In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# Transformer Encoder Block
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, src,key_padding_mask,attn_mask=None):
        src2 = self.self_attn(src, src, src,attn_mask=attn_mask, key_padding_mask=key_padding)[0]
        src = self.norm1(src + src2)
        src2 = self.linear2(F.relu(self.linear1(src)))
        src = self.norm2(src + self.dropout(src2))
        return src




class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
    
    def forward(self, x):
        seq_len = x.size(1)
        # Create a lower triangular mask to prevent attending to future tokens
        mask = torch.tril(torch.ones(seq_len, seq_len)).to(x.device)  # (seq_len, seq_len)
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return self.attn(x, x, x, attn_mask=mask)[0]


class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048):
        super().__init__()
        self.self_attn = CausalSelfAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x,encoder_output):
        x2 = self.self_attn(x,encoder_output,encoder_output)[0]
        x = self.norm1(x + x2)  # Residual connection
        x2 = self.linear2(torch.relu(self.linear1(x)))
        x = self.norm2(x + self.dropout(x2))  # Another residual connection
        return x

# Full Transformer Model
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([TransformerEncoderLayer(d_model, nhead) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x)
        return self.fc_out(x)

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.decoder_layers = nn.ModuleList([TransformerDecoderLayer(d_model, nhead) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x)
        return self.fc_out(x)


# Model parameters
vocab_size = 30000  # Urdu vocabulary size
d_model = 512
nhead = 8
num_layers = 6

# Instantiate model
model = CustomTransformer(vocab_size, d_model, nhead, num_layers)

model
# # Dummy input
# dummy_input = torch.randint(0, vocab_size, (10, 32))  # (seq_length, batch_size)
# output = model(dummy_input)

# print("Output shape:", output.shape)  # Expected: (seq_length, batch_size, vocab_size)
