In [1]:
import gymnasium as gym
import random
import numpy as np
import pandas as pd
from scipy import interpolate

# Read the CSV file. Option 1 is red, Option 2 is blue
df_red = pd.read_csv('../path_trace_foraging_red.csv')
df_blue = pd.read_csv('../path_trace_foraging_blue.csv')

# Create probability functions with interpolation for both options
p_red = interpolate.interp1d(df_red['trial'], df_red['p'], kind='linear', fill_value='extrapolate')
p_blue = interpolate.interp1d(df_blue['trial'], df_blue['p'], kind='linear', fill_value='extrapolate')


In [None]:
# num_episodes = 10000
max_steps = 300
k = 2

chosen_options = np.empty((max_steps))
chosen_options[:] = np.nan
rewards = np.empty((max_steps))
rewards[:] = np.nan

q_table = np.zeros((k,max_steps)) #normally it would be an array of shape (n_states, k), and we would just update the value in place,
# without storing the previous values. But here we don't have states, and we do care about 
# storing the values for each time step so we can plot it later, so we do it like this

#  Since on the paper it says that participants got to know at the beginning if the decks where good, bad, or mediocre, the RL algorithm
# can have also that information through the bias (the first q_value for each deck).But interestingly they do not do that on the paper,
# they give both options an initial value of 0.5
# q_table[0,0] = 1
# q_table[1,0] = 0
q_table[:,0] = 0.5

# v = np.zeros((n_decks,max_pulls))
def choose_option(time, beta, rng):
    p_choose_red = 1/(1 + np.exp(-beta*(q_table[0,time]-q_table[1,time])))
    if rng.uniform(0,1) < p_choose_red:
        return 0, p_red(time)/100.0, 1
    else:
        return 1, p_blue(time)/100.0, 0
    


In [8]:
def get_reward(probability, rng):
    if rng.uniform(0,1) < probability:
        return 1
    else:
        return 0


In [84]:
#  I'm leaving this here in case I need to remember how to reset the seed to random again. But from now on
# every time I use random numbers I'll use  rng = np.random.default_rng()

import importlib

importlib.reload(random)

<module 'random' from 'c:\\Users\\Alfredo\\.conda\\envs\\yt-RL-tutorial\\lib\\random.py'>

In [131]:
fig_before = create_fig_of_RL_experiment(chosen_options, q_table, p_red, p_blue)
# random.seed(222234234)
alpha = 0.65
beta = 3
rng = np.random.default_rng()
for step in range(max_steps):
    chosen_option, option_prob_value, unchosen_option = choose_option(step, beta, rng)
    reward = get_reward(option_prob_value, rng)
    rewards[step] = reward
    chosen_options[step] = chosen_option
    if step+1 == max_steps:
        break
    q_table[chosen_option,step+1] = q_table[chosen_option,step]+alpha*(reward -q_table[chosen_option,step])
    q_table[unchosen_option,step+1] = q_table[unchosen_option,step]
fig_after = create_fig_of_RL_experiment(chosen_options, q_table, p_red, p_blue)
fig_after.show()
fig_before.show()

In [44]:
import plotly.graph_objects as go
def create_fig_of_RL_experiment(chosen_options, q_table, p_red, p_blue):
    fig = go.Figure()

    #add original probabilities
    x = np.linspace(1,300,300)
    y = p_red(x)
    y2 = p_blue(x)
    fig.add_trace(go.Scatter(x=x,y=y/100.0, name="Option 1 Ideal", mode='lines', opacity=0.5, line={"color":"#C37364"}))
    fig.add_trace(go.Scatter(x=x,y=y2/100.0, name="Option 2 Ideal", mode='lines', opacity=0.5, line={"color":"#136EAC"}))

    #add simple RL model
    fig.add_trace(go.Scatter(x=x,y=q_table[0,:], name="Option 1 Perceived", mode='lines', opacity=1, line={"color":"#C37364"}))
    fig.add_trace(go.Scatter(x=x,y=q_table[1,:], name="Option 2 Perceived", mode='lines', opacity=1, line={"color":"#136EAC"}))

    # Add reward outcome points at the top
    chosen_options_y_red = np.ones(len(chosen_options)) * 1.075  # Slightly above 1
    chosen_options_y_blue = np.ones(len(chosen_options)) * 1.05  # Slightly above 1
    chosen_options_x = np.arange(1, len(chosen_options) + 1)

    # Red dots (reward = 0)
    fig.add_trace(go.Scatter(
        x=chosen_options_x[chosen_options == 0],
        y=chosen_options_y_red[chosen_options == 0],
        mode='markers',
        marker=dict(color='#C37364', size=6),
        name='Chosen option: Red'
    ))

    # Blue dots (reward = 1)
    fig.add_trace(go.Scatter(
        x=chosen_options_x[chosen_options == 1],
        y=chosen_options_y_blue[chosen_options == 1],
        mode='markers',
        marker=dict(color='#136EAC', size=6),
        name='Chosen option: Blue'
    ))

    fig.update_layout(yaxis=dict(range=[0, 1.15]))  # Adjust y-axis to fit dots
    return fig

# Foraging model

In [31]:
max_steps = 300
chosen_options_foraging = np.empty((max_steps))
chosen_options_foraging[:] = np.nan
rewards_foraging = np.empty((max_steps))
rewards_foraging[:] = np.nan

explore_or_exploit = np.zeros((max_steps))
termination_probabilities = np.zeros((max_steps))
q_table_foraging = np.zeros((max_steps))
q_table_foraging[0] = 0.4

def get_p_exploit(last_value, beta, threshold):
    return 1/(1+np.exp(-beta*(last_value-threshold)))

def explore(options, last_value, threshold, beta, alpha, time, rng):
    chosen_option = rng.choice(options)
    probability = p_red(time)/100.0 if chosen_option == 0 else p_blue(time)/100.0
    termination_probability = get_p_exploit(last_value, beta, threshold)
    reward = get_reward(probability, rng)
    last_value = last_value + alpha * (reward - last_value)
    return termination_probability, chosen_option, last_value, reward

def exploit(last_option, last_value, threshold, beta, alpha, time, rng):
    probability = p_red(time)/100.0 if last_option == 0 else p_blue(time)/100.0
    termination_probability = 1 - get_p_exploit(last_value, beta, threshold)
    reward = get_reward(probability, rng)
    last_value = last_value + alpha * (reward - last_value)
    return termination_probability, last_value, reward

In [41]:
beta_foraging = 3
alpha_foraging = 0.3
rng_foraging = np.random.default_rng()
state = "explore"
threshold = 0.4
last_value = q_table_foraging[0]
options = [0,1] #red, blue
fig_before = create_fig_of_RL_foraging_experiment(chosen_options_foraging, q_table_foraging, p_red, p_blue, threshold, termination_probabilities, explore_or_exploit)
for step in range(max_steps):
    match state:
        case "explore":
            termination_probability, last_option, last_value, reward = explore(options, last_value, threshold, beta_foraging, alpha_foraging, step, rng_foraging)
            chosen_options_foraging[step] = last_option
            termination_probabilities[step] = termination_probability
            rewards_foraging[step] = reward 
            explore_or_exploit[step] = 0
            if step+1 < max_steps:
                    q_table_foraging[step+1] = last_value
            if (rng_foraging.uniform(0,1) < termination_probability):
                state = "exploit"
        case "exploit":
            termination_probability, last_value, reward = exploit(last_option, last_value, threshold, beta_foraging, alpha_foraging, step, rng_foraging)
            chosen_options_foraging[step] = last_option
            rewards_foraging[step] = reward
            termination_probabilities[step] = termination_probability
            explore_or_exploit[step] = 1
            if step+1 < max_steps:
                    q_table_foraging[step+1] = last_value
            if (rng_foraging.uniform(0,1) < termination_probability):
                state = "explore"
fig_after = create_fig_of_RL_foraging_experiment(chosen_options_foraging, q_table_foraging, p_red, p_blue, threshold, termination_probabilities, explore_or_exploit)
fig_after.show()
fig_before.show()

In [38]:
import plotly.graph_objects as go
def create_fig_of_RL_foraging_experiment(chosen_options, q_table, p_red, p_blue, threshold, termination_probabilities, explore_or_exploit):
    fig = go.Figure()

    #add original probabilities
    x = np.linspace(1,300,300)
    y = p_red(x)
    y2 = p_blue(x)
    fig.add_trace(go.Scatter(x=x,y=y/100.0, name="Option 1 Ideal", mode='lines', opacity=0.5, line={"color":"#C37364"}))
    fig.add_trace(go.Scatter(x=x,y=y2/100.0, name="Option 2 Ideal", mode='lines', opacity=0.5, line={"color":"#136EAC"}))

    #add foraging  RL model
    # fig.add_trace(go.Scatter(x=x,y=termination_probabilities[:], name="Chosen Option Value", mode='lines', opacity=1, line={"color":"#FF0CE7"}))
    fig.add_trace(go.Scatter(x=x,y=q_table[:], name="Chosen Option Value", mode='lines', opacity=1, line={"color":"#9D516F"}))
    

    # Add reward outcome points at the top
    chosen_options_y_red = np.ones(len(chosen_options)) * 1.075  # Slightly above 1
    chosen_options_y_blue = np.ones(len(chosen_options)) * 1.050  # Slightly above 1
    chosen_options_y_explore = np.ones(len(chosen_options)) * 1.175  # Slightly above 1
    chosen_options_y_exploit = np.ones(len(chosen_options)) * 1.150  # Slightly above 1
    # chosen_options_x = np.arange(1, len(chosen_options) + 1)

    #add threshold
    fig.add_trace(go.Scatter(x=x,y=np.ones(len(chosen_options)) * threshold, name="Chosen Option Value", mode='lines', opacity=1, line={"color":"#399563"}))
    
    # Red dots (reward = 0)
    fig.add_trace(go.Scatter(
        x=x[chosen_options == 0],
        y=chosen_options_y_red[chosen_options == 0],
        mode='markers',
        marker=dict(color='#C37364', size=6),
        name='Chosen option: Red'
    ))

    # Blue dots (reward = 1)
    fig.add_trace(go.Scatter(
        x=x[chosen_options == 1],
        y=chosen_options_y_blue[chosen_options == 1],
        mode='markers',
        marker=dict(color='#136EAC', size=6),
        name='Chosen option: Blue'
    ))
    
    # Explore dots (reward = 0)
    fig.add_trace(go.Scatter(
        x=x[explore_or_exploit == 0],
        y=chosen_options_y_explore[explore_or_exploit == 0],
        mode='markers',
        marker=dict(color="#0DAE25", size=6),
        name='Explore'
    ))

    # Exploit dots (reward = 1)
    fig.add_trace(go.Scatter(
        x=x[explore_or_exploit == 1],
        y=chosen_options_y_exploit[explore_or_exploit == 1],
        mode='markers',
        marker=dict(color="#9E0CA3", size=6),
        name='Exploit'
    ))

    fig.update_layout(yaxis=dict(range=[0, 1.25]))  # Adjust y-axis to fit dots
    return fig