In [1]:
import tensorflow as tf

_buffer_size = 20000
_thread_num = 16


def get_vocab(vocab_path, isTF=True):
    if isTF:
        vocab_path_tensor = tf.constant(vocab_path)
        tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_path_tensor)
        vocab_dict = tf.contrib.lookup.index_table_from_file(
            vocabulary_file=vocab_path_tensor,
            num_oov_buckets=0,
            default_value=1)
    else:
        vocab_dict = {}
        with open(vocab_path, "r") as f:
            for vocab in f:
                vocab_dict[len(vocab_dict)] = vocab.strip()
    return vocab_dict

In [2]:
corpus_path = "data/ratings_train.cleaned.txt"
vocab_path = "Word2vec.vocab"
max_len = 100
batch_size = 4

In [4]:
tf_vocab = get_vocab(vocab_path)
dataset = tf.data.TextLineDataset(corpus_path)

dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(_buffer_size))

dataset = dataset.map(lambda x: tf.string_split([x], delimiter="\t").values)

dataset = dataset.map(lambda x: {
                        "input": tf.string_split([x[0]]).values,
                        "label": tf.string_split([x[1]]).values
                    },
                    num_parallel_calls=_thread_num
                    )

# Truncate to max_len
dataset = dataset.map(lambda x:
            tf.cond(tf.greater(tf.shape(x["input"])[0], max_len),
                    lambda: {
                        "input": x["input"][:max_len],
                        "label": x["label"]
                    },
                    lambda: {
                        "input": x["input"],
                        "label": x["label"]
                    }
                )
            )

dataset = dataset.map(lambda x: {
                        "input": tf_vocab.lookup(x["input"]),
                        "label": x["label"]
                    },
                    num_parallel_calls=_thread_num
                    )

dataset = dataset.map(lambda x: {
                        "input": tf.to_int32(x["input"]),
                        "len": tf.shape(x["input"])[0],
                        "label": tf.strings.to_number(x["label"][0], tf.int32)
                    },
                        num_parallel_calls=_thread_num
                    )

dataset = dataset.padded_batch(
  batch_size,
  {
      "input": [tf.Dimension(max_len)],
      "len": [],
      "label": []
  },
  {
      "input": 0,
      "len": 0,
      "label": 0
  }
)

dataset = dataset.prefetch(3)

In [5]:
iters = tf.data.Iterator.from_structure(dataset.output_types,
                                        dataset.output_shapes)
features = iters.get_next()


# create the initialisation operations
init_op = iters.make_initializer(dataset)

In [6]:
sess = tf.Session()

In [7]:
sess.run(init_op)
sess.run(tf.tables_initializer())

In [8]:
sess.run(features)

{'input': array([[  294,    26,    20,     4,    76,   115,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  511,   493,    91,     1,     4,   462,   292,    29,    49,
            17,   377,    58,     3,  3563,  1292,  1693,    96,    58,
            13,  3697,     2,   442,  

In [9]:
sess.run(features)

{'input': array([[  22,   25,   12,   74,  253,   21, 2438,  253,   21,  720,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0],
        [  44,   14,   24,  496,  356,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
