In [2]:
import os
import time
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.python.layers.core import Dense

In [3]:
class Seq2Seq():

    def __init__(self,enc_units,num_enc_layer,num_dec_layer,vocab_size,embed_dim,batch_size,word2int,beam_width,bidirectional=True,cell_type='LSTM',attention='bahdanau',decode_mech='Greedy',coverage=False):
        self.encoder_units = enc_units
        self.encoder_layers = num_enc_layer
        self.vocab_size = vocab_size
        self.embedding_dim = embed_dim
        self.bidirectional = bidirectional
        self.cell_type = cell_type
        self.decoder_units = self.encoder_units*2 if bidirectional else self.encoder_units
        self.decoder_layers = num_dec_layer
        self.attention = attention
        self.decode_mechanism = decode_mech
        self.beam_width = beam_width
        self.use_coverage = coverage
        self.batch_size = batch_size
        self.vocab_to_int = word2int

    def _init_placeholders(self):
        # Creates placeholders for encoder and decoder inputs and lengths
        with tf.variable_scope("Placeholders", reuse = tf.AUTO_REUSE):    
            self.encoder_inputs = tf.placeholder(dtype=tf.int32, shape=(None,None), name='encoder_inputs')
            self.decoder_targets = tf.placeholder(dtype=tf.int32, shape=(None,None), name='decoder_targets')
            self.encoder_lengths = tf.placeholder(dtype=tf.int32, shape=(None,), name='encoder_lengths')
            self.decoder_lengths = tf.placeholder(dtype=tf.int32, shape=(None,), name='decoder_lengths')
            self.max_dec_length = tf.reduce_max(self.decoder_lengths, name='max_dec_len')
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

    def _decoder_input_processing(self):
        with tf.variable_scope("Processing", reuse=tf.AUTO_REUSE):
            ending = tf.strided_slice(self.decoder_targets,begin=[0,0],end=[self.batch_size,-1],strides=[1,1])
            self.decoder_inputs = tf.concat([tf.fill([self.batch_size, 1],self.vocab_to_int['<GO>']),ending],1)
            
    def _create_encoder(self):
        with tf.variable_scope("Encoder_Layer", reuse=tf.AUTO_REUSE):
            if self.cell_type == 'LSTM':
                self.cells_fw = [tf.contrib.rnn.LSTMCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
                if self.bidirectional:
                    self.cells_bw = [tf.contrib.rnn.LSTMCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
            elif self.cell_type == 'GRU':
                self.cells_fw = [tf.contrib.rnn.GRUCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
                if self.bidirectional:
                    self.cells_bw = [tf.contrib.rnn.GRUCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
            elif self.cell_type == 'GLSTM':
                self.cells_fw = [tf.contrib.rnn.GLSTMCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
                if self.bidirectional:
                    self.cells_bw = [tf.contrib.rnn.GLSTMCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
            elif self.cell_type == 'RNN':
                self.cells_fw = [tf.contrib.rnn.RNNCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
                if self.bidirectional:
                    self.cells_bw = [tf.contrib.rnn.RNNCell(self.encoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.encoder_layers)]
    
    def _create_decoder(self):
        with tf.variable_scope("Decoder_Layer",reuse=tf.AUTO_REUSE):
            if self.cell_type == 'LSTM':
                dec_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(self.decoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.decoder_layers)])
            elif self.cell_type == 'GRU':
                dec_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.GRUCell(self.decoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.decoder_layers)])
            elif self.cell_type == 'GLSTM':
                dec_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.GLSTMCell(self.decoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.decoder_layers)])
            elif self.cell_type == 'RNN':
                dec_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.RNNCell(self.decoder_units,initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) for _ in range(self.decoder_layers)])
            
            dec_cells = tf.contrib.rnn.DropoutWrapper(dec_cells,input_keep_prob=0.8)
            self.output_layer = Dense(self.vocab_size,kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1),name='Output')
            
            if self.attention != None:
                if self.attention == 'bahdanau':
                    attn_mech = tf.contrib.seq2seq.BahdanauAttention(self.decoder_units,
                                                                     self.enc_output,
                                                                     self.encoder_lengths,
                                                                     normalize=False,
                                                                     name='BahdanauAttention')
                elif self.attention =='luong':
                    attn_mech = tf.contrib.seq2seq.LuongAttention(self.decoder_units,
                                                                  self.enc_output,
                                                                  self.encoder_lengths,
                                                                  name='LuongAttention')
                self.dec_cells = tf.contrib.seq2seq.AttentionWrapper(dec_cells,attn_mech,self.decoder_units)
                init_state = self.dec_cells.zero_state(batch_size=self.batch_size, dtype=tf.float32)
                self.initial_state = init_state.clone(cell_state=self.enc_state)
            else:
                self.dec_cells = dec_cells
                self.initial_state = self.enc_state
            
    
    def encoding(self,embeddings):
        with tf.variable_scope("Encoding",reuse=tf.AUTO_REUSE):
            self._create_encoder()
            enc_embed_input = tf.nn.embedding_lookup(embeddings, self.encoder_inputs)
            if self.bidirectional:
                enc_output,enc_fw_state,enc_bw_state = \
                tf.contrib.rnn.stack_bidirectional_dynamic_rnn(self.cells_fw,
                                                               self.cells_bw,
                                                               enc_embed_input,
                                                               sequence_length=self.encoder_lengths,
                                                               dtype=tf.float32)

                self.enc_output = tf.concat(enc_output,2)
                if(self.encoder_layers > 1):
                    if isinstance(enc_fw_state[0], tf.contrib.rnn.LSTMStateTuple):
                        enc_state_c = tf.reduce_mean([tf.concat((enc_fw_state[i].c, enc_bw_state[i].c) , 1) for i in range(self.encoder_layers)], axis=0, name='bidirectional_concat_c')
                        enc_state_h = tf.reduce_mean([tf.concat((enc_fw_state[i].h, enc_bw_state[i].h), 1) for i in range(self.encoder_layers)], axis=0, name='bidirectional_concat_c')
                        self.enc_state = (tf.contrib.rnn.LSTMStateTuple(c=enc_state_c, h=enc_state_h),)

                    elif isinstance(enc_fw_state[0], tf.Tensor):
                        self.enc_state = (tf.concat((enc_fw_state, enc_bw_state),1, name='bidirectional_concat'),)
                else:
                    if isinstance(enc_fw_state, tf.contrib.rnn.LSTMStateTuple):
                        enc_state_c = tf.concat((enc_fw_state.c, enc_bw_state.c), 1, name='bidirectional_concat_c')
                        enc_state_h = tf.concat((enc_fw_state.h, enc_bw_state.h), 1, name='bidirectional_concat_h')
                        self.enc_state = (tf.contrib.rnn.LSTMStateTuple(c=enc_state_c, h=enc_state_h),)

                    elif isinstance(enc_fw_state, tf.Tensor):
                        self.enc_state = (tf.concat((enc_fw_state, enc_bw_state), 1, name='bidirectional_concat'),)

            else:
                self.enc_output, self.enc_state = tf.nn.dynamic_rnn(self.cells_fw,
                                                                    enc_embed_input,
                                                                    self.encoder_lengths,
                                                                    dtype=tf.float32)
    
    def decoding(self,embeddings):
        with tf.variable_scope("Decoding", reuse=tf.AUTO_REUSE):
            self._create_decoder()
            self._decoder_input_processing()
            dec_embed_input = tf.nn.embedding_lookup(embeddings, self.decoder_inputs)
            train_helper = tf.contrib.seq2seq.TrainingHelper(dec_embed_input,
                                                             sequence_length=self.decoder_lengths,
                                                             name='train_helper')
            self.train_decoder = tf.contrib.seq2seq.BasicDecoder(cell=self.dec_cells,
                                                                 helper=train_helper,
                                                                 initial_state=self.initial_state,
                                                                 output_layer=self.output_layer)

            start_tokens = tf.tile(tf.constant([self.vocab_to_int['<GO>']],dtype=tf.int32),[self.batch_size],name='start_tokens')

            if self.decode_mechanism == 'Greedy':
                inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddings,
                                                                            start_tokens,
                                                                            self.vocab_to_int['<EOS>'])

            self.inference_decoder = tf.contrib.seq2seq.BasicDecoder(self.dec_cells,
                                                                     inference_helper,
                                                                     self.initial_state,
                                                                     self.output_layer)

            self.training_logits, training_state, training_lengths = \
            tf.contrib.seq2seq.dynamic_decode(self.train_decoder,
                                              impute_finished=True,
                                              maximum_iterations=self.max_dec_length)

            self.inference_logits, inference_state, inference_lengths = \
            tf.contrib.seq2seq.dynamic_decode(self.inference_decoder,
                                              impute_finished=True,
                                              maximum_iterations=self.max_dec_length)

In [6]:
def model_gen():
    enc_units = 128
    num_enc_layer = 3
    num_dec_layer = 1
    vocab_size = 1000
    embed_dim = 300
    batch_size = 32
    word2int = {'<GO>':1,'<EOS>':2,'<PAD>':3}
    bidirectional = True
    cell_type = 'LSTM'
    attention = 'bahdanau'
    beam_width = 10
    decode_mech = 'Greedy'
    coverage = False
    embedding_matrix = np.ones(shape=(vocab_size,embed_dim),dtype=np.float32)
    
    seq2seq_ob = Seq2Seq(enc_units=enc_units,num_enc_layer=num_enc_layer,num_dec_layer=num_dec_layer,beam_width=beam_width,vocab_size=vocab_size,embed_dim=embed_dim,word2int=word2int,batch_size=batch_size,bidirectional=bidirectional,cell_type=cell_type,attention=attention,decode_mech=decode_mech,coverage=coverage)
    seq2seq_ob._init_placeholders()
    seq2seq_ob._create_encoder()
    seq2seq_ob.encoding(embedding_matrix)
    seq2seq_ob._create_decoder()
    seq2seq_ob.decoding(embedding_matrix)
    
    return seq2seq_ob

In [7]:
learning_rate = 0.005
# Build the graph
train_graph = tf.Graph()
# Set the graph to default to ensure that it is ready for training
with train_graph.as_default():
    
    model_ob = model_gen()
    
    # Create tensors for the training logits and inference logits
    training_logits = tf.identity(model_ob.training_logits.rnn_output, 'logits')
    inference_logits = tf.identity(model_ob.inference_logits.sample_id, name='predictions')
    
    # Create the weights for sequence_loss
    masks = tf.sequence_mask(model_ob.decoder_lengths,model_ob.max_dec_length,dtype=tf.float32,name='masks')

    with tf.name_scope("optimization"):
        # Loss function
        cost = tf.contrib.seq2seq.sequence_loss(training_logits,model_ob.decoder_targets,masks)

        # Optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate)

        # Gradient Clipping
        gradients = optimizer.compute_gradients(cost)
        capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]
        train_op = optimizer.apply_gradients(capped_gradients)
print("Graph is built.")

Graph is built.


In [None]:
learning_rate_decay = 0.95
min_learning_rate = 0.0005
display_step = 20 # Check training loss after every 20 batches
stop_early = 0 
stop = 10 # If the update loss does not decrease in 3 consecutive update checks, stop training
per_epoch = 3 # Make 3 update checks per epoch
update_check = (len(sorted_texts)//batch_size//per_epoch)-1

update_loss = 0 
batch_loss = 0
summary_update_loss = [] # Record the update losses for saving improvements in the model

checkpoint = "./best_model.ckpt" 
with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    # If we want to continue training a previous session
    #loader = tf.train.import_meta_graph("./" + checkpoint + '.meta')
    #loader.restore(sess, checkpoint)
    
    for epoch_i in range(1, epochs+1):
        update_loss = 0
        batch_loss = 0
        for batch_i, (summaries_batch, texts_batch, summaries_lengths, texts_lengths) in enumerate(
                get_batches(train_summaries, train_texts, batch_size)):
            start_time = time.time()
            _, loss = sess.run(
                [train_op, cost],
                {input_data: texts_batch,
                 targets: summaries_batch,
                 lr: learning_rate,
                 summary_length: summaries_lengths,
                 text_length: texts_lengths,
                 keep_prob: keep_probability})

            batch_loss += loss
            update_loss += loss
            end_time = time.time()
            batch_time = end_time - start_time

            if batch_i % display_step == 0 and batch_i > 0:
                print('Epoch {:>3}/{} Batch {:>4}/{} - Loss: {:>6.3f}, Seconds: {:>4.2f}'
                      .format(epoch_i,
                              epochs, 
                              batch_i, 
                              len(sorted_texts) // batch_size, 
                              batch_loss / display_step, 
                              batch_time*display_step))
                batch_loss = 0

            if batch_i % update_check == 0 and batch_i > 0:
                print("Average loss for this update:", round(update_loss/update_check,3))
                summary_update_loss.append(update_loss)
                
                # If the update loss is at a new minimum, save the model
                if update_loss <= min(summary_update_loss):
                    print('New Record!')
                    stop_early = 0
                    saver = tf.train.Saver() 
                    saver.save(sess, checkpoint)

                else:
                    print("No Improvement.")
                    stop_early += 1
                    if stop_early == stop:
                        break
                update_loss = 0
            
                    
        # Reduce learning rate, but not below its minimum value
        learning_rate *= learning_rate_decay
        if learning_rate < min_learning_rate:
            learning_rate = min_learning_rate
        
        if stop_early == stop:
            print("Stopping Training.")
            break
    
    logits = []
    for batch_i, (summaries_batch, texts_batch, summaries_lengths, texts_lengths) in enumerate(
                  get_batches(test_summaries, test_summaries, batch_size)):
        inf_logits = sess.run(inference_logits,
                              {input_data: texts_batch,
                               summary_length: summaries_lengths,
                               text_length: texts_lengths,
                               keep_prob: keep_probability})
        logits.extend(inf_logits.tolist())
      
    logits_tr = []
    for batch_i, (summaries_batch, texts_batch, summaries_lengths, texts_lengths) in enumerate(
        get_batches(train_summaries, train_texts, batch_size)):
        start_time = time.time()
        inf_logits = sess.run(inference_logits,
                              {input_data: texts_batch,
                               targets: summaries_batch,
                               lr: learning_rate,
                               summary_length: summaries_lengths,
                               text_length: texts_lengths,
                               keep_prob: keep_probability})
        logits_tr.extend(inf_logits.tolist())