<a href="https://colab.research.google.com/github/TheLemonPig/RL-SSM/blob/main/PyTensor_for_HSSM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTensor for HSSM

**TODO: Split the Sections into their own slimmer notebooks e.g. one for Introduction to PyTensor**

##1. Introduction to PyTensor

More can be found at: https://pytensor.readthedocs.io/en/latest/tutorial/index.html

In [4]:
# This is the conventional way to import PyTensor
import pytensor.tensor as pt
from pytensor import function

import numpy as np

More can be found at: https://pytensor.readthedocs.io/en/latest/tutorial/adding.html

PyTensor is a language built for optimization. If you are familiar with NumPy and SymPy, it is sort of a combination of the two. This tutorial will not assume you are familiar with SymPy but it will assume you are familiar with NumPy.

The goal in PyTensor, is to create optimized differentiable versions of functions you have written in other languages. It achieves this by converting your python code into C in a differentiable format. The details of this are not important. But the consequences are:
1. PyTensor requires you to rewrite your functions in terms of a more limited set of functions
2. PyTensor-based functions must be passed to a PyTensor compiler before they will work in Python
3. The PyTensor compiler requires additional information about your input variables in order to work
4. Downstream effects if your are not able to rewrite your function exactly (this is common with more complex functions)

In reality this creates two main additional stages to your workflow:
1. Rewriting functions
2. Compiling functions

In addition, one part of your workflow will be drastically changed, which is how you debug your function. Since PyTensor-based functions do not work in Python, you can't use your debugger or use print statements whenever you want to debug your code. Instead, you will need to create a unit test, which will require compiling the relevant PyTensor code into its own function.

Here I will give a simple example of what that would all look like, including debugging.

Suppose you need to write a function to find the distance between two locations, using the function $f(a,b,x,y) = \sqrt{(x-a)^2 + (y-b)^2}$

In [5]:
# Python Pythagorus function
# Note: You would not need to rewrite much of this functuo
# ...but we will do so to show you all the steps
def python_distance(a, b,x,y):
    return np.sqrt((x-a)**2 + (y-b)**2)

###1a. Rewrite the function

*You will see in the documentation that you do not need to write functions as formally as we will. But since we will always be doing this, please take this as your best starting point* *italicized text*

In [7]:
# Distance function (pretending the bug is not easy to see)
def pytensor_distance(a, b, x, y):
    # Note that arithmetic operations do not need to be rewritten
    # Also note that PyTensor typically has analagous functions to NumPy, often with the same name
    return pt.sqrt((x-a) ** 2 + (y-b) * 2)

###1b. Compile the function



In [9]:
# Initialize Inputs - this will be explained more later
a = pt.dscalar('a')
b = pt.dscalar('b')
x = pt.dscalar('x')
y = pt.dscalar('y')

# Call your function on the inputs
# Collect the output
pytensor_output = pytensor_distance(a,b,x,y)

# Compile your PyTensor based function
pytensor_distance_compiled = function([a, b, x, y], pytensor_output)

###1c. Test the function

In [16]:
def pytensor_test(a,b,x,y):
    try:
        # Note: It is possible for minute differences to be created, usually <1e-10, between the functions, so consider rounding, especially in more complex functions
        assert pytensor_distance_compiled(a,b,x,y) == python_distance(a,b,x,y), f"{pytensor_distance_compiled(a,b,x,y)} != {python_distance(a,b,x,y)}"
    except AssertionError as err:
        raise AssertionError(f"Test failed: {err}")

a, b, x, y = 0.5, 0.5, 1.0, 1.0
pytensor_test(a,b,x,y)

AssertionError: Test failed: 0.75 != 0.7071067811865476

###1d. Debug the function

In [18]:
# unit functions
def pytensor_distance_a(a, x):
    return (x-a) ** 2

def pytensor_distance_b(b, y):
    return (y-b) * 2

# Ideally write your overall function in terms of the unit functions and test it too
# Typos are pretty common and one way you will catch them is if this version...
# ...of your pytensor function works while your original one didn't.
def pytensor_distance_c(a,b,x,y):
    return pt.sqrt(pytensor_distance_a(a,x) + pytensor_distance_b(b,y))

# Initialize Inputs - this will be explained more later
a = pt.dscalar('a')
b = pt.dscalar('b')
x = pt.dscalar('x')
y = pt.dscalar('y')

# Call your functions on the inputs
# Collect the output
pytensor_output_a = pytensor_distance_a(a,x)
pytensor_output_b = pytensor_distance_b(b,y)
pytensor_output_c = pytensor_distance_c(a,b,x,y)

# Compile your PyTensor-based functions
pytensor_distance_compiled_a = function([a, x], pytensor_output_a)
pytensor_distance_compiled_b = function([b, y], pytensor_output_b)
pytensor_distance_compiled_c = function([a, b, x, y], pytensor_output_c)

In [21]:
# Take your original function and do the same
def python_distance_a(a, x):
    return (x-a) ** 2

def python_distance_b(b, y):
    return (y-b) ** 2

def python_distance_c(a,b,x,y):
    return np.sqrt(python_distance_a(a,x) + python_distance_b(b,y))

In [24]:
# Run test
def pytensor_unit_tests(a,b,x,y):
    try:
        assert pytensor_distance_a(a, x) == python_distance_a(a, x), f"Test 1 - {pytensor_distance_a(a, x)} != {python_distance_a(a, x)}"
        assert pytensor_distance_b(b, y) == python_distance_b(b, y), f"Test 2 - {pytensor_distance_b(b, y)} != {python_distance_b(b, y)}"
        assert pytensor_distance_compiled_c(a,b,x,y) == python_distance(a,b,x,y), f"Test 3 - {pytensor_distance_compiled_c(a,b,x,y)} != {python_distance(a,b,x,y)}"

    except AssertionError as err:
        raise AssertionError(f"Test failed: {err}")

a, b, x, y = 0.5, 0.5, 1.0, 1.0
pytensor_unit_tests(a,b,x,y)

AssertionError: Test failed: Test 2 - 1.0 != 0.25

Hopefully through this process you will hone in on your typo and fix your original function.

Note: It is possible to print statements in PyTensor. I have not investigated it much, but it looks like it could be useful in some situations: https://pytensor.readthedocs.io/en/latest/library/printing.html

##2. PyTensor for RL

In [1]:
import pytensor.tensor as pt
from pytensor import function, scan
import numpy as np

Reinforcment Learning is inherently sequential, meaning that we will need to make use of `for` loops.

In PyTensor, this can be acheived in a very narrow sense using the `scan` function.

###2a. Using scan

For more, see: https://pytensor.readthedocs.io/en/latest/library/scan.html

For simplicity, we will start with a numerical example, before moving on to a matrix example.

####2ai. Using scan with numerical inputs/outputs

Suppose we want to calculate the first $n$ terms of the series using the recursion formula: $s_{n+1} = \frac{s_{n}}{n+1} + m$ for $s_0=1$ and any $m$

A python function for this could look like:

In [107]:
def formula(sn,n,m):
    return sn/(n+1) + m

def get_term(n,m):
    sn = 1
    series = np.zeros(n)
    series[0] = sn
    for i in range(n):
        sn = formula(sn,i,m)
        series[i] = sn
    return series

This is how we would write it in PyTensor:

In [99]:
# We have to slightly adapt our function because the order of arguments in the function is incredibly important for getting PyTensor to work correctly
# They should be organized like this: sequences, outputs_info, non_sequences
def formula(n,sn,m):
    return sn/(n+1) + m

def get_term_pytensor(n,m):
    sn = pt.constant(1.0, dtype='float64')  # you will sometimes need to define the type like this. You will need to be very careful about type. PyTensor generally defaults to 32-bit types
    ns = pt.arange(n, dtype='float64')
    term, _ = scan(fn=formula, sequences=ns, non_sequences=m, outputs_info=sn)  # The ordering of
    return term

In [100]:
n_compile = pt.scalar("n", dtype='int64')  # note n must be an integer (you can't have the 3.5th term of a series)
m_compile = pt.scalar("m", dtype='float64')

output = get_term_pytensor(n_compile, m_compile)
my_func = function(inputs=[n_compile,m_compile],outputs=output)  # note the LACK of parentheses for the outputs. You will need these if you have multiple outputs but otherwise you do not want your function unknowingly wrapping your output in a list

It is always important to compare the outputs of your PyTensor function with the original Python function.

In [108]:
get_term(10,1)

array([2.        , 2.        , 1.66666667, 1.41666667, 1.28333333,
       1.21388889, 1.1734127 , 1.14667659, 1.12740851, 1.11274085])

In [102]:
my_func(10,1)

array([2.        , 2.        , 1.66666667, 1.41666667, 1.28333333,
       1.21388889, 1.1734127 , 1.14667659, 1.12740851, 1.11274085])

###2b.Using scan for RL

Now because you cannot index PyTensor arrays using arrays, we must one-hot the data so that indexing is not necessary. This saves us having to scan over trials to extract the relevant Q-values for the likelihood. (If you can find another way, then a lot of the code dealing with one-hotting from here on isn't necessary).

####2bi.Using scan for single participant RL

For simplicty we will go over what it looks like for an individual participant, and then extend the code to multiple participants

In [2]:

def rl_step(C, R, Q_tm1, A):
    """
    function for a single RL step
    C: Choice -- vector (trials)
    R: Reward -- vector (trials)
    Q_tm1: Q-Values at the previous time step -- matrix (choices)
    A: Alpha parameter -- scalar ()
    """
    return Q_tm1 + pt.mul(A,pt.mul(C,(R - Q_tm1)))

In [3]:
def rl_scan(Cs,Rs,A,n_choices):
    """
    scan function over RL steps
    CM: Choices -- matrix (trials,choices)
    RM: Rewards -- matrix (trials,choices)
    A: Alpha parameter -- vector (choices)
    n_choices: number of choices -- scalar ()

    QM: vector (choices) --> matrix (trials, choices)
    """
    Qs = pt.ones(n_choices) * 0.5
    Qs, _ = scan(fn=rl_step, sequences=[Cs,Rs], non_sequences = [A], outputs_info=Qs)
    return Qs

In [7]:
# Defining the inputs of the to-be-compiled function
Cs = pt.imatrix("Cs")
Rs = pt.dmatrix("Rs")
A = pt.dvector("A")
n_choices = pt.iscalar('n_choices')

# Compiling function
output = rl_scan(Cs,Rs,A,n_choices)
rl_func = function(inputs=[Cs,Rs,A,n_choices], outputs=output)

In [None]:
# Synthetic Metaparameters
n_trials = 100
n_choices = 2

# Creating synthetic data
Cs = np.zeros((n_trials,n_choices),dtype=np.int32)
Cs[:,0] = np.random.randint(2,size=(n_trials))
Cs[:,1] = 1 - Cs[:,0]
Rs = np.ones((n_trials, n_choices), dtype=np.float32)
A = np.array(np.ones((n_choices)) * 0.1, dtype=np.float32)

# Testing compiled function
rl_func(Cs,Rs,A,n_choices)

####2bii. Using scan over multiple participants