# Attention
In **Transformers 1 - Before Attention** 2 LSTMs were used to carry the cell state and hidden state between the encoder and decoder
$$dc_{1} = f \odot ec_{t} + i \odot g$$
$$dh_{1} = o \odot \tanh(ec_t)$$

Attention presents the following modification
$$ ah_{t} = \sum_{i=1}^t a_i \cdot dh_{i} $$

Now passing the Attention Layer output to the decoder, the question asked is what do we do with the cell state? 

Do we pass both the cell state and the hidden state scaled to the decoder lstm? 

Just the cell state

In [None]:
# Standard Lib
import os
import math
import random
from time import time
from pathlib import Path

# Visualization
import matplotlib.pyplot as plt

# Tokenization
import spacy 

# Loading Bar
from tqdm import tqdm

# Torch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Dataloader Custom Module
from sample_dataloader import get_dataloaders

In [None]:
data_root = os.path.join(Path(os.getcwd()).parent.parent.parent, "Datasets/")
gpu = torch.device("cuda:0")

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
!python -m spacy download de_core_news_sm

### Datasets

In [None]:
trainset, validset, testset, de_vocab, en_vocab = get_dataloaders(batch_size=128, device=gpu, data_root=data_root)

In [None]:
# all these indices are the same for french and english
PAD_IDX = de_vocab['<pad>']
SOS_IDX = de_vocab['<sos>']
EOS_IDX = de_vocab['<eos>']

## Implementation Details

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, enc_hid_dim)
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        
        # hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        # outputs are always from the last layer
        
        # hidden [-2, :, : ] is the last of the forwards RNN 
        # hidden [-1, :, : ] is the last of the backwards RNN
        
        # initial decoder hidden is final hidden state of the forwards and backwards 
        #  encoder RNNs fed through a linear layer
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        
        # outputs = [src len, batch size, enc hid dim * 2]
        # hidden = [batch size, dec hid dim]
        
        return outputs, hidden

In [None]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        attention = self.v(energy).squeeze(2)
                
        return F.softmax(attention, dim=1)

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
        
        input = input.unsqueeze(0)
        
        embedded = self.dropout(self.embedding(input))
        
        # this are the alignment scores for each of the encoder
        # outputs (note that the encoder outputs are the hidden states
        # of all previous encodings)
        a = self.attention(hidden, encoder_outputs)
        
        a = a.unsqueeze(1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        # batch matrix multiplication
        weighted = torch.bmm(a, encoder_outputs)
        
        weighted = weighted.permute(1, 0, 2)
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
                    
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        
        #seq len, n layers and n directions will always be 1 in this decoder, therefore:
        #output = [1, batch size, dec hid dim]
        #hidden = [1, batch size, dec hid dim]
        #this also means that output == hidden
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        
        return prediction, hidden.squeeze(0)

In [None]:
class Seq2Seq(nn.Module):
    """ this is the same as in the previous module """
    def __init__(self, encoder, decoder, device, trg_vocab_size):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.trg_vocab_size = trg_vocab_size 
        
    def forward(self, src, trg, teacher_forcing_ratio=.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        
        outputs = torch.zeros(trg_len, batch_size, self.trg_vocab_size).to(self.device)
        
        encoder_outputs, hidden = self.encoder(src)
        
        input = trg[0,:]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            
            outputs[t] = output
            
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            
            top1 = output.argmax(1) 
    
            input = trg[t] if teacher_force else top1
        
        return outputs

### Initialization + Number of Params

In [None]:
INPUT_DIM = len(de_vocab)
OUTPUT_DIM = len(en_vocab)
ENC_EMB_DIM = 128
DEC_EMB_DIM = 128
ENC_HID_DIM = 256
DEC_HID_DIM = 256
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, gpu, OUTPUT_DIM).to(gpu)

In [None]:
# initialize model weights
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

### Training

In [None]:
def train(model, iterator, optimizer, criterion, clip=1, num_epochs=10):
    model.train()
    losses = []
    for _ in range(num_epochs):
        for i, batch in tqdm(enumerate(iterator), desc="iteration"):
            src = batch.src
            trg = batch.trg 
            
            optimizer.zero_grad()

            output = model(src, trg)

            output_dim = output.shape[-1]

            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            
            loss = criterion(output, trg)
            losses.append(loss.item())
            loss.backward()

            # clip the gradients to prevent them from exploding (a common issue in RNNs)
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

            optimizer.step()
                    
    return model, losses

In [None]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
model, losses = train(model, trainset, optimizer, criterion)

### Training

In [None]:
plt.scatter(x=list(range(len(losses))), y=losses)

### Testing

In [None]:
def tensor_2_str(tensor, vocab=de_vocab.itos):
    return " ".join([vocab[int(token)] for token in tensor if vocab[int(token)] not in ['<eos>', '<pad>', '.']])

In [None]:
with torch.no_grad():
    sample = next(iter(trainset))
    src, trg = sample.src, sample.trg
    output = model(src, trg)
    output_tensor = output.argmax(2)[:, 0]
    target_tensor = trg[:, 0]
    
    output = tensor_2_str(output_tensor, en_vocab.itos)
    expected = tensor_2_str(target_tensor, en_vocab.itos)
    N = max(len(output), len(expected)) + len("Expected: ")
    
    print("="*N)
    print("Output: {}".format(output).center(N))
    print("="*N)
    
    print("="*N)
    print("Expected: {}".format(expected).center(N))
    print("="*N)