# Imports

In [1]:
# Standard Lib
import os
import math
import random
from time import time
from pathlib import Path

# Visualization
import matplotlib.pyplot as plt

# Tokenization
import spacy 

# Loading Bar
from tqdm import tqdm

# Torch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Dataloader Custom Module
from sample_dataloader import get_dataloaders

In [None]:
data_root = os.path.join(Path(os.getcwd()).parent.parent.parent, "Datasets/")
gpu = torch.device("cuda:0")

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
!python -m spacy download de_core_news_sm

# Datasets

In [None]:
trainset, validset, testset, de_vocab, en_vocab = get_dataloaders(batch_size=128, device=gpu, data_root=data_root)

In [None]:
# all these indices are the same for french and english
PAD_IDX = de_vocab['<pad>']
SOS_IDX = de_vocab['<sos>']
EOS_IDX = de_vocab['<eos>']

# Attention
In **Transformers 1 - Before Attention** each forward pass went as 
```python
s = c = encoder(x)
y = '<s>'
for _ in range(N):
    decoder_input = torch.cat((s, c), dim=1)
    y, s = decoder(y, decoder_input)
```

Now we introduce **Attention**
```python
# Encoding Stage (inp_vec (x) -> hidden_states)
hidden_states, s = rnn(x)


# Decoding Stage(s)

# 1) computing the attention weights
alignment_scores = []
for h_i in hidden_states:
    e_i = f_att(s, h_i)
    alignment_scores.append(e_i)
    
attention_weights = F.softmax(alignment_scores)

# 2) computing the context vector
c = torch.zeros(1, hidden_dim)
for a_i, h_i in zip(attention_weights, hidden_states):
    c += a_i * h_i
    
# 3) decoding
y = '<s>'
y, s = decoder(y, c)
```


## Implementation Details

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # German Embeddings
        self.embedding = nn.Embedding()
        self.dropout = nn.Dropout()
        
        # Encoder
        self.lstm = nn.LSTM()
    
    def forward(self, x):
        """
        x: an encoded german sentence
        """
        embedding = self.dropout(self.embedding(x))
        _, (hidden, cell) = self.lstm(embedding)
        return hidden, cell

In [None]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.alignment = nn.Linear()
        self.softmax = nn.Softmax()
    
    def forward(self, d_hidden, encoder_outputs):
        """
        d_hidden: the current hidden state (dh_t)
        e_hiddens: a list of all the previous hidden states (eh_i for i <= t)
        """
        pass

In [None]:
class Decoder(nn.Module):
    def __init__(self, attention):
        super().__init__()
        # English Embeddings
        self.embedding = nn.Embedding()
        self.dropout = nn.Dropout()
        
        # attention block
        self.attention = attenion
        
        # Decoder
        self.lstm = nn.LSTM()
        self.fc_out = nn.Linear()
    
    def forward(self, x, hidden, cell):
        """
        x: the previous token
        hidden: the previous hidden state
        cell: the previous cell state
        """
        embedding = self.dropout(self.embedding(x))
        
        # attention
        weights = self.attention(hidden)
        weight_hidden = torch.bmm(weights, hidden)
        
        # decoding
        output, (hidden, cell) = self.lstm(weights_hidden, cell)
        
        prediction = self.fc_out(prediction)
        return prediction, hidden, cell

### Initialization + Number of Params

In [None]:
INPUT_DIM = len(de_vocab)
OUTPUT_DIM = len(en_vocab)
ENC_EMB_DIM = 128
DEC_EMB_DIM = 128
ENC_HID_DIM = 256
DEC_HID_DIM = 256
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, gpu, OUTPUT_DIM).to(gpu)

In [None]:
# initialize model weights
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

### Training

In [None]:
def train(model, iterator, optimizer, criterion, clip=1, num_epochs=10):
    model.train()
    losses = []
    for _ in range(num_epochs):
        for i, batch in tqdm(enumerate(iterator), desc="iteration"):
            src = batch.src
            trg = batch.trg 
            
            optimizer.zero_grad()

            output = model(src, trg)

            output_dim = output.shape[-1]

            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            
            loss = criterion(output, trg)
            losses.append(loss.item())
            loss.backward()

            # clip the gradients to prevent them from exploding (a common issue in RNNs)
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

            optimizer.step()
                    
    return model, losses

In [None]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
model, losses = train(model, trainset, optimizer, criterion)

### Training

In [None]:
plt.scatter(x=list(range(len(losses))), y=losses)

### Testing

In [None]:
def tensor_2_str(tensor, vocab=de_vocab.itos):
    return " ".join([vocab[int(token)] for token in tensor if vocab[int(token)] not in ['<eos>', '<pad>', '.']])

In [None]:
with torch.no_grad():
    sample = next(iter(trainset))
    src, trg = sample.src, sample.trg
    output = model(src, trg)
    output_tensor = output.argmax(2)[:, 0]
    target_tensor = trg[:, 0]
    
    output = tensor_2_str(output_tensor, en_vocab.itos)
    expected = tensor_2_str(target_tensor, en_vocab.itos)
    N = max(len(output), len(expected)) + len("Expected: ")
    
    print("="*N)
    print("Output: {}".format(output).center(N))
    print("="*N)
    
    print("="*N)
    print("Expected: {}".format(expected).center(N))
    print("="*N)