# Optimizing experimental design

To test whether participants rely on state abstraction for forward planning, we need to devise an experiment in which dynamic programing would be (presumably) too costly, such that participants have to rely on sub-optimal solution. At the same time, we need that experiment to enable us to discriminate between abstracted vs. ground MDP in participants behaviour. We will adjust the experimental parameters of the study by Ott and colleague. We have indeed seen that in that experiment, participants perform suboptimally, and also that we are able to discriminate quite well between different abstraction levels. However, we also found that the model without any abstraction still performed best. This might indicate that (1) participants do not rely on state abstraction to solve complex MDP or (2) that the size of the state space in that task wasn't large enough for participants to require to state abstraction, that they instead relied on the full MDP as well as some sort of heuristic as suggested in the original paper. 

To arbitrate between both options, we need to devise a task with a larger state space. We will do so by increasing the number of variables in the dimensions of our experiment. While we could do so blindly and hope for the best, we will instead optimize the experimental design for maximizing the distinction between abstraction level. Specifically, we will manipulate (1) the task parameters, (2) compute the state by state distance matrix of the resulting MDP, (3) compute decision values at various $\epsilon$ abstraction levels, (4) compare the decision values between different abstraction levels and select the task that leads to the largest difference. This should be the task that makes it easiest to identify the abstraction level that participants might be using

## 1. Generating tasks
First, we will generate different tasks with different parameters. We will primarily modulate the offers and costs levels to increase the state space size:

In [12]:
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
from scipy.stats import zscore
from stabst.utils import plot_state_matrix, state_classes_from_lbl, avg_reduce_mdp, abstract2ground_value
from scipy.special import expit
from stabst.MarkovDecisionProcess import MDP
from stabst.TaskConfig import LimitedEnergyTask
import matplotlib.pyplot as plt
from sklearn.cluster import AgglomerativeClustering
import arviz as az
import pymc as pm

task_parameters = [
    {
        'name': 'original',
        'O': [1, 2, 3, 4],
        'p_offer': [1/4] * 4,
        'C': [1, 2]
    },
    {
        'name': '6 offers',
        'O': [1, 2, 3, 4, 5, 6],
        'p_offer': [1/6] * 6,
        'C': [1, 2]
    },
    {
        'name': '8 offers',
        'O': [1, 2, 3, 4, 5, 6, 7, 8],
        'p_offer': [1/8] * 8,
        'C': [1, 2]
    },
    {
        'name': '10 offers',
        'O': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        'p_offer': [1/10] * 10,
        'C': [1, 2]
    },
    {
        'name': '3 costs',
        'O': [1, 2, 3, 4],
        'p_offer': [1/4] * 4,
        'C': [1, 2, 3]
    },
    {
        'name': '4 costs',
        'O': [1, 2, 3, 4],
        'p_offer': [1/4] * 4,
        'C': [1, 2, 3, 4]
    },
    {
        'name': '6 costs',
        'O': [1, 2, 3, 4],
        'p_offer': [1/4] * 4,
        'C': [1, 2, 3, 4, 5, 6]
    },
    {
        'name': '6 costs & 10 offers',
        'O': [1, 2, 3, 4, 6, 7, 8, 9, 10],
        'p_offer': [1/10] * 10,
        'C': [1, 2, 3, 4, 5, 6]
    },
]

tasks = {}
for tsk in task_parameters:
    tasks[tsk['name']] = LimitedEnergyTask(O=tsk['O'], p_offer=tsk['p_offer'], C=tsk['C'])
    tasks[tsk['name']].build()


## 2 & 3 Compute state by state distance and compute decision values

Now that we have generated the tasks, we can compute the distance matrix for each, abstract the MDPs based on the $\epsilon$ distance and compute the decision values. Then, we will compare the decision values using correlation:

In [None]:
# Set the abstraction levels to explore for each MDP:
abstraction_level = np.arange(0.001, 0.1, 0.001)
# Prepare dict to store the correlations:
DV_correlations = []
DV = {}
# Loop through each task:
for task_name, task in tasks.items():
    # Create MDP:
    task_mdp = MDP(task.states, task.tp, task.r, s2i=task.s2i)
    n_states = len(task.states)
    # Generate ground MDP decision values:
    _, Q_full = task_mdp.backward_induction()
    DV[0] = Q_full[:, 1] - Q_full[:, 0]
    # Compute distance matrix:
    if task_name == 'original':
        distances_matrix = np.load('../data/bids/limited_energy/derivatives/state_abstraction/bisimulation_distance_matrix.npy')
    else:
        distances_matrix = task_mdp.bisim_metric(gamma=0.99, tol=1e-3, njobs=-1, max_iters=1000)
    # Loop through each abstraction levels:
    for eps in abstraction_level:
        # Reduce the MDP accordingly:
        abstract_mdp, state_classes, class_of_state = task_mdp.distance_reduce_mdp(eps, distance_matrix=distances_matrix)
        n_states = len(abstract_mdp.states)
        # Solve the MDP:
        V_R, Q_R = abstract_mdp.backward_induction()
        # Project back to Ground space:
        V_from_abstract, Q_from_abstract = abstract2ground_value(class_of_state, V_R, Q_R)
        # Compute the decision values:
        DV[f"{eps:.3f}"] =  Q_from_abstract[:, 1] - Q_from_abstract[:, 0]
    # Compute the correlation between each pairs of decision values:
    corr_mat = np.zeros((len(DV), len(DV)))
    for i, eps1 in enumerate(abstraction_level):
        for ii, eps2 in enumerate(abstraction_level):
            corr_mat[i, ii] = pearsonr(DV[f"{eps1:.3f}"], DV[f"{eps2:.3f}"]).statistic
    DV_correlations.append(corr_mat)


  fig, ax = plt.subplots()
  0%|          | 2/1000 [1:19:11<569:53:42, 2055.73s/it] 

## 4 Visualization

Now that we have computed the difference in decision values associated with each abstraction level for each task, we can visualize them to figure out whether one task shows stronger differences between them. Specifically, what we are after is a task for which the difference in decision values decreases the most rapidly as a function of $\epsilon$:

In [None]:
for tsk_i, tsk_param in enumerate(task_parameters):
    # Plot the decision values:
    fig, ax = plt.subplots()
    im = ax.imshow(DV_correlations[tsk_i], 
                   extent=[abstraction_level[0], abstraction_level[-1], 
                           abstraction_level[0], abstraction_level[-1]],
                   cmap='RdYlBu_r');
    ax.set_xlabel('Epsilon')
    ax.set_ylabel('Epsilon')
    ax.set_title(tsk_param['name'])
    fig.colorbar(im, ax=ax, label="r")