In [1]:
import tensorflow as tf
from tensorflow.contrib import layers

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  from ._conv import register_converters as _register_converters
  import pandas.util.testing as tm


In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import logging
from IPython import embed

In [3]:
class Seq2seq:
    def __init__(self, vocab_size, residual=True):
        self.residual = residual
        self.vocab_size = vocab_size

    def make_graph(self,mode, features, labels):
        embed_dim = 256
        num_units = 256

        input,output   = features['input'], features['output']
        batch_size     = tf.shape(input)[0]
        start_tokens   = tf.zeros([batch_size], dtype= tf.int64)
        train_output   = tf.concat([tf.expand_dims(start_tokens, 1), output], 1)
        input_lengths  = tf.reduce_sum(tf.to_int32(tf.not_equal(input, 1)), 1)
        output_lengths = tf.reduce_sum(tf.to_int32(tf.not_equal(train_output, 1)), 1)
        input_embed    = layers.embed_sequence(input, vocab_size=self.vocab_size, embed_dim = embed_dim, scope = 'embed')
        output_embed   = layers.embed_sequence(train_output, vocab_size=self.vocab_size, embed_dim = embed_dim, scope = 'embed', reuse = True)
        with tf.variable_scope('embed', reuse=True):
            embeddings = tf.get_variable('embeddings')
        cell = tf.contrib.rnn.LSTMCell(num_units=num_units)
        if self.residual:
            cell = tf.contrib.rnn.ResidualWrapper(cell)
        encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(cell, input_embed, dtype=tf.float32)


        def decode(helper, scope, reuse=None):
            # Decoder is partially based on @ilblackdragon//tf_example/seq2seq.py
            with tf.variable_scope(scope, reuse=reuse):
                attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                    num_units=num_units, memory=encoder_outputs,
                    memory_sequence_length=input_lengths)
                cell = tf.contrib.rnn.LSTMCell(num_units=num_units)
                attn_cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention_mechanism, attention_layer_size=num_units / 2)
                out_cell = tf.contrib.rnn.OutputProjectionWrapper(attn_cell, self.vocab_size, reuse=reuse)
                decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=out_cell, helper=helper,
                    initial_state=out_cell.zero_state(
                        dtype=tf.float32, batch_size=batch_size))
                outputs = tf.contrib.seq2seq.dynamic_decode(
                    decoder=decoder, output_time_major=False,
                    impute_finished=True, maximum_iterations=30)
                return outputs[0]

        train_helper = tf.contrib.seq2seq.TrainingHelper(output_embed, output_lengths)
        pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddings, start_tokens=tf.to_int32(start_tokens), end_token=1)
        train_outputs = decode(train_helper, 'decode')
        pred_outputs  = decode(pred_helper, 'decode', reuse=True)

        tf.identity(train_outputs.sample_id[0], name='train_pred')
        weights = tf.to_float(tf.not_equal(train_output[:, :-1], 1))
        loss = tf.contrib.seq2seq.sequence_loss(train_outputs.rnn_output, output, weights=weights)
        train_op = layers.optimize_loss(
            loss, tf.train.get_global_step(),
            optimizer='Adam',
            learning_rate=0.001,
            summaries=['loss', 'learning_rate'])

        tf.identity(pred_outputs.sample_id[0], name='predict')
        return tf.estimator.EstimatorSpec(mode=mode, predictions=pred_outputs.sample_id, loss=loss, train_op=train_op)

In [11]:
class Data:
    def __init__(self, input_filename, output_filename, vocab_filename):
        self.input_filename = input_filename
        self.output_filename = output_filename
        self.vocab_filename = vocab_filename
        
        # create vocab and reverse vocab maps
        self.vocab     = {}
        self.rev_vocab = {}
        self.END_TOKEN = 1 
        self.UNK_TOKEN = 2
        self.FLIP = False
        with open(vocab_filename) as f:
            for idx, line in enumerate(f):
                self.vocab[line.strip()] = idx
                self.rev_vocab[idx] = line.strip()
        self.vocab_size = len(self.vocab)

    def tokenize_and_map(self,line):
        return [self.vocab.get(token, self.UNK_TOKEN) for token in line.split(' ')]

    def prepare(self,text):
        tokens = self.tokenize_and_map(text)
        input_length   = len(tokens)
        source = [tokens]
        source[0] += [self.END_TOKEN] * (input_length - len(source[0]))
        return source



    def single(self, sentence):
        tokens = self.tokenize_and_map(sentence)
        def input_fn():
            inp = tf.placeholder(tf.int64, shape=[None, None], name='input')
            output = tf.placeholder(tf.int64, shape=[None, None], name='output')
            tf.identity(inp[0], 'source')
            tf.identity(output[0], 'target')
            return { 'input': inp, 'output': output}, None
        def feed_fn():
            input_length   = len(tokens)
            source = [tokens]
            source[0] += [self.END_TOKEN] * (input_length - len(source[0]))
            # this source is not used to compute anything, just so that placeholder does not complain about
            # missing values for target during prediction
            self.FLIP = not self.FLIP
            if not self.FLIP:
                raise StopIteration

            return { 'input:0': source, 'output:0': source }
        return input_fn, feed_fn

    def make_input_fn(self):
        def input_fn():
            inp = tf.placeholder(tf.int64, shape=[None, None], name='input')
            output = tf.placeholder(tf.int64, shape=[None, None], name='output')
            tf.identity(inp[0], 'source')
            tf.identity(output[0], 'target')
            return { 'input': inp, 'output': output}, None

        def sampler():
            while True:
                with open(self.input_filename) as finput, open(self.output_filename) as foutput:
                    for source,target in zip(finput, foutput):
                        yield {
                            'input': self.tokenize_and_map(source)[:30 - 1] + [self.END_TOKEN],
                            'output': self.tokenize_and_map(target)[:30 - 1] + [self.END_TOKEN]}

        data_feed = sampler()
        def feed_fn():
            source, target = [], []
            input_length, output_length = 0, 0
            for i in range(32):
                rec = data_feed.__next__()
                source.append(rec['input'])
                target.append(rec['output'])
                input_length = max(input_length, len(source[-1]))
                output_length = max(output_length, len(target[-1]))
            for i in range(32):
                source[i] += [self.END_TOKEN] * (input_length - len(source[i]))
                target[i] += [self.END_TOKEN] * (output_length - len(target[i]))
            return { 'input:0': source, 'output:0': target }
        return input_fn, feed_fn

    def get_formatter(self,keys):
        def to_str(sequence):
            tokens = [
                self.rev_vocab.get(x, "<UNK>") for x in sequence]
            return ' '.join(tokens)

        def format(values):
            res = []
            for key in keys:
                res.append("****%s == %s" % (key, to_str(values[key]).replace('</S>','').replace('<S>', '')))
            return '\n'+'\n'.join(res)
        return format

In [16]:
class Predict:
    def __init__(self, checkpoint='checkpoint'):
        self.data  = Data('train_source.txt', 'train_target.txt', 'train_vocab.txt')
        model = Seq2seq(self.data.vocab_size)
        estimator = tf.estimator.Estimator(model_fn=model.make_graph, model_dir=checkpoint)
        def input_fn():
            inp = tf.placeholder(tf.int64, shape=[None, None], name='input')
            output = tf.placeholder(tf.int64, shape=[None, None], name='output')
            tf.identity(inp[0], 'source')
            tf.identity(output[0], 'target')
            dict =  { 'input': inp, 'output': output}
            return tf.estimator.export.ServingInputReceiver(dict, dict)
        self.predictor = tf.contrib.predictor.from_estimator(estimator, input_fn)

    def infer(self, sentence):
        input = self.data.prepare(sentence)
        predictor_prediction = self.predictor({"input": input, "output":input})
        print(predictor_prediction)
        words = [self.data.rev_vocab.get(i, '<UNK>') for i in predictor_prediction['output'][0] if i > 2]
        return ' '.join(words)


In [17]:
P = Predict()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'checkpoint', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x00000209B1773160>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_ini

In [18]:
res = P.infer('what be the symbol of magnesium sulphate')

{'output': array([[  533,  4897,  4897,  4897,  4897,  4897,  4897,  4897,  4897,
         4897,  4897,  1130,  1130,  1130, 16004, 16004, 16004,  9812,
        12323, 17747, 17747, 18364,  6247,  6247,  6247,  6247, 15452,
         6247, 15452,  6247]])}


In [15]:
res

'weighing erect erect erect squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl squirrl hooking hooking hooking hooking hooking eye squirrl squirrl weighing "troop "troop furinture'