In [1]:
import tensorflow as tf
import numpy as np
import modules.mediawiki as mw
from tqdm import tqdm

In [8]:
class HyperParameters():
    # adam learning rate
    learning_rate = 1e-3
    
    # maximum number of symbols in an input sequence
    max_sequence_length = 40

    # number of symbols in vocabulary
    # (symbols are expected to be in range(vocab_size))
    vocab_size = 10000

    # number of dimensions in input embeddings
    embedding_size = 256
    
    # number of dimensions in hidden state
    rnn_hidden_state_size = 256
    
    # feed-forward hidden state size
    ff_hidden_state_size = 512
    
    # dropout rate
    dropout_rate = 0.1
    
    # number of sequences per batch
    pipeline_batch_size = 256
    
    # number of parsing threads in data pipeline
    pipeline_num_parallel_calls = 4
    
    # size of prefetch in data pipeline
    pipeline_prefetch_size = pipeline_batch_size * 16
    
    # shuffle buffer size
    pipeline_shuffle_size = 10000

hp = HyperParameters()

In [3]:
sess = None

In [9]:
class RnnModel(mw.BaseModel):
    def __init__(self, hp):
        super().__init__(hp)
    
    def _build_prediction_model_internal(self):
        # Embeddings
        # ----------
        
        input_sequence_embeddings = tf.get_variable('input_sequence_embeddings', 
                                                    (self._hp.vocab_size, self._hp.embedding_size))
        input_sequences_embedded = tf.nn.embedding_lookup(input_sequence_embeddings, 
                                                          self._input_sequences,
                                                          name = 'input_sequences_embedded')

        input_position_embeddings = tf.get_variable('input_position_embeddings', 
                                                    (self._hp.max_sequence_length, self._hp.embedding_size))
        input_positions_embedded = tf.nn.embedding_lookup(input_position_embeddings, self._input_positions)

        input_word_ending_embeddings = tf.get_variable('input_word_ending_embeddings',
                                                       (2, self._hp.embedding_size))
        input_word_endings_embedded = tf.nn.embedding_lookup(input_word_ending_embeddings, 
                                                             self._input_word_endings,
                                                             name = 'input_word_endings_embedded')

        input_combined_embedded = tf.add_n([input_sequences_embedded, 
                                            input_positions_embedded, 
                                            input_word_endings_embedded])
        input_combined_embedded = tf.layers.dropout(input_combined_embedded,
                                                    rate = self._hp.dropout_rate,
                                                    training = self._is_training)
        
        # RNNs
        # ----
        
        dropout_keep_prob = tf.cond(self._is_training,
                                    lambda: tf.constant(1.0),
                                    lambda: tf.constant(1.0 - self._hp.dropout_rate))

        fw_rnn_cell = tf.nn.rnn_cell.GRUCell(self._hp.rnn_hidden_state_size)
        fw_rnn_cell = tf.nn.rnn_cell.DropoutWrapper(fw_rnn_cell,
                                                    input_keep_prob = dropout_keep_prob,
                                                    output_keep_prob = dropout_keep_prob)

        bw_rnn_cell = tf.nn.rnn_cell.GRUCell(self._hp.rnn_hidden_state_size)
        bw_rnn_cell = tf.nn.rnn_cell.DropoutWrapper(bw_rnn_cell,
                                                    input_keep_prob = dropout_keep_prob,
                                                    output_keep_prob = dropout_keep_prob)

        rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(fw_rnn_cell,
                                                         bw_rnn_cell,
                                                         input_combined_embedded,
                                                         sequence_length = self._input_lengths,
                                                         dtype = tf.float32)
        rnn_outputs = tf.concat(rnn_outputs, 2)
        
        # Softmax
        # -------
        
        # TODO: more layers here?
        feed_forward = tf.layers.dense(rnn_outputs,
                                       self._hp.ff_hidden_state_size,
                                       activation = tf.nn.relu,
                                       name = 'feed_forward')
        feed_forward = tf.layers.dropout(feed_forward,
                                         rate = self._hp.dropout_rate,
                                         training = self._is_training)

        output_logits = tf.layers.dense(feed_forward, 2, name = 'softmax')
        
        return output_logits

In [10]:
sess = mw.reset_tf(sess)
model = RnnModel(hp)
model.dump_statistics()

parameters for "input_sequence_embeddings:0": 2560000
parameters for "input_position_embeddings:0": 10240
parameters for "input_word_ending_embeddings:0": 512
parameters for "bidirectional_rnn/fw/gru_cell/gates/kernel:0": 262144
parameters for "bidirectional_rnn/fw/gru_cell/gates/bias:0": 512
parameters for "bidirectional_rnn/fw/gru_cell/candidate/kernel:0": 131072
parameters for "bidirectional_rnn/fw/gru_cell/candidate/bias:0": 256
parameters for "bidirectional_rnn/bw/gru_cell/gates/kernel:0": 262144
parameters for "bidirectional_rnn/bw/gru_cell/gates/bias:0": 512
parameters for "bidirectional_rnn/bw/gru_cell/candidate/kernel:0": 131072
parameters for "bidirectional_rnn/bw/gru_cell/candidate/bias:0": 256
parameters for "feed_forward/kernel:0": 262144
parameters for "feed_forward/bias:0": 512
parameters for "softmax/kernel:0": 1024
parameters for "softmax/bias:0": 2
total parameters: 3622402


In [11]:
sess.run(tf.global_variables_initializer())

In [13]:
num_epochs = 100

for epoch in range(num_epochs):
    model.evaluate_dataset(sess,
                           '../data/simplewiki/simplewiki-20171103.entity_recognition.train.tfrecords',
                           header='train %d' % epoch,
                           train=True,
                           show_progress=True)
    model.evaluate_dataset(sess,
                           '../data/simplewiki/simplewiki-20171103.entity_recognition.dev.tfrecords',
                           header='dev %d' % epoch,
                           train=False,
                           show_progress=False)

23718333it [08:03, 49049.94it/s]


train 0: loss=0.165133, precision=0.705067, recall=0.341897, F1=0.460494


0it [00:00, ?it/s]

dev 0: loss=0.158745, precision=0.734976, recall=0.376528, F1=0.497954


132901it [00:02, 58260.65it/s]

KeyboardInterrupt: 