In [4]:
import numpy as np
import pandas as pd
from numba import njit
from scipy.stats import norm, halfnorm, uniform
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from numba import njit

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import bayesflow as beef
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Bidirectional
from tensorflow.keras.models import Sequential
from keras.utils import to_categorical
from sklearn.metrics import r2_score

import sys
# sys.path.append("../src/")
# from priors import sample_mrw_eta, sample_mixture_random_walk
# from likelihood import sample_softmax_rl
# from context import generate_context
# from configurator import configure_input
# from helpers import softmax

In [5]:
# Suppress scientific notation for floats
np.set_printoptions(suppress=True)
# Configure rng
RNG = np.random.default_rng()

In [6]:
# physical_devices = tf.config.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
# print(tf.config.list_physical_devices('GPU'))

In [7]:
@njit
def softmax(x, tau):
    """
    Apply the softmax function to an array of values with a temperature parameter.

    Parameters
    ----------
    x : np.ndarray
        A 1D numpy array containing the input values over which the softmax function is to be applied.
    tau : float
        The temperature parameter controlling the sharpness of the softmax output.
        Must be a positive value.

    Returns
    -------
    np.ndarray
        A 1D numpy array of the same shape as `x`, where each value has been transformed by the softmax function,
        representing a probability distribution.

    Note
    ----
    This function is optimized with Numba's @njit decorator for faster execution.
    """
    e_x = np.exp(x * tau)
    out = e_x / e_x.sum()
    return out

@njit
def select_action(x, p):
    """
    Selects a random action based on a given probability distribution.

    Parameters
    ----------
    x : np.ndarray
        A 1D numpy array of actions or values to select from.
    p : np.ndarray
        A 1D numpy array of probabilities associated with each action in `x`. The sum of all probabilities should
        be 1.

    Returns
    -------
    Any
        A randomly selected action from `x`, chosen according to the probabilities specified in `p`.

    Note
    ----
    This function is optimized with Numba's @njit decorator for faster execution.
    """
    return x[np.searchsorted(np.cumsum(p), np.random.random(), side="right")]

## Prior

In [None]:
# def sample_theta_0():
#     alpha = RNG.uniform(low=0, high=1)
#     tau = RNG.normal(loc=1, scale=30)
#     tau = np.log(1 + np.exp(tau))
#     return np.array([alpha, tau])

## Likelihood

$\alpha \rightarrow$ Learning rate [0, 1]

$\tau \rightarrow$ Inverse temperature [0, ]

$\phi \rightarrow$ Memory decay [0, 1]

$w \rightarrow$ Memory contribution = $p*min(1, \frac{C}{n_S})$

$p \rightarrow$ Initial memory weighting 

$C \rightarrow$ Memory capacity

$n_S \rightarrow$ Set size in current block

$\gamma \rightarrow$ Perseveration

#### Context

Column 1: Stimulus [0, 5]

Column 2: Correct response [0, 2]

Column 3: Block id

Columns 4: Set size

In [53]:
stim = np.repeat(np.arange(5), 15)[:, None]
correct = np.random.randint(0, 3, 5)
correct = np.repeat(correct, 15)[:, None]
block_id = np.repeat(0, 75)[:, None]
set_size = np.repeat(5, 75)[:, None]
context = np.c_[stim, correct, block_id, set_size]
idx = np.random.choice(np.arange(75), 75, replace=False)
context = context[idx]

In [54]:
theta = np.array([0.1, 50, 0.9, 0.5, 1, 0.5])

In [77]:
def sample_rlwm(theta, context):
    alpha, tau, phi, p, c, y = theta
    num_steps = context.shape[0]
    sim_data = np.zeros((num_steps, 2))
    current_block = -1
    for t in range(num_steps):
        # reset subjective values
        if context[t, 2] != current_block:
            current_block = context[t, 2]
            set_size = int(context[t, 3])
            q_values = np.full((set_size, 3), 1/3)
            w_values = np.full((set_size, 3), 1/3)
            w = p * np.minimum(1, c/set_size)
        
        current_stim = int(context[t, 0])

        # choice selection
        pi_rl = softmax(q_values[current_stim], tau)
        pi_wm = softmax(w_values[current_stim], tau)
        pi = w*pi_wm + (1 - w)*pi_rl
        sim_data[t, 0] = select_action(np.arange(3), pi)
        current_resp = int(sim_data[t, 0])

        # feedback
        if sim_data[t, 0] == context[t, 1]:
            sim_data[t, 1] = 1
        else:
            sim_data[t, 1] = 0

        # update values
        pe = sim_data[t, 1] - q_values[current_stim, current_resp]
        if pe < 0:
            q_values[current_stim, current_resp] += (y*alpha) * pe
            w_values[current_stim, current_resp] = y*sim_data[t, 1]
        else:
            q_values[current_stim, current_resp] += alpha * pe
            w_values[current_stim, current_resp] = sim_data[t, 1]

        # memory decay
        w_values += phi * (1/3 - w_values)

    return sim_data

In [78]:
x = sample_rlwm(theta, context)