In [1]:
import math 
import copy 
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [21]:
class MultiheadAttention(nn.Module): 
    def __init__(self, num_heads, d_model): 
        super(MultiheadAttention, self).__init__()
        assert d_model % num_heads == 0 # d_model must be divisible by the number of heads so the dimension can be split evenly
        
        self.num_heads = num_heads
        self.d_model = d_model # dimension of the model
        self.d_k = d_model // num_heads # dimension of the key, query, and value. Each head will operate on a d_k dimensional space but the total dimension of the model is d_model
        
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None): 
        # Matrix multiplication of Q and K which is then scaled by the square root of the dimensionality of the key vectors (d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None: 
            scores = scores.masked_fill(mask == 0, -1e9) # Set the scores to -inf where the mask is 0 so that the softmax will ignore these values
            
        # Compute the attention weights
        attention = torch.softmax(scores, dim=-1)
        
        # Apply the attention weights to the value vectors to get the weighted sum
        return torch.matmul(attention, V)
    
    def split_heads(self, x): 
        
        batch_size, seq_len, d_model = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # Split the last dimension into num_heads and transpose the dimensions to facilitate the parallel computation of the attention heads
        return x
    
    def combine_heads(self, x): 
        # After the attention heads have been computed, the results are concatenated and passed through a linear layer to get the final output
        batch_size, num_heads, seq_len, d_k = x.size()
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return x
    
    def forward(self, Q, K, V, mask=None):
        # Foward pass of the multihead attention layer, applies linear transformations to the input Q, K, and V to get the query, key, and value vectors
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        # Compute the attention heads
        x = self.scaled_dot_product_attention(Q, K, V, mask)
        # Results are combined and passed through a linear layer to get the final output
        x = self.combine_heads(x)
        return self.W_o(x)
        

In [4]:
class PositionWiseFeedForward(nn.Module): 
    def __init__(self, d_model, d_ff):
        # Dimension of the model and the dimension of the feed forward layer 
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2= nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [15]:
class PositionalEncoding(nn.Module): 
    def __init__(self, d_model, max_len): 
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(size=(max_len, d_model)) # Initialize the position encoding matrix
    
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Div term is used to scale the position encoding values, it is the denominator of the exponent in the sin and cos functions below. 
        # This uses a geometric progression to ensure that the model can generalize to sequences of different lengths
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 
        
        # This allows the model to learn to attend to the position of the tokens in the sequence
        pe[:, 0::2] = torch.sin(position * div_term) # Compute the sine values for the even indices
        pe[:, 1::2] = torch.cos(position * div_term) # Compute the cosine values for the odd indices
        
        # Register the position encoding as a buffer so that it is saved as part of the model state
        self.register_buffer('pe', pe.unsqueeze(0))
        
    
    def forward(self, x): 
        return x + self.pe[:, :x.size(1)]
        
        

In [31]:
# Bringing both together to create the encoder layer

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiheadAttention(num_heads, d_model)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        # Layer normalization is applied to the output of the multihead attention and feed forward layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        # Apply the multihead attention layer and add the residual connection
        x = self.norm1(x + self.dropout(self.self_attention(x, x, x, mask)))
        # Apply the feed forward layer and add the residual connection
        x = self.norm2(x + self.dropout(self.feed_forward(x)))
        return x

In [5]:
# Build Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiheadAttention(num_heads, d_model)
        self.cross_attention = MultiheadAttention(num_heads, d_model)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoding_output, src_mask, tgt_mask):
        # Both masks a later generated in the transformer class
        # tgt_mask is used to prevent the decoder from attending to future tokens maintaining the autoregressive property
        attention_output = self.self_attention(x, x, x, tgt_mask) 
        x = self.norm1(x + self.dropout(attention_output))
        # src_mask is used to prevent the decoder from attending to padding tokens
        attention_output = self.cross_attention(x, encoding_output, encoding_output, src_mask) 
        x = self.norm2(x + self.dropout(attention_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x 

In [26]:
# Build the Transformer
class Transformer(nn.Module): 
    def __init__(self, num_encoder_layers, num_decoder_layers, d_model, num_heads, d_ff, src_vocab_size, tgt_vocab_size, max_len, dropout=0.1):
        super(Transformer, self).__init__()
        
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model) # Embedding layer for the encoder with the source vocabulary size and the dimension of the model
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model) # Embedding layer for the decoder with the target vocabulary size and the dimension of the model
        self.positional_encoding = PositionalEncoding(d_model, max_len) # Positional encoding layer
        
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_encoder_layers)]) # List of encoder layers
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_decoder_layers)]) # List of decoder layers
        
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def generate_mask(self, source, target): 
        # Generate the source and target masks
        source_mask = (source != 0).unsqueeze(1).unsqueeze(2) # Create a mask to prevent the encoder from attending to padding tokens + unsqueeze to add the head dimension
        target_mask = (target != 0).unsqueeze(1).unsqueeze(2) # Create mask to allow the decoder to only attend to tokens that have been generated so far + unsqueeze to add the head dimension
        
        seq_len = target.size(1)
        # Create a mask to prevent the decoder from attending to future tokens to maintain the autoregressive property
        nopeak_mask = torch.triu(torch.ones(1, seq_len, seq_len), diagonal=1).bool()
        target_mask = target_mask & nopeak_mask
        
        return source_mask, target_mask
    
    def forward(self, source, target):
    
        source_mask, target_mask = self.generate_mask(source, target)
        
        source_embedding = self.dropout(self.positional_encoding(self.encoder_embedding(source)))
        target_embedding = self.dropout(self.positional_encoding(self.decoder_embedding(target)))
        
        encoding_output = source_embedding
        for enc_layer in self.encoder_layers: 
            encoding_output = enc_layer(encoding_output, source_mask)
            
        decoding_output = target_embedding
        for dec_layer in self.decoder_layers: 
            decoding_output = dec_layer(decoding_output, encoding_output, source_mask, target_mask)
            
        return self.fc(decoding_output)

In [32]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_encoder_layers = 6
num_decoder_layers = 6

d_ff = 2048
max_len = 100
dropout = 0.1

transformer = Transformer(src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, d_model=d_model, num_heads=num_heads, 
                          num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, d_ff=d_ff, max_len=max_len, dropout=dropout)

src_data = torch.randint(1, src_vocab_size, (64, max_len))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_len))  # (batch_size, seq_length)


In [35]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch [{epoch +1 }], Loss: {loss.item()}")

Epoch [1], Loss: 8.069128036499023
Epoch [2], Loss: 8.157293319702148
Epoch [3], Loss: 8.14328670501709
Epoch [4], Loss: 8.11997127532959
Epoch [5], Loss: 8.09232234954834
Epoch [6], Loss: 8.062867164611816
Epoch [7], Loss: 7.974868297576904
Epoch [8], Loss: 7.8291120529174805
Epoch [9], Loss: 7.776881217956543
Epoch [10], Loss: 7.848037242889404
Epoch [11], Loss: 7.6584553718566895
Epoch [12], Loss: 7.199096202850342
Epoch [13], Loss: 6.92490816116333
Epoch [14], Loss: 7.036306858062744
Epoch [15], Loss: 6.701191425323486
Epoch [16], Loss: 6.306272983551025
Epoch [17], Loss: 6.134025573730469
Epoch [18], Loss: 5.979645729064941
Epoch [19], Loss: 5.762683391571045
Epoch [20], Loss: 5.278907775878906
Epoch [21], Loss: 5.5385260581970215
Epoch [22], Loss: 5.178043842315674
Epoch [23], Loss: 5.00003719329834
Epoch [24], Loss: 5.025485515594482
Epoch [25], Loss: 4.30435037612915
Epoch [26], Loss: 3.8983993530273438
Epoch [27], Loss: 3.771472692489624
Epoch [28], Loss: 4.011817932128906
Epo