# Computational modeling : likelihood for various RL agents

Question : under very low amounts of evidence, how do human sample a complex action space ? Can we infer some form of structure in this exploration ? Can Active Inference provide some answers regarding the mechanistic processes behind it ?

In notebook 103, we derived a few proposal models to explain the behaviour of our subjects.  However, this is not enough as we aim at performing model inversion based on task data ! This means that we're going to need **likelihood functions** for each of these models !

Likelihood function describe the probability of these models generating the observed actions, given their hyperparameters $\theta$ and their previous experiences $o_{1:T,1:t},s_{1:T,1:t}$ : 
$$
\prod_T \prod_{t\in T} P(u_t|o_{1:T,1:t},u_{1:T,1:t-1},\theta)
$$

In this notebook, we modify the previous models to compute their likelihood in a jax environment.

In [10]:
# Import the needed packages 
# 
# 1/ the usual suspects
import sys,os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import vmap


from functools import partial

# 2/ The Active Inference package 
import actynf
from actynf.jaxtynf.jax_toolbox import _normalize,_jaxlog
from actynf.jaxtynf.layer_trial import compute_step_posteriors
from actynf.jaxtynf.layer_learn import learn_after_trial
from actynf.jaxtynf.layer_options import get_learning_options,get_planning_options
from actynf.jaxtynf.shape_tools import to_log_space,get_vectorized_novelty

from actynf.jaxtynf.layer_process import initial_state_and_obs,process_update
from actynf.jaxtynf.shape_tools import vectorize_weights


# 3/ Tools for : 
# a. Getting the raw data : 
from database_handling.database_extract import get_all_subject_data_from_internal_task_id
from utils import remove_by_indices
# b. Preprocessing the data :
from analysis_tools.preprocess import OPTIONS_PREPROCESS_DEFAULT,get_preprocessed_data



# The environment is statically defined by its HMM matrices : 
from hmm_weights import behavioural_process
# Weights for the active inference model : 
from hmm_weights import basic_latent_model






First, we grab the data corresponding to the experiment we're interested in (here, experiment 002). We also remove the subjects that either had technical issues or had very suspicious results. *(we should provide a clear rule on subject exclusion here, maybe based on action variance across all dimensions or reaction times ?).*

In [2]:

# Except subjects for predictors :
problematic_subjects_misc = ["5c9cb670b472d0001295f377"]
        # This subject has read the instructions with one submission and ran
        # the actual task with another, rendering statistics computed impossible to 
        # compare, this should be substracted from any statistical models based on
        # instructional data, but can be kept for raw performance plots.
# problematic_subjects_fraudulent =["6595ae358923ce48b037a0dc"]
        # This subject has very suspicious responses, including always putting both points in the same place
        # and acting as quickly as possible, to be removed from all analysis ?


# Import the data from the remote mongodb database & the imported prolific demographics :
INTERNAL_TASK_ID = "002"
TASK_RESULTS_ALL = get_all_subject_data_from_internal_task_id(INTERNAL_TASK_ID,None,
                                        autosave=True,override_save=False,autoload=True)
print("Loaded the task results for " + str(len(TASK_RESULTS_ALL)) + " subjects.")

# Each subject in task results has the following entries : 
# TASK_RESULT_FEATURES, TASK_RESULTS_EVENTS, TASK_RESULTS_DATA, TASK_RESULTS,RT_FB
remove_these_subjects = []
for index,entry in enumerate(TASK_RESULTS_ALL):
    subj_dict,_,_,_ = entry
    subj_name = subj_dict["subject_id"]
    if subj_name in problematic_subjects_misc:
        remove_these_subjects.append(index)

TASK_RESULTS = remove_by_indices(TASK_RESULTS_ALL,remove_these_subjects)
print(str(len(TASK_RESULTS)) + " subjects remaining after removing problematic subjects.")


LABELS = [entry[0] for entry in TASK_RESULTS]
EVENTS = [entry[1] for entry in TASK_RESULTS]
TRIAL_DATA = [entry[2] for entry in TASK_RESULTS]
RT_FBS = [entry[3] for entry in TASK_RESULTS]


Loaded the task results for 90 subjects.
89 subjects remaining after removing problematic subjects.


Once the "raw data" is loaded, we can use the preprocessing pipeline described briefly [here](./computational_modeling_101_preprocessing.ipynb) to generate a dictionnary with the observations and actions ready for use in our models. 

We can generate several dictionnaries depending on how we want the data to be formatted. This is driven by the option dictionnary :

In [3]:
preprocessing_options = {
    "actions":{
        "distance_bins" : np.array([0.0,0.05,0.2,0.5,jnp.sqrt(2) + 1e-10]),
        "angle_N_bins"  : 8,
        "position_N_bins_per_dim" : 3
    },
    "observations":{
        "N_bins" : 10,
        "observation_ends_at_point" : 2
    }
}

# We can modify these at will depending on the hypothesis we want to test
data = get_preprocessed_data(TRIAL_DATA,RT_FBS,preprocessing_options,
                            verbose=True,
                            autosave=True,autoload=True,override_save=True,
                            label="default")
print(data.keys())

Out of the 9790.0 actions performed by our subjects, 8233.0 were 'valid' (84.1 %)


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Out of the 9790.0 feedback sequences potentially observed by our subjects, 8803 were 'valid' (89.9 %)
dict_keys(['observations', 'actions'])


In [11]:
def random_agent(hyperparameters,constants):
    # a,b,c = hyperparameters
    num_actions, = constants
    
    
    # ____________________________________________________________________________________________
    # Each agent is a set of functions of the form :    
    def initial_params():
        return None # A function of the hyperparameters
    
    def initial_state(params):
        # Initial agent state (beginning of each trial)
        return None

    def actor_step(observation,state,params,rng_key):
        gauge_level,reward,trial_over,t = observation
        
        # OPTIONAL : Update states based on previous states, observations and parameters
        new_state = state
        
        # Compute action distribution using observation, states and parameters
        action_distribution,_ = _normalize(jnp.ones((num_actions,)))
        action_selected = jr.categorical(rng_key,_jaxlog(action_distribution))
        vect_action_selected = jax.nn.one_hot(action_selected,action_distribution.shape[0])
        
        return new_state,(action_distribution,action_selected,vect_action_selected)

    def update_params(trial_history,params):
        rewards,observations,states,actions = trial_history
        
        # Trial history is a list of trial rewards, observations and states, we may want to make them jnp arrays :
        # reward_array = jnp.stack(rewards)
        # observation_array = jnp.stack(observations)
        # states_array = jnp.stack(states)
        # action_array = jnp.stack(actions)
        
        # OPTIONAL :  Update parameters based on states, observations and actions history
        return None
    
    def predict(data_timestep,state,params):
        """Predict the next action given a set of observations,
        as well as the previous internal states and parameters of the agent.

        Args:
            observation (_type_): _description_
            state (_type_): _description_
            params (_type_): _description_
            true_action : the actual action that was performed (for state updating purposes !)

        Returns:
            new_state : the 
            predicted_action : $P(u_t|o_t,s_{t-1},\theta)$
        """
        gauge_level,obs_bool_filter,reward,true_action,t = data_timestep
        
        # OPTIONAL : Update states based on previous states, observations and parameters
        new_state = state
        
        # Compute action distribution using observation, states and parameters
        predicted_action,_ = _normalize(jnp.ones((num_actions,)))
        
        # Here are the data we may want to report during the training : 
        other_data = None
        
        return new_state,predicted_action,other_data
    # ____________________________________________________________________________________________
    
    return initial_params,initial_state,actor_step,update_params,predict


def compute_predicted_actions(data,agent_functions):
    """A function that uses vmap to compute the predicted agent action at time $t$ given $o_{1:t}$ and $u_{1:t-1}$. 
    This function should be differentiable w.r.t. the hyperparameters of the agent's model because we're going to perform
    gradient descent on it !

    Args:
        environment (_type_): _description_
        agent_functions (_type_): _description_
        seed (_type_): _description_
        Ntrials (_type_): _description_

    Returns:
        _type_: _description_
    """
    init_params,init_state,_,agent_learn,predict = agent_functions
    
    
    # Data should contain :
    # - all observations -> stimuli,reward (from the system)
    #       -> a list of stimuli for each modality
    #       -> a list of observation filters for each modality
    #       -> a Ntrials x Ntimesteps tensor array of scalar rewards (\in [0,1])
    # - all true actions 
    #       -> a Ntrials x (Ntimesteps-1) x Nu tensor array encoding the observed actions
    #       -> a Ntrials x (Ntimesteps-1) filter tensor indicating which actions were NOT observed
    
    initial_parameters = init_params()  
        # The initial parameters of the tested model are initialized once per training
    
    
    def _scan_trial(_carry,_data_trial):
        
        _agent_params = _carry
        _initial_state = init_state(_agent_params)
        
        _observations_trial,_observations_filter_trial,_rewards_trial,_actions_trial,_timestamps_trial = _data_trial
        
        # The same actions, with an extra one at the end for scan to work better !
        _expanded_actions_trial = jnp.concatenate([_actions_trial,jnp.zeros((1,_actions_trial.shape[-1]))])
        _expanded_data_trial = (_observations_trial,_observations_filter_trial,_rewards_trial,_expanded_actions_trial,_timestamps_trial)
        
        def __scan_timestep(__carry,__data_timestep):
            # __obs_vect,__obs_bool,__reward,__true_action_vect,__t = __data_timestep
            __agent_state = __carry
                    
            __new_state,__predicted_action,__other_data = predict(__data_timestep,__agent_state,_agent_params)        
            
            return __new_state,(__predicted_action,__new_state,__other_data)
        
        
        
        _,(_predicted_actions,_trial_states,_trial_other_data) = jax.lax.scan(__scan_timestep, (_initial_state),_expanded_data_trial)
          
        
        _new_params = agent_learn((_rewards_trial,_observations_trial,_trial_states,_actions_trial),_agent_params)
        
        return _new_params,(_predicted_actions[:-1,...],(_trial_states,_trial_other_data))

    final_parameters,(predicted_actions,(model_states,other_data)) = jax.lax.scan(_scan_trial,initial_parameters,data)

    return final_parameters,predicted_actions,(model_states,other_data)
            
from simulate.navigate_virtual_environment import TrainingEnvironment,run_loop

# The virtual environment  ENVIRONMENTAL CONSTANTS :
T = 11
N_FEEDBACK_OUTCOMES = 10
TRUE_FEEDBACK_STD = 0.15
GRID_SIZE = (7,7)
START_COORD = [[5,1],[5,2],[4,1]]
END_COORD = [0,6]
(a,b,c,d,e,u),fb_vals = behavioural_process(GRID_SIZE,START_COORD,END_COORD,N_FEEDBACK_OUTCOMES,TRUE_FEEDBACK_STD)
rngkey = jax.random.PRNGKey(np.random.randint(0,10))
ENVIRONMENT = TrainingEnvironment(rngkey,a,b,c,d,e,u,T)


# In : an agent based on some hyperparameters : 
SEED = 100
NTRIALS = 10
random_agent_hyperparameters = None
random_agent_constants = (9,)

# Synthetic data (here, generated randomly) :
params_final,training_hist = run_loop(ENVIRONMENT,random_agent(random_agent_hyperparameters,random_agent_constants),SEED,NTRIALS)

Trial 0
Trial 1
Trial 2
Trial 3
Trial 4
Trial 5
Trial 6
Trial 7
Trial 8
Trial 9


Let's generate some synthetic data :

In [12]:
# raw_stimuli = training_hist["stimuli"]
# # Stimuli to correct format : 
# formatted_stimuli = []
# for modality in range(len(raw_stimuli[0][0])) :
#     formatted_stimuli.append(jnp.array([[__o[modality] for __o in _o_t] for _o_t in raw_stimuli]))
# bool_stimuli = [jnp.ones_like(stim[...,0]) for stim in formatted_stimuli]
# rewards = jnp.array(training_hist["rewards"])
# tmtsp = jnp.array(training_hist["timestamps"])
# actions = jnp.array(training_hist["actions"])
# print(actions.shape)

# Parameter update (once every trial)
def _swaplist(_list):
    """ Put the various factors / modalities as the leading dimension for a 2D list of lists."""
    if _list is None :
        return None
    
    for el in _list :
        if (type(el) != list) and (type(el) != tuple):
            # There is a single factor here ! 
            return _list
    
    _swapped_list = []
    for factor in range(len(_list[0])):
        _swapped_list.append([_el[factor] for _el in _list])
    return _swapped_list
        
        
formatted_stimuli= [jnp.array(o) for o in _swaplist(training_hist["stimuli"])]
bool_stimuli = [jnp.ones_like(stim[...,0]) for stim in formatted_stimuli]
rewards = jnp.array(training_hist["rewards"])
actions = jnp.array(training_hist["actions"])
tmtsp = jnp.array(training_hist["timestamps"])
synthetic_data = (formatted_stimuli,bool_stimuli,rewards,actions,tmtsp)

In [13]:
# And compute the likelihood of each action given the random model : 
final_parameters,predicted_actions,model_states = compute_predicted_actions(synthetic_data,random_agent(random_agent_hyperparameters,random_agent_constants))
print(predicted_actions)
# print(model_states)

# Here's the average log-likelihood of what was observed given this model :
avg_ll = jnp.mean((actions * _jaxlog(predicted_actions)).sum(axis=-1))
print(avg_ll)


[[[0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11111111 0.11111111 0.11111111 0.11111111

In [14]:

def choice_kernel_agent(hyperparameters,constants):
    alpha,beta = hyperparameters
    num_actions, = constants

    # ____________________________________________________________________________________________
    # Each agent is a set of functions of the form :    
    def initial_params():
        # Parameters is the initial choice kernel :
        CK_initial = jnp.zeros((num_actions,))
        
        return CK_initial # A function of the hyperparameters
    
    def initial_state(params):
        # Initial agent state (beginning of each trial)
        # The initial state is the CK table and an initial action (easier integration with rw+ck model)
        return params,jnp.zeros((num_actions,))

    def actor_step(observation,state,params,rng_key):
        gauge_level,reward,trial_over,t = observation
        
        # The state of the agent stores the choice kernel and the last action performed : 
        ck,last_action = state
        
        # Update the choice kernel :
        was_a_last_action = jnp.sum(last_action)  # No update if there was no last action
        new_ck = ck + alpha*(last_action - ck)*was_a_last_action
        
        action_distribution = jax.nn.softmax(beta*ck)
        action_selected = jr.categorical(rng_key,_jaxlog(action_distribution))
        vect_action_selected = jax.nn.one_hot(action_selected,action_distribution.shape[0])        
        
        return (new_ck,vect_action_selected),(action_distribution,action_selected,vect_action_selected)

    def update_params(trial_history,params):
        rewards,observations,states,actions = trial_history
        
        # The params for the next step is the last choice kernel of the trial :
        # (the update already occured during the actor step !)
        cks,previous_actions = states
        
        ck_last = cks[-1]
        
        return ck_last
    
    def predict(data_timestep,state,params):
        """Predict the next action given a set of observations,
        as well as the previous internal states and parameters of the agent.

        Args:
            observation (_type_): _description_
            state (_type_): _description_
            params (_type_): _description_
            true_action : the actual action that was performed (for state updating purposes !)

        Returns:
            new_state : the 
            predicted_action : $P(u_t|o_t,s_{t-1},\theta)$
        """
        gauge_level,obs_bool_filter,reward,true_action,t = data_timestep        
        
        # The state of the agent stores the choice kernel and the last action performed : 
        ck,last_action = state
        
        # Update the choice kernel :
        was_a_last_action = jnp.sum(last_action)  # No update if there was no last action
        new_ck = ck + alpha*(last_action - ck)*was_a_last_action
        
        predicted_action = jax.nn.softmax(beta*ck) 
        
        # Here are the data we may want to report during the training : 
        other_data = None
                
        return (new_ck,true_action),predicted_action,other_data
    # ____________________________________________________________________________________________
    
    return initial_params,initial_state,actor_step,update_params,predict


ck_agent_hyperparameters = (0.5,1.0)   # [0,1] x [0, +oo]
ck_agent_constants = (9,)              # Nactions
final_parameters,predicted_actions,state_history = compute_predicted_actions(synthetic_data,choice_kernel_agent(ck_agent_hyperparameters,ck_agent_constants))
# print(predicted_actions)


# We can have an idea of what happened during training by looking at the inner states of the model
inner_states,_ = state_history
# Of course, these will vary from one model to the next :
ck_table,previous_action = inner_states
print(ck_table.shape)
print(ck_table[-1,-1,:])




(10, 11, 9)
[9.7656250e-03 1.0768622e-03 4.0283222e-03 2.5001526e-01 2.4414156e-04
 4.8834115e-04 5.0000000e-01 5.7220791e-06 2.3437572e-01]


In [17]:
def rescorla_wagner_agent(hyperparameters,constants):
    alpha,beta = hyperparameters
    num_actions, = constants

    # ____________________________________________________________________________________________
    # Each agent is a set of functions of the form :    
    def initial_params():
        # Parameters is the initial perceived reward :
        q_initial = jnp.zeros((num_actions,))
        
        return q_initial # A function of the hyperparameters
    
    def initial_state(params):
        # Initial agent state (beginning of each trial)
        
        # The initial state is the q_table, as well as an initial action selected (None)
        return params,jnp.zeros((num_actions,))

    def actor_step(observation,state,params,rng_key):
        gauge_level,reward,trial_over,t = observation
        
        q_t,previous_action = state
        
        # Update the table now that we have the new reward !
        q_tplus = q_t + alpha*(reward-q_t)*previous_action
        
        action_distribution = jax.nn.softmax(beta*q_tplus)
        action_selected = jr.categorical(rng_key,_jaxlog(action_distribution))
        vect_action_selected = jax.nn.one_hot(action_selected,action_distribution.shape[0])       
        
        return (q_tplus,vect_action_selected),(action_distribution,action_selected,vect_action_selected)

    def update_params(trial_history,params):
        rewards,observations,states,actions = trial_history
                
        # The params for the next step is the last choice kernel of the trial :
        # (the update already occured during the actor step !)
        qts,previous_actions = states
        
        q_t_last = qts[-1]
        
        
        # The params for the next step is the last choice kernel of the trial :
        # (the update already occured during the actor step !)
        return q_t_last
    
    
    def predict(data_timestep,state,params):
        """Predict the next action given a set of observations,
        as well as the previous internal states and parameters of the agent.

        Args:
            observation (_type_): _description_
            state (_type_): _description_
            params (_type_): _description_
            true_action : the actual action that was performed (for state updating purposes !)

        Returns:
            new_state : the 
            predicted_action : $P(u_t|o_t,s_{t-1},\theta)$
        """
        gauge_level,obs_bool_filter,reward,true_action,t = data_timestep        
        
        q_t,previous_action = state
        
        # Update the table now that we have the new reward !
        q_tplus = q_t + alpha*(reward-q_t)*previous_action
        
        predicted_action = jax.nn.softmax(beta*q_tplus)
        
        # Here are the data we may want to report during the training : 
        other_data = None
                
        return (q_tplus,true_action),predicted_action,other_data
    # ____________________________________________________________________________________________
    
    return initial_params,initial_state,actor_step,update_params,predict

rw_agent_hyperparameters = (0.5,1.0)
rw_agent_constants = (9,)
final_parameters,predicted_actions,state_history = compute_predicted_actions(synthetic_data,rescorla_wagner_agent(rw_agent_hyperparameters,rw_agent_constants))
print(predicted_actions)


# We can have an idea of what happened during training by looking at the inner states of the model
inner_states,_ = state_history
# Of course, these will vary from one model to the next :
q_table,previous_action = inner_states
print(q_table.shape)
print(q_table[-1,-1,:])

[[[0.11111111 0.11111111 0.11111111 0.11111111 0.11111111 0.11111111
   0.11111111 0.11111111 0.11111111]
  [0.11041028 0.11041028 0.11041028 0.11041028 0.11041028 0.11041028
   0.11041028 0.11671777 0.11041028]
  [0.11041028 0.11041028 0.11041028 0.11041028 0.11041028 0.11041028
   0.11041028 0.11671777 0.11041028]
  [0.10971824 0.10971824 0.1159862  0.10971824 0.10971824 0.10971824
   0.10971824 0.1159862  0.10971824]
  [0.11037266 0.11037266 0.11667801 0.11037266 0.11037266 0.11037266
   0.10440806 0.11667801 0.11037266]
  [0.11001111 0.11001111 0.11957152 0.11001111 0.11001111 0.11001111
   0.10406605 0.11629579 0.11001111]
  [0.10860699 0.10860699 0.11804537 0.10860699 0.10860699 0.10860699
   0.10273781 0.11481146 0.12137038]
  [0.10924818 0.10924818 0.11874229 0.10924818 0.10334436 0.10924818
   0.10334436 0.11548929 0.12208693]
  [0.10857059 0.11477298 0.11800582 0.10857059 0.10270339 0.10857059
   0.10270339 0.11477298 0.12132971]
  [0.10946952 0.11572326 0.11898286 0.10946952

In [18]:
def rw_ck_agent(hyperparameters,constants):
    alpha,beta,alpha_ck,beta_ck = hyperparameters
    num_actions, = constants

    # ____________________________________________________________________________________________
    # Each agent is a set of functions of the form :    
    def initial_params():
        # Parameters are the initial perceived reward :
        q_initial = jnp.zeros((num_actions,))
        # and the initial choice kernel :
        ck_initial = jnp.zeros((num_actions,))
        
        return q_initial,ck_initial 
    
    def initial_state(params):
        # Initial agent state (beginning of each trial)
        q,ck = params
        # The initial state is the q_table, as well as an initial action selected (None)
        return q,ck,jnp.zeros((num_actions,))

    def actor_step(observation,state,params,rng_key):
        gauge_level,reward,trial_over,t = observation
        
        q_t,ck,previous_action = state
        
        # Update the table now that we have the new reward !
        q_tplus = q_t + alpha*(reward-q_t)*previous_action
        
        # Update the choice kernel :
        was_a_last_action = jnp.sum(previous_action)  # No update if there was no last action
        new_ck = ck + alpha*(previous_action - ck)*was_a_last_action
        
        
        action_distribution = jax.nn.softmax(beta*q_tplus + beta_ck*ck)
        action_selected = jr.categorical(rng_key,_jaxlog(action_distribution))
        vect_action_selected = jax.nn.one_hot(action_selected,action_distribution.shape[0]) 
        
        return (q_tplus,new_ck,vect_action_selected),(action_distribution,action_selected,vect_action_selected)

    def update_params(trial_history,params):
        rewards,observations,states,actions = trial_history
        
        qts,cks,previous_actions = states
        
        
        q_t_last,ck_last = qts[-1],cks[-1]

        return q_t_last,ck_last
    # ____________________________________________________________________________________________
    
    def predict(data_timestep,state,params):
        """Predict the next action given a set of observations,
        as well as the previous internal states and parameters of the agent.

        Args:
            observation (_type_): _description_
            state (_type_): _description_
            params (_type_): _description_
            true_action : the actual action that was performed (for state updating purposes !)

        Returns:
            new_state : the 
            predicted_action : $P(u_t|o_t,s_{t-1},\theta)$
        """
        gauge_level,obs_bool_filter,reward,true_action,t = data_timestep        
        
        q_t,ck,previous_action = state
        
        # Update the table now that we have the new reward !
        q_tplus = q_t + alpha*(reward-q_t)*previous_action
        
        # Update the choice kernel :
        was_a_last_action = jnp.sum(previous_action)  # No update if there was no last action
        new_ck = ck + alpha*(previous_action - ck)*was_a_last_action
                
        predicted_action = jax.nn.softmax(beta*q_tplus + beta_ck*ck)
        
        # Here are the data we may want to report during the training : 
        other_data = None
                
        return (q_tplus,new_ck,true_action),predicted_action,other_data
    # ____________________________________________________________________________________________
    
    return initial_params,initial_state,actor_step,update_params,predict


rw_ck_agent_hyperparameters = (1.0,1.0,0.5,0.1)
rw_ck_agent_constants = (9,)
final_parameters,predicted_actions,state_history = compute_predicted_actions(synthetic_data,rw_ck_agent(rw_ck_agent_hyperparameters,rw_ck_agent_constants))

# We can have an idea of what happened during training by looking at the inner states of the model
inner_states,_ = state_history
# Of course, these will vary from one model to the next :
q_table,ck_table,previous_action = inner_states
print(q_table.shape)
print(q_table[-1,-1,:])
print(ck_table[-1,-1,:])


(10, 11, 9)
[ 0.          0.11111112  0.22222224  0.33333334 -0.22222225  0.33333334
 -0.33333337  0.         -0.22222224]
[0. 0. 0. 0. 0. 0. 1. 0. 0.]


In [19]:


def q_learning_agent(hyperparameters,constants):
    alpha_plus,alpha_minus,beta,alpha_ck,beta_ck = hyperparameters
    num_actions,num_states = constants

    # ____________________________________________________________________________________________
    # Each agent is a set of functions of the form :    
    def initial_params():
        # Parameters are the initial q-table. As opposed to a RW agent, the mappings now depend on the states 
        # This usually allows for better responsiveness to the environment, but in this situation it may make the training
        # harder !
        q_initial = jnp.zeros((num_actions,num_states))
        # and the initial choice kernel :
        ck_initial = jnp.zeros((num_actions,))
        
        return q_initial,ck_initial 
    
    def initial_state(params):
        # Initial agent state (beginning of each trial)
        q,ck = params
        # The initial state is the q_table, as well as an initial action selected (None) and the last gauge level (None)
        return q,ck,jnp.zeros((num_actions,)),[jnp.zeros((num_states,))]

    def actor_step(observation,state,params,rng_key):
        current_stimuli,reward,trial_over,t = observation
        current_gauge_level = current_stimuli[0]
        
        q_t,ck,previous_action,previous_stimuli = state
        previous_gauge_level = previous_stimuli[0]
        
        # Update the table now that we have the new reward !
        # This is "where" the reward was observed in the table :
        previous_action_state = jnp.einsum("i,j->ij",previous_action,previous_gauge_level)
        
        positive_reward = jnp.clip(reward,min=0.0)
        negative_reward = jnp.clip(reward,max=0.0)
        
        positive_reward_prediction_error = positive_reward - q_t
        negative_reward_prediction_error = negative_reward - q_t
        
        q_tplus = q_t + (alpha_plus*positive_reward_prediction_error + alpha_minus*negative_reward_prediction_error)*previous_action_state
        
        # Update the choice kernel :
        was_a_last_action = jnp.sum(previous_action)  # No update if there was no last action
        new_ck = ck + alpha_ck*(previous_action - ck)*was_a_last_action
        


        # Action selection :
        q_table_at_this_state = jnp.einsum("ij,j->i",q_tplus,current_gauge_level)
        
        action_distribution = jax.nn.softmax(beta*q_table_at_this_state + beta_ck*new_ck)
        action_selected = jr.categorical(rng_key,_jaxlog(action_distribution))
        vect_action_selected = jax.nn.one_hot(action_selected,action_distribution.shape[0])  
        
        return (q_tplus,new_ck,vect_action_selected,current_stimuli),(action_distribution,action_selected,vect_action_selected)

    def update_params(trial_history,params):
        rewards,observations,states,actions = trial_history
        
        qts,cks,previous_actions,previous_stimuli = states
        
        q_t_last,ck_last = qts[-1],cks[-1]
        
        # The params for the next step is the last choice kernel of the trial :
        # (the update already occured during the actor step !)
        return q_t_last,ck_last
    
    def predict(data_timestep,state,params):
        """Predict the next action given a set of observations,
        as well as the previous internal states and parameters of the agent.

        Args:
            observation (_type_): _description_
            state (_type_): _description_
            params (_type_): _description_
            true_action : the actual action that was performed (for state updating purposes !)

        Returns:
            new_state : the 
            predicted_action : $P(u_t|o_t,s_{t-1},\theta)$
        """
        current_stimuli,obs_bool_filter,reward,true_action,t = data_timestep      
        current_gauge_level = current_stimuli[0]  
        
        q_t,ck,previous_action,previous_stimuli = state
        previous_gauge_level = previous_stimuli[0]
        
        # Update the table now that we have the new reward !
        # This is "where" the reward was observed in the table : 
        previous_action_state = jnp.einsum("i,j->ij",previous_action,previous_gauge_level)
        
        positive_reward = jnp.clip(reward,min=0.0)
        negative_reward = jnp.clip(reward,max=0.0)
        
        positive_reward_prediction_error = positive_reward - q_t
        negative_reward_prediction_error = negative_reward - q_t
        
        q_tplus = q_t + (alpha_plus*positive_reward_prediction_error + alpha_minus*negative_reward_prediction_error)*previous_action_state
        
        # Update the choice kernel :
        was_a_last_action = jnp.sum(previous_action)  # No update if there was no last action
        new_ck = ck + alpha_ck*(previous_action - ck)*was_a_last_action
        

        # Action selection :
        q_table_at_this_state = jnp.einsum("ij,j->i",q_tplus,current_gauge_level)
        
        predicted_action = jax.nn.softmax(beta*q_table_at_this_state + beta_ck*new_ck)
        
        # Here are the data we may want to report during the training : 
        other_data = None
        
        return (q_tplus,new_ck,true_action,current_stimuli),predicted_action,other_data
            # ____________________________________________________________________________________________
    
    return initial_params,initial_state,actor_step,update_params,predict


ql_ck_agent_hyperparameters = (0.5,0.7,1.0,0.0,0.0)
ql_ck_agent_constants = (9,N_FEEDBACK_OUTCOMES)
final_parameters,predicted_actions,state_history = compute_predicted_actions(synthetic_data,q_learning_agent(ql_ck_agent_hyperparameters,ql_ck_agent_constants))

# We can have an idea of what happened during training by looking at the inner states of the model
inner_states,_ = state_history
# Of course, these will vary from one model to the next :
q_table,ck_table,previous_action,previous_stim = inner_states
print(q_table.shape)
print(q_table[-1,-1,:])
print(ck_table[-1,-1,:])

(10, 11, 9, 10)
[[ 0.          0.11111112 -0.08444444 -0.07777778  0.          0.
   0.         -0.15555555  0.          0.        ]
 [ 0.05555556  0.          0.07777777 -0.06035556 -0.12942222 -0.01422223
   0.          0.          0.          0.        ]
 [ 0.          0.04444444  0.02622222  0.11111112  0.03822223  0.
   0.          0.          0.05555555  0.        ]
 [ 0.22222222  0.14444445  0.          0.05555555  0.22222222  0.
   0.          0.          0.          0.        ]
 [ 0.11111111  0.05555556  0.         -0.17777781 -0.08337777  0.1111111
   0.          0.          0.          0.        ]
 [ 0.          0.15555556  0.05555556 -0.08888888 -0.09333333  0.
   0.          0.          0.          0.        ]
 [ 0.16666667 -0.07777778  0.02862221  0.         -0.23333333 -0.3888889
   0.          0.          0.          0.        ]
 [ 0.1         0.          0.05555556  0.03111111  0.07111111 -0.15555556
  -0.15555556  0.          0.         -0.15555555]
 [ 0.          0.1

In [21]:

def active_inference_basic_1D(hyperparameters,constants):
    a0,b0,c0,d0,e0,u = basic_latent_model({**constants, **hyperparameters})
    beta = hyperparameters["action_selection_temperature"]
    
    planning_options = get_planning_options(constants["Th"],"classic",a_novel=False,b_novel=False)
    learning_options = get_learning_options(learn_b=True,lr_b=hyperparameters["transition_learning_rate"],method="vanilla+backwards",
                                    state_generalize_function=lambda x : jnp.exp(-hyperparameters["state_interpolation_temperature"]),
                                    action_generalize_table=None,
                                    cross_action_extrapolation_coeff=None)
    # ____________________________________________________________________________________________
    # Each agent is a set of functions of the form :    
    def initial_params():
        # The initial parameters of the AIF agent are its model weights :
        return a0,b0,c0,d0,e0,u
    
    def initial_state(params):
        pa,pb,pc,pd,pe,u = params

        # The "states" of the active Inference agent are : 
        # 1. The vectorized parameters for this trial :
        trial_a,trial_b,trial_d = vectorize_weights(pa,pb,pd,u)
        trial_c,trial_e = to_log_space(pc,pe)
        trial_a_nov,trial_b_nov = get_vectorized_novelty(pa,pb,u,compute_a_novelty=True,compute_b_novelty=True)
        
        # 2. Its priors about the next state : (given by the d matrix parameter)
        prior = trial_d
        
        return prior,jnp.zeros_like(prior),(trial_a,trial_b,trial_c,trial_e,trial_a_nov,trial_b_nov) # We don't need trial_d anymore !

    def actor_step(observation,state,params,rng_key):
        emission,reward,trial_over,t = observation
        gauge_level = emission[0]
                
        state_prior,previous_posterior,timestep_weights = state
        a_norm,b_norm,c,e,a_novel,b_novel = timestep_weights
        
        end_of_trial_filter = jnp.ones((planning_options["horizon"]+2,))
        qs,F,raw_qpi,efe = compute_step_posteriors(t,state_prior,emission,a_norm,b_norm,c,e,a_novel,b_novel,
                                    end_of_trial_filter,
                                    rng_key,planning_options)       

        # Action selection :        
        action_distribution = jax.nn.softmax(beta*efe)
        action_selected = jr.categorical(rng_key,_jaxlog(action_distribution))
        vect_action_selected = jax.nn.one_hot(action_selected,action_distribution.shape[0])  
        
        # New state prior : 
        new_prior = jnp.einsum("iju,j,u->i",b_norm,qs,vect_action_selected)
        
        # OPTIONAL : ONLINE UPDATING OF PARAMETERS 
        
        return (new_prior,timestep_weights),(action_distribution,action_selected,vect_action_selected)


    def update_params(trial_history,params):
        pa,pb,pc,pd,pe,u = params
        rewards,observations,states,actions = trial_history
        
        priors_history,posteriors_history,_ = states   

        obs_vect_arr = [jnp.array(observations[0])]
        qs_arr = jnp.stack(posteriors_history)
        u_vect_arr = jnp.stack(actions)    
        
        # Then, we update the parameters of our HMM model at this level
        # We use the raw weights here !
        a_post,b_post,c_post,d_post,e_post,qs_post = learn_after_trial(obs_vect_arr,qs_arr,u_vect_arr,
                                                pa,pb,c,pd,e,u,
                                                method = learning_options["method"],
                                                learn_what = learning_options["bool"],
                                                learn_rates = learning_options["rates"],
                                                generalize_state_function=learning_options["state_generalize_function"],
                                                generalize_action_table=learning_options["action_generalize_table"],
                                                cross_action_extrapolation_coeff=learning_options["cross_action_extrapolation_coeff"],
                                                em_iter = learning_options["em_iterations"])
        
        # The params for the next step is the last choice kernel of the trial :
        # (the update already occured during the actor step !)
        return a_post,b_post,c_post,d_post,e_post,u
    # ____________________________________________________________________________________________
    
    def predict(data_timestep,state,params):
        """Predict the next action given a set of observations,
        as well as the previous internal states and parameters of the agent.

        Args:
            observation (_type_): _description_
            state (_type_): _description_
            params (_type_): _description_
            true_action : the actual action that was performed (for state updating purposes !)

        Returns:
            new_state : the 
            predicted_action : $P(u_t|o_t,s_{t-1},\theta)$
        """
        current_stimuli,obs_bool_filter,reward,true_action,t = data_timestep      
        current_gauge_level = current_stimuli[0]  
                
        state_prior,previous_posterior,timestep_weights = state
        a_norm,b_norm,c,e,a_novel,b_novel = timestep_weights
        
        end_of_trial_filter = jnp.ones((planning_options["horizon"]+2,))
        qs,F,raw_qpi,efe = compute_step_posteriors(t,state_prior,current_stimuli,a_norm,b_norm,c,e,a_novel,b_novel,
                                    end_of_trial_filter,
                                    None,planning_options)       

        # Action selection :        
        predicted_action = jax.nn.softmax(beta*efe)
        
        # New state prior : 
        new_prior = jnp.einsum("iju,j,u->i",b_norm,qs,true_action)
        
        # OPTIONAL : ONLINE UPDATING OF PARAMETERS 
                
        # Here are the data we may want to report during the training : 
        other_data = (qs,F)
        
        return (new_prior,qs,timestep_weights),predicted_action,other_data
        # ____________________________________________________________________________________________         
    return initial_params,initial_state,actor_step,update_params,predict







# We get a model weights by defining a "parameters" object :
aif_1d_constants = {
    # General environment : 
    "N_feedback_ticks":N_FEEDBACK_OUTCOMES,
    # Latent state space structure
    "Ns_latent":5,      # For 1D
    # Action discretization:
    "N_actions_distance" :3,
    "N_actions_position" :9,
    "N_actions_angle" :9,
    
    "Th" : 3
}

aif_1d_params = {    
    # ----------------------------------------------------------------------------------------------------
    # Model parameters : these should interact with the model components in a differentiable manner
    "transition_concentration": 1.0,
    "transition_stickiness": 1.0,
    "transition_learning_rate" : 1.0,
    "state_interpolation_temperature" : 1.0,
    
    "initial_state_concentration": 1.0,
    
    "feedback_expected_std" : 0.15,
    "emission_concentration" : 1.0,
    "emission_stickiness" : 100.0,
    
    "reward_seeking" : 10.0,
    
    "action_selection_temperature" : 1.0,
}
final_parameters,predicted_actions,state_history = compute_predicted_actions(synthetic_data,active_inference_basic_1D(aif_1d_params,aif_1d_constants))
# print(predicted_actions)


# We can have an idea of what happened during training by looking at the inner states of the model
inner_states,(infered_states,free_energies) = state_history
# Of course, these will vary from one model to the next :
priors,posteriors,weights = inner_states

# a_norm,b_norm,c,e,a_novel,b_novel = weights

print(free_energies)
print(free_energies.shape)

# print(q_table.shape)
# print(q_table[-1,-1,:])
# print(ck_table[-1,-1,:])

(10,)
(10, 5)
(11, 5)
[[[-2.1043572]
  [-1.9794528]
  [-1.9935575]
  [-2.222168 ]
  [-2.0510163]
  [-2.2056844]
  [-2.4584064]
  [-2.3192506]
  [-2.3524148]
  [-2.2984724]
  [-2.1926236]]

 [[-2.4443073]
  [-2.1478336]
  [-2.0198371]
  [-2.0043235]
  [-2.1626353]
  [-2.2722385]
  [-2.376525 ]
  [-2.301517 ]
  [-2.3021688]
  [-2.2983577]
  [-2.2135296]]

 [[-2.1938996]
  [-2.3829422]
  [-2.1498477]
  [-1.8449485]
  [-2.5881083]
  [-2.2958293]
  [-2.1254148]
  [-2.1712437]
  [-2.4299626]
  [-2.2780724]
  [-2.0255327]]

 [[-2.3424745]
  [-2.546113 ]
  [-2.2939634]
  [-2.2811902]
  [-2.1484866]
  [-1.8917875]
  [-2.2019103]
  [-2.4029791]
  [-2.2701359]
  [-2.2629182]
  [-2.348713 ]]

 [[-2.4802887]
  [-2.3544197]
  [-2.1298974]
  [-2.239411 ]
  [-2.3346245]
  [-2.2380896]
  [-2.1994786]
  [-2.3608015]
  [-2.2587957]
  [-2.3698757]
  [-2.3523414]]

 [[-2.4802887]
  [-2.2598128]
  [-2.2965775]
  [-2.2286606]
  [-2.29698  ]
  [-2.366916 ]
  [-2.2680147]
  [-2.347527 ]
  [-2.2112932]
  [-1.75

For the remainder of the study, we'll stock all of these predictive models in the [proposal models](./proposal_models.py) file to manipulate and change them more easily. Next up is inversion !