In [53]:
import torch
import torch.nn as nn
import math

### Scaled Dot‑Product Attention

In [54]:
class Scaled_Dot_Product(nn.Module):
    def __init__(self) :
        super().__init__()
        
    def forward(self , Q, K, V, d_k, mask = None) :
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None and mask.dim() == 2:
            mask = mask.unsqueeze(1).unsqueeze(2)

        if mask is not None :
            scores = scores.masked_fill(mask == 0, float('-inf')) 
        
        attn_weights  = torch.softmax(scores, -1)
        finale_output = torch.matmul(attn_weights, V)
        
        return finale_output, attn_weights

### Multi‑Head Attention

In [55]:
class Multi_head(nn.Module) :
    def __init__(self, num_heads, d_model):
        super().__init__()

        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        self.spd = Scaled_Dot_Product()
    
    def forward(self, Q, K, V, mask) :
        seq_len = Q.size(1)

        Q_proj = self.W_Q(Q)
        K_proj = self.W_K(K)
        V_proj = self.W_V(V)

        batch_size = Q.size(0)

        Q_proj = Q_proj.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K_proj = K_proj.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V_proj = V_proj.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        attn_output, _ = self.spd(Q_proj, K_proj, V_proj, self.d_k, mask)
        
        output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_O(output)

### Feed-Forward Network

In [56]:
class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.ff(x)


### ADD & Norm

In [57]:
class AddNorm(nn.Module) :
    def __init__(self, d_model):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.2)

    def forward(self, raw_input, x) :
        return self.norm(raw_input + self.dropout(x))

### Positional Encoding

In [58]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


### Encoder

In [59]:
class Encoder(nn.Module) :
    def __init__(self, num_heads, d_model, d_ff, dropout, max_len):
        super().__init__()
        self.multihead = Multi_head(num_heads, d_model)
        self.adn1 = AddNorm(d_model)
        self.adn2 = AddNorm(d_model)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.pe = PositionalEncoding(d_model, max_len)

    def forward(self, x, mask = None) :
        positional_encoder = self.pe(x)
        x = self.adn1(positional_encoder, self.multihead(positional_encoder, positional_encoder, positional_encoder, mask))
        x = self.adn2(x, self.ffn(x))
        return x

### Decoder

In [60]:
class Decoder(nn.Module) :
    def __init__(self, num_heads, d_model, d_ff, dropout, max_len):
        super().__init__()
        self.mask_multi_head = Multi_head(num_heads, d_model)
        self.multi_head = Multi_head(num_heads, d_model)
        self.adn1 = AddNorm(d_model)
        self.adn2 = AddNorm(d_model)
        self.adn3 = AddNorm(d_model)
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.pe = PositionalEncoding(d_model, max_len)

    def forward(self, x, encoder_output, mask = None) :
        positional_encoder = self.pe(x)
        x = self.adn1(positional_encoder, self.mask_multi_head(positional_encoder, positional_encoder, positional_encoder, mask))
        x = self.adn2(x, self.multi_head(encoder_output, encoder_output, x, None))
        x = self.adn3(x, self.ffn(x))
        return x

### Transformer

In [61]:
class Transformer(nn.Module):
    def __init__(self, embedding_matrix, num_heads, d_ff, dropout, max_len):
        super().__init__()
        vocab_size, d_model = embedding_matrix.size()
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
        self.pos_encoding = PositionalEncoding(d_model, max_len)

        self.encoder = Encoder(num_heads, d_model, d_ff, dropout, max_len)
        self.decoder = Decoder(num_heads, d_model, d_ff, dropout, max_len)

        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src_input_ids, tgt_input_ids, src_mask=None, tgt_mask=None):
        if tgt_mask is None:
            def generate_square_subsequent_mask(seq_len):
                return torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(1).to(torch.bool)
            tgt_mask = generate_square_subsequent_mask(tgt_input_ids.size(1)).to(tgt_input_ids.device)

        src_emb = self.embedding(src_input_ids)
        src_emb = self.pos_encoding(src_emb)
        enc_output = self.encoder(src_emb, src_mask)

        tgt_emb = self.embedding(tgt_input_ids)
        tgt_emb = self.pos_encoding(tgt_emb)
        dec_output = self.decoder(tgt_emb, enc_output, tgt_mask)

        logits = self.output_layer(dec_output)
        return logits
