In [1]:
import os
import re
import random
import spacy
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from tqdm.notebook import tqdm

# Data Preprocessing

In [3]:
with open('eng-fra.txt', encoding='utf-8') as f:
    content = f.readlines()

eng_fre_pairs = []
for line in content:
    line = re.sub ('\u202f', '', line)
    line = line.strip('\n').split('\t')
    eng_fre_pairs.append(line)

FileNotFoundError: [Errno 2] No such file or directory: 'eng-fra.txt'

In [None]:
len(eng_fre_pairs)

In [None]:
train = []
val = []
test = []
for i, pair in enumerate(eng_fre_pairs):
    if i%5 == 0:
        test.append(pair)
    elif i%17 == 0:
        val.append(pair)
    else:
        train.append(pair)


## Tokenization

### Spacy for tokenizing words 

In [None]:
# Download
# ! python -m spacy download fr
# ! python -m spacy download en
eng_lm = spacy.load('en')
fre_lm = spacy.load('fr')

Creating word(token) 2 index and index 2 word dictionaries for better access and dynamic output equivalance of train token index outputs to val and test tokens during loss calculation .

In [None]:
def tokenize_list(data):
    tokens_eng1 = list(dict.fromkeys([tok.text for tok in eng_lm.tokenizer(" ".join(i[0] for i in data))]))
    tokens_eng1.insert(0, '<sos>')
    tokens_eng1.insert(1, '<eos>')
    tokens_eng1.insert(2, '<unk>')
    tokens_fre1 = list(dict.fromkeys([tok.text for tok in fre_lm.tokenizer(" ".join(i[1] for i in data))]))
    tokens_fre1.insert(0, '<sos>')
    tokens_fre1.insert(1, '<eos>')
    tokens_fre1.insert(2, '<unk>')
    return tokens_eng1, tokens_fre1

tokens_eng_train, tokens_fre_train = tokenize_list(train)
tokens_eng_val, tokens_fre_val = tokenize_list(val)
tokens_eng_test, tokens_fre_test = tokenize_list(test)


In [None]:
def token_index(token_list):
    w2i = {tok : i for i, tok in enumerate(token_list)}
    i2w = dict([(value, key) for key, value in w2i.items()])
    return w2i, i2w

w2i_eng_train, i2w_eng_train = token_index(tokens_eng_train)
w2i_eng_val, i2w_eng_val = token_index(tokens_eng_val)
w2i_eng_test, i2w_eng_test  = token_index(tokens_eng_test)

w2i_fre_train, i2w_fre_train = token_index(tokens_fre_train)
w2i_fre_val, i2w_fre_val = token_index(tokens_fre_val)
w2i_fre_test, i2w_fre_test = token_index(tokens_fre_test)
    

### Data split 

In [None]:
print(" Length : train {}, val {}, test {} ".format(len(train), len(val), len(test)))
print(" Percentage split : train {}, val {}, test {} ".format(len(train)/len(eng_fre_pairs)*100, len(val)/len(eng_fre_pairs)*100, len(test)/len(eng_fre_pairs) *100))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Custom Data initialization

In [None]:
class Translation(Dataset):
    
    def __init__(self, eng_fre_data, eng_indeces, fre_indeces):

        #can either be train, val or test depending on the loader
        self.data = eng_fre_data 
        
        self.eng_indeces = eng_indeces
        self.fre_indeces = fre_indeces
        
    def __getitem__(self, index):
        
        input_tknzd = eng_lm.tokenizer(self.data[index][0])
        label_tknzd = fre_lm.tokenizer(self.data[index][1])
        
        input_indeces = [self.eng_indeces['<sos>']]    
        input_indeces[1:] = [self.eng_indeces[str(tok)] if str(tok) in 
                             self.eng_indeces else self.eng_indeces['<unk>'] for tok in input_tknzd]
        input_indeces.append(self.eng_indeces['<eos>'])
        
        label_indeces = [self.fre_indeces['<sos>']]
        label_indeces[1:] = [self.fre_indeces[str(tok)] if str(tok) in 
                             self.fre_indeces else self.fre_indeces['<unk>'] for tok in label_tknzd]
        label_indeces.append(self.fre_indeces['<eos>'])
        
        return torch.LongTensor(input_indeces), torch.LongTensor(label_indeces)
    
    def __len__(self):
        return len(self.data)
        

#  Dynamically varying batch shape conditioned on the max seqlen of an input in batch
def collate_fn(data):
    batch_inputs, batch_labels = zip(*data)
    
    inp_len = [len(inp) for inp in batch_inputs]
    label_len = [len(label) for label in batch_labels]
    
    inputs = torch.zeros((len(batch_inputs), max(inp_len)), dtype = torch.int64)
    labels = torch.zeros((len(batch_labels), max(label_len)), dtype = torch.int64)
    
    for i, inp in enumerate(batch_inputs):
        inputs[i, :len(inp)] = inp
    for i, label in enumerate(batch_labels):
        labels[i, :len(label)] = label
    
    return inputs, labels

# Encoder-Decoder Architecture
(Attention - based)

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, inp_vocab_dim, enc_hid_dim, embed_dim, dec_hid_dim, drop_prob):
        super().__init__()

        self.embed = nn.Embedding(inp_vocab_dim, embed_dim)
        self.gru = nn.GRU(embed_dim, enc_hid_dim, bidirectional = True)
        
        # concatenate two bidirectional hidden vectors and pass through a linear layer to generate one hidden 
        # vector of decoder hidden size
        self.linear = nn.Linear(enc_hid_dim*2, dec_hid_dim) 
        self.dropout = nn.Dropout(drop_prob)
        
    def forward(self, inp_sntc):

        inp_embed = self.dropout(self.embed(inp_sntc))
        outputs, hidden = self.gru(inp_embed)

        open_hidden = torch.cat((hidden[0, :, :], hidden[1, :, :]), dim = 1)
        hidden = torch.tanh(self.linear(open_hidden))
        
        return outputs, hidden   
    
class Attention(nn.Module):
    
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.sim = nn.Linear((enc_hid_dim*2) + dec_hid_dim, dec_hid_dim)
        self.weight = nn.Parameter(torch.rand(dec_hid_dim))
        
    def forward(self, curr_dec_hid, enc_outputs):
        
        sntc_length = enc_outputs.shape[0]
        batch_size = enc_outputs.shape[1]
        
        curr_dec_hid = curr_dec_hid.unsqueeze(1).repeat(1, sntc_length, 1) 
        enc_outputs =  enc_outputs.permute(1, 0, 2)
        
        e = torch.tanh(self.sim(torch.cat((curr_dec_hid, enc_outputs), dim = 2))) 
        e = e.permute(0, 2, 1)

        weight = self.weight.repeat(batch_size, 1).unsqueeze(1)

        attn_dist = torch.bmm(weight, e).squeeze(1)
        norm_attn = torch.softmax(attn_dist, dim = 1)
        
        return norm_attn
    
class Decoder(nn.Module):
    
    def __init__(self, vocab_dim, embed_dim, attn, enc_hid_dim, dec_hid_dim, drop_prob ):
        super().__init__()
        
        self.embed = nn.Embedding(vocab_dim, embed_dim)
        self.attn = attn
        
        self.vocab_dim = vocab_dim
        self.gru = nn.GRU((enc_hid_dim*2) + embed_dim, dec_hid_dim)
        
        self.linear = nn.Linear(enc_hid_dim*2 + dec_hid_dim + embed_dim, vocab_dim)
        self.dropout = nn.Dropout(drop_prob)
    
    def forward(self, inp, dec_hidden, enc_outputs):
        
        inp = inp.unsqueeze(0)
        embedding = self.dropout(self.embed(inp))
        
        norm_attn = self.attn(dec_hidden, enc_outputs)    
        norm_attn = norm_attn.unsqueeze(1)
        
        enc_outputs = enc_outputs.permute(1, 0, 2)

        weighted_sum = torch.bmm(norm_attn, enc_outputs)
        weighted_sum = weighted_sum.permute(1, 0, 2)

        output, hidden = self.gru(torch.cat((embedding, weighted_sum), dim = 2), dec_hidden.unsqueeze(0) )
        
        assert (output == hidden).all()
        
        embedding = embedding.squeeze(0)
        output = output.squeeze(0)
        weighted_sum = weighted_sum.squeeze(0)
        
        next_word = self.linear(torch.cat((output, weighted_sum, embedding), dim = 1))
        
        return next_word, hidden.squeeze(0)

class seq2seq(nn.Module):
    
    def __init__(self, device, encoder, decoder):
        super().__init__()
        self.device = device
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, sntc_input, sntc_label, thresh = 0.5):
        
        sntc_input = sntc_input.permute(1, 0)
        sntc_label = sntc_label.permute(1, 0)
        
        enc_outputs, hidden = self.encoder(sntc_input)
        
        label_len = sntc_label.shape[0]
        batch_size = sntc_input.shape[1]

        vocab_dim = self.decoder.vocab_dim
        
        dec_outputs = torch.zeros(label_len, batch_size, vocab_dim).to(self.device)
        input_word = sntc_label[0, :]
        
        for i in range(1, label_len):
            output, hidden = self.decoder(input_word, hidden, enc_outputs)
            dec_outputs[i] = output
            pred_next_word = output.argmax(1)
            input_word = sntc_label[i] if random.random() < thresh else pred_next_word
        
        return dec_outputs   

In [None]:
inp_vocab_dim = len(w2i_eng_train)
enc_hid_dim = 512
embed_dim = 256
dec_hid_dim = 512
drop_prob = 0.5
label_vocab_dim = len(w2i_fre_train)
enc = Encoder(inp_vocab_dim, enc_hid_dim, embed_dim, dec_hid_dim, drop_prob)
attn = Attention(enc_hid_dim, dec_hid_dim)
dec = Decoder(label_vocab_dim, embed_dim, attn, enc_hid_dim, dec_hid_dim, drop_prob)

def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean = 0, std = 0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model = seq2seq(device, enc, dec).to(device)
model.apply(init_weights)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
next(model.parameters()).is_cuda

In [None]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss().to(device)
epochs = 2
batch_size = 100
best_loss = float('inf')

def val_test(model, data_loader, criterion, status):
        
        total_loss = 0
        model.eval()
        
        for inp, label in data_loader:

#             for token equivalence btw train and val or test
            if status == 'val' :
                for i in range(len(inp)):
                    inp[i, :] = torch.LongTensor( [w2i_eng_train[i2w_eng_val[w_i.item()]] if i2w_eng_val[w_i.item()] in w2i_eng_train  
                                     else w2i_eng_train['<unk>'] for w_i in list(inp[i, :])])

                for i in range(len(label)):
                    label[i, :] = torch.LongTensor([w2i_fre_train[i2w_fre_val[w_i.item()]] if i2w_fre_val[w_i.item()] in w2i_fre_train  
                                     else w2i_fre_train['<unk>'] for w_i in list(label[i, :])])

            if status == 'test':
                for i in range(len(inp)):
                    inp[i, :] = torch.LongTensor( [w2i_eng_train[i2w_eng_test[w_i.item()]] if i2w_eng_test[w_i.item()] in w2i_eng_train  
                                     else w2i_eng_train['<unk>'] for w_i in list(inp[i, :])])

                for i in range(len(label)):
                    label[i, :] = torch.LongTensor([w2i_fre_train[i2w_fre_test[w_i.item()]] if i2w_fre_test[w_i.item()] in w2i_fre_train  
                                     else w2i_fre_train['<unk>'] for w_i in list(label[i, :])])

            inp = inp.to(device)
            label = label.to(device)

            output = model(inp, label)
            output = output[1:].view(-1, output.shape[-1])

            label = label.permute(1, 0)
            label = label[1:].reshape(-1)

            loss = criterion(output, label)
            total_loss += loss.item()
        
        return total_loss

for epoch in tqdm(range(epochs)):
    
    trans_data = Translation(train, w2i_eng_train, w2i_fre_train)
    trans_data_val =  Translation(val, w2i_eng_val, w2i_fre_val)
    trans_data_test =  Translation(test, w2i_eng_test, w2i_fre_test)
    
    data_loader_train = DataLoader(dataset = trans_data, batch_size = batch_size,
                                   shuffle = False,  collate_fn = collate_fn)
    data_loader_val = DataLoader(dataset = trans_data_val, batch_size = batch_size,
                                shuffle = False, collate_fn = collate_fn)
    data_loader_test = DataLoader(dataset = trans_data_test, batch_size = batch_size,
                                shuffle = False, collate_fn = collate_fn)
    
    train_loss = 0
    
    model.train()
    for inp, label in tqdm(data_loader_train):
        inp = inp.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        output = model(inp, label)
        
        output = output[1:].view(-1, output.shape[-1])
        label = label.permute(1, 0)
        label = label[1:].reshape(-1)
        
        try:
            loss = criterion(output, label)
        except:
            print(output.shape, label.shape)
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item() 
    
    val_loss = val_test(model, data_loader_val, criterion, 'val')      
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'attn_seq2seq.pth')
    print('Epoch : ', epoch)
    print('Train loss per-input : ', train_loss/len(data_loader_train))
    print('Val loss per-input : ', val_loss/len(data_loader_val))

In [None]:
test_loss = val_test(model, data_loader_test, criterion, 'test') 
print('Test loss per-input : ', test_loss/len(data_loader_test))

#### Val loss - decreasing every epoch - generalizing well
#### Test loss - considerably low for just 4 epochs. ( Further analysis, below )

# TODO
###  1) Translation outputs analysis
###  2) Attention effect
###  3) robustness check - well devised adversarial samples probably
###  4) Long term dependancy check - long sentences performance essentially
###  5) Sentence structure check - grammar etc. thru metrics