# Import necessary libraries

In [None]:
from __future__ import unicode_literals, print_function, division

import numpy as np
import pandas as pd
import nltk
from tqdm.autonotebook import tqdm

from io import open
import unicodedata
import re
import string
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nltk.download('punkt')

# Read Data

In [None]:
# Read corpus of filename
def read_corpus(filename):
    data = []
    for line in open(filename, encoding='utf-8'):
        questions = line.split('\t')
        data.append(questions)
        data[-1][-1] = data[-1][-1][:-1]
    return data

# Prepare Data

In [None]:
str_col = ['reference', 'translation']
num_col = ['ref_tox', 'trn_tox', 'similarity', 'lenght_diff']

data = read_corpus("data/raw/filtered.tsv")
data[0][0] = 'id'
for i in range(1, len(data)):
    if float(data[i][-2]) > float(data[i][-1]):
        data[i][-2], data[i][-1] = data[i][-1], data[i][-2]
        data[i][1], data[i][2] = data[i][2], data[i][1]
        
Data = pd.DataFrame(data[1:], columns=data[0])
for num in num_col:
    Data[num] = pd.to_numeric(Data[num])

Data.index = pd.to_numeric(Data['id']).values
Data.info()

# Prepare Dataloader

In [None]:
SOS_token = 0
EOS_token = 1
PAD_token = 2

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "<sos>", 1: "<eos>", 2 : "<pad>"}
        self.n_words = 3

    def addSentence(self, sentence):
        for word in nltk.word_tokenize(sentence):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [None]:
MAX_LENGTH = 10 # Max length of sentences

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    return nltk.word_tokenize(s)


def filterPair(p):
    return len(nltk.word_tokenize(p[0])) < MAX_LENGTH and \
        len(nltk.word_tokenize(p[1])) < MAX_LENGTH


def filter(norm_ref, norm_trs):
    filter_ref = []
    filter_trs = []
    for pair in zip(norm_ref, norm_trs):
        if filterPair(pair):
            filter_ref.append(pair[0])
            filter_trs.append(pair[1])
    return filter_ref, filter_trs


def prepareData(data):
    # Filter every data
    filt_ref = [row for row in data['reference']]
    filt_trs = [row for row in data['translation']]
    
    filt_ref, filt_trs = filter(norm_ref, norm_trs)
    # Make Vocabulary instances
    vocab_tox = Vocabulary('tox-vocab')
    vocab_detox = Vocabulary('detox-vocab')
    pairs = []
    for row in zip(filt_ref, filt_trs):
        pairs.append(row)

    for row in filt_ref:
        vocab_tox.addSentence(row)

    for row in filt_trs:
        vocab_detox.addSentence(row)

    print("Counted words:")
    print(vocab_tox.name, vocab_tox.n_words)
    print(vocab_detox.name, vocab_detox.n_words)

    return vocab_tox, vocab_detox, pairs

In [None]:
def indexesFromSentence(vocab, sentence):
    return [vocab.word2index[word] for word in nltk.word_tokenize(sentence)]

def tensorFromSentence(vocab, sentence):
    indexes = indexesFromSentence(vocab, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(vocab_tox, pair[0])
    target_tensor = tensorFromSentence(vocab_detox, pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(batch_size, vocab_tox, vocab_detox, pairs, p=0.9):
    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH + 1), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH + 1), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(vocab_tox, inp)
        tgt_ids = indexesFromSentence(vocab_detox, tgt)
        inp_ids.append(EOS_token)
        tgt_ids.append(EOS_token)
        while len(inp_ids) < MAX_LENGTH + 1:
            inp_ids.append(PAD_token)
        
        while len(tgt_ids) < MAX_LENGTH + 1:
            tgt_ids.append(PAD_token)
        
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    idx = [i for i in range(n)]
    train_idx, val_idx = train_test_split(idx, train_size=p, random_state=420)
    train_data = TensorDataset(torch.LongTensor(input_ids[train_idx]).to(device),
                               torch.LongTensor(target_ids[train_idx]).to(device))
    val_data = TensorDataset(torch.LongTensor(input_ids[val_idx]).to(device),
                               torch.LongTensor(target_ids[val_idx]).to(device))
    

    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    return train_dataloader, val_dataloader