In [1]:
from argparse import Namespace
from collections import Counter
import json
import os
import re
import string
import bz2
import nltk

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook

In [2]:
class Vocabulary(object):
    """Class to process text and extract vocabulary for mapping"""

    def __init__(self, token_to_idx=None):
        if token_to_idx is None:
            token_to_idx = dict()
        self._token_to_idx = token_to_idx
        self._idx_to_token = {idx: token for token, idx in self._token_to_idx.items()}
        
    def to_serializable(self):
        """ Returns a dictionary that can be serialized """
        return {'token_to_idx': self._token_to_idx}

    @classmethod
    def from_serializable(cls, contents):
        """ Instantiates the Vocabulary from a serialized dictionary """
        return cls(**contents)

    def add_token(self, token):
        """ Update mapping dicts based on the token """
        if token in self._token_to_idx:
            index = self._token_to_idx[token]
        else:
            index = len(self._token_to_idx)
            self._token_to_idx[token] = index
            self._idx_to_token[index] = token
        return index
            
    def add_many(self, tokens):
        """ Add a list of tokens into the Vocabulary """
        return [self.add_token(token) for token in tokens]

    def lookup_token(self, token):
        """ Retrieve the index associated with the token """
        return self._token_to_idx[token]

    def lookup_index(self, index):
        """ Return the token associated with the index """
        if index not in self._idx_to_token:
            raise KeyError("the index (%d) is not in the Vocabulary" % index)
        return self._idx_to_token[index]

    def __str__(self):
        return "<Vocabulary(size=%d)>" % len(self)

    def __len__(self):
        return len(self._token_to_idx)

In [3]:
class SequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, unk_token="<UNK>", mask_token="<MASK>", begin_seq_token="<BEGIN>", end_seq_token="<END>"):

        super(SequenceVocabulary, self).__init__(token_to_idx)

        self._mask_token = mask_token
        self._unk_token = unk_token
        self._begin_seq_token = begin_seq_token
        self._end_seq_token = end_seq_token

        self.mask_index = self.add_token(self._mask_token)
        self.unk_index = self.add_token(self._unk_token)
        self.begin_seq_index = self.add_token(self._begin_seq_token)
        self.end_seq_index = self.add_token(self._end_seq_token)

    def to_serializable(self):
        contents = super(SequenceVocabulary, self).to_serializable()
        contents.update({'unk_token': self._unk_token,
                         'mask_token': self._mask_token,
                         'begin_seq_token': self._begin_seq_token,
                         'end_seq_token': self._end_seq_token})
        return contents

    def lookup_token(self, token):
        """Retrieve the index associated with the token 
          or the UNK index if token isn't present.
        
        Args:
            token (str): the token to look up 
        Returns:
            index (int): the index corresponding to the token
        Notes:
            `unk_index` needs to be >=0 (having been added into the Vocabulary) 
              for the UNK functionality 
        """
        if self.unk_index >= 0:
            return self._token_to_idx.get(token, self.unk_index)
        else:
            return self._token_to_idx[token]

In [4]:
class WikiDataset(Dataset):

    test_slice = 0.15
    val_slice = 0.15
    inner_sep = '_'
    outer_sep = '|'
    link_cutoff = 1
    default_no_link = "<NONE>"

    def __init__(self, dataset, vectorizer, strip_punctuation):
        
        self._dataset = dataset
        self._vectorizer = vectorizer
        self._strip_punctuation = strip_punctuation

        self.train_ds = self._dataset['train']
        self.train_size = len(self.train_ds['source_sentences'])

        self.val_ds = self._dataset['val']
        self.val_size = len(self.val_ds['source_sentences'])


        self.test_ds = self._dataset['test']
        self.test_size = len(self.test_ds['source_sentences'])

        self._lookup_dict = {'train': (self.train_ds, self.train_size),
                             'val': (self.val_ds, self.val_size),
                             'test': (self.test_ds, self.test_size)}

        self.set_split('train')

    @classmethod
    def is_a_link(cls, word):
        return len(word) >= 2 and word[0] == cls.outer_sep and word[-1] == cls.outer_sep

    @classmethod
    def pre_process(cls, text):

        # Text tokenizing
        source_sentences = []
        valid_links = []
        for source_sentence in nltk.sent_tokenize(text):
            tokens = nltk.word_tokenize(source_sentence)
            for i in range(len(tokens)):
                tokens[i] = tokens[i].lower()
                if cls.is_a_link(tokens[i]):
                    valid_links.append(tokens[i].split(cls.outer_sep)[-2])
            source_sentences.append(tokens)

        # Valid links is a set of the links which occur more than the treshold
        valid_links = Counter(valid_links)
        valid_links = set(link for link,frequence in valid_links.items() if frequence >= cls.link_cutoff)
        sentences = []
        labels = []

        # Form input and label sequences
        for i, source_sentence in enumerate(source_sentences):
            sentence = []
            label = []
            for j, word in enumerate(source_sentence):
                if cls.is_a_link(word):
                    _split = list(filter(None, word.split(cls.outer_sep)))
                    if len(_split) == 2:
                        text, link = _split
                        sub_links = filter(None, text.split(cls.inner_sep))
                        link = link.replace("_", " ") if link in valid_links else cls.default_no_link
                        for sub_link in sub_links:
                            label.append(link)
                            sentence.append(sub_link)
                    else:
                        word = word.replace(cls.outer_sep, '')
                        label.append(cls.default_no_link)
                        sentence.append(word)
                else:
                    label.append(cls.default_no_link)
                    sentence.append(word)
            labels.append(label)
            sentences.append(sentence)
        return sentences, labels

    @classmethod
    def read_dataset(cls, ds_path):
        
        text = bz2.BZ2File(ds_path).read().decode('utf-8')
        sentences, labels = cls.pre_process(text)
        train_size = int(len(sentences) * (1 - cls.test_slice - cls.val_slice))
        test_size = int(len(sentences) * cls.test_slice)
        return {
            'train': {'source_sentences': sentences[:train_size], 'target_labels' : labels[:train_size]},
            'test': {'source_sentences': sentences[train_size:train_size+test_size], 'target_labels' : labels[train_size:train_size+test_size]},
            'val': {'source_sentences' : sentences[train_size+test_size:], 'target_labels' : labels[train_size+test_size:]}
        }

    @classmethod
    def load_dataset_and_make_vectorizer(cls, ds_path, strip_punctuation=True):
        """ Load dataset and make a new vectorizer from scratch """

        ds = cls.read_dataset(ds_path)
        return cls(ds, Vectorizer.from_dataframe(ds), strip_punctuation)

    def get_vectorizer(self):
        """ returns the vectorizer """
        return self._vectorizer

    def set_split(self, split="train"):
        """ Selects the splits in the dataset
        Args:
            split (str): one of "train", "val", or "test"
        """
        self._target_split = split
        self._target_ds, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """the primary entry point method for PyTorch datasets
        
        Args:
            index (int): the index to the data point 
        Returns:
            a dictionary holding the data point's features (x_data) and label (y_target)
        """

        sentence = self._target_ds['source_sentences'][index]
        links = self._target_ds['target_labels'][index]
        vector_dict = self._vectorizer.vectorize(sentence, links)

        return {'x_source': vector_dict['source_vector'], 'y_target': vector_dict['target_vector'], 'x_source_length' : vector_dict['source_length']}

    def get_num_batches(self, batch_size):
        """Given a batch size, return the number of batches in the dataset"""
        return len(self) // batch_size  
        
    def generate_batches(dataset, batch_size, shuffle=True, drop_last=True, device="cpu"):
        """
        A generator function which wraps the PyTorch DataLoader. It will 
        ensure each tensor is on the write device location.
        """
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                                shuffle=shuffle, drop_last=drop_last)

        for data_dict in dataloader:
            lengths = data_dict['x_source_length'].numpy()
            sorted_length_indices = lengths.argsort()[::-1].tolist()
            
            out_data_dict = dict()
            for name, tensor in data_dict.items():
                out_data_dict[name] = data_dict[name].to(device)
            yield out_data_dict

In [26]:
class Vectorizer(object):
    """ The Vectorizer which coordinates the Vocabularies and puts them to use"""        
    def __init__(self, source_vocab, target_vocab, max_source_length):
        """
        Args:
            source_vocab (SequenceVocabulary): maps source words to integers
            target_vocab (SequenceVocabulary): maps target words to integers
            max_source_length (int): the longest sequence in the source dataset
        """
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.max_source_length = max_source_length
        

    def _vectorize(self, indices, vector_length=-1, mask_index=0):
        """Vectorize the provided indices
        
        Args:
            indices (list): a list of integers that represent a sequence
            vector_length (int): an argument for forcing the length of index vector
            mask_index (int): the mask_index to use; almost always 0
        """
        if vector_length < 0:
            vector_length = len(indices)
        
        vector = np.zeros(vector_length, dtype=np.int64)
        vector[:len(indices)] = indices
        vector[len(indices):] = mask_index
        return vector

    def _get_indices(self, tokens):
        indices = [self.source_vocab.begin_seq_index]
        indices.extend(self.source_vocab.lookup_token(token) for token in tokens)
        indices.append(self.source_vocab.end_seq_index)
        return indices

        
    def vectorize(self, source_words, target_links, vector_length=-1):
        """Return the vectorized source and target vectors
        Args:
            source_words (list): text tokens from the source vocabulary
            target_links (list): link tokens from the target vocabulary
            vector_length (int): an argument for forcing the length of index vector
        Returns:
            A tuple: (source_vector, target_vector)
        """

        source_indices = self._get_indices(source_words)
        target_indices = self._get_indices(target_links)

        source_vector = self._vectorize(source_indices, vector_length=vector_length, mask_index=self.source_vocab.mask_index)
        target_vector = self._vectorize(target_indices, vector_length=vector_length, mask_index=self.target_vocab.mask_index)

        return {'source_vector': source_vector, 
                'target_vector': target_vector, 
                'source_length': len(source_indices)}
        
    @classmethod
    def from_dataframe(cls, ds):
        """Instantiate the vectorizer from the dataset dataframe
        
        Args:
            bitext_df (pandas.DataFrame): the parallel text dataset
        Returns:
            an instance of the NMTVectorizer
        """
        source_vocab = SequenceVocabulary()
        target_vocab = SequenceVocabulary()
        max_source_length = 0

        for _, split in ds.items():
            for source_sequence in split['source_sentences']:
                max_source_length = max(max_source_length, len(source_sequence))
                for token in source_sequence:
                    source_vocab.add_token(token)

            for target_sequence in split['target_labels']:
                for token in target_sequence:
                    target_vocab.add_token(token)
            
        return cls(source_vocab, target_vocab, max_source_length)

    @classmethod
    def from_serializable(cls, contents):
        source_vocab = SequenceVocabulary.from_serializable(contents["source_vocab"])
        target_vocab = SequenceVocabulary.from_serializable(contents["target_vocab"])
        return cls(source_vocab=source_vocab, 
                   target_vocab=target_vocab, 
                   max_source_length=contents["max_source_length"])

    def to_serializable(self):
        return {"source_vocab": self.source_vocab.to_serializable(), 
                "target_vocab": self.target_vocab.to_serializable(), 
                "max_source_length": self.max_source_length}

In [28]:
dataset = WikiDataset.load_dataset_and_make_vectorizer("../input_data/wiki.txt.bz2")

In [29]:
#y = 672
#list(zip(dataset.train_ds["source_sentences"][y], dataset.train_ds["target_labels"][y]))

dataset.get_vectorizer().vectorize(["yasa"], ["yyy"])

{'source_vector': array([2, 1, 3]),
 'target_vector': array([     2, 111494,      3]),
 'source_length': 3}