In [1]:
import numpy as np
import tensorflow as tf
import random
import matplotlib.pyplot as plt
%matplotlib inline



In [2]:
batch_size = 4
num_batches = 3

In [3]:
# Define a feedforward NN representing the sensory-motor system

input_vis = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32, name = 'input_vis')
input_aud = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32, name = 'input_aud')
input_total = tf.concat([input_vis, input_aud], axis = -1)

num_hidden_sensorymotor = 10
hidden_sensorymotor = tf.contrib.layers.fully_connected(input_total, num_hidden_sensorymotor, activation_fn = tf.nn.relu)
num_out_sensorymotor = 2
out_sensorymotor = tf.contrib.layers.fully_connected(hidden_sensorymotor, num_out_sensorymotor, activation_fn = tf.nn.relu)

In [4]:
# Just train the sensory-motor system to ignore audition
run_sensorymotor_test = False

if run_sensorymotor_test:
    loss = tf.reduce_mean(tf.norm(out_sensorymotor-input_vis))
    learning_rate = 0.01
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op=optimizer.minimize(loss)
    
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    
    num_inputs = batch_size*num_batches
    
    in_vis_list_unshaped = [[np.random.rand(), np.random.rand()] for k in range(num_inputs)] 
    in_vis_list = np.reshape(in_vis_list_unshaped, [num_batches,batch_size,2])
    
    in_aud_list_unshaped = [[np.random.rand(), np.random.rand()] for k in range(num_inputs)]
    in_aud_list = np.reshape(in_aud_list_unshaped, [num_batches,batch_size,2])

    losses = []
    for i in range(num_batches):
        in_vis_list_batch = in_vis_list[i]
        in_aud_list_batch = in_aud_list[i]
        l, it, os, _ =  sess.run([loss, input_total, out_sensorymotor, train_op], feed_dict = {input_vis:in_vis_list_batch, input_aud:in_aud_list_batch})
        losses.append(l)
        
    plt.figure()
    plt.plot(losses)
    plt.title("Loss function versus number of batches")
    plt.show()

In [5]:
# Define a feedforward NN representing the sensory-motor system with an ancillary rule label input

input_vis = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32, name = 'input_vis')
input_aud = tf.placeholder(shape=[batch_size, 2], dtype=tf.float32, name = 'input_aud')
input_cue = tf.placeholder(shape=[batch_size, 1], dtype=tf.float32, name = 'input_cue')
input_total = tf.concat([input_vis, input_aud, input_cue], axis = -1)

num_hidden_sensorymotor = 10
hidden_sensorymotor = tf.contrib.layers.fully_connected(input_total, num_hidden_sensorymotor, activation_fn = tf.nn.relu)
num_out_sensorymotor = 2
out_sensorymotor = tf.contrib.layers.fully_connected(hidden_sensorymotor, num_out_sensorymotor, activation_fn = tf.nn.relu)

In [6]:
# Train the sensory-motor system to follow the given rule
run_sensorymotor_rule_test = False

if run_sensorymotor_rule_test:    
    
    num_inputs = batch_size*num_batches
    
    in_vis_list_unshaped = [[np.random.rand(), np.random.rand()] for k in range(num_inputs)] 
    in_vis_list = np.reshape(in_vis_list_unshaped, [num_batches,batch_size,2])
    
    in_aud_list_unshaped = [[np.random.rand(), np.random.rand()] for k in range(num_inputs)]
    in_aud_list = np.reshape(in_aud_list_unshaped, [num_batches,batch_size,2])
    
    cue_list_unshaped = [[np.random.rand()] for k in range(num_inputs)]
    in_cue_list = np.reshape(cue_list_unshaped, [num_batches,batch_size,1])
                
    learning_rate = 0.01
    optimizer = tf.train.AdamOptimizer(learning_rate)
    
    sess = tf.Session()
    
    losses = []
    for i in range(num_batches):
        in_vis_list_batch = in_vis_list[i]
        in_aud_list_batch = in_aud_list[i]
        in_cue_list_batch = in_cue_list[i]
        
        loss = 0
        for j in range(batch_size):
            if in_cue_list_batch[j] > 0.5:
                loss += tf.norm(tf.gather(out_sensorymotor-input_vis, [j]))
            else:
                loss += tf.norm(tf.gather(out_sensorymotor-input_aud, [j]))
        
        train_op=optimizer.minimize(loss)
        
        sess.run(tf.initialize_all_variables())
        l, it, os, _ =  sess.run([loss, input_total, out_sensorymotor, train_op], feed_dict = {input_vis:in_vis_list_batch, input_aud:in_aud_list_batch, input_cue:in_cue_list_batch})
        losses.append(l)
        print l
        
    plt.figure()
    plt.plot(losses)
    plt.title("Loss function versus number of batches")
    plt.show()

In [7]:
# Now we will create a mapping from cues to rules
# The rule associated with each cue will change with time and a RNN will have to learn and remember the recent mapping
# while not hanging onto it too long as the mapping changes

# Define an RNN representing the PFC

num_units_PFC = 5

PFC_cell = tf.contrib.rnn.LSTMBlockCell(num_units = num_units_PFC)
PFC_state_previous = PFC_cell.zero_state(batch_size, tf.float32) # Initial state of PFC

# This does one cycle of the RNN
def PFC_step(input_data, network_state):
    with tf.variable_scope("PFC", reuse=False):
        return PFC_cell(inputs = input_data, state = network_state) 
    
# Cue inputs

num_timesteps = 5
cue_timeseries = tf.placeholder(shape=[batch_size, num_timesteps, 1], dtype=tf.float32, name = 'cues_timeseries')

for t in range(num_timesteps):

    current_cue = cue_timeseries[:,t]

    PFC_state = PFC_step(input_data = current_cue, network_state = PFC_state_previous)
    # The output from the PFC into the sensorymotor system will be what we were previously calling the cue variable
    # We'll now call it the rule
    PFC_output = tf.contrib.layers.fully_connected(PFC_state[0], 1, activation_fn = tf.nn.relu)

    PFC_state_previous = PFC_state[1]

sess = tf.Session()

ct = np.reshape([np.random.rand() for k in range(num_timesteps * batch_size * num_batches)], [num_batches, batch_size, num_timesteps, 1])
                
sess.run(tf.initialize_all_variables())

for b in range(num_batches):
        
        ct_in = ct[b, :, :]
        
        o = sess.run([PFC_output], feed_dict = {cue_timeseries:ct_in})
                
        print o

[[[ 0.91189957]
  [ 0.19208222]
  [ 0.0549464 ]
  [ 0.0544428 ]
  [ 0.93110164]]

 [[ 0.38644636]
  [ 0.35380668]
  [ 0.22335302]
  [ 0.76741863]
  [ 0.17695578]]

 [[ 0.08195311]
  [ 0.74056103]
  [ 0.10641894]
  [ 0.30737707]
  [ 0.77039321]]

 [[ 0.09094108]
  [ 0.50670226]
  [ 0.28319575]
  [ 0.65714104]
  [ 0.26613835]]]
[array([[ 0.14098595],
       [ 0.11913574],
       [ 0.1442949 ],
       [ 0.12013169]], dtype=float32)]
[[[ 0.17535359]
  [ 0.82616926]
  [ 0.53789599]
  [ 0.1898887 ]
  [ 0.99074654]]

 [[ 0.46410337]
  [ 0.97355378]
  [ 0.44175653]
  [ 0.66903237]
  [ 0.34107548]]

 [[ 0.33992638]
  [ 0.53098467]
  [ 0.44418646]
  [ 0.31725705]
  [ 0.98292983]]

 [[ 0.16214362]
  [ 0.88608922]
  [ 0.24742835]
  [ 0.89696517]
  [ 0.08556624]]]
[array([[ 0.18814668],
       [ 0.17052776],
       [ 0.18522708],
       [ 0.13535105]], dtype=float32)]
[[[ 0.22805152]
  [ 0.95912999]
  [ 0.34717236]
  [ 0.00925994]
  [ 0.90364302]]

 [[ 0.64765425]
  [ 0.67439162]
  [ 0.05964206]
  