In [1]:
from argparse import Namespace
from collections import Counter
import json
import os
import re
import string
import bz2
import nltk

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook

In [2]:
def pre_processa(text, outer_sep, inner_sep, link_cutoff, default_no_link):

    # Text tokenizing
    source_sentences = []
    valid_links = []
    for source_sentence in nltk.sent_tokenize(text):
        tokens = nltk.word_tokenize(source_sentence)
        for i in range(len(tokens)):
            tokens[i] = tokens[i].lower()
            if is_a_link(tokens[i], outer_sep):
                valid_links.append(tokens[i].split(outer_sep)[-2])
        source_sentences.append(tokens)

    # Valid links is a set of the links which occur more than the treshold
    valid_links = Counter(valid_links)
    valid_links = set(link for link,frequence in valid_links.items() if frequence >= link_cutoff)
    sentences = []
    labels = []

    # Form input and label sequences
    for i, source_sentence in enumerate(source_sentences):
        sentence = []
        label = []
        for j, word in enumerate(source_sentence):
            if is_a_link(word, outer_sep):
                _split = list(filter(None, word.split(outer_sep)))
                if len(_split) == 2:
                    text, link = _split
                    sub_links = filter(None, text.split(inner_sep))
                    link = link.replace("_", " ") if link in valid_links else default_no_link
                    for sub_link in sub_links:
                        label.append(link)
                        sentence.append(sub_link)
                else:
                    word = word.replace(outer_sep, '')
                    label.append(default_no_link)
                    sentence.append(word)
            else:
                label.append(default_no_link)
                sentence.append(word)
        labels.append(label)
        sentences.append(sentence)
    return sentences, labels

def is_a_link(word, outer_sep):
        return len(word) >= 2 and word[0] == outer_sep and word[-1] == outer_sep

In [3]:
class WikiDataset(Dataset):

    test_slice = 0.15
    val_slice = 0.15
    inner_sep = '_'
    outer_sep = '|'
    link_cutoff = 1
    default_no_link = "<NONE>"

    def __init__(self, dataset):#, vectorizer):
        
        self._dataset = dataset
        #self._vectorizer = vectorizer

        self.train_ds = self._dataset['train']
        self.train_size = len(self.train_ds['source_sentences'])

        self.val_ds = self._dataset['val']
        self.val_size = len(self.val_ds['source_sentences'])


        self.test_ds = self._dataset['test']
        self.test_size = len(self.test_ds['source_sentences'])

        self._lookup_dict = {'train': (self.train_ds, self.train_size),
                             'val': (self.val_ds, self.val_size),
                             'test': (self.test_ds, self.test_size)}

        #self.set_split('train')

    @classmethod
    def is_a_link(cls, word):
        return len(word) >= 2 and word[0] == cls.outer_sep and word[-1] == cls.outer_sep

    @classmethod
    def read_dataset(cls, ds_path):
        
        text = bz2.BZ2File(ds_path).read().decode('utf-8')
        sentences, labels = cls.pre_process(text)
        train_size = int(len(sentences) * (1 - cls.test_slice - cls.val_slice))
        test_size = int(len(sentences) * cls.test_slice)
        return {
            'train': {'source_sentences': sentences[:train_size], 'target_labels' : labels[:train_size]},
            'test': {'source_sentences': sentences[train_size:train_size+test_size], 'target_labels' : labels[train_size:train_size+test_size]},
            'val': {'source_sentences' : sentences[train_size+test_size:], 'target_labels' : labels[train_size+test_size:]}
        }

    @classmethod
    def pre_process(cls, text):

        # Text tokenizing
        source_sentences = []
        valid_links = []
        for source_sentence in nltk.sent_tokenize(text):
            tokens = nltk.word_tokenize(source_sentence)
            for i in range(len(tokens)):
                tokens[i] = tokens[i].lower()
                if cls.is_a_link(tokens[i]):
                    valid_links.append(tokens[i].split(cls.outer_sep)[-2])
            source_sentences.append(tokens)

        # Valid links is a set of the links which occur more than the treshold
        valid_links = Counter(valid_links)
        valid_links = set(link for link,frequence in valid_links.items() if frequence >= cls.link_cutoff)
        sentences = []
        labels = []

        # Form input and label sequences
        for i, source_sentence in enumerate(source_sentences):
            sentence = []
            label = []
            for j, word in enumerate(source_sentence):
                if cls.is_a_link(word):
                    _split = list(filter(None, word.split(cls.outer_sep)))
                    if len(_split) == 2:
                        text, link = _split
                        sub_links = filter(None, text.split(cls.inner_sep))
                        link = link.replace("_", " ") if link in valid_links else cls.default_no_link
                        for sub_link in sub_links:
                            label.append(link)
                            sentence.append(sub_link)
                    else:
                        word = word.replace(cls.outer_sep, '')
                        label.append(cls.default_no_link)
                        sentence.append(word)
                else:
                    label.append(cls.default_no_link)
                    sentence.append(word)
            labels.append(label)
            sentences.append(sentence)
        return sentences, labels

    @classmethod
    def load_dataset_and_make_vectorizer(cls, ds_path, strip_punctuation=True):
        """ Load dataset and make a new vectorizer from scratch """

        ds = cls.read_dataset(ds_path)
        return cls(ds)
        #return cls(ds, ReviewVectorizer.from_dataframe(ds))

    def get_vectorizer(self):
        """ returns the vectorizer """
        return self._vectorizer

    def set_split(self, split="train"):
        """ selects the splits in the dataset using a column in the dataframe 
        
        Args:
            split (str): one of "train", "val", or "test"
        """
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        """the primary entry point method for PyTorch datasets
        
        Args:
            index (int): the index to the data point 
        Returns:
            a dictionary holding the data point's features (x_data) and label (y_target)
        """
        row = self._target_df.iloc[index]

        review_vector = \
            self._vectorizer.vectorize(row.review)

        rating_index = \
            self._vectorizer.rating_vocab.lookup_token(row.rating)

        return {'x_data': review_vector,
                'y_target': rating_index}

    def get_num_batches(self, batch_size):
        """Given a batch size, return the number of batches in the dataset
        
        Args:
            batch_size (int)
        Returns:
            number of batches in the dataset
        """
        return len(self) // batch_size  
        
    def generate_batches(dataset, batch_size, shuffle=True,
                        drop_last=True, device="cpu"):
        """
        A generator function which wraps the PyTorch DataLoader. It will 
        ensure each tensor is on the write device location.
        """
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                                shuffle=shuffle, drop_last=drop_last)

        for data_dict in dataloader:
            out_data_dict = {}
            for name, tensor in data_dict.items():
                out_data_dict[name] = data_dict[name].to(device)
            yield out_data_dict

In [4]:
dataset = WikiDataset.load_dataset_and_make_vectorizer("../input_data/wiki.txt.bz2")

In [22]:
y = 671
list(zip(dataset.train_ds["source_sentences"][y], dataset.train_ds["target_labels"][y]))

[('they', '<NONE>'),
 ('are', '<NONE>'),
 ('generally', '<NONE>'),
 ('believed', '<NONE>'),
 ('to', '<NONE>'),
 ('have', '<NONE>'),
 ('been', '<NONE>'),
 ('a', '<NONE>'),
 ('germanic', 'germanic peoples'),
 ('tribe', 'germanic peoples'),
 ('originating', '<NONE>'),
 ('in', '<NONE>'),
 ('jutland', 'jutland'),
 (',', '<NONE>'),
 ('but', '<NONE>'),
 ('celtic', 'celts'),
 ('influences', '<NONE>'),
 ('have', '<NONE>'),
 ('also', '<NONE>'),
 ('been', '<NONE>'),
 ('suggested', '<NONE>'),
 ('.', '<NONE>')]