In [1]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import plotly as pltly
import plotly.express as px
import plotly.graph_objects as go

from jax import vmap
import jax.numpy as jnp

# + local functions : 
from database_handling.database_extract import get_all_subject_data_from_internal_task_id
from utils import remove_by_indices
from analysis_tools.preprocess import get_preprocessed_data_from_df

from simulate.compute_likelihood import fit_map_agent,compute_loglikelihood

from agents import choice_kernel_agent,random_agent,rescorla_wagner_agent,rw_ck_agent,q_learning_agent,active_inference_basic_1D


from functools import partial
import tensorflow_probability.substrates.jax.distributions as tfd
import jax.random as jr
from jax import vmap

# Except subjects for predictors :
problematic_subjects_misc = ["611d60c383f4f70ff4bc99fd", # S2 : Did the task twice 
                             "66a74bdfdcaccdc0703894d5", # Subjects with consent revoked 
                            "667d92f2ea5c1542f417285d",
                            "6548f570022275786186ffbd"]

# Import the data from the remote mongodb database & the imported prolific demographics :
INTERNAL_TASK_ID = "003"
# Study 2 : the std of the was 0.025,0.1 and 0.175
PROLIFIC_STUDY_IDs = ["6703ab18d345eaa4893587e0","66f9aee8210357265a5958fc","6703ab1a7ea30557549dc6da"]

TASK_RESULTS_ALL = []
for prolific_study_id in PROLIFIC_STUDY_IDs:
    task_results = get_all_subject_data_from_internal_task_id(INTERNAL_TASK_ID,prolific_study_id,
                                                              process_feedback_data_stream=True,override_save=False)
    print(" - Loaded the task results for study {} \n    ({} subjects.)".format(prolific_study_id,len(task_results)))
    TASK_RESULTS_ALL += task_results
print("Total : {} subjects".format(len(TASK_RESULTS_ALL)))

# Each subject in task results has the following entries : 
# TASK_RESULT_FEATURES, TASK_RESULTS_EVENTS, TASK_RESULTS_DATA, TASK_RESULTS,RT_FB
remove_these_subjects = []
for index,entry in enumerate(TASK_RESULTS_ALL):
    subj_dict,_,_,_ = entry
    subj_name = subj_dict["subject_id"]
    if subj_name in problematic_subjects_misc:
        remove_these_subjects.append(index)

TASK_RESULTS = remove_by_indices(TASK_RESULTS_ALL,remove_these_subjects)
print(str(len(TASK_RESULTS)) + " subjects remaining after removing problematic subjects.")

  from .autonotebook import tqdm as notebook_tqdm


 - Loaded the task results for study 6703ab18d345eaa4893587e0 
    (49 subjects.)
 - Loaded the task results for study 66f9aee8210357265a5958fc 
    (50 subjects.)
 - Loaded the task results for study 6703ab1a7ea30557549dc6da 
    (50 subjects.)
Total : 149 subjects
145 subjects remaining after removing problematic subjects.


In [2]:
# The initial datframe is the first tuple in our task result list of tuples : 
subjects_df = pd.DataFrame([entry[0] for entry in TASK_RESULTS])

# Avoid too many categories : 
subjects_df['Sex'] = np.where(subjects_df['Sex'].isin(['Male','Female']), subjects_df['Sex'], 'Other')

category_counts = subjects_df['Nationality'].value_counts()
threshold = 2
subjects_df['Nationality_red'] = subjects_df['Nationality'].apply(lambda x: x if category_counts[x] >= threshold else 'Other')

# There was a single noise term for the whole training for each subject : 
subject_noise_parameters = [np.array(entry[2]["parameters"]["noise_int"])[0] for entry in TASK_RESULTS]

# We add it to the df : 
subjects_df["feedback_noise_std"] = subject_noise_parameters

# # Create a pandas dataframe from the list of subject dictionnaries :

# In this dataframe, we're interested in sotring various kinds of data from the trials : 
# 1/ Data from the instruction phase

# 2/ Data from the feedback gauge :
# Timestep values :
all_subject_scores = [subjdata[2]["scoring"] for subjdata in TASK_RESULTS]
subjects_df["raw_feedback_values"] = [subj_scores["feedback"] for subj_scores in all_subject_scores]
# Real time gauge values :
subjects_df["realtime_values"] = [subjdata[3][1] for subjdata in TASK_RESULTS] # Each element is a list of list os arrays (with varying shape)

# 3/ Data from the hidden grid :
# The grid for a specific trial: 
trial_grids = [entry[2]["process"]["grids"] for entry in TASK_RESULTS]
subjects_df["grid_layout"] = trial_grids
# Position value :
subject_positions = [entry[2]["process"]["positions"] for entry in TASK_RESULTS]
subjects_df["subject_positions"] = subject_positions

goal_positions = [np.array(entry[2]["parameters"]["goal_pos"])[:,0,:] for entry in TASK_RESULTS]
subjects_df["goal_position"] = goal_positions

def euclidian_distance(position,goal):
    return jnp.linalg.norm(position-goal,2)
gs = trial_grids[0][0].shape
maximum_euclidian_dist = euclidian_distance(jnp.array(gs) - jnp.ones((2,)),jnp.zeros((2,)))
all_euclidian_distances = vmap(vmap(vmap(euclidian_distance,in_axes=(0,None))))(jnp.array(subject_positions),jnp.array(goal_positions))/maximum_euclidian_dist
subjects_df["norm_distance_to_goal"] = list(all_euclidian_distances)


# 4/ Data from the realized actions :

# Actions performed : this encompasses the points dropped
# But may also include temporal elements such as :
# - the time taken to perform an actions (first point / second point)
# - when the action was performed with regard to the gauge
canvas_size = TASK_RESULTS[0][0]["canvas_size"] # Constant across all subjects + conditions
all_actions_data = np.stack([subjdata[2]["blanket"]["actions"] for subjdata in TASK_RESULTS]).astype(float)

Nsubj,Ntrials,Nactions,Npoints,Nfeatures = all_actions_data.shape
# print(all_actions_data)
# Normalize the point data :
all_actions_data[...,0] = all_actions_data[...,0]/canvas_size[0]
all_actions_data[...,1] = 1.0 - all_actions_data[...,1]/canvas_size[1]


# First, let's get a mask for all actions that were NOT performed :
mask = all_actions_data[...,-1]==1  # values are 1 if the point was recorded
both_points_only = (mask[...,0] & mask[...,1])
     # All points where at least one value is missing

Nactions = all_actions_data[...,0,0].size
Nmissed_actions = (~both_points_only).sum()
print("A total of {}/{} actions were missed. ({:.2f} %)".format(Nmissed_actions,Nactions,100*Nmissed_actions/Nactions))

subjects_df["raw_points"] = list(all_actions_data)


# Encoded barycenters :
barycenter_x = (all_actions_data[...,0,0]+all_actions_data[...,1,0])/2.0
barycenter_y = (all_actions_data[...,0,1]+all_actions_data[...,1,1])/2.0
barycenters = np.stack([barycenter_x,barycenter_y],axis=-1)
subjects_df["action_barycenters"] = list(barycenters)

# Encoded euclidian distance between points :
action_distances = np.linalg.norm(all_actions_data[...,0,:2]-all_actions_data[...,1,:2],axis=-1)
subjects_df["action_distances"] = list(action_distances)

# Encoded evolution of point angles :
angles = np.atan2(all_actions_data[...,1,1]-all_actions_data[...,0,1],all_actions_data[...,1,0]-all_actions_data[...,0,0])
subjects_df["action_angles"] = list(angles)

# Encoded delays between stimuli, point1 and point2 :
all_action_delays = all_actions_data[...,-1,2]
unfit_actions = (all_action_delays<10)
subjects_df["action_time_between_points"] = np.where(all_action_delays>10, all_action_delays, np.nan).tolist()

# Performance metric : we use the average distance to goal state across the lask k_T trials and the last k_t timesteps : (ignoring the blind trial)
last_k_trials,last_t_timesteps = 3,5
all_distances_to_goal = np.mean(np.stack(subjects_df["norm_distance_to_goal"])[:,-last_k_trials:-1,-last_t_timesteps:],axis=(-1,-2))
subjects_df["final_performance"] = (1.0 - all_distances_to_goal).tolist()


# And for the blind trial :
blind_trial_distances_to_goal = np.mean(np.stack(subjects_df["norm_distance_to_goal"])[:,-1,-last_t_timesteps:],axis=(-1))
subjects_df["blind_trial_performance"] = (1.0 - blind_trial_distances_to_goal).tolist()

A total of 33/15950 actions were missed. (0.21 %)


In [10]:
# Preprocess the data according to the wanted hyperparameters :
n_bins_feedback = 10
n_bins_action_angle = 8
n_bins_action_position_per_dim = 3
Nsteps = 250
Nheads = 50
head_init_window = [-20,20]

rngkey = jr.PRNGKey(20)

# If the points were too close, no angle was recorded :
# The limit was arbitrrarily chosen at 7.5 :
min_dist_norm = 7.5/(np.sqrt(2)*750)
preprocessing_options = {
    "actions":{
        "distance_bins" : np.array([0.0,min_dist_norm,0.2,0.5,jnp.sqrt(2) + 1e-10]),
        "angle_N_bins"  : n_bins_action_angle,
        "position_N_bins_per_dim" : n_bins_action_position_per_dim
    },
    "observations":{
        "N_bins" : n_bins_feedback,
        "observation_ends_at_point" : 2
    }
}
data = get_preprocessed_data_from_df(subjects_df,
                            preprocessing_options,
                            verbose=True,
                            autosave=True,autoload=True,override_save=True,
                            label="default")



# Get the data for one subject : 
subj_id = 10


# The data to invert :
formatted_stimuli= [data["observations"]["vect"][1][subj_id]]
bool_stimuli = [jnp.ones_like(stim[...,0]) for stim in formatted_stimuli] # Everything was seen :)
rewards = jnp.array(data["observations"]["deltas"])[subj_id]
Ntrials,Nobservations = rewards.shape
timesteps = jnp.broadcast_to(jnp.arange(Nobservations),(Ntrials,Nobservations))
vect_actions = {}
for key,val in data["actions"]["vect"].items():
    vect_actions[key] = jnp.array(val[subj_id])[:-1,:,:] # Ignore the last trial


def get_models_from_options(_Nactions,_Nbins_feedback):
    aif_th = 2  # Temporal horizon, big values will cause ballooning 
            # compilation times
    Ns_latent = 5 # How many latent state in the subject model ?

    aif_1D_static_params = {
        # General environment : 
        "N_feedback_ticks":_Nbins_feedback,
        # Latent state space structure
        "Ns_latent":Ns_latent,      # For 1D
        # Action discretization:
        "N_actions" : _Nactions,
        
        "Th" : aif_th,
        
        "learn_e" : True
    }

    # Let's assume the following priors regarding the general models of the subjects : 
    # Overall uniform, with the following (weakly) informed priors :
    # - high reward seeking
    # - high action selection (inverse) temperature
    # - low initial action concentration and stickiness (we focus only on the learning rate here)
    # - high initial perception concentration
    aif_1D_priors = {
        # ----------------------------------------------------------------------------------------------------
        # Model parameters : these should interact with the model components in a differentiable manner
        "transition_concentration": tfd.Normal(1.0,0.5), # Initial concentrations should be rather low
        "transition_stickiness": tfd.Normal(1.0,0.5),
        "transition_learning_rate" : tfd.Uniform(low=-0.01,high=1000.0),
        "state_interpolation_temperature" : tfd.Uniform(low=-0.01,high=10000.0),
        
        "initial_state_concentration": tfd.Uniform(low=-0.01,high=1000.0),
        
        "feedback_expected_std" : tfd.Uniform(low=-0.0,high=1000.0),
        "emission_concentration" : tfd.Uniform(low=-0.01,high=1000.0),
        "emission_stickiness" :  tfd.Normal(100.0,10.0),
        
        "reward_seeking" :tfd.Normal(10.0,10.0),
        "habits_learning_rate" : tfd.Uniform(low=-0.01,high=1000.0),
        
        "action_selection_temperature" : tfd.Normal(10.0,1.0),
    }
    
    comparison_index = {
        "random":{
            "model" : partial(random_agent,constants=(_Nactions,)),
            "priors": None
        },
        "rw":{
            "model" : partial(rescorla_wagner_agent,constants=(_Nactions,)),
            "priors" : (tfd.Uniform(low=0.0,high=1.0),tfd.Normal(10.0,5.0))
        },
        "ck":{
            "model" : partial(choice_kernel_agent,constants=(_Nactions,)),
            "priors" : (tfd.Uniform(low=0.0,high=1.0),tfd.Normal(10.0,5.0))
        },
        "rw_ck":{
            "model":partial(rw_ck_agent,constants=(_Nactions,)),
            "priors": (tfd.Uniform(low=0.0,high=1.0),tfd.Normal(5.0,10.0),tfd.Uniform(low=0.0,high=1.0),tfd.Normal(5.0,10.0))
        },
        "naive_qlearning":{
            "model":partial(q_learning_agent,constants=(_Nactions,_Nbins_feedback)),
            "priors":(tfd.Uniform(low=0.0,high=1.0),tfd.Uniform(low=0.0,high=1.0),tfd.Normal(10.0,5.0),tfd.Uniform(low=0.0,high=1.0),tfd.Normal(10.0,5.0))
        },
        "aif_1d":{
            "model":partial(active_inference_basic_1D,constants=aif_1D_static_params),
            "priors":aif_1D_priors
        }
    }
    return comparison_index

subject_results_dict = {}
for action_modality in ["angle"]:
    modality_results_dict = {}
    
    
    
    subject_data = (formatted_stimuli,bool_stimuli,rewards,vect_actions[action_modality],timesteps)
    
    model_index_for_action_modality = get_models_from_options(vect_actions[action_modality].shape[-1],n_bins_feedback)

    for agent_name, agent_contents in model_index_for_action_modality.items():
        print("     -> Agent : {}".format(agent_name))
        
        agent = agent_contents["model"]
        agent_priors = agent_contents["priors"]
        
        # Agent functions
        _,_,_,_,_,encoder = agent(None)
        def fit_agent(_data_one_subject,_fit_rng_key):  
            
            if agent_priors is None:
                _n_params = 0
                
                _loss_history = None
                _opt_vectors = None
                
                # We then compute the log-likelihoods of each solution for model comparison purposes :
                _lls =  compute_loglikelihood(_data_one_subject,agent(None),'sum')
                
            else :
                _n_params = len(agent_priors)
                
                # Multi-iteration based MAP : (we randomize the initial point and try to find minimas)
                _opt_vectors,(_,_loss_history,_param_history),_encoding_function = fit_map_agent(_data_one_subject,agent,
                                                                        _n_params,agent_priors,
                                                                        _fit_rng_key,
                                                                        true_hyperparams=None,verbose=True,                       
                                                                        num_steps=Nsteps,n_iter=Nheads,initial_window=head_init_window)
                
                # We then compute the log-likelihoods of each solution for model comparison purposes :
                _lls =   vmap(lambda x : compute_loglikelihood(_data_one_subject,agent(_encoding_function(x)),'sum'))(_opt_vectors)
                    
            return _loss_history,_opt_vectors,_lls

        rngkey,local_key = jr.split(rngkey)
        loss_histories,best_params,lls = fit_agent(subject_data,_fit_rng_key=local_key)

        modality_results_dict[agent_name] = {
            "losses_hist" : loss_histories,
            "params" : best_params,
            "logliks" : lls
        }
        
        subject_results_dict[action_modality] = modality_results_dict

Out of the 15950.0 actions performed by our subjects, 15917.0 were 'valid' (99.8 %)
Out of the 15950.0 feedback sequences potentially observed by our subjects, 15950 were 'valid' (100.0 %)
     -> Agent : random
     -> Agent : rw
step 0, loss: Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
  val = Array([197.32918, 212.4399 , 212.47246, 212.47235, 210.83157, 211.8431 ,
       198.66756, 212.3359 , 277.167  , 205.005  , 200.20853, 210.50194,
       212.123  , 212.36821, 212.47244, 212.47005, 212.47234, 212.47234,
       212.4724 , 199.39818, 211.75491, 277.3604 , 212.47237, 210.6917 ,
       212.47202, 210.63918, 212.45895, 212.24612, 211.23232, 212.47246,
       212.47246, 212.42912, 212.24966, 212.47246, 210.57492, 212.47246,
       206.48024, 211.1489 , 211.96234, 212.38484, 212.47202, 212.47174,
       206.35751, 210.49788, 211.5727 , 212.47244, 212.39227, 227.00682,
       254.67734, 211.65617], dtype=float32)
  batch_dim = 0
step 10, loss: Traced<ShapedArray(float

In [11]:
print(subject_results_dict["angle"]["aif_1d"]["logliks"])

print(subject_results_dict["angle"]["naive_qlearning"]["logliks"])

print(subject_results_dict["angle"]["rw_ck"]["logliks"])

print(subject_results_dict["angle"]["rw"]["logliks"])

print(subject_results_dict["angle"]["ck"]["logliks"])

[-190.67297 -190.4463  -198.42824 -198.43571 -197.00725 -197.11844
 -200.03658 -199.439   -190.08711 -197.05237 -200.00706 -189.83437
 -198.5281  -189.93663 -202.4051  -200.00961 -190.6494  -198.38045
 -190.04105 -196.11288 -190.26349 -190.18652 -189.99075 -199.66237
 -190.7037  -200.0065  -190.62558 -199.04529 -197.21788 -190.68213
 -190.8552  -190.62968 -200.00638 -199.84148 -197.32658 -190.70375
 -196.25739 -197.59981 -204.78429 -190.30774 -198.54152 -192.90549
 -190.02933 -199.9868  -199.833   -192.03499 -190.36116 -199.22256
 -190.09683 -189.99873]
[-173.36174 -183.86942 -173.34084 -194.32878 -173.38957 -174.60852
 -173.28198 -186.46883 -173.36511 -173.3815  -173.36682 -173.37291
 -173.28008 -173.27531 -173.3555  -173.34781 -173.37268 -184.08997
 -173.36621 -173.37093 -173.28156 -173.3187  -173.30699 -184.08867
 -173.26859 -183.86926 -173.3474  -194.0538  -185.19997 -173.30447
 -173.27118 -173.27087 -173.2713  -173.32265 -173.26959 -173.36615
 -173.39299 -183.87285 -173.27669 -173

In [12]:
md = get_models_from_options(9,n_bins_feedback)

_,_,_,_,_,encoder = md["aif_1d"]["model"](None)

print(subject_results_dict["angle"]["aif_1d"]["params"].shape)
for i in range(10):
    print(encoder(subject_results_dict["angle"]["aif_1d"]["params"][i]))

(50, 11)
{'transition_concentration': Array(1.2071325, dtype=float32), 'transition_stickiness': Array(1.453517, dtype=float32), 'transition_learning_rate': Array(8.594912e-08, dtype=float32), 'state_interpolation_temperature': Array(5.1480413e-08, dtype=float32), 'initial_state_concentration': Array(0.0018086, dtype=float32), 'feedback_expected_std': Array(0.9601711, dtype=float32), 'emission_concentration': Array(3.4872937, dtype=float32), 'emission_stickiness': Array(38.809868, dtype=float32), 'reward_seeking': Array(10.001418, dtype=float32), 'habits_learning_rate': Array(3.4734664, dtype=float32), 'action_selection_temperature': Array(0.62299526, dtype=float32)}
{'transition_concentration': Array(3.0308037, dtype=float32), 'transition_stickiness': Array(1.0000764, dtype=float32), 'transition_learning_rate': Array(0.01039839, dtype=float32), 'state_interpolation_temperature': Array(0.03547411, dtype=float32), 'initial_state_concentration': Array(1.7511253e-08, dtype=float32), 'feedb

Improve the performances of the AIF agent : 
- Implement habit learning (expect a big difference for subjects like n°0)
- Test other preprocessing options
- Implement temporary model updates during a trial

Overall, it matches our expectations that a more model-less approach fits well simple learning paradigms, with a big nudge towards recurring behaviour (if you have chosen a specific action, you will be more likly to pick it again next time). What would be interestting would be :
- a. Assuming some initial parameters, see how long each model takes to predict things correctly
- b. Simulate the behaviour of these fitted agents given a new end position at the middle of the training.