<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Sample-Training-Stimuli" data-toc-modified-id="Sample-Training-Stimuli-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Sample Training Stimuli</a></span></li><li><span><a href="#Visualize-Conditions" data-toc-modified-id="Visualize-Conditions-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Visualize Conditions</a></span></li><li><span><a href="#Run-Function-Learning-Task" data-toc-modified-id="Run-Function-Learning-Task-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Run Function-Learning Task</a></span></li><li><span><a href="#Save-Simulated-Data" data-toc-modified-id="Save-Simulated-Data-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Save Simulated Data</a></span></li><li><span><a href="#Plot-Movements-in-Space" data-toc-modified-id="Plot-Movements-in-Space-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Plot Movements in Space</a></span></li><li><span><a href="#True-y-Value-vs.-Size-of-Change" data-toc-modified-id="True-y-Value-vs.-Size-of-Change-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>True y Value vs. Size of Change</a></span></li><li><span><a href="#Local-Change-vs.-Size-of-Change" data-toc-modified-id="Local-Change-vs.-Size-of-Change-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Local Change vs. Size of Change</a></span></li><li><span><a href="#Prediction-Uncertainty-vs.-Size-of-Change" data-toc-modified-id="Prediction-Uncertainty-vs.-Size-of-Change-8"><span class="toc-item-num">8&nbsp;&nbsp;</span>Prediction Uncertainty vs. Size of Change</a></span></li><li><span><a href="#Correlations-Between-Movements-in-Different-Conditions" data-toc-modified-id="Correlations-Between-Movements-in-Different-Conditions-9"><span class="toc-item-num">9&nbsp;&nbsp;</span>Correlations Between Movements in Different Conditions</a></span></li><li><span><a href="#Plot-Changes-in-Uncertainty" data-toc-modified-id="Plot-Changes-in-Uncertainty-10"><span class="toc-item-num">10&nbsp;&nbsp;</span>Plot Changes in Uncertainty</a></span></li></ul></div>

In [None]:
from functools import reduce, partial
from itertools import repeat
from datetime import date
import sys
import multiprocessing
import pickle

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.optimize import minimize

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from tqdm.notebook import tqdm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

sys.path.append("..")

In [None]:
# home-grown modules
import py_utils.utils as utils
import py_utils.plotting as plot_utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
plt.style.use("seaborn-whitegrid")
sns.set_style("whitegrid")

In [None]:
plt.style.use("presentation.mplstyle")

In [None]:
dict_variables = dict({
    "condition" : ["smooth", "rough"],
    "sampling_strategy": ["stimulus", "trial"],
    "prior_sd" : [.5],
    "sampling" : ["improvement"], # "metropolis-hastings", 
    "constrain_space" : [False], # True, 
    "space_edge_min" : [0],
    "space_edge_max" : [12],
    "n_features" : [2],
    "n_training" : [int(12**2/4)],
    "n_samples_block" : [200],
    "beta_softmax": [1, 10],
    "n_runs" : [10]
})

In [None]:
df_info, l_info, l_titles = utils.simulation_conditions(dict_variables)

In [None]:
df_info

# Sample Training Stimuli

In [None]:
make_stimuli_partial = partial(utils.make_stimuli, map_to_reward=True)
l_df_xy = list(map(make_stimuli_partial, l_info))

# Visualize Conditions

In [None]:
plot_utils.plot_heatmaps(l_info, map_to_reward=True)

In [None]:
idx0 = 0
idx1 = 4
float_formatter = "{:.5f}".format
fig, axes = plt.subplots(nrows=2,ncols=2,figsize=(20, 16), sharex=False, sharey="col",
                  gridspec_kw={"height_ratios":[1, 1]})
axes_flat = axes.flatten()
l_df_xy[idx0], l_ivs, scaler = utils.scale_ivs(l_df_xy[idx0])
l_df_xy[idx1], l_ivs, scaler = utils.scale_ivs(l_df_xy[idx1])
df_train1, df_test1 = utils.split_train_test(l_info[idx0], l_df_xy[idx0])
df_train2, df_test2 = utils.split_train_test(l_info[idx1], l_df_xy[idx1])
# smooth
df_test1 = plot_utils.uncertainty_on_test_data(
    df_train1, df_test1, l_info[idx0], l_ivs
)
# rough
df_test2 = plot_utils.uncertainty_on_test_data(
    df_train2, df_test2, l_info[idx1], l_ivs
)
min_val = min(df_test1["y_pred_sd"].min(), df_test2["y_pred_sd"].min())
max_val = max(df_test1["y_pred_sd"].max(), df_test2["y_pred_sd"].max())

plot_utils.plot_uncertainty_on_test_data(df_test1, axes_flat[0:2], show_colorbar=False, min_val=min_val, max_val=max_val)
plot_utils.plot_uncertainty_on_test_data(df_test2, axes_flat[2:4], show_colorbar=False, min_val=min_val, max_val=max_val)


norm = plt.Normalize(
    min(df_test1["y_pred_sd"].min(), df_test2["y_pred_sd"].min()), 
    max(df_test1["y_pred_sd"].max(), df_test2["y_pred_sd"].max())
)
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
axins = inset_axes(axes_flat[3],
                    width="100%",  
                    height="5%",
                    loc='lower center',
                    borderpad=-7
                   )
_ = fig.colorbar(sm, cax=axins, orientation="horizontal")
l_df_xy[idx0].drop([f"""{iv}_z""" for iv in l_ivs], axis=1, inplace=True)
l_df_xy[idx1].drop([f"""{iv}_z""" for iv in l_ivs], axis=1, inplace=True)

xt0 = [float_formatter(t) for t in axes_flat[0].get_xticks()]
axes_flat[0].xaxis.set_ticks([float(xv) for xv in xt0])
axes_flat[0].set_xticklabels(xt0, rotation=90)
xt2 = [float_formatter(t) for t in axes_flat[2].get_xticks()]
axes_flat[2].xaxis.set_ticks([float(xv) for xv in xt2])
_ = axes_flat[2].set_xticklabels(xt2, rotation=90)

plt.tight_layout()

# Run Function-Learning Task

In [None]:
n_cpus = min(len(l_info), multiprocessing.cpu_count() - 2)
p = multiprocessing.Pool(n_cpus)
#f_partial = partial(utils.run_perception, l_df_xy=l_df_xy)
# multiprocessing Pool map does not accept partial functions
list_perception = []
for x in tqdm(p.starmap(utils.run_perception_pairs, zip(l_info, l_df_xy))):
    list_perception.append(x)

In [None]:
list_df_new_xpos = list()
list_df_reward = list()
list_df_trials = list()
list_df_new_xpos = [list_perception[idx][0].reset_index() for idx in range(len(list_perception))]
list_df_reward = [list_perception[idx][1] for idx in range(len(list_perception))]
list_df_trials = [list_perception[idx][2] for idx in range(len(list_perception))]

# Save Simulated Data

In [None]:
str_today = str(date.today())

with open(f"""../data/{str_today}-reward-learning-dont-refit-gp.pickle""", "wb") as f:
    pickle.dump(list_perception, f)

with open(f"""../data/{str_today}-reward-learning-dont-refit-gp.pickle""", "rb") as f:
    list_perception = pickle.load(f)

- think about leaving in prior points when plotting movements (prior & likelihood...) as I did that in the category learning part

# Plot Movements in Space

In [None]:
f_partial = partial(plot_utils.plot_moves_one_condition, list_dfs_new=list_df_new_xpos, df_info=df_info)
l_idxs = df_info.index[0:4].to_list()
f, axes = plt.subplots(1, 4, figsize=(30, 8))
_ = list(map(f_partial, l_idxs, l_titles[0:4], axes.flatten()))
plt.tight_layout()
# plt.savefig(f"""../figures/{str_today}-func-learning-smooth-movements-refit.png"""as)

In [None]:
f_partial = partial(plot_utils.plot_moves_one_condition, list_dfs_new=list_df_new_xpos, df_info=df_info)
l_idxs = df_info.index[4:8].to_list()
f, axes = plt.subplots(1, 4, figsize=(30, 8))
_ = list(map(f_partial, l_idxs, l_titles[4:8], axes.flatten()))
plt.tight_layout()
# plt.savefig(f"""../figures/{str_today}-func-learning-rough-movements-refit.png""")

# True y Value vs. Size of Change

In [None]:
list_df_new_xpos = list(map(utils.add_max_gradient, tqdm(list_df_new_xpos)))

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(25, 12.5))
axes_flat = axes.flatten()
_ = list(map(plot_utils.regplot_y, list_df_new_xpos, l_titles, axes_flat))
plt.tight_layout()

# Local Change vs. Size of Change

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(25, 12.5))
axes_flat = axes.flatten()
_ = list(map(plot_utils.regplot_max_gradient, list_df_new_xpos, l_titles, axes_flat))
plt.tight_layout()

# Prediction Uncertainty vs. Size of Change

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(25, 12.5))
axes_flat = axes.flatten()
_ = list(map(plot_utils.regplot_start_uncertainty, list_df_new_xpos, l_titles, axes_flat))
plt.tight_layout()
str_note = "Note. y axes differ between panels"
_ = plt.figtext(0.99, -0.02, str_note, horizontalalignment='right', size=25)

# Correlations Between Movements in Different Conditions

In [None]:
l_df_movements = list(map(utils.add_angle_of_movements, list_df_new_xpos))

In [None]:
def correlation_between_movement_angles(idx1: int, idx2:int, l_df_movements: list) -> float:
    df_tmp = pd.merge(l_df_movements[idx1], l_df_movements[idx2], how="inner", left_index=True, right_index=True)
    return np.corrcoef(df_tmp["angle_x"], df_tmp["angle_y"])[0, 1]

In [None]:
df1 = pd.Series(range(len(l_df_movements)), name="Condition 1")
df2 = pd.Series(range(len(l_df_movements)), name="Condition 2")
df_cross = pd.merge(df1, df2, how="cross")
f_partial = partial(correlation_between_movement_angles, l_df_movements = l_df_movements)
df_cross["correlation"] = list(map(f_partial, df_cross["Condition 1"].to_list(), df_cross["Condition 2"].to_list()))

In [None]:
filt_info = (df_info["constrain_space"] == False)#(df_info["prior_sd"] == 1)# & 
idxs_required = df_info.loc[filt_info].index.to_list()#
filt = (df_cross["Condition 1"].isin(idxs_required)) & (df_cross["Condition 2"].isin(idxs_required))
df_cross = df_cross.loc[filt, ].copy()

In [None]:
df_cross["corr_above_0"] = df_cross["correlation"] > 0
prop_above_0 = (
    df_cross.query("`Condition 1` != `Condition 2`")["corr_above_0"].sum() / 
    df_cross.query("`Condition 1` != `Condition 2`").shape[0]
).round(2)
print("Proportion of Correlations > 0 between Movement Directions in Simulation Conditions: ", prop_above_0)

In [None]:
f, ax = plt.subplots(1, 1, figsize=(10, 8))
sns.heatmap(
    df_cross.pivot(index="Condition 1", columns="Condition 2", values="correlation").round(2), 
    annot=True, annot_kws={"size":15}, linewidths=1, cmap="vlag", vmin = -1, vmax = 1, ax=ax
)
_ = ax.set_title("Pairwise Correlations Between Movement Directions in Simulation Conditions")
plt.tight_layout()
# plt.savefig(f"""../figures/{str_today}-func-learning-pairwise-correlations-constrained-refit.png""")

# Plot Changes in Uncertainty

In [None]:
stim_id = 133 # but then 132 shows opposite pattern
l_idxs = [2, 6, 10, 14]
#l_idxs = [0, 4, 8, 12]
#l_idxs = [i + 1 for i in l_idxs]
l_idxs = range(0, 8)
n_plots = len(l_idxs)
n_cols = 4
n_rows = int(np.ceil(n_plots / n_cols))
f, axes = plt.subplots(n_cols, n_rows, figsize=(n_cols*4, n_rows*7), sharex="col", sharey="col")
plot_utils.plot_gp_deviations(axes, l_idxs, list_df_new_xpos, l_titles) # , stim_id
plt.suptitle("Model Deviation from True Points after Testing Period\n", size=20)
plt.tight_layout()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(25, 9), sharex=True, sharey=True)
f_partial = partial(plot_utils.plot_proportion_accepted_samples, n_runs=l_info[0]["n_runs"])
list(map(f_partial, list_df_new_xpos, l_titles, axes.flatten()))
_ = fig.suptitle("Proportion of Accepeted Samples Over Course of Function-Learning Task\n")
plt.tight_layout()

In [None]:
for i in list_df_trials:
    i["r_sq"] = (i["y"] - i["y_pred_mn"])**2

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(25, 18), sharex=True, sharey=True)
plot_uncertainty_over_test_partial = partial(plot_utils.plot_var_over_bintime, var="r_sq")
list(map(plot_uncertainty_over_test_partial, list_df_trials, l_titles, axes.flatten()))
_ = fig.suptitle("R Squared Against Test Trial\n")
plt.tight_layout()
str_note = "Note. y axes differ between panels"
# _ = plt.figtext(0.99, -0.02, str_note, horizontalalignment='right', size=25)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(25, 18), sharex=True, sharey=True)
plot_uncertainty_over_test_partial = partial(plot_utils.plot_sum_over_bintime, var="reward")
list(map(plot_uncertainty_over_test_partial, list_df_reward, l_titles, axes.flatten()))
_ = fig.suptitle("Reward Against Test Trial\n")
plt.tight_layout()
str_note = "Note. y axes differ between panels"
# _ = plt.figtext(0.99, -0.02, str_note, horizontalalignment='right', size=25)

*TODOs*
- plot movements over original y function values aka heatmap
