In [1]:
import tensorflow as tf

In [2]:
class  NNModel:
    
    def  __init__(
        self, vocab_size, batch_size, learning_rate, state_size, 
        num_layers, embedding_size, keep_prob, emb_init, epsilon,
        dtype
    ):
        
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.state_size = state_size 
        self.num_layers = num_layers
        self.embedding_size = embedding_size
        self.keep_prob = keep_prob
        self.emb_init = emb_init
        self.epsilon = epsilon
        self.dtype = dtype
        
        self.model_objects = {}
        
        self._build()
        
    
    def _init_placeholders(self):
        self.model_objects["global_step"] = tf.Variable(0, trainable=False, name="global_step")
        
        self.model_objects["encoder_inputs"] = tf.placeholder(tf.int32, shape=[self.batch_size, None])
        self.model_objects["decoder_inputs"] = tf.placeholder(tf.int32, shape=[self.batch_size, None])
        self.model_objects["decoder_targets"] = tf.placeholder(tf.int32, shape=[self.batch_size, None])
        
        self.model_objects["encoder_len"] = tf.placeholder(tf.int32, shape=[self.batch_size])
        self.model_objects["decoder_len"] = tf.placeholder(tf.int32, shape=[self.batch_size])
        
        self.model_objects["beam_tok"] = tf.placeholder(tf.int32, shape=[self.batch_size])
        self.model_objects["prev_att"] = tf.placeholder(tf.float32, shape=[self.batch_size, self.state_size * 2])

    def _build_cells(self):
        # TODO: Change to LSTM
        encoder_fw_cell = tf.contrib.rnn.GRUCell(self.state_size)
        encoder_bw_cell = tf.contrib.rnn.GRUCell(self.state_size)
        decoder_cell = tf.contrib.rnn.GRUCell(self.state_size)

        # bidirectional
        self.model_objects["encoder_fw_cell"] = tf.contrib.rnn.DropoutWrapper(
            encoder_fw_cell, output_keep_prob=self.keep_prob
        )
        self.model_objects["encoder_bw_cell"] = tf.contrib.rnn.DropoutWrapper(
            encoder_bw_cell, output_keep_prob=self.keep_prob
        )
        self.model_objects["decoder_cell"] = tf.contrib.rnn.DropoutWrapper(
            decoder_cell, output_keep_prob=self.keep_prob
        )
        
    def _seq2seq_embedding(self):
            self.model_objects["embedding"] = tf.get_variable(
                "embedding", [self.vocab_size, self.embedding_size], initializer=self.emb_init
            )
       
    def _seq2seq_encoder(self):
        with tf.variable_scope("encoder"):
            encoder_inputs_emb = tf.nn.embedding_lookup(
                self.model_objects["embedding"], self.model_objects["encoder_inputs"]
            )

            rnn_outs = tf.nn.bidirectional_dynamic_rnn(
                self.model_objects["encoder_fw_cell"], 
                self.model_objects["encoder_bw_cell"], 
                encoder_inputs_emb,
                sequence_length=self.model_objects["encoder_len"], 
                dtype=self.dtype
            )
            self.model_objects["encoder_outputs"], self.model_objects["encoder_states"] = (
                rnn_outs
            )
            
    def _seq2seq_init_state(self):
        with tf.variable_scope("init_state"):
            init_state = tf.contrib.layers.fully_connected(
                tf.concat(self.model_objects["encoder_states"], axis=1), 
                self.state_size
            )
            
            # the shape of bidirectional_dynamic_rnn is weird
            # None for batch_size
            self.model_objects["init_state"] = init_state
            self.model_objects["init_state"].set_shape([self.batch_size, self.state_size])
            
            self.model_objects["att_states"] = tf.concat(
                self.model_objects["encoder_outputs"], axis=2
            )
            self.model_objects["att_states"].set_shape(
                [self.batch_size, None, self.state_size*2]
            )
            
    def _seq2seq_attention(self):
        with tf.variable_scope("attention"):
            attention = tf.contrib.seq2seq.BahdanauAttention(
                self.state_size, self.model_objects["att_states"], 
                self.model_objects["encoder_len"]
            )
            
            self.model_objects["decoder_cell"] = tf.contrib.seq2seq.AttentionWrapper(
                self.model_objects["decoder_cell"], attention, self.state_size * 2
            )
            
#             self.model_objects["wrapper_state"] = tf.contrib.seq2seq.AttentionWrapperState(
#                 self.model_objects["init_state"], self.model_objects["prev_att"]
#             )

            # TODO: Talvez tenha um problema aqui
            self.model_objects["wrapper_state"] = self.model_objects["decoder_cell"].zero_state(
                self.batch_size, self.dtype
            ).clone(cell_state=self.model_objects["init_state"])
            
    def _seq2seq_decoder(self):
        with tf.variable_scope("decoder") as scope:
            decoder_inputs_emb = tf.nn.embedding_lookup(
                self.model_objects["embedding"], self.model_objects["decoder_inputs"]
            )

            self.model_objects["decoder_cell"] = tf.contrib.rnn.OutputProjectionWrapper(
                self.model_objects["decoder_cell"], self.vocab_size
            )

            helper = tf.contrib.seq2seq.TrainingHelper(
                decoder_inputs_emb, self.model_objects["decoder_len"]
            )
            decoder = tf.contrib.seq2seq.BasicDecoder(
                self.model_objects["decoder_cell"], helper, 
                self.model_objects["wrapper_state"]
            )

            outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder)

            self.model_objects["outputs"] = outputs[0]

    def _seq2seq_loss(self):
        with tf.variable_scope("loss") as scope:
            weights = tf.sequence_mask(
                self.model_objects["decoder_len"], dtype=tf.float32
            )

            loss_t = tf.contrib.seq2seq.sequence_loss(
                self.model_objects["outputs"], self.model_objects["decoder_targets"], weights,
                average_across_timesteps=False,
                average_across_batch=False
            )

            self.model_objects["loss"] = tf.reduce_sum(loss_t) / self.batch_size

            tf.summary.scalar('loss', self.model_objects["loss"])
            
    def _seq2seq_optimize(self):
        with tf.variable_scope("optimizer") as scope:
            self.model_objects["optimizer"] = tf.train.AdadeltaOptimizer(
                self.learning_rate, epsilon=self.epsilon
            ).minimize(self.model_objects["loss"])

    def _build(self):
        self._init_placeholders()
        self._build_cells()
        self._seq2seq_embedding()
        self._seq2seq_encoder()
        self._seq2seq_init_state()
        self._seq2seq_attention()
        self._seq2seq_decoder()
        self._seq2seq_loss()
        self._seq2seq_optimize()