In [None]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch.autograd import Variable

In [None]:
np.random.seed(0)
num_objects = 10
num_stimuli = 4

class Shadlen:

    def __init__(self):
        self.actions_correctness = []
        self.step = 0
        self.criterion = 0.5
        self.stimuli = self.generate_stimuli(num_stimuli)
        self.state = self.to_one_hot([self.step] + self.stimuli)

    def to_one_hot(self, array, N=10):
        array = np.array(array).astype(int)
        labels = np.zeros((array.shape[0], N), dtype=np.float32)
        labels[np.arange(array.shape[0]), array] = 1.
        return labels

    def generate_stimuli(self, N_s):
        stimuli = []
        for i in range(N_s):
            stimuli.append(np.random.randint(0, num_objects))
        return stimuli

    def reset(self):
        self.__init__()
        return self.state

    def response(self, action):
        if self.step == 1:
            GT = float((np.sum(self.stimuli) / (num_objects * num_stimuli)) < self.criterion)
        else:
            GT = float((np.sum(self.stimuli) / (num_objects * num_stimuli)) > self.criterion)

        self.actions_correctness.append(GT == action)
        self.step += 1

        finished = False

        if self.step == 1:
            reward = 0.

        elif self.step == 2:
            if self.actions_correctness[-2]:
                reward = 0.
            else:
                finished = True
                if self.actions_correctness[-1]:
                    reward = 1.2
                else:
                    reward = -2.

        elif self.step == 3:
            finished = True

            if self.actions_correctness[-2]:
                if self.actions_correctness[-1]:
                    reward = 3.5
                else:
                    reward = 0.5
            else:
                if self.actions_correctness[-1]:
                    reward = 0.
                else:
                    reward = -1.
        else:
            print("StepErr!")
            return

        self.stimuli = self.generate_stimuli(num_stimuli)
        self.state = self.to_one_hot([self.step] + self.stimuli)

        return reward, self.state, finished


In [None]:
# we use the default LSTM, number of input featurs is the inputs of a 5*10 matrix and we choose the number
# of hidden layers to be a 'fitting' number, for now we choose 200
# to compute Q we need a linear operation to be performed on the hidden states of the LSTM, we implement it directly.

input_feature = 10
number_of_hidden = 100
output_feature = 2
number_of_layers = 3

emulator = Shadlen()

class MonkeyLSTM(nn.Module):
    def __init__(self, input_feature, hidden_feature, output_feature, number_of_layers):
        super(MonkeyLSTM, self).__init__()
        self.input_feature = input_feature
        self.hidden_feature = hidden_feature
        self.output_feature = output_feature
        self.layers = number_of_layers
        self.make_q = nn.Linear(hidden_feature, output_feature)
        self.lstm = nn.LSTM(input_feature, hidden_feature, number_of_layers)
        self.reset_hiddens()
        
    def forward(self, input_data):
        input_data = input_data.view(5, 1, -1) # pytorch LSTM expects 3D tensors
        output, (h_t_next, c_t_next) = self.lstm(input_data, (self.h_t, self.c_t))
        q_t = self.make_q(h_t_next[-1])
        
        return q_t
        
    def reset_hiddens(self):
        self.h_t = torch.zeros(self.layers, 1, self.hidden_feature)
        self.c_t = torch.zeros(self.layers, 1, self.hidden_feature)

class Monkey:
    def __init__(self, gamma, epsilon_0):
        self.gamma = gamma
        self.epsilon_0 = epsilon_0
        self.terminal = False
        self.action = 0
        self.reward = 0
        
    def init_monkey(self, init_state):
        self.state = torch.from_numpy(init_state) # we use the emulator initial state for the Monkey
        self.terminal = False
        
    def play(self, number_of_episodes, window, lr):
        
        loss_list = list()
        reward_list = list()
        running_reward = 0.0
        optimizer = optim.Adam(monkey_lstm.parameters(), lr)
        criterion = nn.MSELoss()
        loss_end = 0.0
        for episode in range(number_of_episodes):
            state = torch.from_numpy(emulator.reset())
            monkey.lessGreedy(episode) 
            
            if(episode % window == 0):
                reward_list.append(running_reward / window)
                running_reward = 0.0
                loss_list.append(loss_end)

            finished = False
            
            while finished is False:
                monkey_lstm.reset_hiddens()
                optimizer.zero_grad()

                q = monkey_lstm.forward(state.view(5, 1, -1))
                p = np.random.rand()

                if p < self.epsilon:
                    action = np.random.randint(0, 2)
                else:
                    if(q[0, 0] < q[0, 1]):
                        action = 1
                    else:
                        action = 0

                reward, next_state, finished = emulator.response(action)
                next_state = torch.from_numpy(next_state).view(5, 1, -1)
                monkey_lstm.reset_hiddens()
                q_next = monkey_lstm.forward(next_state.view(5, 1, -1))

                if finished:
                    y = torch.tensor(reward)
                else:
                    if(q_next[0, 0] < q_next[0, 1]):
                        y = torch.tensor(reward) + self.gamma * q_next[0, 1]
                    else:
                        y = torch.tensor(reward) + self.gamma * q_next[0, 0]

                state = next_state
                
                loss = criterion(y, q[0, action])
                loss.backward()
                optimizer.step() 
                
                running_reward = running_reward + reward
        
                if finished is True:
                    loss_end = loss
                
        plt.plot(reward_list)
        plt.title('Mean reward')
        plt.figure()
        plt.plot(loss_list)
        plt.title('Loss')
        
    def lessGreedy(self, episode):
        self.epsilon = (self.epsilon_0) * np.exp(-episode / 1000)

In [None]:
monkey_lstm = MonkeyLSTM(input_feature=10, hidden_feature=100, output_feature=2, number_of_layers=3)
monkey = Monkey(0.4, 0.6)
monkey.play(number_of_episodes=20000, window=200, lr=1e-3)

In [None]:
reward_test = list()

for test in range(500):
    reward_test.append(monkey.reward)
    init_state = emulator.reset()
    monkey.init_monkey(init_state)
    while monkey.terminal is False:
        monkey_lstm.reset_hiddens()
        monkey.play()
        
plt.plot(reward_test)
        

In [None]:
x = list()
for i in range(100):
  x.append(np.exp(i/100))
plt.plot(x)
plt.title('title with \Alpha = and %epsilon=0.5')