# Transformers in Pytorch

The idea of this notebook is to explain how transformers are coded in pytorch. We will take as reference the original [Attention is all you need](https://arxiv.org/abs/1706.03762) paper and this [video](https://www.youtube.com/watch?v=ISNdQcPhsts). The transformer we are going to build is rather simple and it will translate sentences from English to Spanish.

First we will build the transformer component by component and then we will work on the training loop and inference.

## The Transformer

In order to build the transformer, we will have to build all the inner components first. The base components of the transformer are:
- Input Embeddings
- Positional Encoding
- Layer Normalization
- Feed Forward Block
- Multi Head Attention Block

Then we have the encoder and the decoder, both composed of many encoder and decoder blocks. And finally a projection layer.

![Transformer Architecture](assets/transformer-network.png)


In [1]:
import torch
import torch.nn as nn
import math

### Input Embeddings
This layer will assign a vector to each of the tokens of the input sequence. This vectors are learned during training and represent the "meaning" of the token (or word). 

A `nn.Module` with this functionality already exists in PyTorch, but we will build a module on top in order to make reference to it.

In [2]:
class InputEmbedding (nn.Module) :
    
    def __init__(self, d_model: int, vocab_size: int) -> None :
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward (self, x) :
        return self.embedding(x) * math.sqrt(self.d_model) # as specified in the original paper

### Positional Encoding

The Positional Encodings adds some vectors to the embeddings in order to encode the position of the token in the sentence (e.g. first, second, ...). There are many ways to archive this, but here we will use the vectors proposed in the original paper, calculated with the following functions:

![Positional Encodings Functions](assets/positional-encoding-functions.png)

Where $pos$ is the position of the token in the sentence and $i$ is the dimension.

In [3]:
class PositionalEncoding(nn.Module) :
    
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None :
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = dropout
        
        pe = torch.zeros(seq_len, d_model) # (seq_len, d_model)
        
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(1000) / d_model)) # more numerically stable
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        
        self.register_buffer("pe", pe) # Save it to the state file, but not as a parameter
        
    def forward (self, x) :
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

### Layer Normalization

This component normalizes each input (each vector corresponding to a token) so its values have mean 0 and variance 1. Then it scales the values with a parameter $\alpha$ and shifts them with a parameter $\beta$.

The propose of this block is to stabilize and accelerate the training of the model as inputs of the next block will be on a specified range.

In [4]:
class LayerNormalization(nn.Module) :
    
    def __init__(self, eps: float = 10**-6) -> None :
        super().__init__()
        self. eps = eps # numerical stability
        
        self.alpha = nn.Parameter(torch.ones(1)) # Multiplied 
        self.beta = nn.Parameter(torch.ones(1)) # Added
        
    def forward (self, x) :
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.beta

### Feed Forward Block

This block is a simple fully-connected two layer neural network. 

In [5]:
class FeedForwardBlock (nn.Module) :
    
    def __init__ (self, d_model: int, d_ff: int, dropout: float) -> None :
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward (self, x) :
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

### Multi-Head Attention Block

This block is the game changer in transformers, its objective is to update the embeddings by giving them some context of the other tokens in the sentence. We use multiple heads in order to focus on different parts of the embeddings in each one, allowing to process different traits and aspects of each word.

This implementation will cover both the self-attention, masked-attention and cross-attention, as it changes only the input values and the use of a mask, that will come handy in all three cases.

In [6]:
class MultiHeadAttentionBlock(nn.Module) :
    
    def __init__(self, d_model: int, h: int, dropout: float) -> None :
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h" 
        
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout) :
        d_k = query.shape[-1]
        
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
            
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
            
        return (attention_scores @ value)
    
    def forward (self, q, k, v, mask) :
        query = self.w_q(q) 
        key = self.w_k(k)
        value = self.w_v(v)
        
        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        
        x = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        return self.w_o(x)

### Residual Connection

This last block will be used to handle the residual connection that appears in the diagram, also applying Layer Normalization. This way it gets more compact than simply writing a complex forward function in the encoder and decoder blocks.

In [7]:
class ResidualConnection(nn.Module) :
    def __init__(self, dropout: float) -> None :
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()
        
    def forward(self, x, sublayer) :
        return self.dropout(self.norm(x + sublayer(x)))

### Encoder Block

An encoder block processes the encoder inputs and generates new embeddings that will be later processed by either another encoder block or the decoder.

Normally many encoder blocks are present inside the encoder.

In [None]:
class EncoderBlock (nn.Module) :
    
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None :
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ ResidualConnection(dropout) for _ in range(2) ])
        
    def forward(self, x, src_mask) :
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

### Encoder

Now an encoder module that contains a list of Encoder Blocks.

In [8]:
class Encoder (nn.Module) :
    def __init__ (self, layers: nn.ModuleList) -> None :
        super().__init__()
        self.layers = layers
        
        
    def forward(self, x, mask) :
        for layer in self.layers :
            x = layer(x, mask)
            
        return x

### Decoder Block

The Decoder Block generates new embeddings for the decoder input tokens based both on masked self attention and cross-attention with the encoder output (attention with the tokens of the encoder input).

Normally more than one decoder block is present in the decoder.

In [9]:
class DecoderBlock (nn.Module) :
    
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None :
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ ResidualConnection(dropout) for _ in range(3) ])
        
    def forward (self, x, encoder_output, src_mask, tgt_mask) :
        x = self.residual_connections[0](x, lambda x : self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        
        return x

### Decoder

Now the decoder module that contains all the decoder blocks.

In [10]:
class Decoder (nn.Module) :
    
    def __init__ (self, layers: nn.ModuleList) -> None :
        super().__init__()
        self.layers = layers
        
    def forward (self, x, encoder_output, src_mask, tgt_mask) :
        for layer in self.layers :
            x = layer(x, encoder_output, src_mask, tgt_mask)
            
        return x

### Projection Layer

This layer will project the embeddings from the decoder output and transform them into probabilities of picking a specific token in each position. It consist of a linear layer followed by a softmax.

In [11]:
class ProjectionLayer (nn.Module) :
    
    def __init__(self, d_model: int, vocab_size: int) -> None :
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x) :
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return torch.log_softmax(self.proj(x), dim=-1)

### Transformer
Finally everything comes together in the transformer module. We won't build a forward method as we want to be able to run each part separately.

In [None]:
class Transformer (nn.Module) :
    
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, tgt_embed: InputEmbedding, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None :
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def encode (self, src, src_mask) :
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode (self, encoder_output, src_mask, tgt, tgt_mask) :
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

In [12]:
def build_transformer (src_vocab_size: int, tgt_vocab_size: int,
                       src_seq_len: int, tgt_seq_len: int,
                       d_model: int = 512, N: int = 6, h: int = 8, d_ff: int = 2048,
                       dropout: float = 0.1) -> None :
    
    # Create the embedding layers
    src_embed = InputEmbedding(d_model, src_vocab_size)
    tgt_embed = InputEmbedding(d_model, tgt_vocab_size)
    
    # Create the positional encoding layers (redundant as we just need one)
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N) :
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)
        
    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N) :
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decocer_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block, decocer_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)
        
    # Create the encoder and the decoder
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    
    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize the parameters (Make training faster)
    for p in transformer.parameters() :
        if p.dim() > 1 :
            nn.init.xavier_uniform_(p)
            
    return transformer