In [1]:
import bz2
import nltk
import numpy as np
import re
from collections import Counter
nltk.download("punkt")

[nltk_data] Downloading package punkt to /home/leo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [26]:
def collect_links(input_file, valid_links, INNER_SEP, OUTER_SEP):
    links = []
    for sentence in input_file:
        for word in sentence:
            if word[0] == OUTER_SEP and word[-1] == OUTER_SEP:
                _link = list(filter(None, word.split(OUTER_SEP)))
                if len(_link) == 2 and _link[1] in valid_links:
                    links.append(_link[1])
    return input_file, links

def encode_labels(text_corpus, link2idx, INNER_SEP, OUTER_SEP, default_no_link):
    labels = []
    sentences = []
    for sentence in text_corpus:
        _label = []
        _sentence = []
        for word in sentence:
            if word[0] == OUTER_SEP and word[-1] == OUTER_SEP:
                _split = list(filter(None, word.split(OUTER_SEP)))
                if len(_split) != 2:
                    _label.append(default_no_link)
                    _sentence.append(word.replace(OUTER_SEP, ''))
                else:
                    text, link = _split
                    sub_links = filter(None, text.split(INNER_SEP))
                    for sub_link in sub_links:
                        _label.append(link2idx.get(link, default_no_link))
                        _sentence.append(sub_link)
            else:
                _label.append(default_no_link)
                _sentence.append(word)
        labels.append(_label)
        sentences.append(_sentence)
    return labels, sentences

# This function takes a batch of sentences and pads/trims every sentence to seq_len
def pad_input(sentences, seq_len, pad_token):
    features = np.ones((len(sentences), seq_len)) * pad_token
    for ii, sentence in enumerate(sentences):
        features[ii, :len(sentence)] = np.array(sentence)[:seq_len]
    return features

In [27]:
# Preprocess data ...

INNER_SEP = '_'
OUTER_SEP = '|'

train_file = bz2.BZ2File('../input_data/train.txt.bz2').read().decode('utf-8')
test_file = bz2.BZ2File('../input_data/test.txt.bz2').read().decode('utf-8')

train_file = [[word.lower() for word in nltk.word_tokenize(sentence)] for sentence in nltk.sent_tokenize(train_file)]
test_file = [[word.lower() for word in  nltk.word_tokenize(sentence)] for sentence in nltk.sent_tokenize(test_file)]

link_treshold = 1
valid_links = Counter([word.split(OUTER_SEP)[-2] for sentence in train_file for word in sentence if word[0] == OUTER_SEP and word[-1] == OUTER_SEP])
valid_links = set([link for link,frequence in valid_links.items() if frequence >= link_treshold])

train_file, train_links = collect_links(train_file, valid_links, INNER_SEP, OUTER_SEP)
test_file, test_links = collect_links(test_file, valid_links, INNER_SEP, OUTER_SEP)

output = train_links + test_links + ["_TEXT"]
_NO_LINK = len(output) -1
link2idx = {l:i for i,l in enumerate(output)}
idx2link = {i:l for i,l in enumerate(output)}

train_labels, train_file = encode_labels(train_file, link2idx, INNER_SEP, OUTER_SEP, _NO_LINK)
test_labels, test_file = encode_labels(test_file, link2idx, INNER_SEP, OUTER_SEP, _NO_LINK)

PAD_TOKEN = "_PAD"
UNKOWN_TOKEN = "_UNK"
vocabulary = Counter([word for sentence in train_file for word in sentence])
vocabulary = [PAD_TOKEN, UNKOWN_TOKEN] + sorted(vocabulary, key=vocabulary.get, reverse=True)

word2idx = {w:i for i,w in enumerate(vocabulary)}
idx2word = {i:w for i,w in enumerate(vocabulary)} # probably vocabulary array is enough

In [9]:
# Encode words as integers ...

train_sentences = [[word2idx[word] if word in word2idx else 1 for word in sentence] for sentence in train_file]
test_sentences = [[word2idx[word] if word in word2idx else 1 for word in sentence] for sentence in test_file]

#train_sentences.sort(key=len, reverse=True)
#test_sentences.sort(key=len, reverse=True)

batch_size = 32



#train_lengths = [len(sentence) for sentence in train_sentences]
#test_lengths = [len(sentence) for sentence in test_sentences]

In [18]:
print(train_lengths[10])

print(word2idx["_PAD"])

3085
0
