In [1]:
import tensorflow as tf
import numpy as np
import modules.mediawiki as mw
from tqdm import tqdm

In [2]:
class HyperParameters():
    # adam learning rate
    learning_rate = 1e-3
    
    # maximum number of symbols in an input sequence
    max_sequence_length = 40

    # number of symbols in vocabulary
    # (symbols are expected to be in range(vocab_size))
    vocab_size = 10000

    # number of dimensions in input embeddings
    embedding_size = 256
    
    # feed-forward hidden state size
    ff_hidden_state_size = 512
    
    # num feed_forward layers
    ff_layers = 4
    
    # dropout rate
    dropout_rate = 0.1
    
    # number of sequences per batch
    pipeline_batch_size = 1024
    
    # number of parsing threads in data pipeline
    pipeline_num_parallel_calls = 4
    
    # size of prefetch in data pipeline
    pipeline_prefetch_size = pipeline_batch_size * 16
    
    # shuffle buffer size
    pipeline_shuffle_size = 10000

hp = HyperParameters()

In [3]:
sess = None

In [6]:
class TrivialModel(mw.BaseModel):
    def __init__(self, hp):
        super().__init__(hp)
    
    def _build_prediction_model_internal(self):
        input_sequence_embeddings = tf.get_variable('input_sequence_embeddings', 
                                                    (self._hp.vocab_size, self._hp.embedding_size))
        
        input_sequences_embedded = tf.nn.embedding_lookup(input_sequence_embeddings, 
                                                          self._input_sequences,
                                                          name = 'input_sequences_embedded')
        input_sequences_embedded = tf.layers.dropout(input_sequences_embedded,
                                                     rate = self._hp.dropout_rate,
                                                     training = self._is_training)
        
        curr_layer = input_sequences_embedded
        for i in range(self._hp.ff_layers):
            curr_layer = tf.layers.dense(curr_layer,
                                         self._hp.ff_hidden_state_size,
                                         activation = tf.nn.relu,
                                         name = 'ff_%d' % i)
            curr_layer = tf.layers.dropout(curr_layer,
                                           rate = self._hp.dropout_rate,
                                           training = self._is_training)
        
        output_logits = tf.layers.dense(curr_layer, 2, name = 'softmax')
        
        return output_logits


In [7]:
sess = mw.reset_tf(sess)
model = TrivialModel(hp)
model.dump_statistics()

parameters for "input_sequence_embeddings:0": 2560000
parameters for "ff_0/kernel:0": 131072
parameters for "ff_0/bias:0": 512
parameters for "ff_1/kernel:0": 262144
parameters for "ff_1/bias:0": 512
parameters for "ff_2/kernel:0": 262144
parameters for "ff_2/bias:0": 512
parameters for "ff_3/kernel:0": 262144
parameters for "ff_3/bias:0": 512
parameters for "softmax/kernel:0": 1024
parameters for "softmax/bias:0": 2
total parameters: 3480578


In [8]:
sess.run(tf.global_variables_initializer())

In [9]:
num_epochs = 100

for epoch in range(num_epochs):
    model.evaluate_dataset(sess,
                           '../data/simplewiki/simplewiki-20171103.entity_recognition.train.tfrecords',
                           header='train %d' % epoch,
                           train=True,
                           show_progress=True)
    model.evaluate_dataset(sess,
                           '../data/simplewiki/simplewiki-20171103.entity_recognition.dev.tfrecords',
                           header='dev %d' % epoch,
                           train=False,
                           show_progress=False)

23718333it [02:20, 175559.54it/s]


train 0: loss=0.210792, precision=0.605264, recall=0.0965164, F1=0.166485


0it [00:00, ?it/s]

dev 0: loss=0.212034, precision=0.624766, recall=0.122003, F1=0.204142


23718333it [02:16, 174536.23it/s]


train 1: loss=0.206981, precision=0.623669, recall=0.115057, F1=0.194274


0it [00:00, ?it/s]

dev 1: loss=0.21143, precision=0.638382, recall=0.112477, F1=0.191256


23718333it [02:17, 174460.12it/s]


train 2: loss=0.206416, precision=0.626484, recall=0.113882, F1=0.19273


0it [00:00, ?it/s]

dev 2: loss=0.211461, precision=0.633917, recall=0.115084, F1=0.194803


6327629it [00:36, 172911.42it/s]

KeyboardInterrupt: 

6327629it [00:50, 172911.42it/s]