# Computational modeling : RL algorithms in a virtual environment

Question : under very low amounts of evidence, how do human sample a complex action space ? Can we infer some form of structure in this exploration ? Can Active Inference provide some answers regarding the mechanistic processes behind it ?



First, we grab the data corresponding to the experiment we're interested in (here, experiment 002). We also remove the subjects that either had technical issues or had very suspicious results. *(we should provide a clear rule on subject exclusion here, maybe based on action variance across all dimensions or reaction times ?).*

In [3]:
!pip list

Package                   Version
------------------------- -----------
absl-py                   2.1.0
active_pynference         0.1.8
arviz                     0.19.0
asttokens                 2.4.1
attrs                     24.2.0
beautifulsoup4            4.12.3
certifi                   2024.7.4
chardet                   3.0.4
charset-normalizer        3.3.2
chex                      0.1.86
cloudpickle               3.0.0
colorama                  0.4.6
comm                      0.2.2
contourpy                 1.2.1
corner                    2.2.2
cycler                    0.12.1
debugpy                   1.8.2
decorator                 5.1.1
deep-translator           1.11.4
dm-tree                   0.1.8
dnspython                 2.6.1
etils                     1.9.4
exceptiongroup            1.2.2
executing                 2.0.1
fastjsonschema            2.20.0
fastprogress              1.0.3
fonttools                 4.53.1
gast                      0.6.0
google               

In [4]:
import sys
print(sys.prefix)
# Import the needed packages 
# 
# 1/ the usual suspects
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from jax.tree_util import tree_map
import jax.random as jr
import tensorflow_probability.substrates.jax.distributions as tfd 

from functools import partial

# 2/ Useful functions from our package :
from actynf.jaxtynf.jax_toolbox import _normalize,_jaxlog
from actynf.jaxtynf.jax_toolbox import random_split_like_tree

# To make nice plots : 
from simulate.plot_trajectory import plot_training

# To import model files with poorly written names  ¯\_(ツ)_/¯
import importlib  

# to make pretty tables : 
from tabulate import tabulate

# The simulated environment :
from simulate.hmm_weights import behavioural_process # The environment is statically defined by its HMM matrices
from simulate.generate_observations_full_actions import TrainingEnvironment,run_loop,generate_synthetic_data

# The methods to predict actions, compute the log-likelihoods and fit the models :
from simulate.compute_likelihood_full_actions import compute_predicted_actions,compute_loglikelihood
from simulate.compute_likelihood_full_actions import fit_mle_agent,fit_map_agent
from simulate.invert_model import invert_data_for_library_of_models


c:\Users\annic\OneDrive\Bureau\MainPhD\code\behavioural_exp_code\exploit_results_env


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# environment constants :
NTRIALS = 10
T = 11

# MODEL CONSTANTS :
N_LATENT_STATES = 5

# ENVIRONMENTAL CONSTANTS :
N_FEEDBACK_OUTCOMES = 6
TRUE_FEEDBACK_STD =  0.175#0.025
GRID_SIZE = (7,7)
START_COORD = [[5,1],[5,2],[4,1]]
END_COORD = [0,6]

# The weights of the HMM environment
(a,b,c,d,e,u),fb_vals = behavioural_process(GRID_SIZE,START_COORD,END_COORD,N_FEEDBACK_OUTCOMES,TRUE_FEEDBACK_STD)

rngkey = jax.random.PRNGKey(np.random.randint(0,10))
ENVIRONMENT = TrainingEnvironment(rngkey,a,b,c,d,e,u,T)

# define the static dimensions of the problem :
No = N_FEEDBACK_OUTCOMES
Ns = N_LATENT_STATES
MODEL_CONSTANTS = {
    "position" : {
        "N_actions" : 9,
        "N_outcomes" : No,
        "N_states" : Ns
    },
    "angle" : {
        "N_actions" : 9,
        "N_outcomes" : No,
        "N_states" : Ns
    },
    "distance" : {
        "N_actions" : 4,
        "N_outcomes" : No,
        "N_states" : Ns
    },
    
}

Behaviour of one static biais agent with uniform biais :


Let's initialize a set of parameter ranges and priors for a plethora of models !

In [25]:

# Predefine the priors for MAP approaches
zero_one_uni = tfd.Uniform(low=-1e-5,high=1.0+1e-5)  # Bounds may be a bit finnicky and cause loglikelihood overflows
zero_big_uni = tfd.Uniform(low=-1e-5,high=1e3+1e-5) 

beta_biais_prior = tfd.Normal(10,5)
beta_Q_prior = tfd.Normal(10,5)
beta_omega_prior = tfd.Normal(10,5)




MODEL_LIBRARY = {}


# Random agent -------------------------------------------------------------------------------------------------------
# Random agent : selects actions randomly
from agents.agent_random import agent as random_agent
MODEL_LIBRARY["random"] ={
    "model" : partial(random_agent,constants=MODEL_CONSTANTS),
    "ranges" : {"angle":None,"position":None,"distance":None},
    "priors": {"angle":None,"position":None,"distance":None},
    "bypass" : True,  # Don't try to invert this one, the log likelihood is fixed !
    "tags": ["random"]
}



# latent states Q-Learning agents --------------------------------------------------------------------------------------=
# Independent model weights : ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
model_module = importlib.import_module("agents.agent_i-latQL+b")
MODEL_LIBRARY["i_latQL+b"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS),
    "ranges" : {
            "angle":{
                "biais" : jnp.array([-10,10,9]),
                "beta_biais" : jnp.array([-3,3]),
                "alpha_Q" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "position":{
                "biais" : jnp.array([-10,10,9]),
                "beta_biais" : jnp.array([-3,3]),
                "alpha_Q" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "distance":{
                "biais" : jnp.array([-10,10,4]),
                "beta_biais" : jnp.array([-3,3]),
                "alpha_Q" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            }
        },
    "priors": {
            "angle":{
                "biais" : zero_one_uni,  # No priors on the individual biaises !
                "beta_biais" : beta_Q_prior,
                "alpha_Q" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "position":{
                "biais" : zero_one_uni,  # No priors on the individual biaises !
                "beta_biais" : beta_Q_prior,
                "alpha_Q" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "distance":{
                "biais" : zero_one_uni,  # No priors on the individual biaises !
                "beta_biais" : beta_Q_prior,
                "alpha_Q" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            }
        },
    "tags": ["latql","static_biais"]
}

model_module = importlib.import_module("agents.agent_i-latQL&b")
MODEL_LIBRARY["i_latQL&b"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS),
    "ranges" : {
            "angle":{
                "initial_q" : jnp.array([-10,10,9]),
                "alpha_Q" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "position":{
                "initial_q" : jnp.array([-10,10,9]),
                "alpha_Q" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "distance":{
                "initial_q" : jnp.array([-10,10,4]),
                "alpha_Q" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            }
        },
    "priors": {
            "angle":{
                "initial_q" : zero_one_uni,  # No priors on the individual biaises !
                "alpha_Q" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "position":{
                "initial_q" : zero_one_uni,  # No priors on the individual biaises !
                "alpha_Q" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "distance":{
                "initial_q" : zero_one_uni,  # No priors on the individual biaises !
                "alpha_Q" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            }
        },
    "tags": ["latql","initial_biais"]
}

model_module = importlib.import_module("agents.agent_i-latQLa+b")
MODEL_LIBRARY["i_latQLa+b"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS),
    "ranges" : {
            "angle":{
                "biais" : jnp.array([-10,10,9]),
                "beta_biais" : jnp.array([-3,3]),
                "alpha_Q+" : jnp.array([-10,10]),
                "alpha_Q-" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "position":{
                "biais" : jnp.array([-10,10,9]),
                "beta_biais" : jnp.array([-3,3]),
                "alpha_Q+" : jnp.array([-10,10]),
                "alpha_Q-" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "distance":{
                "biais" : jnp.array([-10,10,4]),
                "beta_biais" : jnp.array([-3,3]),
                "alpha_Q+" : jnp.array([-10,10]),
                "alpha_Q-" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            }
        },
    "priors": {
            "angle":{
                "biais" : zero_one_uni,  # No priors on the individual biaises !
                "beta_biais" : beta_Q_prior,
                "alpha_Q+" : zero_one_uni,
                "alpha_Q-" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "position":{
                "biais" : zero_one_uni,  # No priors on the individual biaises !
                "beta_biais" : beta_Q_prior,
                "alpha_Q+" : zero_one_uni,
                "alpha_Q-" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "distance":{
                "biais" : zero_one_uni,  # No priors on the individual biaises !
                "beta_biais" : beta_Q_prior,
                "alpha_Q+" : zero_one_uni,
                "alpha_Q-" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            }
        },
    "tags": ["latql","static_biais","assymetric"]
}

model_module = importlib.import_module("agents.agent_i-latQLa&b")
MODEL_LIBRARY["i_latQLa&b"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS),
    "ranges" : {
            "angle":{
                "initial_q" : jnp.array([-10,10,9]),
                "alpha_Q+" : jnp.array([-10,10]),
                "alpha_Q-" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "position":{
                "initial_q" : jnp.array([-10,10,9]),
                "alpha_Q+" : jnp.array([-10,10]),
                "alpha_Q-" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            },
            "distance":{
                "initial_q" : jnp.array([-10,10,4]),
                "alpha_Q+" : jnp.array([-10,10]),
                "alpha_Q-" : jnp.array([-10,10]),
                "beta_Q" : jnp.array([-3,3]),
                "transition_alpha" :  jnp.array([-10,10]),
                "perception_sigma" : jnp.array([-3,3]),
                "gamma_generalize" : jnp.array([-3,3])
            }
        },
    "priors": {
            "angle":{
                "initial_q" : zero_one_uni,  # No priors on the individual biaises !
                "alpha_Q+" : zero_one_uni,
                "alpha_Q-" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "position":{
                "initial_q" : zero_one_uni,  # No priors on the individual biaises !
                "alpha_Q+" : zero_one_uni,
                "alpha_Q-" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            },
            "distance":{
                "initial_q" : zero_one_uni,  # No priors on the individual biaises !
                "alpha_Q+" : zero_one_uni,
                "alpha_Q-" : zero_one_uni,
                "beta_Q" : beta_Q_prior,
                "transition_alpha" : zero_one_uni,
                "perception_sigma" : zero_big_uni,
                "gamma_generalize" : zero_big_uni
            }
        },
    "tags": ["latql","initial_biais","assymetric"]
}


# And the mixed models : 
model_module = importlib.import_module("agents.agent_m-latQL+b_omegac")
MODEL_LIBRARY["m_latQL+b_omegac+fl"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=True),
    "ranges" : {
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "beta_biais" : jnp.array([-3,3]),
            "angle":{
                "biais" : jnp.array([-10,10,9])
            },
            "position":{
                "biais" : jnp.array([-10,10,9])
            },
            "distance":{
                "biais" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "beta_biais" : beta_biais_prior,
            "angle":{
                "biais" : zero_one_uni,
            },
            "position":{
                "biais" : zero_one_uni,
            },
            "distance":{
                "biais" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","static_biais","focused_action_selection","focused_learning","omegac"]
}

model_module = importlib.import_module("agents.agent_m-latQL+b_omegac")
MODEL_LIBRARY["m_latQL+b_omegac"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=False),
    "ranges" : {
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "beta_biais" : jnp.array([-3,3]),
            "angle":{
                "biais" : jnp.array([-10,10,9])
            },
            "position":{
                "biais" : jnp.array([-10,10,9])
            },
            "distance":{
                "biais" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "beta_biais" : beta_biais_prior,
            "angle":{
                "biais" : zero_one_uni,
            },
            "position":{
                "biais" : zero_one_uni,
            },
            "distance":{
                "biais" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","static_biais","focused_action_selection","omegac"]
}


model_module = importlib.import_module("agents.agent_m-latQL+b_omegac&b")
MODEL_LIBRARY["m_latQL+b_omegac&b+fl"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=True),
    "ranges" : {
            "initial_omega" : jnp.array([-10,10,3]),
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "beta_biais" : jnp.array([-3,3]),
            "angle":{
                "biais" : jnp.array([-10,10,9])
            },
            "position":{
                "biais" : jnp.array([-10,10,9])
            },
            "distance":{
                "biais" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "initial_omega" : zero_one_uni,
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "beta_biais" : beta_biais_prior,
            "angle":{
                "biais" : zero_one_uni,
            },
            "position":{
                "biais" : zero_one_uni,
            },
            "distance":{
                "biais" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","static_biais","initial_omega_biais","focused_action_selection","focused_learning","omegac"]
}

model_module = importlib.import_module("agents.agent_m-latQL+b_omegac&b")
MODEL_LIBRARY["m_latQL+b_omegac&b"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=False),
    "ranges" : {
            "initial_omega" : jnp.array([-10,10,3]),
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "beta_biais" : jnp.array([-3,3]),
            "angle":{
                "biais" : jnp.array([-10,10,9])
            },
            "position":{
                "biais" : jnp.array([-10,10,9])
            },
            "distance":{
                "biais" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "initial_omega" : zero_one_uni,
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "beta_biais" : beta_biais_prior,
            "angle":{
                "biais" : zero_one_uni,
            },
            "position":{
                "biais" : zero_one_uni,
            },
            "distance":{
                "biais" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","static_biais","initial_omega_biais","focused_action_selection","omegac"]
}

model_module = importlib.import_module("agents.agent_m-latQL&b_omegac")
MODEL_LIBRARY["m_latQL&b_omegac+fl"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=True),
    "ranges" : {
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "angle":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "position":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "distance":{
                "initial_q" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "angle":{
                "initial_q" : zero_one_uni,
            },
            "position":{
                "initial_q" : zero_one_uni,
            },
            "distance":{
                "initial_q" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","initial_biais","focused_action_selection","focused_learning","omegac"]
}

model_module = importlib.import_module("agents.agent_m-latQL&b_omegac")
MODEL_LIBRARY["m_latQL&b_omegac"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=False),
    "ranges" : {
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "angle":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "position":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "distance":{
                "initial_q" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "angle":{
                "initial_q" : zero_one_uni,
            },
            "position":{
                "initial_q" : zero_one_uni,
            },
            "distance":{
                "initial_q" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","initial_biais","focused_action_selection","omegac"]
}



model_module = importlib.import_module("agents.agent_m-latQL&b_omegac&b")
MODEL_LIBRARY["m_latQL&b_omegac&b+fl"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=True),
    "ranges" : {
            "initial_omega" : jnp.array([-10,10,3]),
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "angle":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "position":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "distance":{
                "initial_q" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "initial_omega" : zero_one_uni,
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "angle":{
                "initial_q" : zero_one_uni,
            },
            "position":{
                "initial_q" : zero_one_uni,
            },
            "distance":{
                "initial_q" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","initial_biais","initial_omega_biais","focused_action_selection","focused_learning","omegac"]
}

model_module = importlib.import_module("agents.agent_m-latQL&b_omegac&b")
MODEL_LIBRARY["m_latQL&b_omegac&b"] ={
    "model" : partial(model_module.agent,constants=MODEL_CONSTANTS,focused_learning=False),
    "ranges" : {
            "initial_omega" : jnp.array([-10,10,3]),
            "alpha_omega": jnp.array([-10,10]),
            "beta_omega" : jnp.array([-3,3]),
            "alpha_Q" : jnp.array([-10,10]),
            "beta_Q" : jnp.array([-3,3]),
            "transition_alpha" :  jnp.array([-10,10]),
            "perception_sigma" : jnp.array([-3,3]),
            "gamma_generalize" : jnp.array([-3,3]),
            "angle":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "position":{
                "initial_q" : jnp.array([-10,10,9])
            },
            "distance":{
                "initial_q" : jnp.array([-10,10,4])
            }
        },
    "priors": {
            "initial_omega" : zero_one_uni,
            "alpha_omega": zero_one_uni,
            "beta_omega" : beta_omega_prior,
            "alpha_Q" : zero_one_uni,
            "beta_Q" : beta_Q_prior,
            "transition_alpha" : zero_one_uni,
            "perception_sigma" : zero_big_uni,
            "gamma_generalize" : zero_big_uni,
            "angle":{
                "initial_q" : zero_one_uni,
            },
            "position":{
                "initial_q" : zero_one_uni,
            },
            "distance":{
                "initial_q" : zero_one_uni,
            }
        },
    "tags": ["latql","mixed","initial_biais","initial_omega_biais","focused_action_selection","omegac"]
}


In [26]:
# Check for duplicates :
namelist = list(MODEL_LIBRARY.keys())
print(len(namelist) == len(set(namelist)))

from simulate.compute_likelihood_full_actions import get_random_parameter_set



# Get a typical hyperparameter dictionnary for each model :
parameter_count_list = []
for model_name,model_contents in MODEL_LIBRARY.items():   
    parameter_set_tree = get_random_parameter_set(model_contents["ranges"],jr.PRNGKey(0))

    
    vls,_ = (jax.tree.flatten(tree_map(lambda x : x.size,parameter_set_tree)))    
    
    nbr,_ = (jax.tree.flatten(tree_map(lambda x : x.shape[0],parameter_set_tree)))    
    
    parameter_count_list.append([model_name,sum(vls),sum(nbr)])#"{}({:.1f} %)".format(sum(nbr),(sum(nbr)+1e-10)/(sum(vls)+1e-10)*100)])


print(tabulate(parameter_count_list, headers=['model name', '# of parameters \n(scalar values)','# of parameters \n(variables)'], tablefmt='orgtbl'))    

True
| model name            |   # of parameters  |   # of parameters  |
|                       |    (scalar values) |        (variables) |
|-----------------------+--------------------+--------------------|
| random                |                  0 |                  0 |
| i_latQL+b             |                 40 |                 21 |
| i_latQL&b             |                 37 |                 18 |
| i_latQLa+b            |                 43 |                 24 |
| i_latQLa&b            |                 40 |                 21 |
| m_latQL+b_omegac+fl   |                 30 |                 11 |
| m_latQL+b_omegac      |                 30 |                 11 |
| m_latQL+b_omegac&b+fl |                 33 |                 12 |
| m_latQL+b_omegac&b    |                 33 |                 12 |
| m_latQL&b_omegac+fl   |                 29 |                 10 |
| m_latQL&b_omegac      |                 29 |                 10 |
| m_latQL&b_omegac&b+fl |                 3