In [1]:
import tensorflow as tf
import numpy as np
import modules.mediawiki_er as mw_er
from tqdm import tqdm

In [2]:
class HyperParameters():
    # adam learning rate
    learning_rate = 1e-3
    
    # maximum number of symbols in an input sequence
    max_sequence_length = 40

    # number of symbols in vocabulary
    # (symbols are expected to be in range(vocab_size))
    vocab_size = 30000

    # number of dimensions in input embeddings
    embedding_size = 256
    
    # feed-forward hidden state size
    ff_hidden_state_size = 512
    
    # num feed_forward layers
    ff_layers = 4
    
    # dropout rate
    dropout_rate = 0.1
    
    # number of sequences per batch
    pipeline_batch_size = 1024
    
    # number of parsing threads in data pipeline
    pipeline_num_parallel_calls = 4
    
    # size of prefetch in data pipeline
    pipeline_prefetch_size = pipeline_batch_size * 16
    
    # shuffle buffer size
    pipeline_shuffle_size = 10000

hp = HyperParameters()

In [3]:
sess = None

In [4]:
class TrivialModel(mw_er.BaseModel):
    def __init__(self, hp):
        super().__init__(hp)
    
    def _build_prediction_model_internal(self):
        input_sequence_embeddings = tf.get_variable('input_sequence_embeddings', 
                                                    (self._hp.vocab_size, self._hp.embedding_size))
        
        input_sequences_embedded = tf.nn.embedding_lookup(input_sequence_embeddings, 
                                                          self._input_sequences,
                                                          name = 'input_sequences_embedded')
        input_sequences_embedded = tf.layers.dropout(input_sequences_embedded,
                                                     rate = self._hp.dropout_rate,
                                                     training = self._is_training)
        
        curr_layer = input_sequences_embedded
        for i in range(self._hp.ff_layers):
            curr_layer = tf.layers.dense(curr_layer,
                                         self._hp.ff_hidden_state_size,
                                         activation = tf.nn.relu,
                                         name = 'ff_%d' % i)
            curr_layer = tf.layers.dropout(curr_layer,
                                           rate = self._hp.dropout_rate,
                                           training = self._is_training)
        
        output_logits = tf.layers.dense(curr_layer, 2, name = 'softmax')
        
        return output_logits


In [5]:
sess = mw_er.reset_tf(sess)
model = TrivialModel(hp)
model.dump_statistics()

parameters for "input_sequence_embeddings:0": 7680000
parameters for "ff_0/kernel:0": 131072
parameters for "ff_0/bias:0": 512
parameters for "ff_1/kernel:0": 262144
parameters for "ff_1/bias:0": 512
parameters for "ff_2/kernel:0": 262144
parameters for "ff_2/bias:0": 512
parameters for "ff_3/kernel:0": 262144
parameters for "ff_3/bias:0": 512
parameters for "softmax/kernel:0": 1024
parameters for "softmax/bias:0": 2
total parameters: 8600578


In [6]:
sess.run(tf.global_variables_initializer())

In [7]:
num_epochs = 100

for epoch in range(num_epochs):
    model.evaluate_dataset(sess,
                           '../data/simplewiki/simplewiki-20171103.entity_recognition.train.tfrecords',
                           header='train %d' % epoch,
                           train=True,
                           show_progress=True)
    model.evaluate_dataset(sess,
                           '../data/simplewiki/simplewiki-20171103.entity_recognition.dev.tfrecords',
                           header='dev %d' % epoch,
                           train=False,
                           show_progress=False)


train 0 (1294): loss=0.152654, precision=0.629828, recall=0.206087, F1=0.310556
dev 0 (1294): loss=0.149086, precision=0.61779, recall=0.247108, F1=0.353014



train 1 (2588): loss=0.146788, precision=0.639153, recall=0.239815, F1=0.348769
dev 1 (2588): loss=0.149134, precision=0.617171, recall=0.250461, F1=0.35632



train 2 (3882): loss=0.144691, precision=0.638446, recall=0.248977, F1=0.358247
dev 2 (3882): loss=0.151257, precision=0.610995, recall=0.25529, F1=0.360114



train 3 (5176): loss=0.142835, precision=0.640023, recall=0.266429, F1=0.376238
dev 3 (5176): loss=0.15183, precision=0.595462, recall=0.264914, F1=0.366691


KeyboardInterrupt: 