TODO:
- Figure out dimensions
- Determine if architecture will work properly

In [2761]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

To start with we need to create some input and output data, taking into account the fact that this data might have different lengths. Let us randomly generate some inputs and outputs. These will represent sentences which have already been processed so that words are represented by integer word ids. 

In [2762]:
inputs = [np.random.randint(1,100,n) for n in  np.random.randint(5,11,10)]
outputs = [np.random.randint(1,100,n) for n in  np.random.randint(5,11,10)]

print(inputs)
print(outputs)

[array([63, 97, 85, 31, 55, 20, 47, 67, 79]), array([23, 88, 88, 37, 82, 83, 68]), array([28, 29, 45,  7, 19,  8, 13, 12]), array([67, 10, 59, 67, 72, 67]), array([ 1, 46, 19, 62, 40, 86, 47, 18,  9]), array([38, 49, 85, 93, 95, 80, 53, 98]), array([83, 19, 15, 14, 63, 59]), array([16, 19, 28, 31, 90]), array([47, 99, 81, 30, 54, 51, 23, 86]), array([ 4, 48, 24,  1,  7, 93, 65, 18, 13, 20])]
[array([85, 79, 44, 81,  5, 44, 13, 33]), array([61, 35, 47, 52, 67, 37]), array([94, 95, 17, 61, 26, 19,  4]), array([13, 14, 58, 62, 82]), array([33, 66, 62, 47, 86, 14, 94]), array([97, 17, 82, 99, 24]), array([44, 75, 76, 22, 16, 62, 45, 85]), array([31, 70, 27, 87, 97]), array([ 7, 42, 59, 31, 32, 57, 66]), array([17, 92, 17, 76, 16, 48,  2, 38, 27, 19])]


For use in the RNN, let us calculate sequence lengths

In [2763]:
input_lens = list(map(len,inputs))
output_lens = list(map(len,outputs))

print(input_lens, output_lens)

[9, 7, 8, 6, 9, 8, 6, 5, 8, 10] [8, 6, 7, 5, 7, 5, 8, 5, 7, 10]


The first step is to run a bidirectional RNN and get all the states. We will use tensorflow's birectional_dynamic_RNN for this purpose. 

First let us define placeholder for the input

In [2764]:
tf.reset_default_graph()

In [2765]:
batch_size = 2

In [2766]:
input_ids = tf.placeholder(dtype=tf.int32,shape=[batch_size,None])
seq_lens = tf.placeholder(dtype=tf.int32,shape=[batch_size])

In [2767]:
mask = tf.placeholder(dtype=tf.float32,shape=[batch_size,None])

Now look up embeddings for X

In [2768]:
source_vocab_size = 100
target_vocab_size = 100
embed_size = 20
hidden_size = 25

In [2769]:
source_embeddings = tf.get_variable('source_embedding_matrix',
                            [source_vocab_size+1, embed_size])
encoder_inputs = tf.nn.embedding_lookup(source_embeddings, input_ids)

We will use GRU cells for both directions of the RNN and use tensorflow's bidirectional_dynamic_rnn dynamically unroll the network.

In [2770]:
enc_fw_cell = tf.contrib.rnn.GRUCell(hidden_size)
enc_bw_cell = tf.contrib.rnn.GRUCell(hidden_size)

In [2771]:
out, states = \
tf.nn.bidirectional_dynamic_rnn(cell_fw = enc_fw_cell, 
                                         cell_bw = enc_bw_cell,
                                         inputs = encoder_inputs,
                                         sequence_length = seq_lens,
                                         dtype = tf.float32)

In [2772]:
concat_outputs = tf.concat(out, 2) #F in the pseudo-code

So far this is not much different from running an ordinary RNN as you might do for a language modelling task. However now we need to implement a decoder with attention. Whilst tensorflow has functions that can simplify this process, let us go through it step-by-step.

First we need to calculate an initial state for the decoder

Dimensions of the different matrices.

- X - batch_size x max_sequence_length
- rnn_inputs - batch_size x max_sequence_length x source_embed_size
- Each output - batch_size x max_sequence_length x encoder_state_size
- F = concat_outputs - batch_size x (max_sequence_length\*2) x encoder_state_size
- U - encoder_state_size x decoder_state_size
- bw_output1 - batch_size x encoder_state_size

- s_0 = decoder_state = tf.matmul(U,tf.transpose(bw_output1))

- s_0 = decoder_state - batch_size x decoder_state_size

- W - [dim] x (max_sequence_length\*2)

- X - batch_size x [dim] x encoder_state_size

- V - [dim] x 


- c_t - batch_size x target_embed_size
- P - vocab_size x decoder_state_size
- b - vocab_size

In [2773]:
out[-1].get_shape()

TensorShape([Dimension(2), Dimension(None), Dimension(25)])

In [2774]:
bw_output1 = out[-1][:,1,:]

In [2775]:
bw_output1.get_shape().as_list()

[2, 25]

In [2776]:
U = tf.get_variable('U',[hidden_size,hidden_size])

In [2777]:
padded_inputs = np.array([np.concatenate((i,
            [0 for j in range(10-len(i))])) for i in inputs])

In [2778]:
padded_inputs

array([[ 63.,  97.,  85.,  31.,  55.,  20.,  47.,  67.,  79.,   0.],
       [ 23.,  88.,  88.,  37.,  82.,  83.,  68.,   0.,   0.,   0.],
       [ 28.,  29.,  45.,   7.,  19.,   8.,  13.,  12.,   0.,   0.],
       [ 67.,  10.,  59.,  67.,  72.,  67.,   0.,   0.,   0.,   0.],
       [  1.,  46.,  19.,  62.,  40.,  86.,  47.,  18.,   9.,   0.],
       [ 38.,  49.,  85.,  93.,  95.,  80.,  53.,  98.,   0.,   0.],
       [ 83.,  19.,  15.,  14.,  63.,  59.,   0.,   0.,   0.,   0.],
       [ 16.,  19.,  28.,  31.,  90.,   0.,   0.,   0.,   0.,   0.],
       [ 47.,  99.,  81.,  30.,  54.,  51.,  23.,  86.,   0.,   0.],
       [  4.,  48.,  24.,   1.,   7.,  93.,  65.,  18.,  13.,  20.]])

In [2779]:
mask_array = 1.0*(padded_inputs[:2]>0)

In [2780]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    U_,bw_ = sess.run([U,bw_output1],
                      feed_dict={input_ids: padded_inputs[:2],
                                 mask: mask_array,
                                 seq_lens:input_lens[:2]})

In [2781]:
U_.shape, bw_.shape

((25, 25), (2, 25))

Now we will unroll the decoder network. We will do so for 1 plus the maximum sequence length of the outputs. The reason for the 1 plus the maximum length is that each input will have a special ``<GO>`` symbol at the start. So the target embedding matrix has an additional embedding for this symbol. The reason that the embedding matrix has target_vocab_size + 2 is to account for the fact we zero-pad inputs. Thus we need to ensure that there is a different id for ``<GO>`` so that it gets a different embedding.

In [2782]:
decoder_state = tf.matmul(bw_output1,U)

In [2783]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    d_ = sess.run(decoder_state,feed_dict={input_ids: padded_inputs[:2], seq_lens:input_lens[:2]})

In [2784]:
d_.shape

(2, 25)

In [2785]:
target_embeddings = tf.get_variable('target_embedding_matrix',
                            [target_vocab_size+2, embed_size])

We also define a placeholder for the inputs. This will consist of a batch of embeddings for the 

In [2786]:
target_input = tf.placeholder(tf.int32,[batch_size])

In [2787]:
target_output = tf.placeholder(tf.int32,[batch_size,None])

In [2788]:
decoder_length = max(output_lens)
print(decoder_length)

10


In [2789]:
align_dim = 15

In [2790]:
W = tf.get_variable('W',[1,2*hidden_size,align_dim],dtype=tf.float32)

In [2791]:
W_rep = tf.tile(W,[2,1,1])

In [2792]:
X = tf.matmul(concat_outputs,W_rep)

In [2793]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    c_,W_,Wr_,X_ = sess.run([concat_outputs,W,W_rep,X],feed_dict={input_ids: padded_inputs[:2], seq_lens:input_lens[:2]})

In [2794]:
c_.shape, W_.shape,Wr_.shape, X_.shape

((2, 10, 50), (1, 50, 15), (2, 50, 15), (2, 10, 15))

In [2795]:
V = tf.get_variable('V',[hidden_size,align_dim])

In [2796]:
r_t = tf.matmul(decoder_state,V)

In [2797]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    r_ = sess.run(r_t,feed_dict={input_ids: padded_inputs[:2], seq_lens:input_lens[:2]})

In [2798]:
r_.shape

(2, 15)

In [2799]:
tanh_input = X + tf.expand_dims(r_t,axis=1)

In [2800]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    X_,r_,ti = sess.run([X,r_t,tanh_input],
                        feed_dict={input_ids: padded_inputs[:2], seq_lens:input_lens[:2]})

We do the calculation for $WF$ in advance as it does not depend on the output dimensions.

In [2801]:
ti[0][0]-r_[0],X_[0][0]

(array([ 0.01032581,  0.05872528,  0.03408724,  0.05432902,  0.10617474,
        -0.05815943,  0.0832762 ,  0.00407253,  0.02143284,  0.10079458,
         0.00548378, -0.01343164,  0.01918724, -0.06038353,  0.00208974], dtype=float32),
 array([ 0.01032581,  0.05872528,  0.03408724,  0.05432902,  0.10617474,
        -0.05815942,  0.0832762 ,  0.00407253,  0.02143284,  0.10079458,
         0.00548378, -0.01343164,  0.01918724, -0.06038353,  0.00208974], dtype=float32))

In [2802]:
ti.shape,r_.shape,X_.shape

((2, 10, 15), (2, 15), (2, 10, 15))

In [2803]:
v = tf.get_variable('v',[1,align_dim,1])

In [2804]:
v_rep = tf.tile(v,[2,1,1])

In [2805]:
u_t = tf.matmul(tanh_input,v_rep)

In [2806]:
u_t = tf.squeeze(u_t,axis=2)

In [2807]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    u_,v_,vr_ = sess.run([u_t,v,v_rep],
                        feed_dict={input_ids: padded_inputs[:2], seq_lens:input_lens[:2]})

In [2808]:
u_.shape,v_.shape,vr_.shape

((2, 10), (1, 15, 1), (2, 15, 1))

In [2809]:
exp_u_t = tf.exp(u_t)

In [2810]:
softmax_denom = tf.reduce_sum(exp_u_t*mask,axis=1,keep_dims=True)

In [2811]:
a_t = exp_u_t/softmax_denom

In [2812]:
a_t = a_t*mask

In [2813]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    a_ = sess.run(a_t,
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   seq_lens:input_lens[:2]})

In [2814]:
a_.shape

(2, 10)

In [2815]:
c_.shape

(2, 10, 50)

In [2816]:
d_.shape

(2, 25)

In [2817]:
a_expn = tf.expand_dims(a_t,axis=1)

In [2818]:
c_t = tf.matmul(a_expn,concat_outputs)

In [2819]:
c_t = tf.squeeze(c_t,axis=1)

In [2820]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    ax_,ct_ = sess.run([a_expn,c_t],
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   seq_lens:input_lens[:2]})

In [2821]:
ax_.shape,ct_.shape

((2, 1, 10), (2, 50))

Let us create a cell for the decoder

In [2822]:
decoder_cell = tf.contrib.rnn.GRUCell(hidden_size)

In [2823]:
decoder_embed = tf.nn.embedding_lookup(target_embeddings, tf.zeros([batch_size],tf.int32))

In [2824]:
padded_outputs = np.array([np.concatenate((i,
            [0 for j in range(10-len(i))])) for i in outputs])

In [2825]:
seq_lens_out = tf.placeholder(dtype=tf.int32,shape=[batch_size])

In [2826]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    dec_em = sess.run(decoder_embed,
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2][:,1], 
                                   seq_lens:input_lens[:2],
                                   seq_lens_out:output_lens[:2],
                                  })

In [2827]:
dec_em.shape

(2, 20)

In [2828]:
decoder_input = tf.concat([decoder_embed,c_t],axis=1)

In [2829]:
with tf.variable_scope("RNN"):
    decoder_state_t, _ = decoder_cell(inputs=decoder_input,
                                  state=decoder_state) 

In [2830]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    di_,dst_ = sess.run([decoder_input,decoder_state_t],
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2], 
                                   seq_lens:input_lens[:2],
                                   target_input:padded_outputs[:2][:,0],
                                   seq_lens_out:output_lens[:2],
                                  })

In [2831]:
di_.shape, dst_.shape

((2, 70), (2, 25))

In [2832]:
P = tf.get_variable('P',[hidden_size,target_vocab_size+2])
b = tf.get_variable('b',[target_vocab_size+2])

We also define a mask so that outputs at points beyond the maximum sequence length are not taken into consideration when calculating loss.

In [2833]:
losses = []

In [2834]:
y_t = tf.nn.softmax(tf.matmul(decoder_state_t,P)+b)

In [2835]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    y_ = sess.run(y_t,
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2], 
                                   seq_lens:input_lens[:2],
                                   target_input:padded_outputs[:2][:,0],
                                   seq_lens_out:output_lens[:2],
                                  })

In [2836]:
y_t.shape

TensorShape([Dimension(2), Dimension(102)])

In [2837]:
onehot = tf.one_hot(target_output[:,0],depth=target_vocab_size+2)

In [2838]:
embed_t = tf.argmax(y_t, 1)

In [2839]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    em_t,oh_,to_ = sess.run([embed_t,onehot,target_output],
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2], 
                                   seq_lens:input_lens[:2],
                                   target_input:padded_outputs[:2][:,0],
                                   seq_lens_out:output_lens[:2],
                                   target_output:np.concatenate((np.zeros((1,9)),
                                                               np.ones((1,9))))
                                  })

In [2840]:
em_t, oh_,to_[:,0]

(array([52, 52]),
 array([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0., 

In [2841]:
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                        logits =tf.matmul(decoder_state_t,P)+b,
                        labels = onehot
                    )

In [2842]:
losses.append(cross_entropy)

In [2843]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    xe_= sess.run(cross_entropy,
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2], 
                                   seq_lens:input_lens[:2],
                                   target_input:padded_outputs[:2][:,0],
                                   seq_lens_out:output_lens[:2],
                                   target_output:np.concatenate((np.ones((1,9)),
                                                               np.zeros((1,9))))
                                  })

In [2844]:
xe_

array([ 4.72140074,  4.75428104], dtype=float32)

In [2845]:
r_t2 = tf.matmul(decoder_state,V)

u_t2 = tf.matmul(tanh_input,v_rep)
u_t2 = tf.squeeze(u_t2,axis=2)

exp_u_t2 = tf.exp(u_t2)
softmax_denom = tf.reduce_sum(exp_u_t2*mask,axis=1,keep_dims=True)
a_t2 = exp_u_t2/softmax_denom
a_t2 = a_t2*mask

a_expn2 = tf.expand_dims(a_t2,axis=1)
c_t2 = tf.matmul(a_expn2,concat_outputs)
c_t2 = tf.squeeze(c_t2,axis=1)

decoder_input2 = tf.concat([tf.nn.embedding_lookup(target_embeddings,embed_t),
                            c_t2],axis=1)

In [2846]:
with tf.variable_scope("RNN", reuse=True):
    decoder_state_t2, _ = decoder_cell(inputs=  decoder_input2,
                                      state=decoder_state_t) 

In [2847]:
y_t2 = tf.nn.softmax(tf.matmul(decoder_state_t2,P)+b)

In [2848]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    y2_ = sess.run(y_t2,
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2], 
                                   seq_lens:input_lens[:2],
                                   target_input:padded_outputs[:2][:,0],
                                   seq_lens_out:output_lens[:2],
                                   target_output:padded_outputs[:2][:,1:]
                                  })

In [2849]:
y2_.shape

(2, 102)

In [2850]:
onehot2 = tf.one_hot(target_output[:,1],depth=target_vocab_size+2)

In [2851]:
cross_entropy2 = tf.nn.softmax_cross_entropy_with_logits(
                        logits = tf.matmul(decoder_state_t2,P)+b,
                        labels = onehot
                    )

In [2852]:
losses.append(cross_entropy2)

In [2853]:
target_mask = tf.placeholder(dtype=tf.float32,shape=[batch_size,None])

In [2854]:
losses = [tf.expand_dims(loss,axis=1) for loss in losses]

In [2857]:
loss_mat = tf.concat(losses,axis=1)

In [2858]:
avg_loss = tf.reduce_sum(loss_mat*target_mask)/tf.reduce_sum(target_mask)

In [2861]:
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    oh2_,xe2_,av_,lm_ = sess.run([onehot2,cross_entropy2,avg_loss,loss_mat],
                        feed_dict={input_ids: padded_inputs[:2], 
                                   mask:mask_array,
                                   target_input:padded_outputs[:2], 
                                   seq_lens:input_lens[:2],
                                   target_input:padded_outputs[:2][:,0],
                                   seq_lens_out:output_lens[:2],
                                   target_output:np.concatenate((np.zeros((1,9)),
                                                               np.ones((1,9)))),
                                   target_mask:np.array([[1,0],[1,1]])
                                  })

In [2862]:
 oh2_,xe_,xe2_,av_,lm_

(array([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.

In [2759]:
np.sum(np.vstack((xe_,xe2_))*np.array([[1,0],[1,1]]))/3

ValueError: operands could not be broadcast together with shapes (4,) (2,2) 

In [None]:
mask = tf.placeholder(tf.float32, [])

In [None]:
for t in range(decoder_length):
    decoder_embed = tf.nn.embedding_lookup(target_embeddings, target_input)
    logits = []
    if t > 0:
        tf.get_variable_scope().reuse_variables()
        r_t = tf.matmul(V, decoder_state)
        u_t = tf.matmul(v, tf.tanh(X + r_t))
        a_t = tf.nn.softmax(u_t)
        c_t = tf.matmul(concat_outputs,a_t)
        decoder_input = tf.concat([decoder_embed,c_t])
        decoder_state, _ = decoder_cell(decoder_input,decoder_state) 
        #tf's RNNCell classes return two outputs out and state 
        #but for the GRUCell these are the same
        
        logit = tf.nn.softmax(tf.nn.xw_plus_b(decoder_state,P,b))
        logits.append(tf.argmax(logit, 1))
        onehot = tf.one_hot(decoder_output,depth=target_vocab_size+2)
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                        logits = logit,
                        label = onehot
                    )
        loss = tf.reduce_sum(cross_entropy)

        total_loss += loss
    total_loss = total_loss / tf.reduce_sum(mask[:, 1:])

In [None]:
F = encode_as_matrix(f)
embed = '<START>' #e_0
U = tf.get_variable('U',[hid_dim,inp_dim])
V = tf.get_variable('V',[])
W = tf.get_variable('W',[])
P = tf.get_variable('P',[])

X = tf.matmul(W,F)
state = tf.xw_plus_b(U,h_1)
RNN = tf.contrib.rnn.RNNCell
while embed != '<END>': #e_t
    t = t + 1
    r_t = tf.matmul(V,state)
    u_t = tf.matmul(v,tf.tanh(X + r_t))
    a_t = tf.nn.softmax(u_t)
    c_t = tf.matmul(F,a_t)
    hidden = RNN(hidden,tf.concat([embed,c_t],axis=0)) #s_t
    y_t = tf.nn.softmax(tf.matmul(P,hidden)+b)
    distr = tf.contrib.distributions.Categorical(p=y_t)
    e_t = distr.sample()

In [263]:
enc_fw_cell.state_size

20