# Understanding Transformers using PyTorch
https://www.geeksforgeeks.org/deep-learning/transformer-using-pytorch/

<img src="https://media.geeksforgeeks.org/wp-content/uploads/20250325174552667398/transformer.png">

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

## MultiHeadAttention

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model # embedding size (e.g. 512)
        self.num_heads = num_heads # number of attention heads (e.g. 8)

        # Ensuring each head gets equal dimensions
        self.d_k = d_model // num_heads # (e.g. 512 // 8 = 64)
        print(f"Number of dimensions each head gets: {self.d_k}")

        # Creating Q(Query), K(Key), V(Value) from the input embeddings.
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output Layer: This combines all attention heads back into one vector.
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        '''
            This is the core attention computation.
            Compute attention scores by taking the dot product of Q and K, scaling the result and applying softmax to normalise.
            - Measures similarity between Q and K
            - Division by √d_k prevents extremely large values → stabilizes training
            - Mask used for:
                - Padding mask
                - Causal (future-token) masking
            - Apply softmax = Converts scores into probabilities.
            - Softmax example: 
                tensor([-0.8058, -0.9375,  1.2299,  0.2358, -1.0952,  0.0997,  0.8335,  2.3506, -0.3834,  0.1132]) ----> 
                tensor([0.0207, 0.0182, 0.1587, 0.0587, 0.0155, 0.0512, 0.1067, 0.4867, 0.0316, 0.0519])
            - Multiply attn_probs with V
        '''
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # (batch, heads, seq_len, seq_len) seq_len = -1 and -2

        # Applying mask if needed,
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Converts scores into probabilities.
        attn = torch.softmax(scores, dim=-1)

        # Multiply attn_probs with V
        output = torch.matmul(attn, V)
        return output

    def split_heads(self, x):
        '''
            - Input Shape: (batch_size, seq_length, d_model)
            - Transform to: (batch_size, num_heads, seq_length, d_k)
            - ✔ Allows parallel attention across heads.
        '''
        batch_size, seq_length, d_model = x.size() # example: (32, 512, 512)
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # example: (32, 8, 512, 64)
        
    def combine_heads(self, x):
        '''
            - Input shape: (batch, heads, seq_len, d_k)
            - Output shape: (batch, seq_len, d_model)
            - ✔ Merges all heads into a single vector.
        '''
        batch_size, _, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)   # Returns a contiguous in-memory tensor containing the same data as self tensor. If self tensor is already in the specified memory format, this function returns the self tensor.

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask=None)

        output = self.W_o(self.combine_heads(attn_output))
        return output        

### Shape Summary

| Step          | Shape                          |
| ------------- | ------------------------------ |
| Input         | `(batch, seq_len, d_model)`    |
| Split heads   | `(batch, heads, seq_len, d_k)` |
| Attention     | `(batch, heads, seq_len, d_k)` |
| Combine heads | `(batch, seq_len, d_model)`    |
| Final output  | `(batch, seq_len, d_model)`    |

### Why Multi-Head Attention is Powerful

- ✔ Captures multiple relationships
- ✔ Works in parallel (faster than RNNs)
- ✔ Handles long-range dependencies
- ✔ Foundation of Transformers

## Position-Wise Feed Forward

In a Transformer, after attention mixes information across tokens, the FFN:
- Processes each token independently
- Adds non-linearity
- Expands and compresses feature space

Think of it as a small neural network applied to every token separately.

In [3]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        '''
            d_model: embedding size
            d_ff: hidden size of FFN
        '''
        super(PositionWiseFeedForward, self).__init__()

        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

        self.relu = nn.ReLU() # introduces non-linearity

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

### Positional Encoding
This block defines the Positional Encoding class, which adds positional information to the token embeddings, allowing the model to retain information about word positions in the input sequence.

This class implements Sinusoidal Positional Encoding, which is essential for Transformers because attention alone has no sense of order.

Transformers:
- Do not use RNNs or CNNs
- Process all tokens in parallel
- Have no inherent notion of sequence order

So we inject position information into embeddings.

This class creates fixed (non-learned) sinusoidal positional encodings.

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        # Create Empty Positional Encoding Matrix
        pe = torch.zeros(max_seq_length, d_model)

        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        '''
            creates:
                [[0],
                 [1],
                 [2],
                 ...
                 [max_seq_length-1]]
            shape:
                (max_seq_length, 1)
        '''

        # This generates different wavelengths for different dimensions.
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )

        # Apply Sine to Even Indices # Encodes position smoothly
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply Cosine to Odd Indices # Paired with sine for phase-shift encoding
        pe[:, 1::2] = torch.cos(position * div_term)

        # register_buffer so that during training it doesn't get updated 
        self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_seq_length, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

## Encoder Layer

<img src="https://media.datacamp.com/legacy/v1691083306/Figure_2_The_Encoder_part_of_the_transformer_network_Source_image_from_the_original_paper_b0e3ac40fa.png">

This block defines the Encoder Layer class, which contains the multi-head attention mechanism and the position-wise feed-forward network, with layer normalisation and dropout applied.

This class is one complete Transformer Encoder block.

It combines everything you’ve learned so far:
- Multi-Head Self-Attention
- Position-wise Feed-Forward Network
- Residual connections
- Layer Normalisation
- Dropout

#### What Is an Encoder Layer?
In a Transformer encoder, each layer does two things:
- Self-attention → tokens look at each other
- Feed-forward network → each token is refined individually

This block is stacked N times (e.g., 6 or 12 layers).

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        '''
            - d_model: embedding size (e.g. 512)
            - num_heads: attention heads (e.g. 8)
            - d_ff: FFN hidden size (e.g. 2048)
            - dropout: regularisation probability
        '''
        
        super(EncoderLayer, self).__init__()

        self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model=d_model, d_ff=d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [6]:
encoder = EncoderLayer(512, 8, 2048, 0.2)

Number of dimensions each head gets: 64


In [7]:
encoder

EncoderLayer(
  (self_attn): MultiHeadAttention(
    (W_q): Linear(in_features=512, out_features=512, bias=True)
    (W_k): Linear(in_features=512, out_features=512, bias=True)
    (W_v): Linear(in_features=512, out_features=512, bias=True)
    (W_o): Linear(in_features=512, out_features=512, bias=True)
  )
  (feed_forward): PositionWiseFeedForward(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
    (relu): ReLU()
  )
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [8]:

# Hyperparameters
batch_size = 2
seq_len = 5
d_model = 8
num_heads = 2
d_ff = 32
dropout = 0.1

# Random input
x = torch.randn(batch_size, seq_len, d_model)
x

tensor([[[ 0.2160, -1.4414,  0.9248,  0.8221, -0.7134, -0.5704,  0.9568,
           0.5216],
         [ 0.6904,  0.5042, -0.3047, -1.8446, -0.9328, -0.6943, -0.1510,
          -1.1770],
         [-0.6040,  2.2395, -0.7380,  1.1772,  0.5163,  1.3221,  1.1290,
           0.4016],
         [-1.3940, -0.1856,  1.7577, -1.0043, -0.1822,  0.2720,  0.2024,
           2.0691],
         [-1.1455, -0.2866,  0.0954, -0.3493, -1.8010,  0.1306, -1.1826,
           0.1104]],

        [[ 1.0447,  1.0268,  0.0237, -1.0587,  2.6600, -1.0581, -1.0108,
           2.4103],
         [-0.2807,  0.1621,  0.3188, -0.3377,  1.0698, -0.2189,  0.0544,
           0.4956],
         [ 1.2000,  0.4260,  0.0424,  0.9311, -0.1950, -0.9808, -0.2860,
          -0.5460],
         [ 0.1338,  0.7567,  1.0144,  0.3682, -0.3540,  0.5332, -0.1951,
           0.9310],
         [ 0.6857, -0.1839,  2.0093, -0.5572,  0.5208,  1.3920,  0.7627,
          -0.0888]]])

In [9]:
# Optional mask (1 = keep, 0 = mask)
mask = torch.ones(batch_size, 1, 1, seq_len)

# Encoder
encoder = EncoderLayer(d_model, num_heads, d_ff, dropout)

# Forward pass
output = encoder(x, mask)

# Results
print("Input shape :", x.shape)
print("Output shape:", output.shape)
print("\nSample output tensor:\n", output)

Number of dimensions each head gets: 4
Input shape : torch.Size([2, 5, 8])
Output shape: torch.Size([2, 5, 8])

Sample output tensor:
 tensor([[[ 0.1160, -1.9058,  1.2555,  0.4634, -0.0704, -1.2425,  0.8607,
           0.5231],
         [ 1.4181,  1.1662,  0.5641, -1.6535,  0.3372, -0.7753, -0.1390,
          -0.9179],
         [-0.6727,  1.8312, -1.7930,  0.4063, -0.4296,  0.6819,  0.2747,
          -0.2988],
         [-1.2074, -0.4310,  1.5404, -1.2539,  0.5126, -0.3134, -0.2561,
           1.4088],
         [-0.5458,  0.4601,  1.6484, -0.3289, -1.2459,  0.3304, -1.3742,
           1.0559]],

        [[ 0.3426,  0.1328, -0.3007, -0.6316,  2.0421, -1.2133, -1.0981,
           0.7263],
         [-0.5920, -0.0828,  0.1886, -0.5443,  2.3220, -1.1768, -0.5710,
           0.4563],
         [ 1.6055,  0.5703, -0.2807,  1.1716, -0.0110, -1.6115, -0.6693,
          -0.7750],
         [-0.4545,  1.0567,  1.6231, -0.4368, -0.9167, -0.6329, -1.2557,
           1.0167],
         [ 0.5090, -0.9316

## Decoder Layer 

<img src="https://media.datacamp.com/legacy/v1691083444/Figure_3_The_Decoder_part_of_the_Transformer_network_Souce_Image_from_the_original_paper_b90d9e7f66.png">

This block defines the Decoder Layer class, which is similar to the encoder layer but also includes a cross-attention mechanism to attend to the encoder’s output.

It is more complex than the encoder because the decoder must:
- Look at past generated tokens only (masked self-attention)
- Look at the encoder output (cross-attention)
- Refine features with a feed-forward network

The decoder block has 3 sub-blocks
1) Masked Self-Attention
2) Encoder–Decoder (Cross) Attention
3) Feed-Forward Network


In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()

        # Masked Self-Attention
        self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        
        # Encoder-Decoder (Cross) Attention
        self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

        # Feed-Forward Network
        self.feed_forward = PositionWiseFeedForward(d_model=d_model, d_ff=d_ff)

        # Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        '''
            x: decoder input
            (batch, tgt_seq_len, d_model)
            
            enc_output: encoder output
            (batch, src_seq_len, d_model)
            
            src_mask: mask for encoder tokens (padding)
            
            tgt_mask: mask for decoder tokens (causal + padding)
        '''

        # Masked Self-Attention (Decoder → Decoder)
        # Q = K = V = x | tgt_mask prevents attending to future tokens | Enables autoregressive generation
        attn_output = self.self_attn(x, x, x, tgt_mask)

        x = self.norm1(x + self.dropout(attn_output)) # Stabilizes gradients and preserves information.

        # Cross-Attention (Decoder → Encoder)
        # Query = decoder states (x) | Key & Value = encoder output (enc_output) | src_mask hides encoder padding tokens
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)

        x = self.norm2(x + self.dropout(attn_output))

        # Feed Forward Network
        ff_output = self.feed_forward(x)

        x = self.norm3(x + self.dropout(ff_output))

        return x  # Output shape: (batch, tgt_seq_len, d_model)

## Transformer Model

<img src="https://media.datacamp.com/legacy/v1691083566/Figure_4_The_Transformer_Network_Source_Image_from_the_original_paper_120e177956.png">

This block defines the main Transformer class which combines the encoder and decoder layers. It also includes the embedding layers and the final output layer.

This model is designed for sequence-to-sequence tasks, such as:
- Machine Translation
- Text Summarization
- Question Answering
- Code Generation

```
Source sentence (src)
   ↓
Encoder (understands input)
   ↓
Decoder (generates output)
   ↓
Vocabulary logits
```

In [34]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        '''
            - src_vocab_size: vocabulary size of source language
            - tgt_vocab_size: vocabulary size of target language
            - d_model: embedding dimension
            - num_heads: attention heads
            - num_layers: encoder/decoder layers
            - d_ff: feed-forward hidden dimension
            - max_seq_length: maximum sequence length
            - dropout: regularization rate
        '''

        super(Transformer, self).__init__()

        # Convert token IDs → dense vectors
        self.endoder_embedding = nn.Embedding(num_embeddings=src_vocab_size, embedding_dim=d_model)
        self.decoder_embedding = nn.Embedding(num_embeddings=tgt_vocab_size, embedding_dim=d_model)

        # Add positional information so order matters
        self.positional_encoding = PositionalEncoding(d_model=d_model, max_seq_length=max_seq_length)

        '''
            Creates a stack of encoder layers.
            Each encoder layer:
                Self-attention
                Feed-forward
                Add & Norm
        '''
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        ) # Holds submodules in a list.

        '''
            Creates a stack of decoder layers.
            Each decoder layer:
                Masked self-attention
                Cross-attention (encoder–decoder)
                Feed-forward
        '''
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        ) # Holds submodules in a list.

        # Converts decoder outputs into logits over target vocabulary.
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        ''' This is crucial for correct Transformer behavior. '''

        device = src.device
        
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2) # Shape: (batch, 1, 1, src_len) # Masks out <PAD> tokens in the encoder
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3) # Shape: (batch, 1, tgt_len, 1) # Masks <PAD> tokens in decoder input.

        seq_length = tgt.size(1) 

        # No-Peek (Causal) Mask: Creates a lower-triangular matrix:
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length, device=device), diagonal=1)).bool() # Prevents decoder from seeing future tokens
        
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        '''
            src: source token IDs (batch, src_len)
            tgt: target token IDs (batch, tgt_len)
        '''
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        src_embedded = self.dropout(self.positional_encoding(self.endoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

## Training and Testing the Transformer using the dummy dataset

In [35]:
import torch
from torch.utils.data import Dataset, DataLoader

class DummyTranslationDataset(Dataset):
    def __init__(self, num_samples=1000, seq_len=10, vocab_size=50):
        self.data = []
        for _ in range(num_samples):
            src = torch.randint(3, vocab_size, (seq_len,))
            tgt = torch.randint(3, vocab_size, (seq_len,))
            self.data.append((src, tgt))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [36]:
dataset = DummyTranslationDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True)


In [37]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print(device)

model = Transformer(
    src_vocab_size=50,
    tgt_vocab_size=50,
    d_model=128,
    num_heads=4,
    num_layers=4,
    d_ff=512,
    max_seq_length=50,
    dropout=0.1
).to(device)


mps
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32
Number of dimensions each head gets: 32


In [38]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.98),
    eps=1e-9
)


In [39]:
epochs = 10

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for src, tgt in loader:
        src = src.to(device)
        tgt = tgt.to(device)

        # Teacher forcing
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        optimizer.zero_grad()

        # Forward pass
        logits = model(src, tgt_input)

        # Reshape for loss
        logits = logits.reshape(-1, logits.size(-1))
        tgt_output = tgt_output.reshape(-1)

        loss = criterion(logits, tgt_output)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss / len(loader):.4f}")


Epoch 1 | Loss: 3.8924
Epoch 2 | Loss: 3.7792
Epoch 3 | Loss: 3.4974
Epoch 4 | Loss: 3.1723
Epoch 5 | Loss: 2.7903
Epoch 6 | Loss: 2.2669
Epoch 7 | Loss: 1.6176
Epoch 8 | Loss: 1.1770
Epoch 9 | Loss: 0.9573
Epoch 10 | Loss: 0.8430


In [42]:
model.eval()

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_tokens = 0
    correct_tokens = 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src = src.to(device)
            tgt = tgt.to(device)

            # Teacher forcing
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            logits = model(src, tgt_input)

            # Loss
            logits_flat = logits.reshape(-1, logits.size(-1))
            tgt_output_flat = tgt_output.reshape(-1)

            loss = criterion(logits_flat, tgt_output_flat)
            total_loss += loss.item()

            # Accuracy (ignore PAD tokens)
            predictions = torch.argmax(logits, dim=-1)
            mask = tgt_output != 0  # PAD = 0

            correct_tokens += (predictions == tgt_output).masked_select(mask).sum().item()
            total_tokens += mask.sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0

    return avg_loss, accuracy

eval_loss, eval_acc = evaluate(model, loader, criterion, device)

print(f"Evaluation Loss: {eval_loss:.4f}")
print(f"Token Accuracy : {eval_acc * 100:.2f}%")


Evaluation Loss: 0.4719
Token Accuracy : 89.46%
