# Introduction

This is an implementation for [Attention is all you need](https://arxiv.org/abs/1706.03762) and references from [Pytorch Seq2Seq - Transformer](https://github.com/SethHWeidman/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb), [Harvard's Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/) and this brilliant blog -[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/) . 

It is implemented as an excerise to gain a deeper understanding of Transformer models by exploring its internal layers and implementing the same for translation task. The notebook is a followup to the first 3 notebooks where last implementation was a CNN Encoder-Decoder model with Attention.

Dataset used - [English - French Translations](https://www.kaggle.com/datasets/dhruvildave/en-fr-translation-dataset)

### Model architecture

The Transformer is an Encoder-Decoder Model ->

Basic Model flow = Input -> **Encoder** -> **Decoder** -> Output

    Encoder =
        Stack of 6 Encoders =
            Each Encoder =
                EncoderLayer(
                    Self-Attention -> Feed-Forward 
                    (each word embedding has its own parallel processing independently. only self attention has dependency on other words)
               )

The first encoder layer also contains the embedding layer for word along with positional embedding which it receives as input. All the other encoder layers receive the output of the previous layer as input. We'll see the embedding layer in detail later.

    Decoder =
        Stack of 6 Decoders =
            Each Decoder =
                DecoderLayer(
                    Self-Attention -> Encoder-Decoder-Attention -> Feed-Forward
               )  
               
Note - Each internal representation is 512 in dim

## Encoder

#### Embedding layer in first encoder layer

Embedding layer -> WordEmbedding
Postitional Encoding (formula mentioned in paper) -> PositionalEncoding (for each word)

Input to Encoder = WordEmbedding + PostitionalEncoding (each is of 512 size, summed element-wise) -> output(512 in length)


### Encoder Layer

Note - Each encoder layer has the same architecture

Consider a sentence as a matrix of **sent_len X 512** (where sent_len is length of sentence and 512 is the vector representation size of each word. For eg- input layer has the word embeddings)

Represent input matrix by **inp_mat**

Each encoder layer receives this dim vector as input.

#### Self-Attention
* For calculating self-attention, 3 weight matrices are maintained for Query(Wq), Key(Wk) and Value(Wv). Each of size **512 X 64** in the paper 
* **inp_mat** is multiplied with each weight matrix to create respective Query(Qv), Key(Kv) and Value(Vv) vectors. Each vector of size - **sent_len X 64**
* Calculate self-attention score for each word against all other words.
* Self attention score calculation - 

    For each word ->
        Multiply Qv (query) of the word with other words KvT (key vector transpose)
        -> returns the score vector(Sv) of size -> **sent_len X sent_len**
        -> by intution ( the size sent_len X sent_len means a value for each word in the sentence w.r.t all the other words)
        -> eg - sentence -> Good morning 
        -> Sv could be =           Good  morning
                         Good       [1.5,  2.7
                         morning    3.5,   0.3]


* Divide Sv by 8 (square root of key vector size - 64). For more stable gradients as per paper.
* Pass Sv through softmax operation to normalize the scores and make them add up to 1.
* Matrix **Sv(size - sent_len X sent_len)** is multiplied with Matrix **Vv(size - sent_len X 64)**. This produces the output of the self attention layer. Output size - **sent_len X 64
* By intitution, in the last step -> multiplying each words value vector by the current word's attention score for that vector highlights important words as their attention score would be more and diminshing other words with lower attention score as they get multiplied  with values like 0.0001.
* Let's call this output matrix Z (size - **sent_len X 64**)

##### Multi-headed attention
* Following from the previous step -> instead of a single set of weight matrices (Wq, Wk, Wv), consider multiple sets of these matrices (in paper 8 sets of query, key and value matrices are used).
* Now each of these sets are used separately to process the self-attention flow listed above and produce their respective Z matrix as output -> (in paper Z1, Z2....Z8 -> 8 matrices).
* This is called multi-headed attention. This is helpful for the model to look at differnt patterns in the sentences and maybe consider different sub-sentences lengths in different heads.

##### Final Processing of Self-Attention
* Concatenate all Z matrices -> along the column -> matrix of size - sent_len X (64X8) = **sent_len X 512**
* Multiply with another weight matrix W0 (this matrix is also learned along with the model) -> output O => size - **sent_len X 512**

#### Feed-Forward
* Now the output **O from Self-Attention** is passed to feed forward network. Each word embedding goes through a separate feed forward network. **So, the no. of feed forward networks = sent_len**.
* **Feed forward input 1 X 512**  -> **Feed forward output 1 X 512**
* All word outputs together form an output Fi matrix of size -> **sent_len X 512** -> which is the output of Encoder layer i. This will serve as the input of the next encoder layer and is of the same dimension as input to the encoder layer.
* The size is kept constant across layers in transformer.

#### Residuals (LayerNorm - Add & Normalize)
* Output of each Sublayer(Self-Attention & Feed-Forward) of Encoder layer is summed element wise with the input to that layer and normalized.
* For eg - Oi output from Self-Attention is added with Embedding in the first encoder layer and output Ei-1 in the other encoder layers and normalized. Fi output from feed-forward layer is added with Oi output from Self-Attention just before it that has been normalized and the sum is further normalized.

(ignoring embedding here in the first case for generic representation)
Encoder single layer process ->

    Input inp (sent_len X 512) ->
        Self-Attention ->
            W = 8 sets of Wq, Wk, Wv
            Z = []
            for (Wk, Wq, Wv) in W:
                Qv = inp X Wq (sent_len X 64)
                Kv = inp X Wk (sent_len X 64)
                Vv = inp X Wv (sent_len X 64)

                Sv = Qv X KvT (sent_len X sent_len)
                Sv = Sv/8  (for stable gradients)
                Sv = Softmax(Sv) (for normalizing and making sure all values sum upto 1)
                Zi = Sv X Vv (sent_len X 64)
                Z.append(Z)
                
             Zconcat = Z1.concat(Z2).concat(Z3)....concat(Z8)   (sent_len X 512)
             
             FI = Zconcat X Wo (sent_len X 512)
             
             FI = FI + inp (residual adding)
             FI = norm(FI)
        
        Feed-Forward ->
             FO = []
             for i in parallel_process(sent_len):
                 FOi = FI[i] -> feed-forward layer (1 X 512)
                 FO.concat(FOi)
                 
             FO (sent_len X 512)
             FO = FO + FI (residual summing)
             FO = norm(FO) (layer norm)
             
        Encoder layer output = FO
             



## Decoder

The decoder is fed as input the target sentence tokens as input along with the ouptut of the top Encoder layer. It produces as output in a single run the id of the token generated in vocabulary and it keeps producing the tokens until a special token, generally EOS token is reached.

#### Embedding layer in first decoder layer for target sentence

Embedding layer -> WordEmbedding
Postitional Encoding (formula mentioned in paper) -> PositionalEncoding (for each word)

Input to Decoder = WordEmbedding + PostitionalEncoding (each is of 512 size, summed element-wise) -> output(512 in length)

In the decoder, the self-attention layer is only allowed to attend to earlier positions in the output sequence.** This is done by masking future positions (setting them to -inf)** before the softmax step in the self-attention calculation.


### Decoder Layer

Note - Each decoder layer has the same architecture

Target sentence is a matrix of **sent_len X 512** after passing through the embedding layer(where sent_len is length of sentence and 512 is the vector representation size of each word).

Represent target sent matrix by **trg_mat**

Each decoder layer receives this dim vector as input.

#### Self-Attention
* Self-Attention operates similarly to how it operates in encoder layer. Only difference is that in the decoder, the self-attention layer is only allowed to attend to earlier positions in the output sequence. This is done by masking future positions (setting them to -inf) before the softmax step in the self-attention calculation.

* Output of self attention size - **sent_len X 512**

#### Encoder-Decoeder-Attention

* Operates in the same way as self-attention except the Keys and Values vectors are constructed from the Encoder output(sent_len X 512) and the query vector is constructed from the previous layer's output matrix (sent_len X 512).

### Final Linear Layer
* Output of the last decoder layer is passed through a linear layer which takes a vector of size 512 (float values) and maps it to vector of vocab size (1 X target_vocab_size). 
* Softmax is applied to covert these values to probabilities adding upto 1. 
* The index with the highest probability is considered as the index of the predicted word from the vocabulary. 
* This word is fed along with the previous generated words to the decoder for the next run.

(ignoring embedding here in the first case for generic representation)
Decoder single layer process ->

    Input - trg_mat (sent_len X 512, future tokens will be masked during attention process)
          - enc_out (sent_len X 512, to be used in encoder-decoder attention layer)
    ->
        Self-Attention ->
            W = 8 sets of Wq, Wk, Wv
            Z = []
            for (Wk, Wq, Wv) in W:
                Qv = trg_mat X Wq (sent_len X 64)
                Kv = trg_mat X Wk (sent_len X 64)
                Vv = trg_mat X Wv (sent_len X 64)

                Sv = Qv X KvT (sent_len X sent_len)
                Sv = Sv/8  (for stable gradients)
                Sv = mask(Sv) # future tokens are masked by setting probability to -inf
                Sv = Softmax(Sv) (for normalizing and making sure all values sum upto 1)
                Zi = Sv X Vv (sent_len X 64)
                Z.append(Z)
                
             Zconcat = Z1.concat(Z2).concat(Z3)....concat(Z8)   (sent_len X 512)
             
             EDI = Zconcat X Wo (sent_len X 512)
             
             EDI = EDI + trg_mat (residual adding)
             EDI = norm(EDI)
             
        Encoder-Decoder-Attention ->
            W = 8 sets of Wq, Wk, Wv
            Z = []
            for (Wk, Wq, Wv) in W:
                Qv = EDI X Wq (sent_len X 64)
                Kv = enc_out X Wk (sent_len X 64)
                Vv = enc_out X Wv (sent_len X 64)

                Sv = Qv X KvT (sent_len X sent_len)
                Sv = Sv/8  (for stable gradients)
                Sv = mask(Sv) # future tokens are masked by setting probability to -inf
                Sv = Softmax(Sv) (for normalizing and making sure all values sum upto 1)
                Zi = Sv X Vv (sent_len X 64)
                Z.append(Z)
                
             Zconcat = Z1.concat(Z2).concat(Z3)....concat(Z8)   (sent_len X 512)
             
             FI = Zconcat X Wo (sent_len X 512)
             
             FI = FI + EDI (residual adding)
             FI = norm(FI)
        
        Feed-Forward ->
             FO = []
             for i in parallel_process(sent_len):
                 FOi = FI[i] -> feed-forward layer (1 X 512)
                 FO.concat(FOi)
                 
             FO (sent_len X 512)
             FO = FO + FI (residual summing)
             FO = norm(FO) (layer norm)
             
        Decoder layer output = FO

In [18]:
import numpy as np
import pandas as pd
import spacy
from string import digits
import random
from torchtext.data.utils import get_tokenizer
import torch
import torchtext
from collections import Counter
from torchtext.vocab import vocab
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
#torch.cuda.empty_cache()

import math
import time

import numpy as np
import pandas as pd

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/en-fr-translation-dataset/en-fr.csv


In [2]:
SEED = 97

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
!python -m spacy download en_core_web_sm
!python -m spacy download fr_core_news_sm

Collecting en-core-web-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m67.9 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[0m[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Collecting fr-core-news-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-3.5.0/fr_core_news_sm-3.5.0-py3-none-any.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m62.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: fr-core-news-sm
Successfully installed fr-core-news-sm-3.5.0
[0m[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('fr_core_news_sm')


In [4]:
MAX_LEN = 256
check_len = lambda x: len(x.split(' ')) > MAX_LEN

In [6]:
data = pd.read_csv('/kaggle/input/en-fr-translation-dataset/en-fr.csv', nrows=5000)
data = data.dropna().drop_duplicates()
data = data.drop(data[data.en.apply(check_len) | data.fr.apply(check_len)].index)
data.head(5)

Unnamed: 0,en,fr
0,Changing Lives | Changing Society | How It Wor...,Il a transformé notre vie | Il a transformé la...
1,Site map,Plan du site
2,Feedback,Rétroaction
3,Credits,Crédits
4,Français,English


In [7]:
len(data)

4998

In [8]:
fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

In [9]:
val_frac = 0.1
test_frac = 0.05
val_split_idx = int(len(data)*val_frac)
test_split_idx = int(len(data)*(val_frac + test_frac))
data_idx = list(range(len(data)))
np.random.shuffle(data_idx)

val_idx, test_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:test_split_idx], data_idx[test_split_idx:]
print('Length of train set: ', len(train_idx))
print('Length of val set: ', len(val_idx))
print('Length of test set: ', len(test_idx))

df_train = data.iloc[train_idx].reset_index().drop('index',axis=1)
df_test = data.iloc[test_idx].reset_index().drop('index',axis=1)
df_val = data.iloc[val_idx].reset_index().drop('index',axis=1)

Length of train set:  4249
Length of val set:  499
Length of test set:  250


In [10]:
def build_vocab(data, source_tokenizer, target_tokenizer):
    en_counter = Counter()
    fr_counter = Counter()
    translations = data.values.tolist()
    for translation in translations:
        en_counter.update(source_tokenizer(translation[0]))
        fr_counter.update(target_tokenizer(translation[1]))
    return vocab(en_counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'], min_freq=5), vocab(fr_counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'], min_freq=5)

In [11]:
en_vocab, fr_vocab = build_vocab(df_train, en_tokenizer, fr_tokenizer)
en_vocab.set_default_index(en_vocab['<unk>'])
fr_vocab.set_default_index(fr_vocab['<unk>'])

In [12]:
def data_process(data):
    translations = data.values.tolist()
    pairs = []
    for translation in translations:
        en_tensor = torch.tensor([en_vocab[token] for token in en_tokenizer(translation[0])],
                            dtype=torch.long)
        fr_tensor = torch.tensor([fr_vocab[token] for token in fr_tokenizer(translation[1])],
                            dtype=torch.long)
        pairs.append((en_tensor, fr_tensor))
    return pairs

In [13]:
train_data = data_process(df_train)
val_data = data_process(df_val)
test_data = data_process(df_test)

In [14]:
BATCH_SIZE = 16
PAD_IDX = en_vocab['<pad>']
BOS_IDX = en_vocab['<bos>']
EOS_IDX = en_vocab['<eos>']

In [15]:
def generate_batch(data_batch):
    en_batch, fr_batch = [], []
    for (en_item, fr_item) in data_batch:
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0).to(device))
        fr_batch.append(torch.cat([torch.tensor([BOS_IDX]), fr_item, torch.tensor([EOS_IDX])], dim=0).to(device))  
        
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    fr_batch = pad_sequence(fr_batch, padding_value=PAD_IDX)
    return en_batch, fr_batch

In [16]:
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)

In [19]:
class SelfAttention(nn.Module):
    def __init__(self, 
                 hid_dim: int, 
                 n_heads: int, 
                 dropout: float, 
                 device: torch.device):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        
        assert hid_dim % n_heads == 0
        
        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
        
    def forward(self, 
                query: Tensor, 
                key: Tensor, 
                value: Tensor, 
                mask: Tensor = None):
        
        bsz = query.shape[0]
        
        #query = key = value [batch size, sent len, hid dim]
                
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
        
        #Q, K, V = [batch size, sent len, hid dim]
        
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        
        #Q, K, V = [batch size, n heads, sent len, hid dim // n heads]
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, sent len, sent len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = self.dropout(F.softmax(energy, dim=-1))
        
        #attention = [batch size, n heads, sent len, sent len]
        
        x = torch.matmul(attention, V)
        
        #x = [batch size, n heads, sent len, hid dim // n heads]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, sent len, n heads, hid dim // n heads]
        
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        
        #x = [batch size, src sent len, hid dim]
        
        x = self.fc(x)
        
        #x = [batch size, sent len, hid dim]
        
        return x

In [20]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, 
                 hid_dim: int, 
                 pf_dim: int, 
                 dropout: float):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.pf_dim = pf_dim
        
        self.fc_1 = nn.Conv1d(hid_dim, pf_dim, 1)
        self.fc_2 = nn.Conv1d(pf_dim, hid_dim, 1)
        
        self.do = nn.Dropout(dropout)
        
    def forward(self, 
                x: Tensor):
        
        #x = [batch size, sent len, hid dim]
        
        x = x.permute(0, 2, 1)
        
        #x = [batch size, hid dim, sent len]
        
        x = self.dropout(F.relu(self.fc_1(x)))
        
        #x = [batch size, ff dim, sent len]
        
        x = self.fc_2(x)
        
        #x = [batch size, hid dim, sent len]
        
        x = x.permute(0, 2, 1)
        
        #x = [batch size, sent len, hid dim]
        
        return x


In [21]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim: int, 
                 n_heads: int, 
                 pf_dim: int, 
                 self_attention: SelfAttention, 
                 positionwise_feedforward: PositionwiseFeedforward, 
                 dropout: float, 
                 device: torch.device):
        super().__init__()      
        self.layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = self_attention(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = positionwise_feedforward(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, 
                src: Tensor, 
                src_mask: Tensor):
        
        #src = [batch size, src sent len, hid dim]
        #src_mask = [batch size, 1, 1, src sent len]
 
        #src = [batch size, src sent len, hid dim]
        src = self.layer_norm(
            src + self.dropout(self.self_attention(
                src, src, src, src_mask)))
        
        #src = [batch size, src sent len, hid dim]        
        src = self.layer_norm(
            src + self.dropout(
                self.positionwise_feedforward(src)))
        
        return src

In [22]:
class Encoder(nn.Module):
    def __init__(self, 
                 input_dim: int, 
                 hid_dim: int, 
                 n_layers: int, 
                 n_heads: int, 
                 pf_dim: int, 
                 encoder_layer: EncoderLayer, 
                 self_attention: SelfAttention, 
                 positionwise_feedforward: PositionwiseFeedforward, 
                 dropout: float, 
                 device: torch.device):
        super().__init__()

        self.input_dim = input_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.encoder_layer = encoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        self.device = device
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(1000, hid_dim)
        
        self.layers = nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, 
                src: Tensor, 
                src_mask: Tensor):      
        #src = [batch size, src sent len]
        #src_mask = [batch size, 1, 1, src sent len]
        
        pos = torch.arange(0, src.shape[1]).unsqueeze(0).repeat(src.shape[0], 1).to(self.device)
        #pos = [batch size, src sent len]
        
        src_embedded = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src sent len, hid dim]
        
        # each layer is an "EncoderLayer"
        for layer in self.layers:
            src_embedded = layer(src_embedded, src_mask)
            
        return src_embedded

In [23]:
class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim: int, 
                 n_heads: int, 
                 pf_dim: int, 
                 self_attention: SelfAttention, 
                 positionwise_feedforward: PositionwiseFeedforward, 
                 dropout: float, 
                 device: torch.device):
        super().__init__()
        
        self.layer_nore = nn.LayerNorm(hid_dim)
        self.self_attention = self_attention(hid_dim, n_heads, dropout, device)
        self.encoder_attention = self_attention(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = positionwise_feedforward(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, 
                trg: Tensor, 
                src: Tensor, 
                trg_mask: Tensor, 
                src_mask: Tensor):
        
        #trg = [batch size, trg sent len, hid dim]
        #src = [batch size, src sent len, hid dim]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
                
        trg = self.layer_norm(
            trg + self.dropout(
                self.self_attention(trg, trg, trg, trg_mask)))
                
        trg = self.layer_norm(
            trg + self.do(
                self.encoder_attention(trg, src, src, src_mask)))
        
        trg = self.layer_norm(
            trg + self.dropout(
                self.positionwise_feedforward(trg)))
        
        return trg

In [24]:
class Decoder(nn.Module):
    def __init__(self, 
                 output_dim: int, 
                 hid_dim: int, 
                 n_layers: int, 
                 n_heads: int, 
                 pf_dim: int, 
                 decoder_layer: DecoderLayer, 
                 self_attention: SelfAttention, 
                 positionwise_feedforward: PositionwiseFeedforward, 
                 dropout: float, 
                 device: torch.device):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.decoder_layer = decoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(1000, hid_dim)
        
        self.layers = nn.ModuleList([decoder_layer(hid_dim, n_heads, pf_dim, 
                                                   self_attention, 
                                                   positionwise_feedforward, 
                                                   dropout, device)
                                     for _ in range(n_layers)])
        
        self.fc = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, 
                trg: Tensor, 
                src: Tensor, 
                trg_mask: Tensor, 
                src_mask: Tensor):            
        #trg = [batch_size, trg sent len]
        #src = [batch_size, src sent len]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
        
        #pos = [batch_size, trg sent len]
        pos = torch.arange(0, trg.shape[1]).unsqueeze(0).repeat(trg.shape[0], 1).to(self.device)

        #trg = [batch_size, trg sent len]        
        trg_embedded = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        
        #trg = [batch size, trg sent len, hid dim]
        
        #trg_mask = [batch size, 1, trg sent len, trg sent len]      
        
        for layer in self.layers:
            trg_embedded = layer(trg_embedded, src, trg_mask, src_mask)

        #trg = [batch size, trg sent len, hid dim]            

        return self.fc(trg_embedded)

In [25]:
class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder: Encoder, 
                 decoder: Decoder, 
                 pad_idx: int, 
                 device: torch.device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device
        
    def make_masks(self, 
                   src: Tensor, 
                   trg: Tensor):
        
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
        
        src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
        
        trg_pad_mask = (trg != self.pad_idx).unsqueeze(1).unsqueeze(3)

        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), dtype=torch.uint8, device=self.device))
        
        trg_mask = trg_pad_mask & trg_sub_mask
        
        return src_mask, trg_mask
    
    def forward(self, 
                src: Tensor, 
                trg: Tensor):
        
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
                
        src_mask, trg_mask = self.make_masks(src, trg)
        
        enc_src = self.encoder(src, src_mask)
        
        #enc_src = [batch size, src sent len, hid dim]
                
        out = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #out = [batch size, trg sent len, output dim]
        
        return out