In [1]:
import collections
import os
import sys

import numpy as np
import tensorflow as tf

In [2]:
Py3 = sys.version_info[0] == 3

def _read_words(filename):
    with tf.gfile.GFile(filename, "r") as f:
        if Py3:
            return f.read().replace("\n", "<eos>").split()
        else:
            return f.read().decode("utf-8").replace("\n", "<eos>").split()

def _build_vocab(filename):
    data = _read_words(filename)
    word_to_id = {}
    word_to_id['UNK'] = 0
    word_count_sorted = sorted(collections.Counter(data).items(), key=lambda item: item[1])
    for item in word_count_sorted:
        if item[1] > 3: # if word frequency > 3
            word_to_id[item[0]] = len(word_to_id) # index by dictionary length, starting from 1
        else:
            word_to_id['UNK'] += 1
    return word_to_id

def _file_to_word_ids(filename, word_to_id):
    data = _read_words(filename)
    return [word_to_id.get(word, 0) for word in data]

In [3]:
data_path = 'simple-examples/data/'

In [4]:
word_to_id = _build_vocab(os.path.join(data_path, "wiki.train.txt"))
print(len(word_to_id))

27247


In [5]:
id_to_word = dict(zip(word_to_id.values(), word_to_id.keys()))

In [6]:
for i in range(1, 6, 1):
    print(id_to_word[i])

Battlefield
Azure
replayed
unaltered
422


In [7]:
id_to_word[27092]

'10'

In [8]:
def ptb_raw_data(data_path):
    train_path = os.path.join(data_path, "wiki.train.txt")
    valid_path = os.path.join(data_path, "wiki.valid.txt")
    test_path = os.path.join(data_path, "wiki.test.txt")

    word_to_id = _build_vocab(train_path)
    train_data = _file_to_word_ids(train_path, word_to_id)
    valid_data = _file_to_word_ids(valid_path, word_to_id)
    test_data = _file_to_word_ids(test_path, word_to_id)
    vocabulary = len(word_to_id)
    return train_data, valid_data, test_data, vocabulary

train_data, valid_data, test_data, vocabulary = ptb_raw_data(data_path)

In [9]:
def generate_tokens(data_path, token_class):
    # create word-id dictionary
    word_to_id = _build_vocab(os.path.join(data_path, "wiki.train.txt"))

    # generate token ids
    tens = [i for i in range(10, 100, 10)]
    hundreds = [i for i in range(100, 1000, 100)]
    rounds = set([word_to_id.get(str(i), 0) for i in tens + hundreds])
    days = set([word_to_id.get(str(i), 0) for i in range(1, 32, 1)])
    years = set([word_to_id.get(str(i), 0) for i in range(1000, 2021, 1)])
    
    if token_class == 'rounds':
        return rounds
    elif token_class == 'days':
        return days
    else:
        return years

rounds = generate_tokens(data_path, 'rounds')
print(rounds)

{25793, 26979, 26756, 26532, 24971, 23403, 25262, 25968, 26929, 26106, 27092, 26997, 26068, 25687, 25017, 26266, 26525, 26591}


In [14]:
def generator(data, batch_size, num_steps, tokens):
    data_len = len(data)
    batch_len = data_len // batch_size
    data = np.reshape(data[0: batch_size * batch_len], [batch_size, batch_len])
    epoch_size = (batch_len - 1) // num_steps
    print(epoch_size)
    print(data.shape)

    i = 1
    while True:
        positions = []

        for j in range(batch_size):
            for k in range(num_steps):
                if data[j][i + k] in tokens:
                    positions.append((j, k, data[j][i + k]))

        if i == batch_len - num_steps - 1:
            i = 1
        else:
            i += 1
            
        yield positions

In [15]:
gen = generator(train_data, 20, 20, rounds)

In [16]:
for i in range(20):
    a = next(gen)
    print(a)

5221
(20, 104431)
[(14, 10, 27092)]
[(14, 9, 27092)]
[(14, 8, 27092)]
[(14, 7, 27092)]
[(14, 6, 27092)]
[(14, 5, 27092)]
[(14, 4, 27092)]
[(14, 3, 27092)]
[(14, 2, 27092)]
[(14, 1, 27092)]
[(14, 0, 27092)]
[]
[]
[]
[]
[]
[]
[]
[]
[]


In [17]:
for k in range(1):
    print(k)

0
