# Import Modules

In [53]:
import torch
import torch.nn as nn
import numpy as np
import math

# Naive Transformer Version 1

In [62]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.embedding  = nn.Embedding(vocab_size, embed_size)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.embed_size)

class PositionalEncoding(nn.Module):

    def __init__(self, embed_size:int, seq_len:int, dropout: float):
        super().__init__()
        self.embed_size = embed_size
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # Matrix shape (embed_size, seq_len)
        pe = torch.zeros(seq_len, embed_size)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # shape: (seq_len, 1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * -(math.log((1000.0))) / embed_size)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0) # Shape: (1, seq_len, embed_size)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        x = (self.pe[:, :x.size(1)].detach()).requires_grad_(False)
        return self.dropout(x)
    
class LayerNorm(nn.Module):
    
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1)) ## Multiplied
        self.bias = nn.Parameter(torch.zeros(1)) ## Added
    
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim=True)
        std = x.std(dim = -1, keepdim=True)
        x = (x - mean) / (std + self.eps)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias
    
class FeedForwardBlock(nn.Module):

    def __init__(self, embed_size: int, dff: int, dropout: float) -> None:
        super().__init__()
        self.linear1 = nn.Linear(embed_size, dff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dff, embed_size)
        self.relu = nn.ReLU()
    
    def forward (self, x):
        # (Batch, Seq_len, Embed_size) -> (Batch, Seq_len, embed_size * expansion) --> (Batch, Seq_len, Embed_size)
        return self.linear2(self.dropout(self.relu(self.linear1(x))))

class AttentionBlock(nn.Module):

    def __init__(self, embed_size: int, heads: int, dropout: float) -> None:
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads

        assert embed_size % heads == 0, "Embedding size must be divisible by heads"


        self.d_k = embed_size // heads # Dimension of vector seen by each head
        self.w_q = nn.Linear(embed_size, embed_size, bias=False) # Wq
        self.w_k = nn.Linear(embed_size, embed_size, bias=False) # Wk
        self.w_v = nn.Linear(embed_size, embed_size, bias=False) # Wv
        self.w_o = nn.Linear(embed_size, embed_size, bias=False) # Wo

        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def scaled_dot_product_attention(query, key, value, mask, dropout: nn.Dropout):
        # (Batch, Heads, sequence_length, embed_size)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.shape[-1])
        # scores = torch.einsum("bhqd,bhkd->bhqk", [query, key]) / math.sqrt(query.shape[-1])
        if mask is not None: 
            scores = scores.masked_fill(mask == 0, float('-1e20'))
        if dropout is not None:
            scores = dropout(scores)

        # (batch, heads, sequence_len, d_k) --> (batch, heads, sequence_len, sequence_len)
        attention_score = torch.softmax(scores, dim=-1) # (batch, h, seq_len, seq_len)
        out = torch.matmul(attention_score, value)
        # out = torch.einsum("bhqk, bhkd -> bhqd", [attention_score, value])
        return out, attention_score
    
    def forward(self, q, k, v, mask):
        query = self.w_q(q) 
        key = self.w_k(k) 
        value = self.w_v(v) 

        # (batch, seq_len, embed_size) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.heads, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.heads, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.heads, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = self.scaled_dot_product_attention(query, key, value, mask, self.dropout)
        
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, embed_size)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.heads * self.d_k)

        # (batch, seq_len, embed_size) --> (batch, seq_len, embed_size)  
        return self.w_o(x)
    
class ResidualConnection(nn.Module):

    def __init__(self, features: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNorm(features)
    
    def forward(self, x, sublayer):

        residual = x
        x = self.norm(x)
        x = sublayer(x)
        x = residual + self.dropout(x)

        # return x + self.dropout(sublayer(self.norm(x)))
        return x
    
class EncoderBlock(nn.Module):
    def __init__(self,
                 features: int,
                 attention_block: AttentionBlock,
                 feed_forward_block: FeedForwardBlock,
                 dropout: float,
                 ):
        super().__init__()
        self.attention_block = attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
    
    def forward(self, x, source_mask):
        x = self.residual_connections[0](x, lambda x: self.attention_block(x, x, x, source_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
    
class Encoder(nn.Module):

    def __init__(self,
                 features: int,
                 layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(
        self,
        features: int,
        attention_block: AttentionBlock,
        cross_attention_block: AttentionBlock,
        feed_forward_block: FeedForwardBlock,
        dropout: float,
    ) -> None:
        super().__init__()
        self.attention_block = attention_block
        self.feed_forward_block = feed_forward_block
        self.cross_attention_block = cross_attention_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)]
    )
    
    def forward(self, x, enc_out, source_mask, target_mask):
        x = self.residual_connections[0](x, lambda x: self.attention_block(x,x,x,target_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x,enc_out,enc_out,source_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x
    
class Decoder(nn.Module):

    def __init__(self,
                 features: int,
                 layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm(features)
    
    def forward(self, x, enc_out, source_mask, target_mask):
        for layer in self.layers:
            x = layer(x, enc_out, source_mask, target_mask)
        return self.norm(x)

class FullyConnected(nn.Module):
    def __init__(self, embed_size: int, target_vocab_size: int) -> None:
        super().__init__()
        self.fc = nn.Linear(embed_size, target_vocab_size)
    
    def forward(self, x):
        x = self.fc(x)
        x = torch.log_softmax(x, dim=-1)

        # return torch.log_softmax(self.fc(x), dim=-1)
        return x
    
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, input_embedding, target_embedding,
                 positional_encoding, target_positional_encoding, fc):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.input_embedding = input_embedding
        self.target_embedding = target_embedding
        self.positional_encoding = positional_encoding
        self.target_positional_encoding = target_positional_encoding
        self.fc = fc

    def encode(self, source, source_mask):
        source = self.input_embedding(source)
        source = self.positional_encoding(source)
        return self.encoder(source, source_mask)

    def decode(self, target, enc_out, source_mask, target_mask):
        target = self.target_embedding(target)
        target = self.target_positional_encoding(target)
        return self.decoder(target, enc_out, source_mask, target_mask)

    def forward(self, source, target, source_mask, target_mask):
        enc_out = self.encode(source, source_mask)
        dec_out = self.decode(target, enc_out, source_mask, target_mask)
        return self.fc(dec_out)


def my_transformer(source_vocab_size: int, 
                   target_vocab_size: int, 
                   source_seq_len: int, 
                   target_seq_len: int,
                   embed_size: int = 512, 
                   Nx: int = 6, 
                   heads: int = 8, 
                   dff: int = 2048, 
                   dropout: float = 0.1):
    
    source_embed = InputEmbeddings(source_vocab_size, embed_size)
    target_embed = InputEmbeddings(target_vocab_size, embed_size)
    source_pe = PositionalEncoding(embed_size, source_seq_len, dropout)
    target_pe = PositionalEncoding(embed_size, target_seq_len, dropout)

    encoder_blocks = nn.ModuleList([
        EncoderBlock(embed_size,
                     AttentionBlock(embed_size, heads, dropout),
                     FeedForwardBlock(embed_size, dff, dropout),
                     dropout)
        for _ in range(Nx)
    ])

    decoder_blocks = nn.ModuleList([
        DecoderBlock(embed_size,
                     AttentionBlock(embed_size, heads, dropout),
                     AttentionBlock(embed_size, heads, dropout),
                     FeedForwardBlock(embed_size, dff, dropout),
                     dropout)
        for _ in range(Nx)
    ])

    encoder = Encoder(embed_size, encoder_blocks)
    decoder = Decoder(embed_size, decoder_blocks)
    fc = FullyConnected(embed_size, target_vocab_size)

    model = Transformer(encoder, decoder, source_embed, target_embed, source_pe, target_pe, fc)

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

## Testing on a Dummy Data

In [63]:
# ==== TEST ====
# Dummy test
batch_size = 2
src_vocab_size = 10000
tgt_vocab_size = 10000
src_seq_len = 10
tgt_seq_len = 12

model = my_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len)

source = torch.randint(0, src_vocab_size, (batch_size, src_seq_len))
target = torch.randint(0, tgt_vocab_size, (batch_size, tgt_seq_len))
source_mask = target_mask = None

output = model(source, target, source_mask, target_mask)
# print("Output shape:", output.shape)
print(output)  # Expected: (batch_size, tgt_seq_len, tgt_vocab_size)

tensor([[[-9.1968, -9.2240, -9.2042,  ..., -9.1960, -9.1761, -9.2444],
         [-9.1968, -9.2240, -9.2042,  ..., -9.1960, -9.1761, -9.2444],
         [-9.1968, -9.2240, -9.2042,  ..., -9.1960, -9.1761, -9.2444],
         ...,
         [-9.1967, -9.2241, -9.2042,  ..., -9.1960, -9.1762, -9.2444],
         [-9.1966, -9.2241, -9.2042,  ..., -9.1959, -9.1762, -9.2443],
         [-9.1967, -9.2241, -9.2042,  ..., -9.1960, -9.1762, -9.2444]]],
       grad_fn=<LogSoftmaxBackward0>)
