Source:
https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb

# Algorith Summery

## Input Processing:

1. Source and target sequences are converted to embeddings
2. Positional encodings are added to provide position information
3. Masks are generated to handle padding and prevent looking at future tokens


## Encoder Operation:

Input goes through multiple identical encoder layers. Each encoder layer has two sub-layers:
1. Multi-head self-attention
2. Position-wise feed-forward network
3. Layer normalization and residual connections are used after each sub-layer


## Multi-Head Attention Mechanism:

Divides input into multiple heads (to process in parallel). For each head:
1. Transform input into Query (Q), Key (K), and Value (V) matrices
2. Compute attention scores: (Q × K^T) / √d_k
3. Apply softmax to get attention weights
4. Multiply weights with Values (V)
5. Combine heads and apply final transformation


## Decoder Operation:

Output sequence goes through multiple identical decoder layers.
Each decoder layer has three sub-layers:
1. Masked multi-head self-attention (prevents attending to future tokens)
2. Multi-head cross-attention (attends to encoder output)
3. Position-wise feed-forward network
4. Layer normalization and residual connections after each sub-layer


## Final Output Generation:

The decoder output is projected to the target vocabulary size
Can be used to predict the next token in the sequence

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

# Multi-Head Attention
<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/0*--TCGWYxwASbv2ra.png" width="500"/>

The Multi-Head Attention mechanism computes the attention between each pair of positions in a sequence. It consists of multiple “attention heads” that capture different aspects of the input sequence.

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        Initialize multi-head attention module
        Args:
            d_model: dimension of model (embedding dimension)
            num_heads: number of attention heads
        """
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model  # Model's dimension
        self.num_heads = num_heads  # Number of parallel attention heads
        self.d_k = d_model // num_heads  # Dimension of each head's key/query/value
        
        # Linear transformations for Query, Key, Value, and Output
        self.W_q = nn.Linear(d_model, d_model)  # Query transformation
        self.W_k = nn.Linear(d_model, d_model)  # Key transformation
        self.W_v = nn.Linear(d_model, d_model)  # Value transformation
        self.W_o = nn.Linear(d_model, d_model)  # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Compute scaled dot-product attention
        Args:
            Q: Query matrix
            K: Key matrix
            V: Value matrix
            mask: Optional mask to prevent attention to certain positions
        """
        # Compute attention scores (Q × K^T)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))
        # Scale scores by sqrt(d_k) to prevent extremely small gradients
        attn_scores = attn_scores / math.sqrt(self.d_k)
        
        # Apply mask if provided (for padding or future tokens)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Compute weighted sum of values
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        """Combine the split heads back together"""
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        """
        Forward pass of multi-head attention
        Args:
            Q: Query input
            K: Key input
            V: Value input
            mask: Optional attention mask
        """
        # Transform and split heads for Query, Key, and Value
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Compute attention output
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output


# Position-wise Feed-Forward Networks
This network enables the model to consider the position of input elements while making predictions.

In [3]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        """
        Initialize feed-forward network
        Args:
            d_model: Input/output dimension
            d_ff: Hidden layer dimension
        """
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)  # First linear transformation
        self.fc2 = nn.Linear(d_ff, d_model)  # Second linear transformation
        self.relu = nn.ReLU()  # ReLU activation

    def forward(self, x):
        """Apply two linear transformations with ReLU activation"""
        return self.fc2(self.relu(self.fc1(x)))


# Positional Encoding
Positional Encoding is used to inject the position information of each token in the input sequence.
It uses sine and cosine functions of different frequencies to generate the positional encoding.
The forward method computes the positional encoding by adding the stored positional encoding values to the input tensor.

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        """
        Initialize positional encoding
        Args:
            d_model: Embedding dimension
            max_seq_length: Maximum sequence length
        """
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (won't be updated during training)
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        """Add positional encoding to input embeddings"""
        return x + self.pe[:, :x.size(1)]


# Encoder Layer
<img src="https://miro.medium.com/v2/resize:fit:552/format:webp/0*bPKV4ekQr9ZjYkWJ.png" hight=500>

An Encoder layer consists of a Multi-Head Attention layer, a Position-wise Feed-Forward layer, and two Layer Normalization layers.

The forward methods computes the encoder layer output by applying self-attention, adding the attention output to the input tensor, and normalizing the result.
Then, it computes the position-wise feed-forward output, combines it with the normalized self-attention output, and normalizes the final result before returning the processed tensor.

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        """
        Initialize encoder layer
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            d_ff: Feed-forward network dimension
            dropout: Dropout rate
        """
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)  # First layer normalization
        self.norm2 = nn.LayerNorm(d_model)  # Second layer normalization
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        """
        Forward pass of encoder layer
        Args:
            x: Input tensor
            mask: Attention mask
        """
        # Self attention block
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))  # Add & Norm
        
        # Feed forward block
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))  # Add & Norm
        return x

# Decoder Layer
<img src="https://miro.medium.com/v2/resize:fit:552/format:webp/0*SPZgT4k8GQi37H__.png" hight=500>

A Decoder layer consists of two Multi-Head Attention layers, a Position-wise Feed-Forward layer, and three Layer Normalization layers.

The forward method computes the decoder layer output by performing the following steps:

1. Calculate the masked self-attention output and add it to the input tensor, followed by dropout and layer normalization.
2. Compute the cross-attention output between the decoder and encoder outputs, and add it to the normalized masked self-attention output, followed by dropout and layer normalization.
3. Calculate the position-wise feed-forward output and combine it with the normalized cross-attention output, followed by dropout and layer normalization.
4. Return the processed tensor.

These operations enable the decoder to generate target sequences based on the input and the encoder output.

In [6]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        """
        Initialize decoder layer
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            d_ff: Feed-forward network dimension
            dropout: Dropout rate
        """
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        """
        Forward pass of decoder layer
        Args:
            x: Input tensor
            enc_output: Encoder output
            src_mask: Source sequence mask
            tgt_mask: Target sequence mask
        """
        # Self attention block
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross attention block
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed forward block
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

# Transformer Model (Encoder + Decoder)
<img src="https://miro.medium.com/v2/resize:fit:720/format:webp/0*ljYs7oOlKC71SzSr.png" width=500>

The generate_mask method creates binary masks for source and target sequences to ignore padding tokens and prevent the decoder from attending to future tokens. The forward method computes the Transformer model’s output through the following steps:

1. Generate source and target masks using the generate_mask method.
2. Compute source and target embeddings, and apply positional encoding and dropout.
3. Process the source sequence through encoder layers, updating the enc_output tensor.
4. Process the target sequence through decoder layers, using enc_output and masks, and updating the dec_output tensor.
5. Apply the linear projection layer to the decoder output, obtaining output logits.


These steps enable the Transformer model to process input sequences and generate output sequences based on the combined functionality of its components.


## How is a full Transformer (encoder + decoder) different than a decoder-only architecture?

The main differences are:

1. Information Access Pattern
- Full Transformer: The encoder processes the entire input sequence in parallel, creating a rich contextual representation that the decoder can access at every step. The decoder can attend to any encoded input state at any time through cross-attention.
- Decoder-only: Only has access to previous tokens in the sequence (through causal/masked self-attention). It must process information sequentially and can't look at future tokens.

1. Architecture Components
- Full Transformer:
  - Encoder: Has self-attention layers that can see all input tokens
  - Decoder: Has both masked self-attention (for target sequence) and cross-attention (to access encoder representations)
- Decoder-only: Only has masked self-attention layers

1. Typical Use Cases
- Full Transformer: Best for tasks requiring complex input understanding and transformation
  - Machine translation (input: source language, output: target language)
  - Summarization (input: long text, output: summary)
- Decoder-only: Better for generative tasks and completing sequences
  - Language modeling
  - Text generation
  - Code completion

1. Memory & Computation
- Full Transformer: 
  - Processes input once through encoder
  - Encoder representations are cached and reused by decoder
  - Generally more computationally intensive
- Decoder-only:
  - Simpler architecture but needs to encode all information in the same sequence
  - Can be more memory-efficient but might need longer sequences

Here's a simplified visualization using LaTeX:

For Full Transformer:
$$
\text{Input} \xrightarrow{\text{Encoder}} \text{Hidden States} \xrightarrow{\text{Cross-Attention}} \text{Decoder} \xrightarrow{} \text{Output}
$$

For Decoder-only:
$$
\text{Input} \xrightarrow{\text{Masked Self-Attention}} \text{Output}
$$


In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        """
        Initialize transformer model
        Args:
            src_vocab_size: Source vocabulary size
            tgt_vocab_size: Target vocabulary size
            d_model: Model dimension
            num_heads: Number of attention heads
            num_layers: Number of encoder/decoder layers
            d_ff: Feed-forward network dimension
            max_seq_length: Maximum sequence length
            dropout: Dropout rate
        """
        super(Transformer, self).__init__()
        
        # Embedding layers
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        # Create encoder and decoder layers
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        # Output projection
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        """
        Generate masks for source and target sequences
        Args:
            src: Source sequence
            tgt: Target sequence
        """
        # Create mask for source sequence (1 for tokens, 0 for padding)
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        
        # Create mask for target sequence
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        
        # Create mask to prevent attention to future tokens
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        """
        Forward pass of transformer
        Args:
            src: Source sequence
            tgt: Target sequence
        """
        # Generate masks
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        
        # Embed and apply positional encoding to source sequence
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        
        # Embed and apply positional encoding to target sequence
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        # Pass through encoder layers
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        # Pass through decoder layers
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        # Project to vocabulary size
        output = self.fc(dec_output)
        return output

# Testing

In [8]:
# Sample data

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [9]:
# Training

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 8.685790061950684
Epoch: 2, Loss: 8.556265830993652
Epoch: 3, Loss: 8.483915328979492
Epoch: 4, Loss: 8.432638168334961
Epoch: 5, Loss: 8.376571655273438
Epoch: 6, Loss: 8.31143856048584
Epoch: 7, Loss: 8.231369018554688
Epoch: 8, Loss: 8.148785591125488
Epoch: 9, Loss: 8.068408012390137
Epoch: 10, Loss: 7.988169193267822
Epoch: 11, Loss: 7.907508850097656
Epoch: 12, Loss: 7.820796966552734
Epoch: 13, Loss: 7.740450382232666
Epoch: 14, Loss: 7.650052547454834
Epoch: 15, Loss: 7.574429035186768
Epoch: 16, Loss: 7.4861321449279785
Epoch: 17, Loss: 7.406088352203369
Epoch: 18, Loss: 7.326938629150391
Epoch: 19, Loss: 7.251533508300781
Epoch: 20, Loss: 7.166291236877441
Epoch: 21, Loss: 7.081549644470215
Epoch: 22, Loss: 6.999995231628418
Epoch: 23, Loss: 6.930159091949463
Epoch: 24, Loss: 6.8478007316589355
Epoch: 25, Loss: 6.7815070152282715
Epoch: 26, Loss: 6.703404426574707
Epoch: 27, Loss: 6.626461029052734
Epoch: 28, Loss: 6.5552897453308105
Epoch: 29, Loss: 6.4825544