In [None]:
from functools import reduce, partial
from itertools import repeat
import sys
import multiprocessing
import pickle

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from tqdm.notebook import tqdm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

sys.path.append("..")

In [None]:
# home-grown modules
import py_utils.utils as utils
import py_utils.plotting as plot_utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
plt.style.use("seaborn-whitegrid")
sns.set_style("whitegrid")

In [None]:
plt.style.use("presentation.mplstyle")

In [None]:
dict_variables = dict({
    "condition" : ["smooth", "rough"],
    "prior_sd" : [.1, 1],
    "sampling" : ["metropolis-hastings", "improvement"],
    "constrain_space" : [True, False],
    "space_edge_min" : [0],
    "space_edge_max" : [12],
    "n_features" : [2],
    "n_training" : [int(12**2/4)],
    "n_runs" : [1000]
})

In [None]:
df_info, l_info = utils.simulation_conditions(dict_variables)

# Visualize Conditions

In [None]:
plot_utils.plot_heatmaps(l_info)

In [None]:
plt.style.use("presentation.mplstyle")
plot_utils.plot_1d_waves(l_info)

# Sample Training Stimuli

In [None]:
l_df_xy = list(map(utils.make_stimuli, l_info))

In [None]:
l_df_xy[0]

In [None]:
df_info

In [None]:
fig, axes = plt.subplots(nrows=2,ncols=2,figsize=(20, 16), sharex="col", sharey="col",
                  gridspec_kw={"height_ratios":[1, 1]})
axes_flat = axes.flatten()
l_df_xy[0], l_ivs, scaler = utils.scale_ivs(l_df_xy[0])
l_df_xy[8], l_ivs, scaler = utils.scale_ivs(l_df_xy[8])
df_train1, df_test1 = utils.split_train_test(l_info[0], l_df_xy[0])
df_train2, df_test2 = utils.split_train_test(l_info[8], l_df_xy[8])
plot_utils.uncertainty_on_test_data(df_train1, df_test1, l_ivs, axes_flat[0:2], show_colorbar=False) # smooth
plot_utils.uncertainty_on_test_data(df_train2, df_test2, l_ivs, axes_flat[2:4], show_colorbar=False) # rough
norm = plt.Normalize(
    min(df_test1["y_pred_sd"].min(), df_test2["y_pred_sd"].min()), 
    max(df_test1["y_pred_sd"].max(), df_test2["y_pred_sd"].max())
)
sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
sm.set_array([])
axins = inset_axes(axes_flat[3],
                    width="100%",  
                    height="5%",
                    loc='lower center',
                    borderpad=-7
                   )
_ = fig.colorbar(sm, cax=axins, orientation="horizontal")
l_df_xy[0].drop([f"""{iv}_z""" for iv in l_ivs], axis=1, inplace=True)
l_df_xy[8].drop([f"""{iv}_z""" for iv in l_ivs], axis=1, inplace=True)

*TO CONSIDERs*
- proportion of datapoints shown to participants during training (currently .25)

Metropolis-Hastings Implementation
1. compare the deviation on the current trial to the distribution of expected deviations from y_pred_sd
2. uniformly sample a value between 0 and 1
3. accept the sample if that value is smaller than the ratio in 1.

In [None]:
n_cpus = min(len(l_info), multiprocessing.cpu_count() - 2)
p = multiprocessing.Pool(n_cpus)
#f_partial = partial(utils.run_perception, l_df_xy=l_df_xy)
# multiprocessing Pool map does not accept partial functions
list_dfs_new = []
for x in tqdm(p.starmap(utils.run_perception, zip(l_info, l_df_xy))):
    list_dfs_new.append(x)

In [None]:
df_info

In [None]:
stim_id = 133 # but then 132 shows opposite pattern
l_idxs = [2, 6, 10, 14]
l_idxs = [0, 4, 8, 12]
n_plots = len(l_idxs)
n_cols = 2
n_rows = int(np.ceil(n_plots / n_cols))
f, axes = plt.subplots(n_cols, n_rows, figsize=(n_rows*4, 8), sharex="col", sharey="col")
plot_utils.plot_gp_deviations(axes, l_idxs, list_dfs_new, df_info) # , stim_id
plt.suptitle("Model Deviation from True Points after Testing Period", size=20)
plt.tight_layout()

In [None]:
list_dfs_new[0]["x_1_sample"].min()

# Save Simulated Data

In [None]:
str_today = str(date.today())

with open("../data/str_today-refit-gp.pickle", "wb") as f:
    pickle.dump(list_dfs_new, f)

In [None]:
with open("../data/2021-11-12-refit-gp-1000-samples.pickle", "rb") as f:
    list_dfs_new = pickle.load(f)