In [None]:
from argparse import Namespace
from collections import Counter
import json
import os
import re
import string
import bz2
import nltk

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
class WikiDataset(Dataset):

    test_slice = 0.15
    val_slice = 0.15
    inner_sep = '_'
    outer_sep = '|'
    link_cutoff = 3
    _text_token = "<TEXT>"

    def __init__(self, dataset, max_seq_len, strip_punctuation):
        
        self._dataset = dataset
        self._max_seq_len = max_seq_len
        self._strip_punctuation = strip_punctuation

        self.train_ds = self._dataset['train']
        self.train_size = len(self.train_ds['source_sentences'])

        self.val_ds = self._dataset['val']
        self.val_size = len(self.val_ds['source_sentences'])

        self.test_ds = self._dataset['test']
        self.test_size = len(self.test_ds['source_sentences'])

        self._lookup_dict = {'train': (self.train_ds, self.train_size),
                             'val': (self.val_ds, self.val_size),
                             'test': (self.test_ds, self.test_size)}

        self.set_split('train')

    @classmethod
    def is_a_link(cls, word):
        return len(word) >= 2 and word[0] == cls.outer_sep and word[-1] == cls.outer_sep

    @classmethod
    def pre_process(cls, text):

        # Text tokenizing
        source_sentences = []
        valid_links = []
        for source_sentence in nltk.sent_tokenize(text):
            tokens = nltk.word_tokenize(source_sentence)
            for i in range(len(tokens)):
                tokens[i] = tokens[i].lower()
                if cls.is_a_link(tokens[i]):
                    valid_links.append(tokens[i].split(cls.outer_sep)[-2])
            source_sentences.append(tokens)

        # Valid links is a set of the links which occur more than the treshold
        valid_links = Counter(valid_links)
        valid_links = set(link for link,frequence in valid_links.items() if frequence >= cls.link_cutoff)
        sentences = []
        labels = []
        max_seq_len = 0

        # Form input and label sequences
        for i, source_sentence in enumerate(source_sentences):
            sentence = []
            label = []
            for j, word in enumerate(source_sentence):
                if cls.is_a_link(word):
                    _split = list(filter(None, word.split(cls.outer_sep)))
                    if len(_split) == 2:
                        text, link = _split
                        sub_links = filter(None, text.split(cls.inner_sep))
                        link = link.replace("_", " ") if link in valid_links else cls._text_token
                        for sub_link in sub_links:
                            label.append(link)
                            sentence.append(sub_link)
                    else:
                        word = word.replace(cls.outer_sep, '').replace(cls.inner_sep, ' ')
                        label.append(cls._text_token)
                        sentence.append(word)
                else:
                    label.append(cls._text_token)
                    sentence.append(word)
            labels.append(label)
            sentences.append(sentence)
            max_seq_len = max(max_seq_len, len(sentence))

        return sentences, labels, max_seq_len

    @classmethod
    def read_dataset(cls, ds_path):
        
        text = bz2.BZ2File(ds_path).read().decode('utf-8')
        sentences, labels, max_seq_len = cls.pre_process(text)
        train_size = int(len(sentences) * (1 - cls.test_slice - cls.val_slice))
        test_size = int(len(sentences) * cls.test_slice)
        return ({
            'train': {'source_sentences': sentences[:train_size], 'target_labels' : labels[:train_size]},
            'test': {'source_sentences': sentences[train_size:train_size+test_size], 'target_labels' : labels[train_size:train_size+test_size]},
            'val': {'source_sentences' : sentences[train_size+test_size:], 'target_labels' : labels[train_size+test_size:]}
        }, max_seq_len)

    @classmethod
    def load_dataset(cls, ds_path, strip_punctuation=True):
        """ Load dataset and make a new vectorizer from scratch """
        ds, max_seq_len = cls.read_dataset(ds_path)
        return cls(ds, max_seq_len, strip_punctuation)

    def encode_from(self, vocabulary):

        for i, sentence in enumerate(self.train_ds['source_sentences']):
            for j, token in enumerate(sentence):
                self.train_ds['source_sentences'][i][j] = vocabulary.lookup_word(token)

        for i, sentence in enumerate(self.val_ds['source_sentences']):
            for j, token in enumerate(sentence):
                self.val_ds['source_sentences'][i][j] = vocabulary.lookup_word(token)

        for i, sentence in enumerate(self.test_ds['source_sentences']):
            for j, token in enumerate(sentence):
                self.test_ds['source_sentences'][i][j] = vocabulary.lookup_word(token)

        for i, sentence in enumerate(self.train_ds['target_labels']):
            for j, token in enumerate(sentence):
                self.train_ds['target_labels'][i][j] = vocabulary.lookup_link(token)

        for i, sentence in enumerate(self.val_ds['target_labels']):
            for j, token in enumerate(sentence):
                self.val_ds['target_labels'][i][j] = vocabulary.lookup_link(token)

        for i, sentence in enumerate(self.test_ds['target_labels']):
            for j, token in enumerate(sentence):
                self.test_ds['target_labels'][i][j] = vocabulary.lookup_link(token)


    def set_split(self, split="train"):
        """ Selects the splits in the dataset, from 'train', 'val' or 'test' """
        self._target_split = split
        self._target_ds, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):
        sentence = self._target_ds['source_sentences'][index]
        links = self._target_ds['target_labels'][index]
        #vector_dict = self._vectorizer.vectorize(sentence, links)
        return {'x_source': sentence, 'y_target': links, 'x_source_length' : len(sentence)}

    def get_num_batches(self, batch_size):
        """Given a batch size, return the number of batches in the dataset"""
        return len(self) // batch_size  

In [None]:
class Vocabulary(object):

    _padding_token = '<PAD>'
    _unknown_token = '<UNK>'
    
    def __init__(self, word_to_index, link_to_index, text_token):
        self._word_to_index = word_to_index
        self._link_to_index = link_to_index
        self._index_to_word = {i:w for i,w in enumerate(word_to_index)}
        self._index_to_link = {i:l for i,l in enumerate(link_to_index)}
        self.source_size = len(word_to_index)
        self.target_size = len(link_to_index)
        self._text_token = text_token
    
    def lookup_word(self, word):
        return self._word_to_index.get(word, 1)

    def lookup_link(self, link):
        return self._link_to_index.get(link, 0)

    def lookup_word_index(self, index):
        return self._index_to_word.get(index, self._unknown_token)

    def lookup_link_index(self, index):
        return self._index_to_link.get(index, self._text_token)

    @classmethod
    def of(cls, ds):
        source_vocab = dict()
        target_vocab = dict()
        text_token = ds._text_token
        source_vocab = {cls._padding_token : 0, cls._unknown_token : 1}
        target_vocab = {cls._padding_token : 0, text_token : 1}

        for source_sequence in ds.train_ds['source_sentences']:
            for token in source_sequence:
                if token not in source_vocab:
                    source_vocab[token] = len(source_vocab)

        for target_sequence in ds.train_ds['target_labels']:
            for token in target_sequence:
                if token not in target_vocab:
                    target_vocab[token] = len(target_vocab)

        return cls(source_vocab, target_vocab, text_token)

In [None]:
def _vectorize(indices, padding_index, vector_length):
    vector = np.zeros(vector_length, dtype=np.int)
    vector[:len(indices)] = indices
    vector[len(indices):] = padding_index
    return vector.tolist()

def vectorize(input_sequence, padding_index=0, vector_length=-1):
    if vector_length < 0:
        vector_length = input_sequence['x_source_length']

    source_sequence = _vectorize(input_sequence['x_source'], padding_index, vector_length)
    target_sequence = _vectorize(input_sequence['y_target'], padding_index, vector_length)

    return {'x_source' : source_sequence, 'y_target' : target_sequence, 'x_source_length' : input_sequence['x_source_length']}

def collate_fn(batch):
    batch.sort(key=lambda sample: sample['x_source_length'], reverse=True)
    local_max_length = batch[0]['x_source_length']
    batch = [vectorize(sequence, vector_length=local_max_length) for sequence in batch]
    output_batch = {'x_source' : [], 'y_target' : [], 'x_source_length' : []}
    for sample in batch:
        output_batch['x_source'].append(sample['x_source'])
        output_batch['y_target'].append(sample['y_target'])
        output_batch['x_source_length'].append(sample['x_source_length'])
    return {'x_source' : torch.LongTensor(output_batch['x_source']), 'y_target' : torch.LongTensor(output_batch['y_target']), 'x_source_length' : torch.LongTensor(output_batch['x_source_length'])}

def compute_accuracy(y_hat, y, mask_index=0):
    y_hat = torch.argmax(y_hat, dim=1)
    y = y.view(-1)
    correct_indices = torch.eq(y_hat, y).float()
    valid_indices = torch.ne(y, mask_index).float()
    n_correct = (correct_indices * valid_indices).sum().item()
    n_valid = valid_indices.sum().item()

    return n_correct / n_valid * 100

def make_state():
    return {'train_loss' : [], 'train_acc' : [], 'val_loss' : [], 'val_acc' : [], 'test_loss' : -1, 'test_acc' : -1}

In [None]:
class BiLSTM(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, target_size, batch_size, num_layers=1, num_directions=2, padding_idx=0):
        super(BiLSTM, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.target_size = target_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.num_directions = num_directions

        self.word_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_dim * num_directions, self.target_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(self.num_layers * self.num_directions, self.batch_size, self.hidden_dim), \
        torch.randn(self.num_layers * self.num_directions, self.batch_size, self.hidden_dim))
    
    def forward(self, sequences, lengths):
        # Reset LSTM hidden state, otherwise the LSTM will treat a new batch as a continuation of a sequence
        self.hidden = self.init_hidden()

        # Dim transformation: (batch_size, seq_size, 1) -> (batch_size, seq_size, embedding_dim)
        embeds = self.word_embedding(sequences)
        embeds = torch.nn.utils.rnn.pack_padded_sequence(embeds, lengths, batch_first=True)

        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        # Please note that output_lengths are the original 'lengths'
        lstm_out, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        batch_size, seq_size, feat_size = lstm_out.shape

        # Dim transformation: (batch_size, seq_size, hidden_size * directions) -> (batch_size * seq_size, hidden_size * directions)
        lstm_out = lstm_out.contiguous().view(batch_size * seq_size, feat_size)

        link_outputs = self.fc(lstm_out)
        #link_scores = F.log_softmax(link_outputs, dim=1)
        # Output has the shape (batch_size * seq_size, target_size)
        
        return link_outputs

In [None]:
dataset = WikiDataset.load_dataset("../input_data/wiki.txt.bz2")
vocabulary = Vocabulary.of(dataset)
dataset.encode_from(vocabulary)

In [None]:
embedding_dim = 64
hidden_dim = 64
batch_size = 8
epochs = 1
device = 'cpu'
learning_rate = 5e-4
print(len(dataset)/batch_size)

In [None]:
model = BiLSTM(vocab_size=vocabulary.source_size, embedding_dim=embedding_dim, hidden_dim=hidden_dim, target_size=vocabulary.target_size, 
batch_size=batch_size)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

state = make_state()

for i in range(epochs):
    dataset.set_split('train')
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn, num_workers=4)

    running_loss = 0.0
    running_acc = 0.0
    model.train()

    for batch_index, batch in enumerate(dataloader):

        optimizer.zero_grad()

        x, y, x_len = batch.values()
        y_hat = model(x, x_len)

        loss = criterion(y_hat, y.view(-1))
        loss.backward()

        optimizer.step()

        running_loss += (loss.item() - running_loss) / (batch_index + 1)
        acc_t = compute_accuracy(y_hat, y)
        running_acc += (acc_t - running_acc) / (batch_index + 1)

    train_state['train_loss'].append(running_loss)
    train_state['train_acc'].append(running_acc)

In [None]:
#model = BiLSTM(vocab_size=vocabulary.source_size, embedding_dim=embedding_dim, hidden_dim=hidden_dim, target_size=vocabulary.target_size, batch_size=batch_size)
#x, y, x_len = batch.values()
#y_hat = model(x, x_len)
#criterion = nn.CrossEntropyLoss(ignore_index=0)
#print(y_hat.shape, y.shape, x_len)
#loss = criterion(y_hat, y.view(-1))
#loss.backward()
