In [None]:
import pandas as pd

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare Data

In [None]:
str_col = ['reference', 'translation']
num_col = ['ref_tox', 'trn_tox', 'similarity', 'lenght_diff']

data = pd.read_csv("data/interim/preprocessed_filtered.tsv", sep='\t', index_col=0)

In [None]:
data.info()

# Prepare Dataloader

In [None]:
from src.data.utils import prepareData, get_dataloader, tensorFromSentence

# Architecture of seq2seq model

In [None]:
from src.models.seq2seq import Seq2Seq, Encoder, Decoder

# Train model

In [None]:
from src.models.train_model import train

In [None]:
epochs = 10
embed_size = 128
hidden_size = 128
batch_size = 32
lr = 1e-3
MAX_LENGTH = 11

vocab_tox, vocab_detox, pairs = prepareData(data, MAX_LENGTH)
train_dataloader, val_dataloader = get_dataloader(batch_size, vocab_tox, vocab_detox, pairs, MAX_LENGTH, device=device)

In [None]:
encoder = Encoder(vocab_tox.n_words, embed_size, hidden_size, vocab_tox, device=device, max_length=MAX_LENGTH)
decoder = Decoder(embed_size, hidden_size, vocab_detox.n_words, vocab_detox, device=device, max_length=MAX_LENGTH)
seq2seq_model = Seq2Seq(encoder, decoder).to(device)

optimizer = optim.Adam(seq2seq_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=vocab_tox.word2index['<pad>'])

loss_train, loss_val = train(seq2seq_model, train_dataloader, val_dataloader, optimizer=optimizer, criterion=criterion, epochs=epochs, lr=lr, model_path='seq2seq.pt')

# Plot loss

In [None]:
plt.plot(range(1, epochs + 1), loss_train, label='Training loss')
plt.plot(range(1, epochs + 1), loss_val, label='Validation loss')

plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

plt.legend(loc='best')
plt.show()

# Check in random sample from dataset

In [None]:
def evaluate(model, sentence, vocab_tox, vocab_detox):
    with torch.no_grad():
        model.eval()
        input_tensor = tensorFromSentence(vocab_tox, sentence, device=device)

        outputs = model(input_tensor)

        _, topi = outputs.topk(1)
        ids = topi.squeeze()

        words = []
        for idx in ids:
            if idx.item() == vocab_tox.word2index['<eos>']:
                break
            words.append(vocab_detox.index2word[idx.item()])
    return words

In [None]:
def evaluateRandomly(model, vocab_tox, vocab_detox, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('origin:     ', pair[0])
        print('translated: ', pair[1])
        output_words = evaluate(model, pair[0], vocab_tox, vocab_detox)
        output_sentence = "".join([" "+i if not i.startswith("'") and not i.startswith("n'") and i not in string.punctuation else i for i in output_words]).strip()
        print('predicted:  ', output_sentence)
        print('')

In [None]:
load_seq2seq_model = torch.load("seq2seq.pt")

load_seq2seq_model.eval()
evaluateRandomly(load_seq2seq_model, vocab_tox, vocab_detox)