<a href="https://colab.research.google.com/github/TheLemonPig/RL-SSM/blob/main/RLWM_SSM_differentiable.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install 'pymc>=5.9'
!pip install numpyro
!pip install git+https://github.com/AlexanderFengler/ssm-simulators@main!
!pip install git+https://github.com/AlexanderFengler/LANfactory

In [1]:
import pymc as pm
import numpy as np
import random, pickle
import matplotlib.pyplot as plt
import pandas as pd
import pytensor
import pytensor.tensor as pt
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pickle
import ssms
from ssms.basic_simulators.simulator import simulator
import lanfactory

passing 1
wandb not available
wandb not available


In [None]:
np.random.seed(0)
random.seed(0)

In [None]:
ssms_model = 'lba_angle_3_v1'
model_config = ssms.config.model_config[ssms_model]
param_theta = np.array([0.5, 0.3, 0.2, 0.5, 0.2, 0.0])
res = simulator(
                param_theta,
                model=ssms_model,
                n_samples=2000,
                delta_t=0.001,
                max_t=5,
                )

##PyTensor

####Functions

In [None]:
def rlwm_step(dC, dR, pA, pG, pP, dq_RL, dq_WM):
    """
    rlwm_step: calculate a single RLWM step
    (n_participants, n_choices, n_blocks, n_stimuli)
    dC: Choices (data)
    dR: Rewards (data)
    pA: Alphas (parameter)
    pG: Gammas (parameter)
    pP: Phis (parameter)
    dq_RL: RL Qs (data)
    dq_WM: WM Qs (data)
    """
    cond = pt.switch(pt.lt(dq_RL, dR), 1, 0)
    dq_RL += (cond + (1.0 - cond) * pG) * pA * (dR - dq_RL) * dC
    dq_WM += (cond + (1.0 - cond) * pG) * 1.0 * (dR - dq_WM) * dC
    dq_WM += pP * (1 / dR.shape[1] - dq_WM)
    return [dq_RL, dq_WM]

def rlwm_scan(dC, dR, pA, pG, pP, dq_RL, dq_WM):
    """
    rlwm_scan: calculate a RLWM Q-Values
    (n_trials, n_participants, n_choices, n_blocks, n_stimuli)
    dC: Choices (data)
    dR: Rewards (data)
    pA: Alphas (parameter)
    pG: Gammas (parameter)
    pP: Phis (parameter)
    dq_RL: RL Qs (data)
    dq_WM: WM Qs (data)
    """
    ([dQ_RL, dQ_WM], _) = pytensor.scan(rlwm_step, sequences=[dC, dR, pA, pG, pP], non_sequences=[], outputs_info=[dq_RL, dq_WM])
    shape = dC.shape
    n_trials_m1 = shape[0]-1
    dQ_RL = pt.subtensor.set_subtensor(pt.repeat(dq_RL.reshape((1,shape[1],shape[2],shape[3],shape[4])),shape[0],axis=0)[-n_trials_m1:], dQ_RL[:n_trials_m1])
    dQ_WM = pt.subtensor.set_subtensor(pt.repeat(dq_WM.reshape((1,shape[1],shape[2],shape[3],shape[4])),shape[0],axis=0)[-n_trials_m1:], dQ_WM[:n_trials_m1])
    return dQ_RL, dQ_WM


def rlwmssm_softmax(Qs, pB):
    """
    rlwm_softmax: calculate probabilities using a tempered softmax over Q-Values

    Qs: Q-Values (data)
    pB: Betas (parameter)
    """
    shape = Qs.shape
    tempered_qs = pt.mul(Qs,pB)
    qs_max = pt.max(tempered_qs,axis=2)
    qs_max = pt.repeat(qs_max.reshape((shape[0], shape[1], 1, shape[3], shape[4])), shape[2], axis=2)
    numerator = pt.exp(tempered_qs - qs_max)
    denominator = pt.sum(numerator, axis=2)
    denominator = pt.repeat(denominator.reshape((shape[0], shape[1], 1, shape[3], shape[4])), shape[2], axis=2)
    Ps = numerator / denominator
    return Ps

def rlwm_policy(dC, dq_RL, dq_WM, pB, pC, pE, pR, set_sizes):
    weight = pR * pt.clip(pC/set_sizes, 0, 1)
    Ps_RL = rlwm_softmax(dq_RL, pB)
    Ps_WM = rlwm_softmax(dq_WM, pB)
    pol = weight * Ps_WM + (1.0 - weight) * Ps_RL
    pol_final = (1.0 - pE) * pol + pE * 1.0/dC.shape[2]
    return pol_final

def rlwmssm_likelihood(dC, dq_RL, dq_WM, pB, pC, pE, pR, set_sizes):
    """
    rlwm_likelihood: calculate RLWM Likelihoods from precomputed Q-Values

    dC: Choices (data)
    dq_RL: Precomputed RL Qs (data)
    dq_WM: Precomputed WM Qs (data)
    pB: Betas (parameter)
    pC: Working Memory Capacities (parameter)
    pE: Epsilons (parameter)
    pR: Rhos (parameter)
    set_sizes: set sizes for each participant block (data)
    """
    pol_final = rlwm_policy(dC, dq_RL, dq_WM, pB, pC, pE, pR, set_sizes)
    # pol_final: (n_trials, n_participants, n_choices, n_blocks, n_stimuli)
    p_select = (pol_final * dC).sum(axis=[2,4])
    # p_select: (n_trials, n_participants, n_blocks)
    # p_select_padded = p_select
    p_select_padded = p_select + (1.0 - dC.sum(axis=[2,4]))
    ll_select = pt.log(p_select_padded)
    return ll_select


def rlwmssm_recovery(dq_RL, dq_WM, dC, dR, pA, pB, pC, pE, pG, pP, pR, set_sizes):
    """
    rlwm_likelihood: calculate RLWM Likelihoods given over a valid set of parameters and complete dataset of choices, rewards, and set sizes

    dC: Choices (data)
    dq_RL: Precomputed RL Qs (data)
    dq_WM: Precomputed WM Qs (data)
    pB: Betas (parameter)
    pC: Working Memory Capacities (parameter)
    pE: Epsilons (parameter)
    pR: Rhos (parameter)
    set_sizes: set sizes for each participant block (data)
    """
    dq_RL, dq_WM = rlwm_scan(dC, dR, pA, pG, pP, dq_RL, dq_WM)
    likelihood = rlwmssm_likelihood(dC, dq_RL, dq_WM, pB, pC, pE, pR, set_sizes)

    return likelihood

####Compilers

In [None]:
def rlwm_step_compile():
    dC3 = pt.dtensor4("dC3")
    dR3 = pt.dtensor4("dR3")
    dq_RL3 = pt.dtensor4("dq_RL3")
    dq_WM3 = pt.dtensor4("dq_WM3")
    pA3 = pt.dtensor4("pA3")
    pG3 = pt.dtensor4("pG3")
    pP3 = pt.dtensor4("pP3")

    dq_RL, dq_WM = rlwm_step(dC3, dR3, pA3, pG3, pP3, dq_RL3, dq_WM3)
    rlwm_step_func = pytensor.function(inputs=[dC3, dR3, pA3, pG3, pP3, dq_RL3, dq_WM3], outputs=[dq_RL, dq_WM])

    return rlwm_step_func


def rlwm_scan_compile():
    dC4 = pt.dtensor5("dC4")
    dR4 = pt.dtensor5("dR4")
    dq_RL3 = pt.dtensor4("dq_RL3")
    dq_WM3 = pt.dtensor4("dq_WM3")
    pA4 = pt.dtensor5("pA4")
    pG4 = pt.dtensor5("pG4")
    pP4 = pt.dtensor5("pP4")

    dq_RL, dq_WM = rlwm_scan(dC4, dR4, pA4, pG4, pP4, dq_RL3, dq_WM3)
    rlwm_step_func = pytensor.function(inputs=[dC4, dR4, pA4, pG4, pP4, dq_RL3, dq_WM3], outputs=[dq_RL, dq_WM])

    return rlwm_step_func


def rlwm_softmax_compile():
    Qs = pt.dtensor5('Qs')
    B = pt.dtensor5('B')

    Ps = rlwm_softmax(Qs, B)
    Ps_func = pytensor.function(inputs=[Qs, B], outputs=Ps)

    return Ps_func


def rlwm_likelihood_compile():
    dC4 = pt.dtensor5("dC4")
    dq_RL4 = pt.dtensor5("dq_RL4")
    dq_WM4 = pt.dtensor5("dq_WM4")
    pB4 = pt.dtensor5("pB4")
    pC4 = pt.dtensor5("pC4")
    pE4 = pt.dtensor5("pE4")
    pR4 = pt.dtensor5("pR4")
    set_sizes = pt.dtensor5("set_sizes")

    likelihood = rlwm_likelihood(dC4, dq_RL4, dq_WM4, pB4, pC4, pE4, pR4, set_sizes)
    rlwm_likelihood_func = pytensor.function(inputs=[dC4, dq_RL4, dq_WM4, pB4, pC4, pE4, pR4, set_sizes], outputs=likelihood)

    return rlwm_likelihood_func

def rlwm_Ps_compile():

    dq_RL = pt.dtensor4("dq_RL")
    dq_WM = pt.dtensor4("dq_WM")
    dC = pt.dtensor5("dC")
    dR = pt.dtensor5("dR")
    pA = pt.dtensor5("pA")
    pB = pt.dtensor5("pB")
    pG = pt.dtensor5("pG")
    pP = pt.dtensor5("pP")

    Ps_RL, Ps_WM = rlwm_Ps(dq_RL, dq_WM, dC, dR, pA, pB, pG, pP)
    rlwm_Ps_func = pytensor.function(inputs=[dq_RL, dq_WM, dC, dR, pA, pB, pG, pP], outputs=[Ps_RL, Ps_WM])

    return rlwm_Ps_func

def rlwm_recovery_compile():
    dq_RL = pt.dtensor4("dq_RL")
    dq_WM = pt.dtensor4("dq_WM")
    dC = pt.dtensor5("dC")
    dR = pt.dtensor5("dR")
    pA = pt.dtensor5("pA")
    pB = pt.dtensor5("pB")
    pC = pt.dtensor5("pC")
    pE = pt.dtensor5("pE")
    pG = pt.dtensor5("pG")
    pP = pt.dtensor5("pP")
    pR = pt.dtensor5("pR")
    set_sizes = pt.dtensor5("set_sizes")

    likelihood = rlwm_recovery(dq_RL, dq_WM, dC, dR, pA, pB, pC, pE, pG, pP, pR, set_sizes)
    rlwm_recovery_func = pytensor.function(inputs=[dq_RL, dq_WM, dC, dR, pA, pB, pC, pE, pG, pP, pR, set_sizes], outputs=likelihood)

    return rlwm_recovery_func