<a href="https://colab.research.google.com/github/TheLemonPig/RL-SSM/blob/main/PyTensor_for_HSSM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTensor for HSSM

**TODO: Split the Sections into their own slimmer notebooks e.g. one for Introduction to PyTensor**

##1. Introduction to PyTensor

More can be found at: https://pytensor.readthedocs.io/en/latest/tutorial/index.html

In [4]:
# This is the conventional way to import PyTensor
import pytensor.tensor as pt
from pytensor import function

import numpy as np

More can be found at: https://pytensor.readthedocs.io/en/latest/tutorial/adding.html

PyTensor is a language built for optimization. If you are familiar with NumPy and SymPy, it is sort of a combination of the two. This tutorial will not assume you are familiar with SymPy but it will assume you are familiar with NumPy.

The goal in PyTensor, is to create optimized differentiable versions of functions you have written in other languages. It achieves this by converting your python code into C in a differentiable format. The details of this are not important. But the consequences are:
1. PyTensor requires you to rewrite your functions in terms of a more limited set of functions
2. PyTensor-based functions must be passed to a PyTensor compiler before they will work in Python
3. The PyTensor compiler requires additional information about your input variables in order to work
4. Downstream effects if your are not able to rewrite your function exactly (this is common with more complex functions)

In reality this creates two main additional stages to your workflow:
1. Rewriting functions
2. Compiling functions

In addition, one part of your workflow will be drastically changed, which is how you debug your function. Since PyTensor-based functions do not work in Python, you can't use your debugger or use print statements whenever you want to debug your code. Instead, you will need to create a unit test, which will require compiling the relevant PyTensor code into its own function.

Here I will give a simple example of what that would all look like, including debugging.

Suppose you need to write a function to find the distance between two locations, using the function $f(a,b,x,y) = \sqrt{(x-a)^2 + (y-b)^2}$

In [5]:
# Python Pythagorus function
# Note: You would not need to rewrite much of this functuo
# ...but we will do so to show you all the steps
def python_distance(a, b,x,y):
    return np.sqrt((x-a)**2 + (y-b)**2)

###1a. Rewrite the function

*You will see in the documentation that you do not need to write functions as formally as we will. But since we will always be doing this, please take this as your best starting point* *italicized text*

In [7]:
# Distance function (pretending the bug is not easy to see)
def pytensor_distance(a, b, x, y):
    # Note that arithmetic operations do not need to be rewritten
    # Also note that PyTensor typically has analagous functions to NumPy, often with the same name
    return pt.sqrt((x-a) ** 2 + (y-b) * 2)

###1b. Compile the function



In [9]:
# Initialize Inputs - this will be explained more later
a = pt.dscalar('a')
b = pt.dscalar('b')
x = pt.dscalar('x')
y = pt.dscalar('y')

# Call your function on the inputs
# Collect the output
pytensor_output = pytensor_distance(a,b,x,y)

# Compile your PyTensor based function
pytensor_distance_compiled = function([a, b, x, y], pytensor_output)

###1c. Test the function

In [16]:
def pytensor_test(a,b,x,y):
    try:
        # Note: It is possible for minute differences to be created, usually <1e-10, between the functions, so consider rounding, especially in more complex functions
        assert pytensor_distance_compiled(a,b,x,y) == python_distance(a,b,x,y), f"{pytensor_distance_compiled(a,b,x,y)} != {python_distance(a,b,x,y)}"
    except AssertionError as err:
        raise AssertionError(f"Test failed: {err}")

a, b, x, y = 0.5, 0.5, 1.0, 1.0
pytensor_test(a,b,x,y)

AssertionError: Test failed: 0.75 != 0.7071067811865476

###1d. Debug the function

In [18]:
# unit functions
def pytensor_distance_a(a, x):
    return (x-a) ** 2

def pytensor_distance_b(b, y):
    return (y-b) * 2

# Ideally write your overall function in terms of the unit functions and test it too
# Typos are pretty common and one way you will catch them is if this version...
# ...of your pytensor function works while your original one didn't.
def pytensor_distance_c(a,b,x,y):
    return pt.sqrt(pytensor_distance_a(a,x) + pytensor_distance_b(b,y))

# Initialize Inputs - this will be explained more later
a = pt.dscalar('a')
b = pt.dscalar('b')
x = pt.dscalar('x')
y = pt.dscalar('y')

# Call your functions on the inputs
# Collect the output
pytensor_output_a = pytensor_distance_a(a,x)
pytensor_output_b = pytensor_distance_b(b,y)
pytensor_output_c = pytensor_distance_c(a,b,x,y)

# Compile your PyTensor-based functions
pytensor_distance_compiled_a = function([a, x], pytensor_output_a)
pytensor_distance_compiled_b = function([b, y], pytensor_output_b)
pytensor_distance_compiled_c = function([a, b, x, y], pytensor_output_c)

In [21]:
# Take your original function and do the same
def python_distance_a(a, x):
    return (x-a) ** 2

def python_distance_b(b, y):
    return (y-b) ** 2

def python_distance_c(a,b,x,y):
    return np.sqrt(python_distance_a(a,x) + python_distance_b(b,y))

In [24]:
# Run test
def pytensor_unit_tests(a,b,x,y):
    try:
        assert pytensor_distance_a(a, x) == python_distance_a(a, x), f"Test 1 - {pytensor_distance_a(a, x)} != {python_distance_a(a, x)}"
        assert pytensor_distance_b(b, y) == python_distance_b(b, y), f"Test 2 - {pytensor_distance_b(b, y)} != {python_distance_b(b, y)}"
        assert pytensor_distance_compiled_c(a,b,x,y) == python_distance(a,b,x,y), f"Test 3 - {pytensor_distance_compiled_c(a,b,x,y)} != {python_distance(a,b,x,y)}"

    except AssertionError as err:
        raise AssertionError(f"Test failed: {err}")

a, b, x, y = 0.5, 0.5, 1.0, 1.0
pytensor_unit_tests(a,b,x,y)

AssertionError: Test failed: Test 2 - 1.0 != 0.25

Hopefully through this process you will hone in on your typo and fix your original function.

Note: It is possible to print statements in PyTensor. I have not investigated it much, but it looks like it could be useful in some situations: https://pytensor.readthedocs.io/en/latest/library/printing.html

**Note:** If you have a bug you cannot make sense of, it is always worth restarting the runtime if you are working within a notebook. This will genuinely fix half of your confusing bugs!! This is **especially** true if you have jumped around in your notebook. You will also sometimes find that skipping sections helps. This is a sign that you are naming variables/functions with the same name, as PyTensor does not like it when you do this.

##2. PyTensor for RL

In [1]:
import pytensor.tensor as pt
from pytensor import function, scan
import numpy as np

Reinforcment Learning is inherently sequential, meaning that we will need to make use of `for` loops.

In PyTensor, this can be acheived in a very narrow sense using the `scan` function.

###2a. Using scan

For more, see: https://pytensor.readthedocs.io/en/latest/library/scan.html

For simplicity, we will start with a numerical example, before moving on to a matrix example.

####2ai. Using scan with numerical inputs/outputs

Suppose we want to calculate the first $n$ terms of the series using the recursion formula: $s_{n+1} = \frac{s_{n}}{n+1} + m$ for $s_0=1$ and any $m$

A python function for this could look like:

In [107]:
def formula(sn,n,m):
    return sn/(n+1) + m

def get_term(n,m):
    sn = 1
    series = np.zeros(n)
    series[0] = sn
    for i in range(n):
        sn = formula(sn,i,m)
        series[i] = sn
    return series

This is how we would write it in PyTensor:

In [99]:
# We have to slightly adapt our function because the order of arguments in the function is incredibly important for getting PyTensor to work correctly
# They should be organized like this: sequences, outputs_info, non_sequences
def formula(n,sn,m):
    return sn/(n+1) + m

def get_term_pytensor(n,m):
    sn = pt.constant(1.0, dtype='float64')  # you will sometimes need to define the type like this. You will need to be very careful about type. PyTensor generally defaults to 32-bit types
    ns = pt.arange(n, dtype='float64')
    term, _ = scan(fn=formula, sequences=ns, non_sequences=m, outputs_info=sn)  # The ordering of
    return term

In [100]:
n_compile = pt.scalar("n", dtype='int64')  # note n must be an integer (you can't have the 3.5th term of a series)
m_compile = pt.scalar("m", dtype='float64')

output = get_term_pytensor(n_compile, m_compile)
my_func = function(inputs=[n_compile,m_compile],outputs=output)  # note the LACK of parentheses for the outputs. You will need these if you have multiple outputs but otherwise you do not want your function unknowingly wrapping your output in a list

It is always important to compare the outputs of your PyTensor function with the original Python function.

In [108]:
get_term(10,1)

array([2.        , 2.        , 1.66666667, 1.41666667, 1.28333333,
       1.21388889, 1.1734127 , 1.14667659, 1.12740851, 1.11274085])

In [102]:
my_func(10,1)

array([2.        , 2.        , 1.66666667, 1.41666667, 1.28333333,
       1.21388889, 1.1734127 , 1.14667659, 1.12740851, 1.11274085])

###2b.Using scan for RL

Now because you cannot index PyTensor arrays using arrays, we must one-hot the data so that indexing is not necessary. This saves us having to scan over trials to extract the relevant Q-values for the likelihood. (If you can find another way, then a lot of the code dealing with one-hotting from here on isn't necessary).

**Note:** It is important to know that the first dimension of the tensors that you place in `sequences` should be the dimension that is missing from the tensor you pass to `outputs_info`. `scan` will iterate over subtensors of the sequence tensor by indexing from the first dimension. This dimension will be added to the output of your scan function. In this way `outputs_info` allows you to place an initial value to iterate over, where the intial value is used to calculate the second value and this repeats over each iteration of the scan function.

####2bi.Using scan for single participant RL

For simplicty we will go over what it looks like for an individual participant, and then extend the code to multiple participants

In all of these functions you will notice that we include the dimension for choice. This is the dimension which is being one-hotted: taking the choice 0 is represented as [1, 0], and taking the choice 1 is represented as [0, 1]. We need as many numbers as there are choices if we want to place a 1 at the index of the value of the choice. This is why we need to create an extra dimension which is the size of the number of choices.

**Simulating RL in Python**

In [2]:
from functools import partial
from typing import Callable, List

In [3]:
def softmax(qs, beta):
  return np.exp(qs*beta) / np.exp(qs*beta).sum()

class Distribution:

  def __init__(self, func: Callable, kwargs):
    self.func: Callable = partial(func, **kwargs)

  def __call__(self):
    return self.func.__call__()

class SimpleRL:

  def __init__(self, n_trials: int, distributions: List[Distribution]):
    self.n_choices: int = len(distributions)
    self.n_trials: int = n_trials
    self.distributions: List[Distribution] = distributions
    self.qs: np.array = np.ones((self.n_choices,)) * 0.5
    self.q_trace: np.array = np.ones((self.n_trials,self.n_choices))
    self.rewards: np.array = np.zeros((self.n_trials))
    self.choices: np.array = np.zeros((self.n_trials),dtype=np.int32)

  def simulate(self, alpha, beta):
    for i in range(self.n_trials):
      # Q-values are recorded to trace
      self.q_trace[i] = self.qs
      # softmax decision function
      ps = softmax(self.qs,beta)
      # choice made based on weighted probabilities of Q-values
      choice = np.random.choice(a=self.n_choices,size=1,p=ps)[0]
      # choice is recorded to trace
      self.choices[i] = choice
      # reward calculated
      dist = self.distributions[choice]  # supply a list of distributions to choose from
      reward = dist()  # sample from distribution by calling it
      # Q-values updated
      self.rewards[i] = reward
      self.qs[choice] = self.qs[choice] + alpha * (reward - self.qs[choice])
      self.q_trace[i] = self.qs
    # Q-values trace returned
    # main data to be returned (basis for fits), is choices and rewards per trial
    return self.rewards, self.choices, self.q_trace

In [4]:
seed = 0
np.random.seed(seed)
mean_rewards = [-1.0,1.0]
dists = [Distribution(np.random.normal,{"loc":mn, "scale":1.0}) for mn in mean_rewards]
n_trials = 100

alpha, beta = 0.1, 1.0

simple_rl = SimpleRL(n_trials, dists)
Rs, Cs, Qs = simple_rl.simulate(alpha, beta)

**Recovering Qs in PyTensor**

In [11]:
def rl_step(C, R, Q, A):
    """
    rl_step: function for a single RL step

    C: Choice -- vector (choices)
    R: Reward -- vector (choices)
    Q: Q-Values at the previous time step -- vector (choices)
    A: Alpha parameter -- scalar ()
    """
    return Q + A * (R - Q) * C

In [14]:
def rl_scan(Cs,Rs,A):
    """
    rl_scan: scan function over RL steps

    Cs: Choices -- matrix (trials, choices)
    Rs: Rewards -- matrix (trials, choices)
    A: Alpha parameter -- vector (choices)

    Qs: Q-values -- matrix (trials, choices)
    """
    Qs = pt.ones(Cs.shape[1]) * 0.5 # note Qs shape changes over scan:  vector (choices) --> matrix (trials, choices)
    Qs, _ = scan(fn=rl_step, sequences=[Cs,Rs], non_sequences = [A], outputs_info=Qs)
    return Qs

In [17]:
# Defining the inputs of the to-be-compiled function
Cs = pt.imatrix("Cs")
Rs = pt.dmatrix("Rs")
A = pt.dvector("A")

# Compiling function
output = rl_scan(Cs,Rs,A)
rl_func = function(inputs=[Cs,Rs,A], outputs=output)

In [18]:
seed = 0
np.random.seed(seed)
mean_rewards = [-1.0,1.0]
dists = [Distribution(np.random.normal,{"loc":mn, "scale":1.0}) for mn in mean_rewards]
n_trials = 100

alpha, beta = 0.1, 1.0

simple_rl = SimpleRL(n_trials, dists)
Python_Rs, Python_Cs, Python_Qs = simple_rl.simulate(alpha, beta)
n_choices = len(mean_rewards)

# Testing compiled function
Rs = np.repeat(Python_Rs.reshape(-1,1), repeats = n_choices, axis=1)
Cs = np.zeros((Python_Cs.shape[0], n_choices))
for n in range(n_choices):
  Cs[:, n] = (Python_Cs == n)
Cs = np.array(Cs, dtype=np.int32)
A = np.ones(n_choices) * alpha
PyTensor_Qs = rl_func(Cs,Rs,A)
PyTensor_Qs[:10]

array([[0.5       , 0.62415917],
       [0.5       , 0.81703463],
       [0.5       , 0.75105876],
       [0.5       , 0.97294533],
       [0.5       , 1.10226265],
       [0.5       , 1.04144873],
       [0.5       , 1.08573507],
       [0.5       , 1.13507561],
       [0.5       , 1.10340979],
       [0.5       , 1.23408928]])

In [19]:
Python_Qs[:10]

array([[0.5       , 0.62415917],
       [0.5       , 0.81703463],
       [0.5       , 0.75105876],
       [0.5       , 0.97294533],
       [0.5       , 1.10226265],
       [0.5       , 1.04144873],
       [0.5       , 1.08573507],
       [0.5       , 1.13507561],
       [0.5       , 1.10340979],
       [0.5       , 1.23408928]])

####2bii. Using scan over multiple participants

**Simulating multi-participant RL in Python**

In [19]:
class MultiRL:

  def __init__(self, n_trials: int, n_participants: int, distributions: List[Distribution]):
    self.participants: List[SimpleRL] = [
        SimpleRL(n_trials, distributions) for _ in range(n_participants)
    ]
    self.n_trials = n_trials
    self.n_participants = n_participants
    self.n_choices = len(distributions)
    self.alphas = np.zeros(n_participants)
    self.betas = np.zeros(n_participants)

  def simulate(self, alpha_a, alpha_b, beta_a, beta_b):
    group_data = []
    group_rewards = np.zeros((self.n_trials,self.n_participants))
    group_choices = np.zeros((self.n_trials,self.n_participants))
    group_Qs = np.zeros((self.n_trials,self.n_participants,self.n_choices))
    for idx, participant_model in enumerate(self.participants):
        # sample participant parameters
        alpha = np.random.beta(alpha_a, alpha_b)
        beta = np.random.beta(beta_a, beta_b)
        self.alphas[idx] = alpha
        self.betas[idx] = beta
        # run RL for participant
        participant_rewards, participant_choices, participant_Qs = participant_model.simulate(alpha, beta)
        group_rewards[:, idx] = participant_rewards
        group_choices[:, idx] = participant_choices
        group_Qs[:, idx] = participant_Qs
    return group_rewards, group_choices, group_Qs

  def get_params(self):
    return self.alphas, self.betas


**Recovering multi-participant Qs in PyTensor**

In [6]:
# Notice that this function is written in the same way as for a single participant

def rl_multi_step(C, R, Q, A):
    """
    rl_step: function for a single RL step

    C: Choice -- matrix (participants, choices)
    R: Reward -- matrix (participants, choices)
    Q: Q-Values at the previous time step -- matrix (participants, choices)
    A: Alpha parameter -- matrix (participants, choices)
    """
    return Q + A * (R - Q) * C

In [13]:
# Notice that this function is written in ALMOST the same way as for a single participant
# Tensor3 is how PyTensor defines 3-Dimensional arrays

def rl_multi_scan(CMs,RMs,AM):
    """
    rl_scan: scan function over RL steps

    CMs: Choices -- tensor3 (trials, participants, choices)
    RMs: Rewards -- tensor3 (trials, participants, choices)
    AM: Alpha parameter -- matrix (participants, choices)

    Qs: Q-values -- tensor3 (trials, participants, choices)
    """
    QMs = pt.ones((CMs.shape[1], CMs.shape[2])) * 0.5 # note Qs shape changes over scan:  matrix (participants, choices) --> tensor3 (trials, participants, choices)
    QTs, _ = scan(fn=rl_multi_step, sequences=[CMs,RMs], non_sequences = [AM], outputs_info=QMs)  # NOTICE THAT WE DO NOT PLACE BRACKETS AROUND THE OUTPUT!!
    return QTs

In [14]:
# Defining the inputs of the to-be-compiled function
CMs = pt.itensor3("CMs")
RMs = pt.dtensor3("RMs")
AM = pt.dmatrix("AM")

# Compiling function
output = rl_multi_scan(CMs,RMs,AM)
rl_func = function(inputs=[CMs,RMs,AM], outputs=output)

In [26]:
seed = 0
np.random.seed(seed)
mean_rewards = [-1.0,1.0]
dists = [Distribution(np.random.normal,{"loc":mn, "scale":1.0}) for mn in mean_rewards]
n_trials = 100
n_participants = 3
n_choices = len(mean_rewards)

alpha_a, alpha_b, beta_a, beta_b = 2, 2, 2, 2
# alpha, beta = 0.1, 1.0 -- For MultiRL the parameters vary with a distribution

multi_rl = MultiRL(n_trials, n_participants, dists)
Python_Rs, Python_Cs, Python_Qs = multi_rl.simulate(alpha_a, alpha_b, beta_a, beta_b)

alphas, _ = multi_rl.get_params()

# Testing compiled function
Rs = np.repeat(Python_Rs.reshape(n_trials,n_participants,1), n_choices, axis=2)
Cs = np.zeros((n_trials, n_participants, n_choices))
for n in range(n_choices):
    Cs[:, :, n] = (Python_Cs == n)
Cs = np.array(Cs, dtype=np.int32)
A = np.tile(alphas.reshape(-1,1),[1,n_choices])
PyTensor_Qs = rl_func(Cs,Rs,A)
PyTensor_Qs[:5]

array([[[ 0.5       ,  1.73022911],
        [ 0.21225973,  0.5       ],
        [ 0.07119759,  0.5       ]],

       [[ 0.5       ,  0.86919225],
        [-0.03987525,  0.5       ],
        [-0.52974563,  0.5       ]],

       [[ 0.5       ,  1.29766769],
        [-0.03987525,  0.75788475],
        [-0.52974563,  0.61445943]],

       [[ 0.5       ,  1.49373382],
        [-0.26587443,  0.75788475],
        [-0.52974563, -0.30926341]],

       [[ 0.5       ,  1.02332718],
        [-0.4027839 ,  0.75788475],
        [-2.22131283, -0.30926341]]])

In [27]:
Python_Qs[:5]

array([[[ 0.5       ,  1.73022911],
        [ 0.21225973,  0.5       ],
        [ 0.07119759,  0.5       ]],

       [[ 0.5       ,  0.86919225],
        [-0.03987525,  0.5       ],
        [-0.52974563,  0.5       ]],

       [[ 0.5       ,  1.29766769],
        [-0.03987525,  0.75788475],
        [-0.52974563,  0.61445943]],

       [[ 0.5       ,  1.49373382],
        [-0.26587443,  0.75788475],
        [-0.52974563, -0.30926341]],

       [[ 0.5       ,  1.02332718],
        [-0.4027839 ,  0.75788475],
        [-2.22131283, -0.30926341]]])