# STATIC GRAPH

In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xinit

## DATA

In [2]:
import data_utils
metadata, idx_q, idx_a = data_utils.load_data('../data/')

In [3]:
# add special symbol
i2w = metadata['idx2w'] + ['GO']
w2i = metadata['w2idx']
w2i['GO'] = len(i2w)-1

## Parameters

In [4]:
B = 1000
L = len(idx_q[0])
vocab_size = len(i2w)
enc_hdim = 250
dec_hdim = enc_hdim

## Graph

In [5]:
tf.reset_default_graph()

### Placeholders

In [6]:
inputs = tf.placeholder(tf.int32, shape=[B,L], name='inputs')
targets = tf.placeholder(tf.int32, shape=[B,L], name='targets')
decoder_inputs = tf.concat(
    values=[tf.constant(w2i['GO'], dtype=tf.int32, shape=[B,1]), targets[:, 1:]],
    axis=1)

In [7]:
decoder_inputs, targets, inputs

(<tf.Tensor 'concat:0' shape=(1000, 20) dtype=int32>,
 <tf.Tensor 'targets:0' shape=(1000, 20) dtype=int32>,
 <tf.Tensor 'inputs:0' shape=(1000, 20) dtype=int32>)

## Embedding

In [8]:
emb_mat = tf.get_variable('emb', shape=[vocab_size, enc_hdim], dtype=tf.float32, 
                         initializer=xinit())
emb_enc_inputs = tf.nn.embedding_lookup(emb_mat, inputs)
emb_dec_inputs = tf.nn.embedding_lookup(emb_mat, decoder_inputs)

## Encoder

In [9]:
with tf.variable_scope('encoder'):
    encoder_cell = tf.contrib.rnn.LSTMCell(num_units=enc_hdim)
    encoder_outputs, context = tf.nn.dynamic_rnn(cell=encoder_cell,
                          inputs=emb_enc_inputs, 
                          initial_state =
                                encoder_cell.zero_state(batch_size=B, dtype=tf.float32)
                         )

## Decoder

In [10]:
with tf.variable_scope('decoder'):
    decoder_cell = tf.contrib.rnn.LSTMCell(num_units=dec_hdim)
    decoder_outputs, _ = tf.nn.dynamic_rnn(cell=decoder_cell,
                          inputs=emb_dec_inputs, 
                          initial_state= context
                         )

## Logits and Probabilities

In [11]:
Wo = tf.get_variable('Wo', shape=[dec_hdim, vocab_size], dtype=tf.float32, 
                         initializer=xinit())
bo = tf.get_variable('bo', shape=[vocab_size], dtype=tf.float32, 
                         initializer=xinit())
proj_outputs = tf.matmul(tf.reshape(decoder_outputs, [B*L, dec_hdim]), Wo) + bo
logits = tf.reshape(proj_outputs, [B, L, vocab_size])
probs = tf.nn.softmax(logits)

## Loss

In [12]:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits = logits,
    labels = targets)
loss = tf.reduce_mean(cross_entropy)

## Optimizatin

In [13]:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)

## Inference

In [14]:
prediction = tf.argmax(probs, axis=-1)

# TRAINING

In [15]:
config = tf.ConfigProto(allow_soft_placement = True)
sess = tf.InteractiveSession(config = config)
sess.run(tf.global_variables_initializer())

### Training parameters

In [16]:
num_epochs = 10

## Begin

In [17]:
for i in range(num_epochs):
    avg_loss = 0.
    for j in range(len(idx_q)//B):
        _, loss_v = sess.run([train_op, loss], feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })
        avg_loss += loss_v
        if j and j%20==0:
            print('{}.{} : {}'.format(i,j,avg_loss/20))
            avg_loss = 0.

0.20 : 4.374236321449279
0.40 : 2.9668627977371216
0.60 : 2.504377770423889
0.80 : 1.9982671082019805
0.100 : 1.444276547431946
0.120 : 1.019878402352333
1.20 : 0.7292188316583633
1.40 : 0.5727434158325195
1.60 : 0.49078361988067626
1.80 : 0.4375475347042084
1.100 : 0.4037304982542992
1.120 : 0.3748603016138077
2.20 : 0.35517664849758146
2.40 : 0.32739430814981463
2.60 : 0.3153625696897507
2.80 : 0.3068651154637337
2.100 : 0.30491813123226164
2.120 : 0.3002783715724945
3.20 : 0.3121029630303383
3.40 : 0.2952332764863968
3.60 : 0.29264847487211226
3.80 : 0.2903579115867615
3.100 : 0.2914220854640007
3.120 : 0.29028379172086716
4.20 : 0.3058256059885025
4.40 : 0.2903655216097832
4.60 : 0.28862053602933885
4.80 : 0.2873972564935684
4.100 : 0.2889208883047104
4.120 : 0.2882304400205612
5.20 : 0.3042555972933769
5.40 : 0.2887479826807976
5.60 : 0.2871220126748085
5.80 : 0.2859190031886101
5.100 : 0.28765359073877333
5.120 : 0.28710009157657623
6.20 : 0.3032954305410385
6.40 : 0.287793479859

In [19]:
j = 7 
pred_v = sess.run(prediction, feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })

In [20]:
pred_v[0], idx_a[j*B:(j+1)*B][0]

(array([  4,  45, 754,  11, 694,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0]),
 array([1989,   45,  754,   11,  694,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0], dtype=int32))

In [21]:
def arr2sent(arr):
    return ' '.join([i2w[item] for item in arr])

In [26]:
arr2sent(pred_v[11])

'i as if her supporters just accept shes a crook _ _ _ _ _ _ _ _ _ _'

In [27]:
arr2sent( idx_a[j*B:(j+1)*B][11])

'its as if her supporters just accept shes a crook _ _ _ _ _ _ _ _ _ _'