# Two-step-task as described in "Prefrontal cortex as a meta-reinforcement learning system"

This iPython notebook includes an implementation of the two-step task as described here: [biorxiv link](https://www.biorxiv.org/content/early/2018/04/13/295964).

The difference with our first trial (in `biorxiv-first-try.ipynb`) is that, in this notebook, **the A2C LSTM receives only actions made at stage 1 and the rewards received after leaving stage 2.** Therefore, the LSTM receives "two-steps in one", as implemented in "trial" from the class "two_step_task()".

Note: in the biorxiv paper, they do not point out to any state in the first_stage, and refer to the second stage states as S_1 and S_2. In this notebook, we use the terminology presented in their [previous work on arxiv](https://arxiv.org/abs/1611.05763), where at the first stage there is one state S_1, and at the second stage there are two states S_2 and S_3.

For this final step, the goal was to reproduce the plots from the [biorxiv pre-print](https://www.biorxiv.org/content/early/2018/04/13/295964) (Simulation 4, Figure b) ). To that end, we launched n=8 trainings using different seeds, but with the same hyperparameters as the paper, to compare to the results obtained by Wang et al.

For each seed, the training consisted of 20k episodes of 100 trials (instead of 10k episodes of 100 trials in the paper). The reason for our number of episodes choice is that, in our case, the learning seemed to converge after around ~20k episodes for most seeds, without any significant gap in reward before ~15k episodes.

 ![reward curve](results/biorxiv/final/reward_curve.png)
 
After training, we tested the 8 different models for 300 further episodes (like in the paper), with the weights of the LSTM being fixed. 

Here is the side by side comparison of our results (on the left) with the results from the paper (on the right):

![side by side](results/biorxiv/final/side_by_side.png)

Running the cells below will reproduce those tests. It will generate 8 different plots of the probabilities of repeating an action for a common/uncommon transition, if the last action was rewarded/unrewarded. Finally, it will average those plots to output a final plot, to reproduce the Figure b) from Simulation 4 in biorxiv. Each datapoint from a different seed is represented by a black dot.


In [1]:
import threading
import multiprocessing

import keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from helper import *

from random import choice
from time import sleep
from time import time



In [2]:
import tensorflow as tf
import tf_slim as slim
import scipy.signal
from PIL import Image
from PIL import ImageDraw 
from PIL import ImageFont
tf.compat.v1.disable_eager_execution()

### Initialization

In [3]:
# Directories/book-keeping
from datetime import datetime

dir_name = "train_" + datetime.now().strftime("%m%d-%H%M%S")

In [4]:
# encoding of the higher stages
S_1 = 0
S_2 = 1
S_3 = 2
nb_states = 3

### Definition of our environment: the two-step task

In [5]:
class two_step_task():
    def __init__(self):
        # start in S_1
        self.state = S_1
        
        # defines what is the stage with the highest expected reward. Initially random
        self.highest_reward_second_stage = np.random.choice([S_2,S_3])
        
        self.num_actions = 2
        self.reset()
        
        # initialization of plotting variables
        common_prob = 0.8
        self.transitions = np.array([
            [common_prob, 1-common_prob],
            [1-common_prob, common_prob]
        ])
        self.transition_count = np.zeros((2,2,2))
        
        self.last_action = None
        self.last_state = None
    
    def get_state(self):
        one_hot_array = np.zeros(nb_states)
        one_hot_array[self.state] = 1
        return one_hot_array

    def possible_switch(self):
        if (np.random.uniform() < 0.025):
            # switches which of S_2 or S_3 has expected reward of 0.9
            self.highest_reward_second_stage = S_2 if (self.highest_reward_second_stage == S_3) else S_3
            
    def get_rprobs(self):
        """
        probability of reward of states S_2 and S_3, in the form [[p, 1-p], [1-p, p]]
        """
        if (self.highest_reward_second_stage == S_2):
            r_prob = 0.9
        else:
            r_prob = 0.1
        
        rewards = np.array([
            [r_prob, 1-r_prob],
            [1-r_prob, r_prob]
        ])
        return rewards
            
    def isCommon(self,action,state):
        if self.transitions[action][state] >= 1/2:
            return True
        return False
        
    def updateStateProb(self,action):
        if self.last_is_rewarded: #R
            if self.last_is_common: #C
                if self.last_action == action: #Rep
                    self.transition_count[0,0,0] += 1
                else: #URep
                    self.transition_count[0,0,1] += 1
            else: #UC
                if self.last_action == action: #Rep
                    self.transition_count[0,1,0] += 1
                else: #URep
                    self.transition_count[0,1,1] += 1
        else: #UR
            if self.last_is_common:
                if self.last_action == action:
                    self.transition_count[1,0,0] += 1
                else:
                    self.transition_count[1,0,1] += 1
            else:
                if self.last_action == action:
                    self.transition_count[1,1,0] += 1
                else:
                    self.transition_count[1,1,1] += 1
                    
        
    def stayProb(self):
        print(self.transition_count)
        row_sums = self.transition_count.sum(axis=-1)
        stay_prob = self.transition_count / row_sums[:,:,np.newaxis] 
       
        return stay_prob

    def reset(self):
        self.timestep = 0
        
        # for the two-step task plots
        self.last_is_common = None
        self.last_is_rewarded = None
        self.last_action = None
        self.last_state = None
        
        # come back to S_1 at the end of an episode
        self.state = S_1
        
        return self.get_state()
        
    def step(self,action):
        self.timestep += 1
        self.last_state = self.state
        
        # get next stage
        if (self.state == S_1):
            # get reward
            reward = 0
            # update stage
            self.state = S_2 if (np.random.uniform() < self.transitions[action][0]) else S_3
            # keep track of stay probability after first action
            if (self.last_action != None):    
                self.updateStateProb(action)
            self.last_action = action
            # book-keeping for plotting
            self.last_is_common = self.isCommon(action,self.state-1)
            
        else:# case S_2 or S_3
            # get probability of reward in stage
            r_prob = 0.9 if (self.highest_reward_second_stage == self.state) else 0.1
            # get reward
            reward = 1 if np.random.uniform() < r_prob else 0
            # update stage
            self.state = S_1
            # book-keeping for plotting
            self.last_is_rewarded = reward

        # new state after the decision
        new_state = self.get_state()
        if self.timestep >= 200: 
            done = True
        else: 
            done = False
        return new_state,reward,done,self.timestep
    
    def trial(self,action):
        # do one action in S_1, and keep track of the perceptually distinguishable state you arive in
        observation,_,_,_ = self.step(action)
        # do the same action in the resulting state (S_2 or S_3). The action doesn't matter, the reward does
        _,reward,done,_ = self.step(action)
        return observation,reward,done,self.timestep

### Actor-Critic Network

### Worker Agent

In [16]:
def loss_a2c(y_true, y_pred):
    squared_difference = tf.square(y_true - y_pred)

    # value_loss = 0.5 * tf.reduce_sum(input_tensor=tf.square(target_v - tf.reshape(value,[-1])))
    # entropy = - tf.reduce_sum(input_tensor=policy * tf.math.log(policy + 1e-7))
    # policy_loss = -tf.reduce_sum(input_tensor=tf.math.log(responsible_outputs + 1e-7)*advantages)
    # loss = 0.05 * value_loss + policy_loss - entropy * 0.05

    # #Get gradients from local network using local losses
    # local_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope)
    # gradients = tf.gradients(ys=loss,xs=local_vars)
    # var_norms = tf.linalg.global_norm(local_vars)
    # grads,grad_norms = tf.clip_by_global_norm(gradients,999.0)

    # #Apply local gradients to global network
    # global_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, 'global')
    # apply_grads = trainer.apply_gradients(zip(grads,global_vars))

    return tf.reduce_mean(squared_difference, axis=-1)



In [19]:
#model for two-step task
model = keras.models.Sequential()
model.add(keras.layers.LSTM(48, input_shape=(None,3)))
model.add(keras.layers.Dense(2,activation='softmax'))

model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)
model.summary()


Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm_3 (LSTM)               (None, 48)                9984      
                                                                 
 dense_3 (Dense)             (None, 2)                 98        
                                                                 
Total params: 10,082
Trainable params: 10,082
Non-trainable params: 0
_________________________________________________________________


## Training

In [20]:
# Hyperparameters for training/testing
gamma = .9
a_size = 2 
n_seeds = 8
num_episode_train = 20000
num_episode_test = 300

In [21]:
game = two_step_task()
done = False
reward = 0
action = 0
t = 0
state = game.reset()
input = state.reshape(1,1,3)
trained = model.predict(input)
print(trained)
output = trained.reshape(1,2)

model.fit(input,output,epochs=1)

# for i in range(10):
#     print(model.predict(state.reshape(1,1,3)))
#     # action = model.__call__(state.reshape(1,1,3))
#     # game.step(action)

  updates=self.state_updates,


InvalidArgumentError: Graph execution error:

Node 'training/Adam/gradients/gradients/lstm/while_grad/lstm/while_grad': Connecting to invalid output 45 of source node lstm/while which has 45 outputs. Try using tf.compat.v1.experimental.output_all_intermediates(True).