In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [4]:
class AddNorm(nn.Module):
    def __init__(self, d_model , dropout_rate=0.1):
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, sublayer_output  ):
        dropped_sublayer_output = self.dropout(sublayer_output)
        residual_output = x + dropped_sublayer_output
        normalized_output = self.norm(residual_output)
        return normalized_output
    






In [5]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model , d_ff, dropout_rate=0.1):
        super(PositionwiseFeedForward , self).__init__()
        self.w_1 = nn.Linear(d_model , d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout_rate)


    def forward(self, x):
        x = self.w_1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.w_2(x)
        return x
    

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads , dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_k = d_model // num_heads
        self.num_heads = self.num_heads
        self.d_model = self.d_model

        self.W_q = nn.Linear(d_model , d_model)
        self.W_k = nn.Linear(d_model , d_model)
        self.W_v = nn.Linear(d_model , d_model)

        self.W_o = nn.Linear(d_model , d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def scaled_dot_product_attention(self, Q, K , V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
                                        #0 masked will 
    
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask==0 , float('-inf'))

        attn_weights = F.softmax(attn_scores,  dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights , V)
        return attn_output , attn_weights
    
    def forward(self, query_input ,key_input , value_input , mask=None):
        batch_size = query_input.size(0)

        Q = self.W_q(query_input)
        K = self.W_k(key_input)
        V = self.W_v(value_input)


        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)   


        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)

        return output, attn_weights

In [7]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model , num_heads, d_ff, dropout_rate=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        self.addnorm_1 = AddNorm(d_model, dropout_rate)

        #feed forward
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout_rate)
        self.addnorm_2 = AddNorm(d_model , dropout_rate)

    def forward(self, x, src_mask):
        # multihead self attentioin sublayer
        self_attn_output , self_attn_weights = self.self_attention(x,x,x,mask = src_mask)
        x = self.addnorm_1(x,self_attn_output)

        #position wise feed forward network sublayer
        ffn_output = self.ffn(x)
        x = self.addnorm_2(x , ffn_output)
        return x, self_attn_weights
    
    

In [None]:
class Encoder(nn.Module):

    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout_rate=0.1):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ])

    def forward(self, x, src_mask):

        all_self_attn_weights = []
        for i in range(self.num_layers):
            x, self_attn_weights = self.layers[i](x, src_mask)
            all_self_attn_weights.append(self_attn_weights)
        return x, all_self_attn_weights


In [None]:
class DecoderLayer(nn.Module):

    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.masked_self_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        self.addnorm_1 = AddNorm(d_model, dropout_rate)

        self.encoder_decoder_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        self.addnorm_2 = AddNorm(d_model, dropout_rate)

        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout_rate)
        self.addnorm_3 = AddNorm(d_model, dropout_rate)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        self_attn_output, self_attn_weights = self.masked_self_attention(x, x, x, mask=tgt_mask)
        x = self.addnorm_1(x, self_attn_output)

        enc_dec_attn_output, enc_dec_attn_weights = self.encoder_decoder_attention(
            query_input=x,
            key_input=encoder_output,
            value_input=encoder_output,
            mask=src_mask
        )
        x = self.addnorm_2(x, enc_dec_attn_output)

        ffn_output = self.ffn(x)
        x = self.addnorm_3(x, ffn_output)

        return x, self_attn_weights, enc_dec_attn_weights

In [None]:
class Decoder(nn.Module):

    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout_rate=0.1):
        super(Decoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ])

    def forward(self, x, encoder_output, src_mask, tgt_mask):

        all_self_attn_weights = []
        all_enc_dec_attn_weights = []
        for i in range(self.num_layers):
            x, self_attn_weights, enc_dec_attn_weights = self.layers[i](x, encoder_output, src_mask, tgt_mask)
            all_self_attn_weights.append(self_attn_weights)
            all_enc_dec_attn_weights.append(enc_dec_attn_weights)
        return x, all_self_attn_weights, all_enc_dec_attn_weights

In [None]:
class Transformer(nn.Module):

    def __init__(self, num_layers, d_model, num_heads, d_ff,
                 input_vocab_size, target_vocab_size, max_seq_len, dropout_rate=0.1):
       
        super(Transformer, self).__init__()


        self.src_embedding = nn.Embedding(input_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(target_vocab_size, d_model)


        self.positional_encoding = PositionalEncoding(d_model, max_len=max_seq_len)

        # Encoder and Decoder Stacks
        self.encoder = Encoder(num_layers, d_model, num_heads, d_ff, dropout_rate)
        self.decoder = Decoder(num_layers, d_model, num_heads, d_ff, dropout_rate)

        # Final Linear Layer to project decoder output to target vocabulary size
        self.final_linear = nn.Linear(d_model, target_vocab_size)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src, tgt, src_mask, tgt_mask):
        
        # 1. Input Embeddings + Positional Encoding + Dropout
        src_embedded = self.dropout(self.positional_encoding(self.src_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.tgt_embedding(tgt)))

        # 2. Encoder Forward Pass
        encoder_output, enc_self_attn_weights = self.encoder(src_embedded, src_mask)

        # 3. Decoder Forward Pass
        decoder_output, dec_self_attn_weights, enc_dec_attn_weights = self.decoder(
            tgt_embedded, encoder_output, src_mask, tgt_mask
        )

        # 4. Final Linear Layer
       
        final_output = self.final_linear(decoder_output)

        return final_output, (enc_self_attn_weights, dec_self_attn_weights, enc_dec_attn_weights)



def create_padding_mask(seq, pad_idx):

    mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
    return mask

def create_look_ahead_mask(size):

    mask = torch.triu(torch.ones(size, size), diagonal=1).type(torch.bool)
    return mask.unsqueeze(0).unsqueeze(0) 

Testing the transformer

In [None]:
print("Testing Full Transformer Model...")


num_layers = 2
d_model = 256 
num_heads = 4
d_ff = 1024 
input_vocab_size = 1000 
target_vocab_size = 800 
max_seq_len = 200 
dropout_rate = 0.1
PAD_IDX = 0 


batch_size = 2 
src_seq_len_test = 50
tgt_seq_len_test = 40


dummy_src = torch.randint(1, input_vocab_size, (batch_size, src_seq_len_test))
dummy_tgt = torch.randint(1, target_vocab_size, (batch_size, tgt_seq_len_test))


dummy_src[0, 45:] = PAD_IDX
dummy_tgt[1, 30:] = PAD_IDX

# Create masks
src_padding_mask = create_padding_mask(dummy_src, PAD_IDX) 
tgt_padding_mask = create_padding_mask(dummy_tgt, PAD_IDX) 
look_ahead_mask = create_look_ahead_mask(tgt_seq_len_test) 



tgt_mask = tgt_padding_mask * (look_ahead_mask == 0)



transformer_model = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    input_vocab_size=input_vocab_size,
    target_vocab_size=target_vocab_size,
    max_seq_len=max_seq_len,
    dropout_rate=dropout_rate
)

# Performing forward pass
output_logits, (enc_self_attn, dec_self_attn, enc_dec_attn) = transformer_model(
    dummy_src, dummy_tgt, src_padding_mask, tgt_mask
)

# Printing shapes to verify
print(f"\nSource input shape: {dummy_src.shape}")
print(f"Target input shape: {dummy_tgt.shape}")
print(f"Output logits shape: {output_logits.shape}")


assert output_logits.shape == (batch_size, tgt_seq_len_test, target_vocab_size)

print(f"\nEncoder Self-Attention weights (first layer) shape: {enc_self_attn[0].shape}")
print(f"Decoder Self-Attention weights (first layer) shape: {dec_self_attn[0].shape}")
print(f"Encoder-Decoder Attention weights (first layer) shape: {enc_dec_attn[0].shape}")


assert enc_self_attn[0].shape == (batch_size, num_heads, src_seq_len_test, src_seq_len_test)
assert dec_self_attn[0].shape == (batch_size, num_heads, tgt_seq_len_test, tgt_seq_len_test)
assert enc_dec_attn[0].shape == (batch_size, num_heads, tgt_seq_len_test, src_seq_len_test)

print("\nFull Transformer Model test passed successfully!")