In [None]:
import import_ipynb
import numpy as np
import tensorflow as tf
from tensorflow import layers
# from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops
from tensorflow.contrib import seq2seq
from tensorflow.contrib.seq2seq import BahdanauAttention
from tensorflow.contrib.seq2seq import LuongAttention
from tensorflow.contrib.seq2seq import AttentionWrapper
from tensorflow.contrib.seq2seq import BeamSearchDecoder
from tensorflow.contrib.rnn import LSTMCell
from tensorflow.contrib.rnn import GRUCell
from tensorflow.contrib.rnn import MultiRNNCell
from tensorflow.contrib.rnn import DropoutWrapper
from tensorflow.contrib.rnn import ResidualWrapper

from word_sequence import WordSequence
from data_utilis import get_embed_device

In [None]:
class Sequence2Sequence(object):
    def __init__(self,
                 input_vocab_size,
                 target_vocab_size,
                 batch_size=32,
                 embedding_size=300,
                 mode='train',
                 hidden_units=256,
                 depth=1,
                 beam_width=0,
                 cell_type='lstm',
                 dropout=0.2,
                 use_dropout=False,
                 use_residual=False,
                 optimizer='adam',
                 learning_rate=1e-3,
                 min_learning_rate=1e-6,
                 decay_steps=500000,
                 max_gradient_norm=5.0,
                 max_decode_step=None,
                 attention_type='Bahdanau',
                 bidirectional=False,
                 time_major=False,
                 seed=0,
                 parallel_iterations=None,
                 share_embedding=False,
                 pretrained_embedding=False):
        
        
        self.input_vocab_size = input_vocab_size
        self.target_vocab_size = target_vocab_size
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.hidden_units = hidden_units
        self.depth = depth
        self.cell_type = cell_type.lower()
        self.use_dropout = use_dropout
        self.use_residual = use_residual
        self.attention_type = attention_type
        self.mode = mode
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.min_learning_rate = min_learning_rate
        self.decay_steps = decay_steps
        self.max_gradient_norm = max_gradient_norm
        self.keep_prob = 1.0 - dropout
        self.bidirectional = bidirectional
        self.seed = seed
        self.pretrained_embedding = pretrained_embedding
        if isinstance(parallel_iterations, int):
            self.parallel_iterations = parallel_iterations
        else: # if parallel_iterations is None:
            self.parallel_iterations = batch_size
        self.time_major = time_major
        self.share_embedding = share_embedding

        self.initializer = tf.random_uniform_initializer(
            -0.05, 0.05, dtype=tf.float32
        )
        # self.initializer = None

        assert self.cell_type in ('gru', 'lstm'), 'cell_type should be gru or lstm'

        if share_embedding:
            assert input_vocab_size == target_vocab_size, 'the two vocb_size must be the same if share_embedding is true'

        assert mode in ('train', 'decode'), 'mode must be train or decode'

        assert dropout >= 0.0 and dropout < 1.0, '0 <= dropout < 1'

        assert attention_type.lower() in ('bahdanau', 'luong'), "attention_type must be bahdanau or luong"

        assert beam_width < target_vocab_size, "beam_width should not bigger than target_vocab_size"

        self.keep_prob_placeholder = tf.placeholder(
            tf.float32,
            shape=[],
            name='keep_prob'
        )

        self.global_step = tf.Variable(
            0, trainable=False, name='global_step'
        )

        self.use_beamsearch_decode = False
        self.beam_width = beam_width
        self.use_beamsearch_decode = True if self.beam_width > 0 else False
        self.max_decode_step = max_decode_step

        assert self.optimizer.lower() in ('adadelta', 'adam', 'rmsprop', 'momentum', 'sgd'), 'optimizer should be one of adadelta, adam, rmsprop, momentum and sgd'

        self.build_model()
        
    def build_model(self):
        
    def init_placeholders(self):
        