Same architecture form hats_riddle_keras.ipynb_ implemented directly in TensorFlow

In [1]:
import tensorflow as tf

Using the notations from the article we first declare: $a^k, s^k$ (inputs of the actions/hats observed by each agent) and $m, n$ (the index of each agent and the total number of agents).

In [2]:
#a^k
actions_input      = tf.placeholder(tf.float64, shape=(1, 1))
#s^k
observations_input = tf.placeholder(tf.float64, shape=(1, 1))

#(m, n)
index_input        = tf.placeholder(tf.float64, shape=(1, 2))

Now we define $z^k_a$ and $z^k_s$, the output of the MLPs which will be feeded to the LSTM network.

In [3]:
#z_a^k
actions_mlp        = tf.layers.dense(actions_input, 64)
actions_index_mlp  = tf.layers.dense(index_input, 64)

actions_lstm_input = tf.add(actions_mlp, actions_index_mlp)

#z_s^k
observations_mlp        = tf.layers.dense(actions_input, 64)
observations_index_mlp  = tf.layers.dense(index_input, 64)

observations_lstm_input = tf.add(observations_mlp, observations_index_mlp)

Now we define the LSTMs with the outputs $y_a^k, h_a^k$ and $y_s^k, h_s^k$ which will be passed to another MLP
to predict the final decision. 

In [4]:
#actions lstm construction

#reshape the output of the dense layer in order to be a sequence of 64 inputs
actions_lstm_input      = tf.reshape(actions_lstm_input, [1, 1, -1])
actions_lstm_cell       = tf.contrib.rnn.BasicLSTMCell(64)
actions_lstm_outputs, _ = tf.nn.dynamic_rnn(actions_lstm_cell,\
                                                                         actions_lstm_input,\
                                                                         dtype=tf.float64)
# _ is actions_lstm_final_state

#same approach for the observations lstm
observations_lstm_input      = tf.reshape(observations_lstm_input, [1, 1, -1])
observations_lstm_cell       = tf.contrib.rnn.BasicLSTMCell(64, reuse=True)
observations_lstm_outputs, _ = tf.nn.dynamic_rnn(observations_lstm_cell,\
                                                 observations_lstm_input,\
                                                 dtype=tf.float64)
# _ is observations_lstm_final_state

Now we create the last part of the model. We feed the output of the LSTMs to an MLP in order to get the predictions $Q^m, m \in [1, n]$.

In [5]:
prediction_inputs  = tf.concat([observations_lstm_outputs[-1], actions_lstm_outputs[-1]], axis=1)

prediction_layer_1 = tf.layers.dense(prediction_inputs, 64, activation=tf.nn.relu)
prediction_layer_2 = tf.layers.dense(prediction_inputs, 64, activation=tf.nn.relu)
prediction_output  = tf.layers.dense(prediction_inputs, 1, activation=tf.nn.relu)

Let's make a mock prediction:

In [6]:
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    output = sess.run(prediction_output,\
                      feed_dict={actions_input: [[1]], observations_input: [[1]], index_input: [[1, 1]]})

In [7]:
output

array([[ 0.00161385]])

Test everything goes as expected (i.e. shapes are ok and similar things)

In [45]:
import unittest

class TestArchitecture(unittest.TestCase):

    def test_lstm_input_shapes(self):
        self.assertEqual(actions_lstm_input.shape, (1, 64))
        self.assertEqual(observations_lstm_input.shape, (1, 64))
        
    def test_lstm_output_shapes(self):
        self.assertEqual(actions_lstm_outputs[-1].shape, (1, 64))
        self.assertEqual(observations_lstm_outputs[-1].shape, (1, 64))

print("Everything is fine until now!")

Everything is fine until now!
