In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
import math


In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [4]:
class PositionWiseFeedForward(nn.Module):
    """
    The MLP layer used in between multi - head attention blocks to capture and retain information between tokens. 
    """
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [5]:
class PositionalEncoding(nn.Module):
    """  
    The positional embedding of the input tokens, to hold information about where the token is in the input sequence. Added with the embedding vector of the token.
    """
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)   # A tensor filled with zeros, which will be populated with positional encodings.
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)  # A tensor containing the position indices for each position in the sequence.
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))  # A term used to scale the position indices in a specific way.
        
        # The sine function is applied to the even indices and the cosine function to the odd indices of pe.
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Added as a buffer so it is added to the model's state but not a trainable parameter
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [6]:
class EncoderLayer(nn.Module):
    """ 
    The encoding layer of the transformer: takes the input sequence of tokens and extracts the meaning and context to output a contextualized matrix of embedded tokens
    """
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)     # Multi head attention blocks, with d_model multi-head blocks and num_heads self attention mechanisms in each multi head block
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)  # MLP layer applied after each multi head block
        self.norm1 = nn.LayerNorm(d_model)  # Normalization for the outputs of the MLP and Multi head blocks
        self.norm2 = nn.LayerNorm(d_model)  
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [7]:
class DecoderLayer(nn.Module):
    """ 
    
    """
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        """ 
        x: The input to the decoder layer.
        enc_output: The output from the corresponding encoder (used in the cross-attention step).
        src_mask: Source mask to ignore certain parts of the encoder's output.
        tgt_mask: Target mask to ignore certain parts of the decoder's inp  

        Steps:

        Self-Attention on Target Sequence: The input x is processed through a self-attention mechanism.
        Add & Normalize (after Self-Attention): The output from self-attention is added to the original x, followed by dropout and normalization using norm1.
        Cross-Attention with Encoder Output: The normalized output from the previous step is processed through a cross-attention mechanism that attends to the encoder's output enc_output.
        Add & Normalize (after Cross-Attention): The output from cross-attention is added to the input of this stage, followed by dropout and normalization using norm2.
        Feed-Forward Network: The output from the previous step is passed through the feed-forward network.
        Add & Normalize (after Feed-Forward): The feed-forward output is added to the input of this stage, followed by dropout and normalization using norm3.
        Output: The processed tensor is returned as the output of the decoder layer.
        """
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [8]:
class Transformer(nn.Module):

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout, num_classes=None, classification=False):

        """ 
        src_vocab_size: Source vocabulary size.
        tgt_vocab_size: Target vocabulary size.
        d_model: The dimensionality of the model's embeddings.
        num_heads: Number of attention heads in the multi-head attention mechanism.
        num_layers: Number of layers for both the encoder and the decoder.
        d_ff: Dimensionality of the inner layer in the feed-forward network.
        max_seq_length: Maximum sequence length for positional encoding.
        dropout: Dropout rate for regularization.

        num_classes and classification: parameters that allow the transformer archietcure to adapt to either classifciation NLP or generation tasks (hopefully easily)
        """


        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)  # Embedding layer for the source sequence.
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)  # Embedding layer for the target sequence.
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length) # Positional encoding component.

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])   # A list of encoder layers.
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])   # A list of decoder layers.
        self.classification = classification
        if not classification:
            out_dim = tgt_vocab_size
        else:
            out_dim = num_classes
        self.fc = nn.Linear(d_model, out_dim)    # Final fully connected (linear) layer mapping to target vocabulary size.
        self.dropout = nn.Dropout(dropout)  # Dropout layer

    def generate_mask(self, src, tgt=None):
        """ 
        This method is used to create masks for the source and target sequences, ensuring that padding tokens are ignored and that future tokens are not visible during training for the target sequence.
        """
        # If we don't want a tgt mask, 
        if tgt is not None:
            tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
            seq_length = tgt.size(1)
            nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
            tgt_mask = tgt_mask & nopeak_mask
        else:
            tgt_mask = None

        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        
        return src_mask, tgt_mask

    def forward(self, src, tgt=None):
        """ 
        Final output is the decoded tensor representing the models prediction for the next token in the sequence. 
        """

        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
        if self.classification:
            x = enc_output[:, 0, :]
            return self.fc(x)
        
       

        # If used for generation, include the decoder layer
        if tgt:
            tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
            dec_output = tgt_embedded
            for dec_layer in self.decoder_layers:
                dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

            output = self.fc(dec_output)
        
        return output