In [1]:
import tensorflow as tf
import numpy as np

# Define parameters

In [2]:
parameters = {}
parameters['seq_length'] = 5
parameters['n_input'] = 3
parameters['n_output'] = 3
parameters['n_hidden'] = 4
parameters['init_stdev'] = 0.1
parameters['learning_rate'] = 0.01

# Define model

In [3]:
# Some auxilar functions
def _seq_length(sequence):
    used = tf.sign(tf.reduce_max(tf.abs(sequence), reduction_indices=2))
    length = tf.reduce_sum(used, reduction_indices=1)
    length = tf.cast(length, tf.int32)
    return length

def _last_relevant(output, length):
    batch_size = tf.shape(output)[0]
    max_length = tf.shape(output)[1]
    out_size = int(output.get_shape()[2])
    index = tf.range(0, batch_size) * max_length + (length - 1)
    flat = tf.reshape(output, [-1, out_size])
    relevant = tf.gather(flat, index)

    return relevant

In [4]:
# Define placeholders
x = tf.placeholder("float", [None, parameters['seq_length'], parameters['n_input']], name='x')
y = tf.placeholder("float", [None, parameters['n_output']], name='y')

# Define weights and bias - For now we will try with attention to hidden state 
weights = {
    'alphas': tf.Variable(tf.random_normal([parameters['n_hidden'], 1], stddev=parameters['init_stdev'])),
    'out': tf.Variable(tf.random_normal([parameters['n_input'], parameters['n_output']], stddev=parameters['init_stdev']), name='w_out')
        }
biases = {
    'out': tf.Variable(tf.random_normal([parameters['n_output']]), name='b_out'),
    'alphas': tf.Variable(tf.random_normal([1]), name='b_alphas')
}

# Define RNN
rnn_cell = tf.contrib.rnn.LSTMCell(parameters['n_hidden'])
outputs, states = tf.nn.dynamic_rnn(
    rnn_cell,
    x,
    dtype=tf.float32,
    sequence_length=_seq_length(x)
)

# Define attention weihts
outputs_reshaped = tf.reshape(outputs, [-1, int(outputs.get_shape()[2])])
ejs = tf.matmul(outputs_reshaped, weights['alphas']) + biases['alphas'] 
ejs_reshaped = tf.reshape(ejs, [-1, int(outputs.get_shape()[1])])
alphas = tf.nn.softmax(ejs_reshaped, name='attention_weights') 
reshaped_alphas = tf.reshape(alphas, [-1, 1])
# Define context
x_reshaped = tf.reshape(x, [-1, int(x.get_shape()[2])])
context = reshaped_alphas * x_reshaped
context_reshaped = tf.reshape(context, [-1, parameters['seq_length'], int(context.get_shape()[1])])
context_reduced = tf.reduce_sum(context_reshaped, axis= 1)
# Define logits and loss
logits = tf.matmul(context_reduced, weights['out']) + biases['out']
pred_prob = tf.nn.softmax(logits, name="predictions")
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))

#Define optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=parameters['learning_rate']).minimize(loss)

# Initialization
init = tf.global_variables_initializer()


Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.



# Define dataset

In [5]:
# Sample 1 - The typical example, always the second element, which leads to the output of the second element to be one
x1 = np.array([[0,1,0], [0,1,0],[0,1,0],[0,1,0],[0,1,0]])
y1 = np.array([0,1,0])
# Samples 2-6 - If in some point the first element is 1, then the output will be one for the first element
x2 = np.array([[1,0,0], [0,1,0],[0,1,0],[0,1,0],[0,1,0]])
y2 = np.array([1,0,0])
x3 = np.array([[0,1,0], [1,0,0],[0,1,0],[0,1,0],[0,1,0]])
y3 = np.array([1,0,0])
x4 = np.array([[0,1,0], [0,1,0],[1,0,0],[0,1,0],[0,1,0]])
y4 = np.array([1,0,0])
x5 = np.array([[0,1,0], [0,1,0],[0,1,0],[1,0,0],[0,1,0]])
y5 = np.array([1,0,0])
x6 = np.array([[0,1,0], [0,1,0],[0,1,0],[0,1,0],[1,0,0]])
y6 = np.array([1,0,0])
# Samples 7 - 11 - If in some point the last element is 1, then the output will be one for the first element
x7 = np.array([[0,0,1], [0,1,0],[0,1,0],[0,1,0],[0,1,0]])
y7 = np.array([0,0,1])
x8 = np.array([[0,1,0], [0,0,1],[0,1,0],[0,1,0],[0,1,0]])
y8 = np.array([0,0,1])
x9 = np.array([[0,1,0], [0,1,0],[0,0,1],[0,1,0],[0,1,0]])
y9 = np.array([0,0,1])
x10 = np.array([[0,1,0], [0,1,0],[0,1,0],[0,0,1],[0,1,0]])
y10 = np.array([0,0,1])
x11 = np.array([[0,1,0], [0,1,0],[0,1,0],[0,1,0],[0,0,1]])
y11 = np.array([0,0,1])

X = [x1, np.copy(x1), np.copy(x1), np.copy(x1), np.copy(x1), np.copy(x1), np.copy(x1), np.copy(x1), 
    x2, x3, x4, x5, x6,
    x7, x8, x9, x10, x11]
Y = [y1, np.copy(y1), np.copy(y1), np.copy(y1), np.copy(y1), np.copy(y1), np.copy(y1), np.copy(y1), 
    y2, y3, y4, y5, y6,
    y7, y8, y9, y10, y11]

# Train and save model

In [6]:
# Start training
saver = tf.train.Saver()
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for step in range(1, 300):
        batch_x = np.array(X)
        batch_y = np.array(Y)
        # Run optimization op (backprop)
        a = sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        if step % 20 == 0 or step == 1:
            # Calculate batch loss and accuracy
            train_loss = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
            print("Step " + str(step) + ", Loss= {:.4f}".format(train_loss))

    print("Optimization Finished!")
   
    # Once trained - Get attention weights for the training samples
    attention_weights = sess.run(alphas, feed_dict={x: batch_x})

    # Saved Model Builder 
    export_path = "models/attentionRNN2/SavedModelBuilder/"
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING])
    builder.save()
    
    
    


Step 1, Loss= 1.2369
Step 20, Loss= 1.0623
Step 40, Loss= 0.9454
Step 60, Loss= 0.6732
Step 80, Loss= 0.4003
Step 100, Loss= 0.2591
Step 120, Loss= 0.1862
Step 140, Loss= 0.1432
Step 160, Loss= 0.1151
Step 180, Loss= 0.0953
Step 200, Loss= 0.0807
Step 220, Loss= 0.0695
Step 240, Loss= 0.0607
Step 260, Loss= 0.0536
Step 280, Loss= 0.0478
Optimization Finished!
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'models/attentionRNN2/SavedModelBuilder/saved_model.pb'


# Check attention weights

In [7]:
attention_weights

array([[  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  3.52991700e-01,   1.88761383e-01,   1.55129313e-01,
          1.52590171e-01,   1.50527418e-01],
       [  9.94235814e-01,   2.82458030e-03,   1.18486839e-03,
          8.85288464e-04,   8.69476702e-04],
       [  1.17540313e-03,   9.9534660

In [12]:
print(X[0])
print(attention_weights[0])

[[0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]]
[ 0.3529917   0.18876138  0.15512931  0.15259017  0.15052742]


In [13]:
print(X[1])
print(attention_weights[1])

[[0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]]
[ 0.3529917   0.18876138  0.15512931  0.15259017  0.15052742]


In [14]:
print(X[8])
print(attention_weights[8])

[[1 0 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]]
[  9.94235814e-01   2.82458030e-03   1.18486839e-03   8.85288464e-04
   8.69476702e-04]


In [15]:
print(X[9])
print(attention_weights[9])

[[0 1 0]
 [1 0 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]]
[  1.17540313e-03   9.95346606e-01   2.22401810e-03   7.30713364e-04
   5.23235532e-04]


In [16]:
print(X[10])
print(attention_weights[10])

[[0 1 0]
 [0 1 0]
 [1 0 0]
 [0 1 0]
 [0 1 0]]
[  9.98247648e-04   5.33810526e-04   9.95851040e-01   1.98952458e-03
   6.27387140e-04]


In [17]:
print(X[-1])
print(attention_weights[-1])

[[0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 0 1]]
[  7.74430169e-04   4.14124486e-04   3.40338913e-04   3.34768440e-04
   9.98136282e-01]
