In [None]:
import numpy as np
import tensorflow as tf
import helpers
from spacy.en import English

parser = English()

from numpy import dot
from numpy.linalg import norm

# you can access known words from the parser's vocabulary
# nasa = parser.vocab['NASA']

# cosine similarity
cosine = lambda v1, v2: dot(v1, v2) / (norm(v1) * norm(v2))

# gather all known words, take only the lowercased versions
allWords = list({w for w in parser.vocab if w.has_vector and w.orth_.islower()})

# sort by similarity to NASA


tf.reset_default_graph()
sess = tf.InteractiveSession()
#
PAD = 0
EOS = 1

vocab_size = 300
input_embedding_size = 200 # character length

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units * 2
#
batch_size = 10
#encoder_inputs = tf.Variable(tf.zeros(shape=[8, batch_size], dtype=tf.int32), dtype=tf.int32, name='encoder_inputs')
encoder_inputs = tf.placeholder(shape=[None,None], dtype=tf.float32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=None, dtype=tf.int32, name='encoder_inputs_length')

decoder_targets = tf.placeholder(shape=[None, None], dtype=tf.float32, name='decoder_targets')
#

## Embeddings

# embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0),
#                         dtype=tf.float32, name='embeddings')

# encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
#

## Encoder
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple
#

with tf.variable_scope('forward'):
    encoder_cell_fw = LSTMCell(encoder_hidden_units)

with tf.variable_scope('backward'):
    encoder_cell_bw = LSTMCell(encoder_hidden_units)
#

((encoder_fw_outputs,
  encoder_bw_outputs),
(encoder_fw_final_state,
 encoder_bw_final_state)) = (
     tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell_fw,
                                     cell_bw=encoder_cell_bw,
                                     inputs=encoder_inputs,
                                     sequence_length=encoder_inputs_length,
                                     dtype=tf.float32,
                                     time_major=True)
 )

#

## concatenate the tensors along one dimension.

encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs),
                            2, name='concat_encoder_outputs')

encoder_final_state_c = tf.concat((encoder_fw_final_state.c, encoder_bw_final_state.c),
                                  1, name='concat_final_state_c')

encoder_final_state_h = tf.concat((encoder_fw_final_state.h, encoder_bw_final_state.h),
                                  1, name='concat_final_state_h')

# tf tuple used by lstm cells for state_size, zero_state and output state
encoder_final_state = LSTMStateTuple(
    c=encoder_final_state_c,
    h=encoder_final_state_h
)
#


## Decoder

decoder_cell = LSTMCell(decoder_hidden_units)

#

encoder_max_time, batch_size = tf.unstack(tf.shape(encoder_inputs), name='unstack')

decoder_lengths = encoder_inputs_length + 3
#

## Output Projection

W = tf.Variable(tf.random_uniform([decoder_hidden_units, vocab_size], -1, 1),
                dtype=tf.float32, name='weight')

b = tf.Variable(tf.zeros([vocab_size]), dtype= tf.float32)
#

## Decoder via tf.nn.raw_rnn

# assert EOS == 1 and PAD == 0

# eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS')
# pad_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='PAD')

# # retrive rows from parm tensors.
# eos_step_embedded = tf.nn.embedding_lookup(embeddings, eos_time_slice)
# pad_step_embedded = tf.nn.embedding_lookup(embeddings, pad_time_slice)
# #


## Manually specifying loop function through time

def loop_fn_initial():
    initial_elements_finished = (0 >= decoder_lengths)   # all false at the initial stage

    #end of sentence
    initial_input = None

    # last time steps cell state
    initial_cell_state = encoder_final_state

    initial_cell_ouput = None
    initial_loop_state = None

    return (initial_elements_finished,
            initial_input,
            initial_cell_state,
            initial_cell_ouput,
            initial_loop_state)
#

## Attention mechanism - choose which previously generated token to pass as input in the next time steps

def loop_fn_transition(time,
                       previous_output,
                       previous_state,
                       previous_loop_state):

    def get_next_input():
        # dot product betn previous output weights, then+ biases
        output_logits = tf.add(tf.matmul(previous_output, W), b)

        #prediction = tf.argmax(output_logits, axis=1)
        #next_input = tf.nn.embedding_lookup(embeddings, prediction)
        return output_logits

    elements_finished = (time >= decoder_lengths)
    # this operation produces bool tensor of batch_size

    finished = tf.reduce_all(elements_finished)
    #inputp = tf.cond(finished, lambda: pad_step_embedded, get_next_input)

    # set previous to current
    state = previous_state
    output = previous_output
    loop_state = None

    return (elements_finished,
            finished,
            state,
            output,
            loop_state)

#


def loop_fn(time, previous_output, previous_state, previous_loop_state):
    if previous_state is None:
        assert previous_output is None and previous_state is None
        return loop_fn_initial()
    else:
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

decoder_ouputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(decoder_cell, loop_fn)
decoder_ouputs = decoder_ouputs_ta.stack()

#

decoder_max_steps, decoder_batch_size, decoder_dim = tf.unstack(tf.shape(decoder_ouputs))

decoder_outputs_flat = tf.reshape(decoder_ouputs, (-1, decoder_dim))

decoder_logits_flat = tf.add(tf.matmul(decoder_outputs_flat, W), b)

decoder_logits = tf.reshape(decoder_logits_flat, (decoder_max_steps, decoder_batch_size, vocab_size))

#

decoder_prediction = tf.argmax(decoder_logits, 2)

## optimizer

stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=decoder_targets,
    logits=decoder_logits)

loss = tf.reduce_mean(stepwise_cross_entropy)

train_op = tf.train.AdamOptimizer().minimize(loss)

sess.run(tf.global_variables_initializer())

#
## Training on the toy task
batch_size = 100
batches = helpers.random_sequences(batch_size=batch_size, length_from=3,
                                   length_to=8, vocab_lower=0, vocab_upper=1000)

print('head of the batch:')
for seq in next(batches)[:1]:
    print(seq)

#
datalist = []
with open('./uk_data.txt') as dataf:
    datalist = [(str(w)).strip() for w in dataf]

def next_feed():
    batch = next(batches)
    inputs_, encoder_inputs_lengths_ = helpers.batch(batch)
    # decoder_targets_, _ = helpers.batch([
    #     (sequence) + [EOS] + [PAD] * 2 for sequence in batch
    # ])
    # print(type(batch[0]))

    # inputs_ is the list of all the index that are used for the batches
    encoder_inputs_ = np.ndarray([parser.vocab[datalist[i]].vector for i in inputs_])

    # decoder_targets_, _ = helpers.batch([
    #     ([i*2 for i in sequence]) + [EOS] + [PAD] * 2 for sequence in batch
    # ])

    
    return {
        encoder_inputs: encoder_inputs_,
        encoder_inputs_length: encoder_inputs_lengths_,
        decoder_targets: encoder_inputs_,
    }

#

loss_track = []

max_batches = 30001
batches_in_epoch = 500

try:
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)

        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('    minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_prediction, fd)

            for i, (inp, pred) in enumerate(zip(sess.run(encoder_inputs, fd),
                                                np.array(predict_))):
                allWords.sort(key=lambda w: cosine(w.vector, inp.vector))
                allWords.reverse()
                print('    sample {}'.format(i + 1))
                print('        input        > {}'.format(allWords[0]))
                allWords.sort(key=lambda w: cosine(w.vector, pred.vector))
                allWords.reverse()
                print('        predicted    > {}'.format(pred))
                if i >= 2:
                    break

            print()

except KeyboardInterrupt:
    print('training Interrupted')
