# Transformer Using Pytroch
Transformers have revolutionized the field of Natural Language Processing (NLP) by introducing a novel mechanism for capturing dependencies within sequences through attention mechanisms. Let’s break it down, implement it from scratch using PyTorch.


The implementation is based on the paper: [*Attention Is All You Need!*](https://arxiv.org/abs/1706.03762)

<img src="https://miro.medium.com/v2/resize:fit:828/format:webp/1*BHzGVskWGS_3jEcYYi6miQ.png" width="500"/>

In [9]:
import torch
import torch.nn as nn
import math

## Input Embedding (20)

It allows to convert the original sentence into a vector of X dimensions (d_model in our case).

In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        x = self.embedding(x) * math.sqrt(x.size(-1))
        return x

## PositionalEncoding Class (20)

Positional encoding is a crucial component in transformer models, which helps the model understand the position of each word in a sentence.

**Mathematical Formulation:**

For a given position *pos* and embedding dimensoin i:

$PE_{(pos,2i)}=sin(\frac{pos}{(10000^{(2i/d_{model})}}) $

$PE_{(pos,2i+1)}=cos(\frac{pos}{(10000^{(2i/d_{model})}}) $

where:

- $PE_{(pos,2i)}$ is the value of the positional encoding at position *pos* for the even dimenstion 2i.
- $PE_{(pos,2i+1)}$ is the value of the positional encoding at position *pos* for the odd dimension 2i + 1.
- $d_{model}$ is the dimension of the embedding (e.g. 512)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        position = torch.arange(seq_len).unsqueeze(1)  # [seq_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices

        self.register_buffer('pe', pe)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        seq_len = x.size(1)
        x = x + self.pe[:seq_len, :].unsqueeze(0)  # broadcast to [batch_size, seq_len, d_model]
        x = self.dropout(x)
        return x

## FeedForwardBlock Class

FeedForward is basically a fully connected layer, that transformer uses in both encoder and decoder. It consists of two linear transformations with a ReLU activation in between. This helps in adding non-linearity to the model, allowing it to learn more complex patterns.

In [12]:
class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

## MultiHeadAttentionBlock Class (50)

Multi-head attention is a core component of the transformer architecture, enabling the model to focus on different parts of the input sequence simultaneously. Let’s break down how multi-head attention works and why it is essential.

In [None]:
class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model must be divisible by number of heads h"

        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        
        # (batch, h, seq_len, d_k) x (batch, h, d_k, seq_len) = (batch, h, seq_len, seq_len)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(scores, dim=-1)  

        if dropout is not None:
            attn = dropout(attn)

        output = torch.matmul(attn, value)

        return output

    def forward(self, x_q, x_k, x_v, mask):
        Q = self.w_q(x_q)  
        K = self.w_k(x_k)  
        V = self.w_v(x_v)  

        batch_size = x_q.size(0)

        Q = Q.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 
        V = V.view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 

        
        out = self.attention(Q, K, V, mask, self.dropout)
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        out = self.w_o(out)

        return out


## ResidualConnection Class

Residual connections, or skip connections, are used to help with the training of deep neural networks by allowing gradients to flow more easily through the network.

In [14]:
class ResidualConnection(nn.Module):
    def __init__(self, d_model: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, Y):
        return self.norm(x + self.dropout(Y))

## EncoderBlock Class (30)

Now we will create the encoder block which will contain one multi-head attention, two Add and Norm (ResidualConnection) & one feed forward layer.

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention = self_attention_block
        self.residual1 = ResidualConnection(d_model, dropout)
        self.feed_forward = feed_forward_block
        self.residual2 = ResidualConnection(d_model, dropout)

    def forward(self, x, src_mask):
        attn_output = self.self_attention(x, x, x, src_mask)
        x = self.residual1(x, attn_output)
        ff_output = self.feed_forward(x)
        x = self.residual2(x, ff_output)
        return x

In [31]:
class Encoder(nn.Module):
    def __init__(self, d_model: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

## DecoderBlock Class (30)

The `DecoderBlock` class represents a single block of the Transformer decoder. Each decoder block contains a self-attention mechanism, a cross-attention mechanism (attending to the encoder's output), and a feed-forward network, all surrounded by residual connections and layer normalization.

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model: int,
                 self_attention_block: MultiHeadAttentionBlock,
                 cross_attention_block: MultiHeadAttentionBlock,
                 feed_forward_block: FeedForwardBlock,
                 dropout: float) -> None:
        super().__init__()
        self.self_attention = self_attention_block
        self.cross_attention = cross_attention_block
        self.feed_forward = feed_forward_block
        self.residual1 = ResidualConnection(d_model, dropout)
        self.residual2 = ResidualConnection(d_model, dropout)
        self.residual3 = ResidualConnection(d_model, dropout)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        self_attn_output = self.self_attention(x, x, x, tgt_mask)  
        x = self.residual1(x, self_attn_output)   
        cross_attn_output = self.cross_attention(x, encoder_output, encoder_output, src_mask)
        x = self.residual2(x, cross_attn_output)   
        ff_output = self.feed_forward(x)   
        x = self.residual3(x, ff_output)  
        return x

In [32]:
class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = nn.LayerNorm(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)


## Projection Layer Class

The `ProjectionLayer` class is used to convert the high-dimensional vectors (output of the decoder) into logits over the vocabulary. This projection is typically the last layer in the decoder of a transformer model.

In [19]:
class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)

## Transformer Class (50)

The `Transformer` class encapsulates the entire transformer model, integrating both the encoder and decoder components along with embedding layers and positional encodings.

In [None]:
class Transformer(nn.Module):
    def __init__(self,
                 encoder: Encoder,
                 decoder: Decoder,
                 src_embed: InputEmbeddings,
                 tgt_embed: InputEmbeddings,
                 src_pos: PositionalEncoding,
                 tgt_pos: PositionalEncoding,
                 projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        src_embeddings = self.src_embed(src)   

        
        src_embeddings = self.src_pos(src_embeddings)  

        encoder_output = self.encoder(src_embeddings, src_mask) 
        return encoder_output

    def decode(self,
               encoder_output: torch.Tensor,
               src_mask: torch.Tensor,
               tgt: torch.Tensor,
               tgt_mask: torch.Tensor) -> torch.Tensor:
        tgt_embeddings = self.tgt_embed(tgt)  
        tgt_embeddings = self.tgt_pos(tgt_embeddings)   
        decoder_output = self.decoder(tgt_embeddings, encoder_output, src_mask, tgt_mask)  
        return decoder_output

    def project(self, x: torch.Tensor) -> torch.Tensor:
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)

## Build Transformer Function (50)

`build_transformer` constructs a full Transformer model by putting together its various components, such as embedding layers, positional encoding, encoder and decoder blocks, and a final projection layer.

In [None]:
def build_transformer(src_vocab_size: int,
                      tgt_vocab_size: int,
                      src_seq_len: int,
                      tgt_seq_len: int,
                      d_model: int=512,
                      N: int=6,
                      h: int=8,
                      dropout: float=0.1,
                      d_ff: int=2048,
                      device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> Transformer:
   

    src_embed = InputEmbeddings(d_model=d_model, vocab_size=src_vocab_size)
    tgt_embed = InputEmbeddings(d_model=d_model, vocab_size=tgt_vocab_size)

     
    src_pos = PositionalEncoding(d_model=d_model, seq_len=src_seq_len, dropout=dropout)
    tgt_pos = PositionalEncoding(d_model=d_model, seq_len=tgt_seq_len, dropout=dropout)

     
    encoder_layers = nn.ModuleList([
        EncoderBlock(d_model=d_model,
                    self_attention_block=MultiHeadAttentionBlock(d_model=d_model, h=h, dropout=dropout),
                    feed_forward_block=FeedForwardBlock(d_model=d_model, d_ff=d_ff, dropout=dropout),
                    dropout=dropout)
        for _ in range(N)
    ])  

     
    decoder_layers = nn.ModuleList([
        DecoderBlock(d_model=d_model,
                    self_attention_block=MultiHeadAttentionBlock(d_model=d_model, h=h, dropout=dropout),
                    cross_attention_block=MultiHeadAttentionBlock(d_model=d_model, h=h, dropout=dropout),
                    feed_forward_block=FeedForwardBlock(d_model=d_model, d_ff=d_ff, dropout=dropout),
                    dropout=dropout)
        for _ in range(N)
    ])

    encoder = Encoder(d_model=d_model, layers=encoder_layers).to(device)
    decoder = Decoder(features=d_model, layers=decoder_layers).to(device)

    projection_layer = ProjectionLayer(d_model=d_model, vocab_size=tgt_vocab_size).to(device)

    transformer = Transformer(encoder=encoder,
                              decoder=decoder,
                              src_embed=src_embed,
                              tgt_embed=tgt_embed,
                              src_pos=src_pos,
                              tgt_pos=tgt_pos,
                              projection_layer=projection_layer).to(device)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)


    return transformer

## Testing the model

Here is a simple test to verify whether I have implemented the transformer correctly. Run the code below and ensure that both the training and validation losses decrease steadily.



In [None]:
!pip install datasets sentencepiece transformers

In [None]:
import torch
from datasets import load_dataset
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from transformers import BertTokenizer
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load WMT14 English-German Translation Dataset (test split is enough for our purpose)
dataset = load_dataset('wmt14', 'de-en', split='test')

# Initialize Tokenizer (use a pretrained tokenizer for simplicity)
src_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  # English tokenizer
tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')  # German tokenizer

# Preprocess data (Tokenization and Padding)
def tokenize_data(batch):
    src = src_tokenizer(batch['translation']['en'], padding="max_length", truncation=True, max_length=32)
    tgt = tgt_tokenizer(batch['translation']['de'], padding="max_length", truncation=True, max_length=32)
    return {'src_input_ids': src['input_ids'], 'tgt_input_ids': tgt['input_ids']}

dataset = dataset.map(tokenize_data)

In [41]:
# Set vocab sizes
src_vocab_size = src_tokenizer.vocab_size
tgt_vocab_size = tgt_tokenizer.vocab_size

# Define model parameters
src_seq_len = 32  # Max length of source sequences
tgt_seq_len = 32  # Max length of target sequences
d_model = 512
N = 6  # Number of layers
h = 8  # Number of heads
dropout = 0.1
d_ff = 2048

# Build Transformer Model
transformer = build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len, d_model, N, h, dropout, d_ff).to(device)

# Loss function and optimizer
criterion = CrossEntropyLoss(ignore_index=0)  # Ignore padding index
optimizer = Adam(transformer.parameters(), lr=2e-5)

def create_src_mask(src_input, pad_idx=0):
    """Create a mask for the source to hide padding tokens."""
    src_mask = (src_input != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)
    return src_mask

def create_tgt_mask(tgt_input, pad_idx=0):
    """Create a target mask to hide future tokens (causal mask) and padding tokens."""
    batch_size, tgt_len = tgt_input.shape
    # Causal mask to prevent looking ahead
    causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).bool().to(tgt_input.device).unsqueeze(0)  # (1, tgt_len, tgt_len)
    # Padding mask
    pad_mask = (tgt_input != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, tgt_len)

    # Combine the causal mask and padding mask
    tgt_mask = causal_mask & pad_mask.squeeze(1)  # (batch_size, tgt_len, tgt_len)
    return tgt_mask.unsqueeze(1)  # (batch_size, 1, tgt_len, tgt_len)

# Training Loop
for epoch in range(10):
    transformer.train()
    train_loss = 0
    val_loss = 0

    transformer.train()
    # Training
    for i in tqdm(range(0, 2000, 32)):
        src_input = torch.tensor(dataset[i:i+32]['src_input_ids']).to(device)
        tgt_input = torch.tensor(dataset[i:i+32]['tgt_input_ids']).to(device)

        # Create masks
        src_mask = create_src_mask(src_input).to(device)
        tgt_mask = create_tgt_mask(tgt_input[:, :-1]).to(device)  # Apply mask only on the decoder input sequence

        # Forward pass
        optimizer.zero_grad()
        encoder_output = transformer.encode(src_input, src_mask)
        decoder_output = transformer.decode(encoder_output, src_mask, tgt_input[:, :-1], tgt_mask)
        output = transformer.project(decoder_output)

        # Calculate loss
        loss = criterion(output.view(-1, tgt_vocab_size), tgt_input[:, 1:].reshape(-1))
        train_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()

    transformer.eval()
    # Evaluation
    for i in tqdm(range(2000, len(dataset), 32)):
        with torch.no_grad():
            src_input = torch.tensor(dataset[i:i+32]['src_input_ids']).to(device)
            tgt_input = torch.tensor(dataset[i:i+32]['tgt_input_ids']).to(device)

            # Create masks
            src_mask = create_src_mask(src_input).to(device)
            tgt_mask = create_tgt_mask(tgt_input[:, :-1]).to(device)  # Apply mask only on the decoder input sequence

            # Forward pass
            optimizer.zero_grad()
            encoder_output = transformer.encode(src_input, src_mask)
            decoder_output = transformer.decode(encoder_output, src_mask, tgt_input[:, :-1], tgt_mask)
            output = transformer.project(decoder_output)

            # Calculate loss
            loss = criterion(output.view(-1, tgt_vocab_size), tgt_input[:, 1:].reshape(-1))
            val_loss += loss.item()

    print(f'Epoch {epoch+1}, Train loss: {train_loss/2000}, Val loss: {val_loss/(len(dataset) - 2000)}')


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1, Train loss: 0.31039235210418703, Val loss: 0.3036927996223731


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 2, Train loss: 0.2900807647705078, Val loss: 0.28730585567972594


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 3, Train loss: 0.27304628705978395, Val loss: 0.27424413637292466


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 4, Train loss: 0.25954092693328856, Val loss: 0.2641999006033657


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 5, Train loss: 0.24909723782539367, Val loss: 0.256743199090777


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 6, Train loss: 0.2414636206626892, Val loss: 0.25176503532310784


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 7, Train loss: 0.2363039960861206, Val loss: 0.2488588759098072


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 8, Train loss: 0.2330855667591095, Val loss: 0.24725568472804244


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 9, Train loss: 0.23094035577774047, Val loss: 0.2462076589331432


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch 10, Train loss: 0.22944898128509522, Val loss: 0.245560116449358
