<a href="https://colab.research.google.com/github/TheLemonPig/RL-SSM/blob/main/PyTensor_for_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##PyTensor for RL

In [None]:
import pytensor.tensor as pt
from pytensor import function, scan
import numpy as np

Reinforcment Learning is inherently sequential, meaning that we will need to make use of `for` loops.

In PyTensor, this can be acheived in a very narrow sense using the `scan` function.

###2a. Using scan

For more, see: https://pytensor.readthedocs.io/en/latest/library/scan.html

For simplicity, we will start with a numerical example, before moving on to a matrix example.

####2ai. Using scan with numerical inputs/outputs

Suppose we want to calculate the first $n$ terms of the series using the recursion formula: $s_{n+1} = \frac{s_{n}}{n+1} + m$ for $s_0=1$ and any $m$

A python function for this could look like:

In [None]:
def formula(sn,n,m):
    return sn/(n+1) + m

def get_term(n,m):
    sn = 1
    series = np.zeros(n)
    series[0] = sn
    for i in range(n):
        sn = formula(sn,i,m)
        series[i] = sn
    return series

This is how we would write it in PyTensor:

In [None]:
# We have to slightly adapt our function because the order of arguments in the function is incredibly important for getting PyTensor to work correctly
# They should be organized like this: sequences, outputs_info, non_sequences
def formula(n,sn,m):
    return sn/(n+1) + m

def get_term_pytensor(n,m):
    sn = pt.constant(1.0, dtype='float64')  # you will sometimes need to define the type like this. You will need to be very careful about type. PyTensor generally defaults to 32-bit types
    ns = pt.arange(n, dtype='float64')
    term, _ = scan(fn=formula, sequences=ns, non_sequences=m, outputs_info=sn)  # The ordering of
    return term

In [None]:
n_compile = pt.scalar("n", dtype='int64')  # note n must be an integer (you can't have the 3.5th term of a series)
m_compile = pt.scalar("m", dtype='float64')

output = get_term_pytensor(n_compile, m_compile)
my_func = function(inputs=[n_compile,m_compile],outputs=output)  # note the LACK of parentheses for the outputs. You will need these if you have multiple outputs but otherwise you do not want your function unknowingly wrapping your output in a list

It is always important to compare the outputs of your PyTensor function with the original Python function.

In [None]:
get_term(10,1)

array([1.        , 1.5       , 1.8       , 2.07142857, 2.30232558,
       2.51408451, 2.70741483, 2.88810811, 3.05755596, 3.21808401])

In [None]:
my_func(10,1)

array([2.        , 2.        , 1.66666667, 1.41666667, 1.28333333,
       1.21388889, 1.1734127 , 1.14667659, 1.12740851, 1.11274085])

###2b.Using scan for RL

Now because you cannot index PyTensor arrays using arrays, we must one-hot the data so that indexing is not necessary. This saves us having to scan over trials to extract the relevant Q-values for the likelihood. (If you can find another way, then a lot of the code dealing with one-hotting from here on isn't necessary).

**Note:** It is important to know that the first dimension of the tensors that you place in `sequences` should be the dimension that is missing from the tensor you pass to `outputs_info`. `scan` will iterate over subtensors of the sequence tensor by indexing from the first dimension. This dimension will be added to the output of your scan function. In this way `outputs_info` allows you to place an initial value to iterate over, where the intial value is used to calculate the second value and this repeats over each iteration of the scan function.

####2bi.Using scan for single participant RL

For simplicty we will go over what it looks like for an individual participant, and then extend the code to multiple participants

In all of these functions you will notice that we include the dimension for choice. This is the dimension which is being one-hotted: taking the choice 0 is represented as [1, 0], and taking the choice 1 is represented as [0, 1]. We need as many numbers as there are choices if we want to place a 1 at the index of the value of the choice. This is why we need to create an extra dimension which is the size of the number of choices.

**Simulating RL in Python**

In [None]:
from functools import partial
from typing import Callable, List

In [None]:
def softmax(qs, beta):
  return np.exp(qs*beta) / np.exp(qs*beta).sum()

class Distribution:

  def __init__(self, func: Callable, kwargs):
    self.func: Callable = partial(func, **kwargs)

  def __call__(self):
    return self.func.__call__()

class SimpleRL:

  def __init__(self, n_trials: int, distributions: List[Distribution]):
    self.n_choices: int = len(distributions)
    self.n_trials: int = n_trials
    self.distributions: List[Distribution] = distributions
    self.qs: np.array = np.ones((self.n_choices,)) * 0.5
    self.q_trace: np.array = np.ones((self.n_trials,self.n_choices))
    self.rewards: np.array = np.zeros((self.n_trials))
    self.choices: np.array = np.zeros((self.n_trials),dtype=np.int32)

  def simulate(self, alpha, beta):
    for i in range(self.n_trials):
      # Q-values are recorded to trace
      self.q_trace[i] = self.qs
      # softmax decision function
      ps = softmax(self.qs,beta)
      # choice made based on weighted probabilities of Q-values
      choice = np.random.choice(a=self.n_choices,size=1,p=ps)[0]
      # choice is recorded to trace
      self.choices[i] = choice
      # reward calculated
      dist = self.distributions[choice]  # supply a list of distributions to choose from
      reward = dist()  # sample from distribution by calling it
      # Q-values updated
      self.rewards[i] = reward
      self.qs[choice] = self.qs[choice] + alpha * (reward - self.qs[choice])
      self.q_trace[i] = self.qs
    # Q-values trace returned
    # main data to be returned (basis for fits), is choices and rewards per trial
    return self.rewards, self.choices, self.q_trace

In [None]:
seed = 0
np.random.seed(seed)
mean_rewards = [-1.0,1.0]
dists = [Distribution(np.random.normal,{"loc":mn, "scale":1.0}) for mn in mean_rewards]
n_trials = 100

alpha, beta = 0.1, 1.0

simple_rl = SimpleRL(n_trials, dists)
Rs, Cs, Qs = simple_rl.simulate(alpha, beta)

**Recovering Qs in PyTensor**

In [None]:
def rl_step(C, R, Q, A):
    """
    rl_step: function for a single RL step

    C: Choice -- vector (choices)
    R: Reward -- vector (choices)
    Q: Q-Values at the previous time step -- vector (choices)
    A: Alpha parameter -- scalar ()
    """
    return Q + A * (R - Q) * C

In [None]:
def rl_scan(Cs,Rs,A):
    """
    rl_scan: scan function over RL steps

    Cs: Choices -- matrix (trials, choices)
    Rs: Rewards -- matrix (trials, choices)
    A: Alpha parameter -- vector (choices)

    Qs: Q-values -- matrix (trials, choices)
    """
    Qs = pt.ones(Cs.shape[1]) * 0.5 # note Qs shape changes over scan:  vector (choices) --> matrix (trials, choices)
    Qs, _ = scan(fn=rl_step, sequences=[Cs,Rs], non_sequences = [A], outputs_info=Qs)
    return Qs

In [None]:
# Defining the inputs of the to-be-compiled function
Cs = pt.imatrix("Cs")
Rs = pt.dmatrix("Rs")
A = pt.dvector("A")

# Compiling function
output = rl_scan(Cs,Rs,A)
rl_func = function(inputs=[Cs,Rs,A], outputs=output)

In [None]:
seed = 0
np.random.seed(seed)
mean_rewards = [-1.0,1.0]
dists = [Distribution(np.random.normal,{"loc":mn, "scale":1.0}) for mn in mean_rewards]
n_trials = 100

alpha, beta = 0.1, 1.0

simple_rl = SimpleRL(n_trials, dists)
Python_Rs, Python_Cs, Python_Qs = simple_rl.simulate(alpha, beta)
n_choices = len(mean_rewards)

# Testing compiled function
Rs = np.repeat(Python_Rs.reshape(-1,1), repeats = n_choices, axis=1)
Cs = np.zeros((Python_Cs.shape[0], n_choices))
for n in range(n_choices):
  Cs[:, n] = (Python_Cs == n)
Cs = np.array(Cs, dtype=np.int32)
A = np.ones(n_choices) * alpha
PyTensor_Qs = rl_func(Cs,Rs,A)
PyTensor_Qs[:10]

array([[0.5       , 0.62415917],
       [0.5       , 0.81703463],
       [0.5       , 0.75105876],
       [0.5       , 0.97294533],
       [0.5       , 1.10226265],
       [0.5       , 1.04144873],
       [0.5       , 1.08573507],
       [0.5       , 1.13507561],
       [0.5       , 1.10340979],
       [0.5       , 1.23408928]])

In [None]:
Python_Qs[:10]

array([[0.5       , 0.62415917],
       [0.5       , 0.81703463],
       [0.5       , 0.75105876],
       [0.5       , 0.97294533],
       [0.5       , 1.10226265],
       [0.5       , 1.04144873],
       [0.5       , 1.08573507],
       [0.5       , 1.13507561],
       [0.5       , 1.10340979],
       [0.5       , 1.23408928]])

####2bii. Using scan over multiple participants

**Simulating multi-participant RL in Python**

In [None]:
class MultiRL:

  def __init__(self, n_trials: int, n_participants: int, distributions: List[Distribution]):
    self.participants: List[SimpleRL] = [
        SimpleRL(n_trials, distributions) for _ in range(n_participants)
    ]
    self.n_trials = n_trials
    self.n_participants = n_participants
    self.n_choices = len(distributions)
    self.alphas = np.zeros(n_participants)
    self.betas = np.zeros(n_participants)

  def simulate(self, alpha_a, alpha_b, beta_a, beta_b):
    group_data = []
    group_rewards = np.zeros((self.n_trials,self.n_participants))
    group_choices = np.zeros((self.n_trials,self.n_participants))
    group_Qs = np.zeros((self.n_trials,self.n_participants,self.n_choices))
    for idx, participant_model in enumerate(self.participants):
        # sample participant parameters
        alpha = np.random.beta(alpha_a, alpha_b)
        beta = np.random.beta(beta_a, beta_b)
        self.alphas[idx] = alpha
        self.betas[idx] = beta
        # run RL for participant
        participant_rewards, participant_choices, participant_Qs = participant_model.simulate(alpha, beta)
        group_rewards[:, idx] = participant_rewards
        group_choices[:, idx] = participant_choices
        group_Qs[:, idx] = participant_Qs
    return group_rewards, group_choices, group_Qs

  def get_params(self):
    return self.alphas, self.betas


**Recovering multi-participant Qs in PyTensor**

In [None]:
# Notice that this function is written in the same way as for a single participant

def rl_multi_step(C, R, A, Q):  # NOTE: we made A a sequence parameter. This will give us more flexibility in the future
    """
    rl_step: function for a single RL step

    C: Choice -- matrix (participants, choices)
    R: Reward -- matrix (participants, choices)
    Q: Q-Values at the previous time step -- matrix (participants, choices)
    A: Alpha parameter -- matrix (participants, choices)
    """
    return Q + A * (R - Q) * C

In [None]:
# Notice that this function is written in ALMOST the same way as for a single participant
# Tensor3 is how PyTensor defines 3-Dimensional arrays

def rl_multi_scan(CMs,RMs,AM):
    """
    rl_scan: scan function over RL steps

    CMs: Choices -- tensor3 (trials, participants, choices)
    RMs: Rewards -- tensor3 (trials, participants, choices)
    AM: Alpha parameter -- matrix (participants, choices)

    Qs: Q-values -- tensor3 (trials, participants, choices)
    """
    QMs = pt.ones((CMs.shape[1], CMs.shape[2])) * 0.5 # note Qs shape changes over scan:  matrix (participants, choices) --> tensor3 (trials, participants, choices)
    QTs, _ = scan(fn=rl_multi_step, sequences=[CMs,RMs,AM], non_sequences = [], outputs_info=QMs)  # NOTICE THAT WE DO NOT PLACE BRACKETS AROUND THE OUTPUT!!
    return QTs

In [None]:
# Defining the inputs of the to-be-compiled function
CMs = pt.itensor3("CMs")
RMs = pt.dtensor3("RMs")
AM = pt.dtensor3("AM")

# Compiling function
output = rl_multi_scan(CMs,RMs,AM)
rl_func = function(inputs=[CMs,RMs,AM], outputs=output)

In [None]:
seed = 0
np.random.seed(seed)
mean_rewards = [-1.0,1.0]
dists = [Distribution(np.random.normal,{"loc":mn, "scale":1.0}) for mn in mean_rewards]
n_trials = 100
n_participants = 3
n_choices = len(mean_rewards)

alpha_a, alpha_b, beta_a, beta_b = 2, 2, 2, 2
# alpha, beta = 0.1, 1.0 -- For MultiRL the parameters vary with a distribution

multi_rl = MultiRL(n_trials, n_participants, dists)
Python_Rs, Python_Cs, Python_Qs = multi_rl.simulate(alpha_a, alpha_b, beta_a, beta_b)

alphas, betas = multi_rl.get_params()

# Testing compiled function
Rs = np.repeat(Python_Rs.reshape(n_trials,n_participants,1), n_choices, axis=2)
Cs = np.zeros((n_trials, n_participants, n_choices))
for n in range(n_choices):
    Cs[:, :, n] = (Python_Cs == n)
Cs = np.array(Cs, dtype=np.int32)
A = np.tile(alphas.reshape(1,-1,1),[n_trials,1,n_choices])
PyTensor_Qs = rl_func(Cs,Rs,A)
PyTensor_Qs[:5]

array([[[ 0.5       ,  1.73022911],
        [ 0.21225973,  0.5       ],
        [ 0.07119759,  0.5       ]],

       [[ 0.5       ,  0.86919225],
        [-0.03987525,  0.5       ],
        [-0.52974563,  0.5       ]],

       [[ 0.5       ,  1.29766769],
        [-0.03987525,  0.75788475],
        [-0.52974563,  0.61445943]],

       [[ 0.5       ,  1.49373382],
        [-0.26587443,  0.75788475],
        [-0.52974563, -0.30926341]],

       [[ 0.5       ,  1.02332718],
        [-0.4027839 ,  0.75788475],
        [-2.22131283, -0.30926341]]])

In [None]:
Python_Qs[:5]

array([[[ 0.5       ,  1.73022911],
        [ 0.21225973,  0.5       ],
        [ 0.07119759,  0.5       ]],

       [[ 0.5       ,  0.86919225],
        [-0.03987525,  0.5       ],
        [-0.52974563,  0.5       ]],

       [[ 0.5       ,  1.29766769],
        [-0.03987525,  0.75788475],
        [-0.52974563,  0.61445943]],

       [[ 0.5       ,  1.49373382],
        [-0.26587443,  0.75788475],
        [-0.52974563, -0.30926341]],

       [[ 0.5       ,  1.02332718],
        [-0.4027839 ,  0.75788475],
        [-2.22131283, -0.30926341]]])

###2c. Getting RL Likelihoods in PyTensor

Now that we can generate synthetic RL data and we can simulate the internal state of an RL agent used the information we have about choices and rewards, we can now calculate the log-likelihood of our simulated Q-values.

This will be necessary because for human subjects we will not know what their "parameters" are. Therefore, we will need a metric to assess how well different parameter values replicate their behavior. The ultimate goal is to locate the parameters which best replicate the behavior of human subjects, and we use the Log-likelihood as a relative measure of how well we are doing. Likelihood is like standard deviation in that it lacks absolute meaning, and so it is only meaningful relative to other measurements. Log-likelihoods are always negative, and less negative values are more likely.

**Note:** This is the step of the process which requires us to one-hot the data. In Python we can index arrays with arrays but we cannot do this in PyTensor. So instead what we do is we create a mask over the Q-values using the one-hotted Choices, which allows us to efficiently choose all of the correct Q-values.

In [None]:
def pytensor_likelihood(Qs, B):
    """
    pytensor_softmax: calculate loglikelihoods using a tempered softmax over Q-Values

    Qs: Q-Values (data)
    pB: Betas (parameter)
    """
    shape = Qs.shape
    tempered_qs = pt.mul(Qs,B)
    qs_max = pt.max(tempered_qs,axis=2)
    qs_max = pt.repeat(qs_max.reshape((shape[0], shape[1], 1)), shape[2], axis=2)
    numerator = pt.exp(tempered_qs - qs_max)
    denominator = pt.sum(numerator, axis=2)
    denominator = pt.repeat(denominator.reshape((shape[0], shape[1], 1)), shape[2], axis=2)
    Ps = (numerator / denominator).sum(axis=2)
    ll = pt.log(Ps)
    return ll.flatten()

In [None]:
Qs_compile = pt.dtensor3('Qs_compile')
B_compile = pt.dtensor3('B_compile')

output = pytensor_likelihood(Qs_compile, B_compile)
likelihood_func = function(inputs=[Qs_compile, B_compile], outputs=output)

In [None]:
B = np.tile(betas.reshape(1,-1,1),[n_trials,1,n_choices])

likelihood_func(PyTensor_Qs, B)

array([ 0.00000000e+00, -1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -1.11022302e-16,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -1.11022302e-16,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00, -1.11022302e-16,  0.00000000e+00,
        2.22044605e-16,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  

###2d. Making your differentiable blackbox likelihood function

In [None]:
def pytensor_blackbox_rl_likelihood(CMs,RMs,AM,BM):
    Qs = rl_multi_scan(CMs,RMs,AM)
    ll = pytensor_likelihood(Qs,BM)
    return ll

In [None]:
C_compile = pt.itensor3("C_compile")
R_compile = pt.dtensor3("R_compile")
A_compile = pt.dtensor3("A_compile")
B_compile = pt.dtensor3('B_compile')

output = pytensor_blackbox_rl_likelihood(C_compile, R_compile, A_compile, B_compile)
pt_bb_rl_ll_func = function(inputs=[C_compile, R_compile, A_compile, B_compile], outputs=output)

In [None]:
alphas, betas = multi_rl.get_params()

# Testing compiled function
Rs = np.repeat(Python_Rs.reshape(n_trials,n_participants,1), n_choices, axis=2)
Cs = np.zeros((n_trials, n_participants, n_choices))
for n in range(n_choices):
    Cs[:, :, n] = (Python_Cs == n)
Cs = np.array(Cs, dtype=np.int32)

pt_bb_rl_ll_func(Cs, Rs, A, B)

array([ 0.00000000e+00, -1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -1.11022302e-16,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -1.11022302e-16,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00, -1.11022302e-16,  0.00000000e+00,
        2.22044605e-16,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  