# STATIC GRAPH

In [93]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xinit

## DATA

In [94]:
import data_utils
metadata, idx_q, idx_a = data_utils.load_data('../data/')

In [95]:
# add special symbol
i2w = metadata['idx2w'] + ['GO']
w2i = metadata['w2idx']
w2i['GO'] = len(i2w)-1

## Parameters

In [96]:
B = 500
L = len(idx_q[0])
vocab_size = len(i2w)
enc_hdim = 250
dec_hdim = enc_hdim

## Graph

In [97]:
tf.reset_default_graph()

### Placeholders

In [98]:
inputs = tf.placeholder(tf.int32, shape=[B,L], name='inputs')
targets = tf.placeholder(tf.int32, shape=[B,L], name='targets')
decoder_inputs = tf.concat(
    values=[tf.constant(w2i['GO'], dtype=tf.int32, shape=[B,1]), targets[:, 1:]],
    axis=1)
training = tf.placeholder(tf.bool, name='is_training')

In [99]:
decoder_inputs, targets, inputs

(<tf.Tensor 'concat:0' shape=(500, 20) dtype=int32>,
 <tf.Tensor 'targets:0' shape=(500, 20) dtype=int32>,
 <tf.Tensor 'inputs:0' shape=(500, 20) dtype=int32>)

## Embedding

In [100]:
emb_mat = tf.get_variable('emb', shape=[vocab_size, enc_hdim], dtype=tf.float32, 
                         initializer=xinit())
emb_enc_inputs = tf.nn.embedding_lookup(emb_mat, inputs)
emb_dec_inputs = tf.nn.embedding_lookup(emb_mat, decoder_inputs)

## Encoder

In [101]:
with tf.variable_scope('encoder'):
    encoder_cell = tf.contrib.rnn.LSTMCell(num_units=enc_hdim)
    encoder_outputs, context = tf.nn.dynamic_rnn(cell=encoder_cell,
                          inputs=emb_enc_inputs, 
                          initial_state =
                                encoder_cell.zero_state(batch_size=B, dtype=tf.float32)
                         )

## Decoder

In [102]:
# inference - custom rnn
def inference_decoder(cell, state):
    #tf.get_variable_scope().reuse_variables()
    with tf.variable_scope('decoder_1'):
        dec_outputs = []
        input_ = tf.constant(w2i['GO'], dtype=tf.int32, shape=[B,])
        input_ = tf.nn.embedding_lookup(emb_mat, input_)
        
        for i in range(L):
            decoder_output, state = cell(input_, state)
            dec_outputs.append(decoder_output)
            input_ = decoder_output

    return tf.stack(dec_outputs), None

In [103]:
with tf.variable_scope('decoder') as scope:
    decoder_cell = tf.contrib.rnn.LSTMCell(num_units=dec_hdim)
    
    decoder_outputs, _ = tf.nn.dynamic_rnn(cell=decoder_cell,
                          inputs=emb_dec_inputs, 
                          initial_state= context,
                          scope='decoder_1')
                         
    tf.get_variable_scope().reuse_variables()
    
    decoder_outputs_inf, _ = inference_decoder(decoder_cell, context)
    #decoder_outputs_inf, _ = inference_decoder(decoder_cell, context)
    '''
    decoder_outputs_inf, _ = tf.nn.dynamic_rnn(cell=decoder_cell,
                          inputs=emb_dec_inputs, 
                          initial_state= context
                         )
    '''

## Logits and Probabilities

In [104]:
Wo = tf.get_variable('Wo', shape=[dec_hdim, vocab_size], dtype=tf.float32, 
                         initializer=xinit())
bo = tf.get_variable('bo', shape=[vocab_size], dtype=tf.float32, 
                         initializer=xinit())
proj_outputs = tf.matmul(tf.reshape(decoder_outputs, [B*L, dec_hdim]), Wo) + bo
proj_outputs_inf = tf.matmul(tf.reshape(decoder_outputs_inf, [B*L, dec_hdim]), Wo) + bo

logits = tf.cond(tf.random_normal(shape=()) > 0.,
    lambda : tf.reshape(proj_outputs, [B, L, vocab_size]),
    lambda : tf.reshape(proj_outputs_inf, [B, L, vocab_size])
                )
probs = tf.nn.softmax(tf.reshape(proj_outputs_inf, [B, L, vocab_size]))

## Loss

In [105]:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits =  logits,
    labels = targets)
loss = tf.reduce_mean(cross_entropy)

## Optimizatin

In [106]:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)

## Inference

In [107]:
prediction = tf.argmax(probs, axis=-1)

# TRAINING

In [108]:
config = tf.ConfigProto(allow_soft_placement = True)
sess = tf.InteractiveSession(config = config)
sess.run(tf.global_variables_initializer())

### Training parameters

In [109]:
num_epochs = 20

## Begin

In [110]:
for i in range(num_epochs):
    avg_loss = 0.
    for j in range(len(idx_q)//B):
        _, loss_v = sess.run([train_op, loss], feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })
        avg_loss += loss_v
        if j and j%30==0:
            print('{}.{} : {}'.format(i,j,avg_loss/30))
            avg_loss = 0.

0.30 : 4.433158040046692
0.60 : 3.6232658942540485
0.90 : 3.3398260196050007
0.120 : 3.3249510685602823
0.150 : 3.153806185722351
0.180 : 3.1631939729054768
0.210 : 2.77339795033137
0.240 : 2.5307992060979205
1.30 : 2.6548523704210916
1.60 : 2.3550113836924234
1.90 : 1.8782964547475178
1.120 : 2.1380068480968477
1.150 : 2.3423380653063455
1.180 : 1.7630279620488485
1.210 : 1.993270958463351
1.240 : 2.160435663660367
2.30 : 2.4909680277109145
2.60 : 2.4011552969614667
2.90 : 1.9493064304192862
2.120 : 2.0211808661619823
2.150 : 2.3212838311990103
2.180 : 2.312272184093793
2.210 : 1.8191819141308467
2.240 : 2.3361177682876586
3.30 : 2.3499578634897866
3.60 : 1.8851163039604824
3.90 : 2.213368421792984
3.120 : 2.430551919341087
3.150 : 1.9670845369497936
3.180 : 2.083962631225586
3.210 : 1.9262704720099768
3.240 : 1.9811260292927424
4.30 : 2.4107699235280355
4.60 : 1.5530174632867177
4.90 : 2.0975517590840655
4.120 : 1.3243701130151748
4.150 : 1.7403157333532968
4.180 : 1.8582387218872707

In [116]:
j = 117
pred_v = sess.run(prediction, feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            #targets : idx_a[j*B:(j+1)*B]
        })

In [None]:
pred_v[0], idx_a[j*B:(j+1)*B][0]

In [117]:
def arr2sent(arr):
    return ' '.join([i2w[item] for item in arr])

In [118]:
for i in range(B):
    print(arr2sent(pred_v[i]))

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _


In [22]:
arr2sent( idx_a[j*B:(j+1)*B][11])

'hey whats that in the background racist _ _ _ _ _ _ _ _ _ _ _ _ _'