# Encoder-decoder network implementation
---
The third notebook in the row. It assumes that we know how to build both Encoder and Decoder networks.

The general idea is to create and combine two networks - endoder and decoder.
After using N encoder bloks we take key `K` and `V` from the final block and inject them to every
decoder block. In the decoder block query `Q` comes from the decoder, whereas `K` and `V` come from the encoder. We may interpret is as: how much the word `Q` (from translation) should pay attention to each words (`K`) in the sentence.


 <img src="./images/encoder-decoder.png" alt="Attention and Multi-head Attention" width="545" />
 
*Image from [Attention is All you need](https://arxiv.org/abs/1706.03762) paper*

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt


### Multi-Head Attention block

`MultiHeadAttention` block should be more general this time. Specifically, `K` and `V` may come from another source than `Q`, and, as a result, have diferent shapes (diferent number of words). This is very natural, an English sentence and its Polish translation do not have to have the same number of words.
We denote the length of the inputs as $T_{input}$, and the length of the outputs as  $T_{output}$. As a consequence, `K` and `V` have shapes  $T_{input} \times d_k$ and `Q` has shape $T_{output} \times d_k$.

In fact, we are using $n_{heads}$ heads and $N$ batches, so the "full" shapes are:
- for `K` and `V`: $N \times n_{heads} \times T_{input} \times d_k$
- for `Q`: $N \times n_{heads} \times T_{output} \times d_k$

The shapes are being changed in the code to keep the tensors squeezed or to make the tensor operations possible. Pay attention to the notes.

Additionally, in the encoder-decoder structure, we will use two types of Multi-head attention blocks - with and without causal mask. The `causal` parameter takes care about it. 




In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, causal=False):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_k
        self.n_heads = n_heads
        
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.key = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)
        
        self.out = nn.Linear(d_k * n_heads, d_model)
        
        # Causal mask
        self.causal = causal
        if causal: 
            cm = torch.tril(torch.ones(max_len, max_len))
            self.register_buffer(
                'causal_mask',
                cm.view(1, 1, max_len, max_len)
            )
        
        
    def forward(self, q, k, v, pad_mask=None):
        
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
        
        q = self.query(q) # N x T_output x (h*d_k) 
        k = self.key(k)   # N x T_input x (h*d_k)
        v = self.value(v) # N x T_input x (h*d_v) # d_v == d_k
        
        N = q.shape[0] # batch size
        T_output = q.shape[1] # Sequence length q
        T_input = k.shape[1] # Sequence length for k amd v
        
        # Changing shapes (reuqired for matrix multiplication)
        # view: (N, T, h*d_k) -> (N, T, h, d_k)
        # transpose: (N, T, h, d_k) -> (N, h, T, d_k)
        q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)
        
        # (N, h, T_output, d_k) x (N, h, d_k, T_input) -> (N, h, T_output, T_input)
        atention_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        
        if pad_mask is not None:
            # Mask has (N, T_input) shape, so we need to add two (inner) dimensions
            # We also change zeros with -inf, so that softmax will ignore these values
            atention_scores = atention_scores.masked_fill(
                 pad_mask[:, None, None, :] == 0, float('-inf')
                 )
            
        # We also may need to add cusal mask, so that we don't look into the future
        # Max_len is the length of the longest sequence possible, but in fact,
        # we need the longest sequence in the batch. Thus we crop causal mask to :T size      
        # Moreover, after getting rid of first two dimensions (batch size and number of heads)
        # Our mask number of rows corresponds to the Q sequence length, and number of columns
        # corresponds to the K sequence length. Thus we need to crop the mask to the size of
        # Q sequence length and K sequence length.
        # Note 2: In the decoder part Q comes from the decoder, whereas K and V come from 
        # the encoder. We may interpret is as: how much the word Q (from translation)
        # should pay attention to each words K in the sentence.  
        if self.causal:      
            atention_scores = atention_scores.masked_fill(
                    self.causal_mask[:, :, :T_output, :T_input] == 0, float('-inf')
                    )
        
        attention_weights = F.softmax(atention_scores, dim=-1)
        
        # (N, h, T_output, T_input) x (N, h, T_input, d_k) -> (N, h, T_output, d_k)
        A = attention_weights @ v
        
        # Reshape (N, h, T_output, d_k) -> (N, T_output, h, d_k) -> (N, T_output, h*d_k)
        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T_output, self.n_heads * self.d_k)
        
        return self.out(A)
        
        
                

### Encoder block

In [3]:
class EncoderBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout=0.1):
        super().__init__()
                
        self.attention = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, pad_mask=None):
        x = self.norm1(x + self.attention(x, x, x, pad_mask))
        x = self.norm2(x + self.ff(x))
        x = self.dropout(x)
        return x

### Decoder block

In [4]:
class DecoderBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout=0.1):
        super().__init__()
                
        self.attention_1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=True)
        self.norm1 = nn.LayerNorm(d_model)

        self.attention_2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
        self.norm2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )

        self.norm3 = nn.LayerNorm(d_model)


        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
        x = self.norm1(dec_input + self.attention_1(dec_input, dec_input, dec_input, dec_mask))
        
        x = self.norm2(x + self.attention_2(x, enc_output, enc_output, enc_mask))
        
        x = self.norm3(x + self.ff(x))
        return self.dropout(x)


In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # [ [0], [1], [2], ..., [max_len-1] ]
        # 2d array of size max_len x 1
        position = torch.arange(max_len).unsqueeze(1)
        
        #[0, 2, 4, ...]
        exp_term = torch.arange(0, d_model, 2) 
        
        
        div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x.shape: N x T x D
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
        
        

### Encoder

In [6]:
class Encoder(nn.Module):
    def __init__(
        self, 
        vocab_size : int,
        max_len : int,
        d_k : int,
        d_model : int,
        n_heads : int,
        n_layers : int,
        dropout : float = 0.1,
    ):
    
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout=dropout)
        transformer_blocks = [
            EncoderBlock(d_k, d_model, n_heads, max_len, dropout=dropout)
            for _ in range(n_layers)
        ]
        
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x, pad_mask = None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        
        x = self.norm(x)
        return x
        
        

### Decoder

In [7]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        max_len,
        d_k,
        d_model,
        n_heads,
        n_layers,
        dropout,
    ):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout=dropout)
        transformer_blocks = [
            DecoderBlock(d_k, d_model, n_heads, max_len, dropout=dropout)
            for _ in range(n_layers)
        ]

        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.norm = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
        
    def forward(self, enc_output, dec_input, enc_mask = None, dec_mask = None):
        x = self.embedding(dec_input)
        x = self.pos_encoding(x)
        
        for block in self.transformer_blocks:
            x = block(enc_output, x, enc_mask, dec_mask)
                
        x = self.norm(x)
        return self.out(x)

In [8]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder


    def forward(self, enc_input, dec_input, enc_mask, dec_mask):
        enc_output = self.encode(enc_input, enc_mask)
        dec_output = self.decode(enc_output, dec_input, enc_mask, dec_mask)    
        return dec_output
    
    def encode(self, enc_input, enc_mask):
        return self.encoder(enc_input, enc_mask)
    
    def decode(self, enc_output, dec_input, enc_mask, dec_mask):
        return self.decoder(enc_output, dec_input, enc_mask, dec_mask)

### Quick test

In [14]:
encoder = Encoder(
    vocab_size=20_000,
    max_len = 1024,
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout = 0.1,
)

decoder = Decoder(
    vocab_size=10_000,
    max_len = 1024,
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout = 0.1,
)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

transformer = Transformer(encoder, decoder)
transformer = transformer.to(device)


In [15]:
batch_size = 8
nr_words_enc = 512
nr_words_dec = 256

x = np.random.randint(0, 20_000, size=(batch_size, nr_words_enc))
enc_input = torch.tensor(x).to(device)

enc_mask = np.ones((batch_size, nr_words_enc))
enc_mask[:, int(nr_words_enc/2):] = 0 # Let's cut off the second part of the sequence
enc_mask = torch.tensor(enc_mask).to(device)

x = np.random.randint(0, 10_000, size=(batch_size, nr_words_dec))
dec_input = torch.tensor(x).to(device)
dec_mask = np.ones((batch_size, nr_words_dec))
dec_mask[:, int(nr_words_dec/2):] = 0 # Let's cut off the second part of the sequence
dec_mask = torch.tensor(dec_mask).to(device)



y = transformer(enc_input, dec_input, enc_mask, dec_mask)
print (y.shape)
print (y.argmax(dim=-1).shape)


torch.Size([8, 256, 10000])
torch.Size([8, 256])
