# Computational modeling : RL algorithms in a virtual environment

Question : under very low amounts of evidence, how do human sample a complex action space ? Can we infer some form of structure in this exploration ? Can Active Inference provide some answers regarding the mechanistic processes behind it ?



First, we grab the data corresponding to the experiment we're interested in (here, experiment 002). We also remove the subjects that either had technical issues or had very suspicious results. *(we should provide a clear rule on subject exclusion here, maybe based on action variance across all dimensions or reaction times ?).*

In [1]:
# Import the needed packages 
# 
# 1/ the usual suspects
import sys,os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
from jax import vmap

from functools import partial

# 2/ The Active Inference package 
import actynf

# 3/ Tools for : 
# a. Getting the raw data : 
from database_handling.database_extract import get_all_subject_data_from_internal_task_id
from utils import remove_by_indices
# b. Preprocessing the data :
from analysis_tools.preprocess import OPTIONS_PREPROCESS_DEFAULT,get_preprocessed_data




# Except subjects for predictors :
problematic_subjects_misc = ["5c9cb670b472d0001295f377"]
        # This subject has read the instructions with one submission and ran
        # the actual task with another, rendering statistics computed impossible to 
        # compare, this should be substracted from any statistical models based on
        # instructional data, but can be kept for raw performance plots.
# problematic_subjects_fraudulent =["6595ae358923ce48b037a0dc"]
        # This subject has very suspicious responses, including always putting both points in the same place
        # and acting as quickly as possible, to be removed from all analysis ?


# Import the data from the remote mongodb database & the imported prolific demographics :
INTERNAL_TASK_ID = "002"
TASK_RESULTS_ALL = get_all_subject_data_from_internal_task_id(INTERNAL_TASK_ID,None,
                                        autosave=True,override_save=False,autoload=True)
print("Loaded the task results for " + str(len(TASK_RESULTS_ALL)) + " subjects.")

# Each subject in task results has the following entries : 
# TASK_RESULT_FEATURES, TASK_RESULTS_EVENTS, TASK_RESULTS_DATA, TASK_RESULTS,RT_FB
remove_these_subjects = []
for index,entry in enumerate(TASK_RESULTS_ALL):
    subj_dict,_,_,_ = entry
    subj_name = subj_dict["subject_id"]
    if subj_name in problematic_subjects_misc:
        remove_these_subjects.append(index)

TASK_RESULTS = remove_by_indices(TASK_RESULTS_ALL,remove_these_subjects)
print(str(len(TASK_RESULTS)) + " subjects remaining after removing problematic subjects.")


LABELS = [entry[0] for entry in TASK_RESULTS]
EVENTS = [entry[1] for entry in TASK_RESULTS]
TRIAL_DATA = [entry[2] for entry in TASK_RESULTS]
RT_FBS = [entry[3] for entry in TASK_RESULTS]

Loaded the task results for 90 subjects.
89 subjects remaining after removing problematic subjects.


Once the "raw data" is loaded, we can use the preprocessing pipeline described briefly [here](./computational_modeling_101_preprocessing.ipynb) to generate a dictionnary with the observations and actions ready for use in our models. 

We can generate several dictionnaries depending on how we want the data to be formatted. This is driven by the option dictionnary :

In [2]:
preprocessing_options = {
    "actions":{
        "distance_bins" : np.array([0.0,0.05,0.2,0.5,jnp.sqrt(2) + 1e-10]),
        "angle_N_bins"  : 8,
        "position_N_bins_per_dim" : 3
    },
    "observations":{
        "N_bins" : 10,
        "observation_ends_at_point" : 2
    }
}
    # We can modify these at will

data = get_preprocessed_data(TRIAL_DATA,RT_FBS,preprocessing_options,
                            verbose=True,
                            autosave=True,autoload=True,override_save=True,
                            label="default")
print(data.keys())

Out of the 9790.0 actions performed by our subjects, 8233.0 were 'valid' (84.1 %)


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Out of the 9790.0 feedback sequences potentially observed by our subjects, 8803 were 'valid' (89.9 %)
dict_keys(['observations', 'actions'])


In [11]:
# Now to the models ! 
# Let us grab a model of the training environment itself : 

# ENVIRONMENTAL CONSTANTS :
GRID_SIZE = (7,7)
START_COORD = [[5,1],[5,2],[4,1]]
END_COORD = [0,6]





# We define the environment as a state machine that outputs a feedback 
# every time an action is given to it : 
from actynf.jaxtynf.layer_process import initial_state_and_obs,process_update
from actynf.jaxtynf.shape_tools import vectorize_weights

class TrainingEnvironment :
    def __init__(self,rng_key,a,b,c,d,e,u,T):
        # Environment parameters
        self.a = a
        self.b = b
        self.c = c 
        self.d = d
        self.e = e
        self.u = u
        
        # Timing
        self.Ntimesteps = T
        self.t = 0
        self.rng_key = rng_key
        
        # Inner state
        self.current_state = None
        
        self.update_vectorized_weights()
    
    def update_vectorized_weights(self):
        self.vec_a,self.vec_b,self.vec_d = vectorize_weights(self.a,self.b,self.d,u)
    
    def reinit_trial(self):
        self.t = 0
        
        self.rng_key,init_tmstp_key = jax.random.split(self.rng_key)
        [s_d,s_idx,s_vect],[o_d,o_idx,o_vect] = initial_state_and_obs(init_tmstp_key,self.vec_a,self.vec_d)
        
        self.current_state = s_vect
        
        return o_vect,True        
    
    def step(self,action_chosen):
        
        if self.t == self.Ntimesteps:
            print("New trial ! The action has not been used here.")
            return self.reinit_trial()
        
        self.rng_key,timestep_rngkey = jax.random.split(self.rng_key)
        [s_d,s_idx,s_vect],[o_d,o_idx,o_vect] = process_update(timestep_rngkey,self.current_state,self.vec_a,self.vec_b,action_chosen)
        
        self.t = self.t + 1   
         
        return o_vect,False
    




In [14]:

# The environment is statically defined by its HMM matrices : 
from models import behavioural_process

T = 10
(a,b,c,d,e,u),fb_vals = behavioural_process(GRID_SIZE,START_COORD,END_COORD,preprocessing_options["observations"]["N_bins"],0.15)

rngkey = jax.random.PRNGKey(0)
env = TrainingEnvironment(rngkey,a,b,c,d,e,u,T)
print(env)

for k in range(100):
    o,new_trial = env.reinit_trial()
    print(o)



<__main__.TrainingEnvironment object at 0x00000213F2039B40>
[Array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], dtype=float32)]
[Array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.], dtype=float32)]
[Array([0., 0., 1., 0., 0., 0.

We are now interested in how multiple kinds of complex RL agents may perform in this situation. (this may depend on hyperparamters ...) 

Here is a repository with a lot of candidate algorithms :
https://github.com/udacity/deep-reinforcement-learning/blob/master/cross-entropy/CEM.ipynb

GPT4 on Reinforcement learning algorithms without Neural Networks. Because our system is quite simple, we don't necessarily need to use amortization components and may focus on more traditionnal (and easy to fit) methods !

### 1. **Tabular Methods**
   These methods explicitly store and update values for each state or state-action pair in a table. They are suitable for environments with a small number of states and actions.

   - **Q-learning (Tabular)**: 
     - A widely used off-policy method where a Q-value table is maintained, and values are updated based on the Bellman equation.
     - No neural network is involved; the Q-values are stored explicitly in a table.
   
   - **SARSA (Tabular)**: 
     - Similar to Q-learning but it is an on-policy method, updating the action-value function based on the action chosen by the current policy.

   - **Monte Carlo Methods**: 
     - These methods estimate the value of a policy by averaging returns following episodes of interaction with the environment.
     - For every state (or state-action pair), returns are recorded and averaged to compute values.

   - **Dynamic Programming (DP)**: 
     - Methods like **policy iteration** and **value iteration** use a model of the environment's dynamics to compute optimal policies.
     - These rely on a table to store value functions or policies.

### 2. **Function Approximation (without Neural Networks)**
   When state spaces are too large for tabular methods, simpler function approximation techniques are used to represent value functions or policies.

   - **Linear Function Approximation**: 
     - The value of a state (or state-action pair) is approximated as a linear combination of features.
     - For example, a value function can be represented as $ V(s) = w^T \phi(s) $, where $ w $ is a vector of weights and $ \phi(s) $ is a feature vector of the state.

   - **Tile Coding**: 
     - A method for discretizing a continuous state space into tiles or grids, where values are maintained for each tile.
     - It is often used with linear approximations to represent values for continuous environments.

   - **Polynomial or Radial Basis Function (RBF) Approximations**: 
     - These methods represent the value function using predefined basis functions, such as polynomials or radial basis functions.

### 3. **Policy Gradient Methods without Neural Networks**
   Policy gradient methods can be implemented without neural networks by using simpler parameterizations of the policy.

   - **Linear Policies**: 
     - The policy is represented as a linear combination of features, and the parameters of the linear policy are optimized using policy gradient techniques.

   - **Softmax Policies with Tabular Representation**: 
     - For discrete action spaces, a policy can be represented as a probability distribution (e.g., softmax) over actions for each state, and these probabilities are updated directly.

### 4. **Model-Based Methods**
   These methods involve learning or using a model of the environment's dynamics to help plan and optimize policies.

   - **Value Iteration (with Model)**: 
     - Given a model of the environment (a transition function and reward function), value iteration updates the value of each state using Bellman's equation.
   
   - **Policy Iteration (with Model)**: 
     - Alternates between policy evaluation and policy improvement to converge on the optimal policy using the model of the environment.