In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks
import numpy as np
import os

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
exp = 22
base_path = os.path.join(os.environ['HOME'], f'fits/neurobiases/exp{exp}')

In [None]:
# Hyperparameter 1
n_coupling_locs = 5
coupling_loc_min = -1
coupling_loc_max = 1
coupling_locs = np.linspace(coupling_loc_min, coupling_loc_max, n_coupling_locs)
# Hyperparameter 2
n_tuning_locs = 5
tuning_loc_min = -1
tuning_loc_max = 1
tuning_locs = np.linspace(tuning_loc_min, tuning_loc_max, n_tuning_locs)
# Hyperparameter 3
n_models = 3
# Hyperparameter 4
n_datasets = 10

In [None]:
N = 10
M = 10
K = 1
D = 2000
n_folds = 3

In [None]:
f = h5py.File(os.path.join(base_path, 'exp22_0_0_0_0.h5'), 'r')

In [None]:
shape_tuple = (n_coupling_locs, n_tuning_locs, n_models, n_datasets)
n_cvs = 3
a_true = np.zeros(shape_tuple + (n_cvs, N))
a_est_tm = np.zeros(shape_tuple + (n_cvs, N))
a_est_tc = np.zeros(shape_tuple + (n_cvs, N,))
b_true = np.zeros(shape_tuple + (n_cvs, M))
b_est_tm = np.zeros(shape_tuple + (n_cvs, M,))
b_est_tc = np.zeros(shape_tuple + (n_cvs, M))

for ii, coupling_loc in enumerate(coupling_locs):
    for jj, tuning_loc in enumerate(tuning_locs):
        for kk in range(n_models):
            for ll in range(n_datasets):
                file = f"exp{exp}_{ii}_{jj}_{kk}_{ll}.h5"
                path = os.path.join(base_path, file)
                
                with h5py.File(path, 'r') as results:
                    # Coupling parameters
                    a_true[ii, jj, kk, ll] = results['a_true'][:]
                    a_est_tm[ii, jj, kk, ll] = results['tm_oracle/a_est'][:]
                    a_est_tc[ii, jj, kk, ll] = results['tc_ols_oracle/a_est'][:]
                    # Tuning parameters
                    b_true[ii, jj, kk, ll] = results['b_true'][:]
                    b_est_tm[ii, jj, kk, ll] = results['tm_oracle/b_est'][:]
                    b_est_tc[ii, jj, kk, ll] = results['tc_ols_oracle/b_est'][:]

In [None]:
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_tm_nz = a_est_tm[a_true != 0].reshape(a_est_tm.shape[:-1] + (-1,))
a_est_tc_nz = a_est_tc[a_true != 0].reshape(a_est_tc.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_tm_nz = b_est_tm[b_true != 0].reshape(b_est_tm.shape[:-1] + (-1,))
b_est_tc_nz = b_est_tc[b_true != 0].reshape(b_est_tc.shape[:-1] + (-1,))

In [None]:
a_bias_tm_all = np.mean(a_est_tm_nz - a_true_nz, axis=3)
a_bias_tc_all = np.mean(a_est_tc_nz - a_true_nz, axis=3)
b_bias_tm_all = np.mean(b_est_tm_nz - b_true_nz, axis=3)
b_bias_tc_all = np.mean(b_est_tc_nz - b_true_nz, axis=3)

In [None]:
a_bias_agg_tm = np.median(np.mean(np.median(np.median(a_est_tm_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
a_bias_agg_tc = np.median(np.mean(np.median(np.median(a_est_tc_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_tm = np.median(np.mean(np.median(np.median(b_est_tm_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_tc = np.median(np.mean(np.median(np.median(b_est_tc_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)

In [None]:
shape_tuple = (n_coupling_locs, n_tuning_locs, n_models, n_datasets)
n_cvs = 3
a_true = np.zeros(shape_tuple + (n_cvs, N))
a_est_tm = np.zeros(shape_tuple + (n_cvs, N))
a_est_tc = np.zeros(shape_tuple + (n_cvs, N,))
b_true = np.zeros(shape_tuple + (n_cvs, M))
b_est_tm = np.zeros(shape_tuple + (n_cvs, M,))
b_est_tc = np.zeros(shape_tuple + (n_cvs, M))

for ii, coupling_loc in enumerate(coupling_locs):
    for jj, tuning_loc in enumerate(tuning_locs):
        for kk in range(n_models):
            for ll in range(n_datasets):
                file = f"exp{exp}_{ii}_{jj}_{kk}_{ll}.h5"
                path = os.path.join(base_path, file)
                
                with h5py.File(path, 'r') as results:
                    # Coupling parameters
                    a_true[ii, jj, kk, ll] = results['a_true'][:]
                    a_est_tm[ii, jj, kk, ll] = results['tm_t_c_uoi/a_est'][:]
                    a_est_tc[ii, jj, kk, ll] = results['tc_ols_t_c_uoi/a_est'][:]
                    # Tuning parameters
                    b_true[ii, jj, kk, ll] = results['b_true'][:]
                    b_est_tm[ii, jj, kk, ll] = results['tm_t_c_uoi/b_est'][:]
                    b_est_tc[ii, jj, kk, ll] = results['tc_ols_t_c_uoi/b_est'][:]

In [None]:
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_tm_nz = a_est_tm[a_true != 0].reshape(a_est_tm.shape[:-1] + (-1,))
a_est_tc_nz = a_est_tc[a_true != 0].reshape(a_est_tc.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_tm_nz = b_est_tm[b_true != 0].reshape(b_est_tm.shape[:-1] + (-1,))
b_est_tc_nz = b_est_tc[b_true != 0].reshape(b_est_tc.shape[:-1] + (-1,))

In [None]:
a_bias_tm_all_sel = np.mean(a_est_tm_nz - a_true_nz, axis=3)
a_bias_tc_all_sel = np.mean(a_est_tc_nz - a_true_nz, axis=3)
b_bias_tm_all_sel = np.mean(b_est_tm_nz - b_true_nz, axis=3)
b_bias_tc_all_sel = np.mean(b_est_tc_nz - b_true_nz, axis=3)

In [None]:
a_bias_agg_sel_tm = np.median(np.mean(np.median(np.median(a_est_tm_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
a_bias_agg_sel_tc = np.median(np.mean(np.median(np.median(a_est_tc_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_sel_tm = np.median(np.mean(np.median(np.median(b_est_tm_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_sel_tc = np.median(np.mean(np.median(np.median(b_est_tc_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(24, 10),
                         gridspec_kw={'width_ratios': [1, 1, 0.05, 1, 1]})
plt.subplots_adjust(hspace=0.5, wspace=0.3)

axes[0, 2].axis('off')
axes[1, 2].axis('off')

oracle_tc_coupling_ax = axes[0, 0]
oracle_tm_coupling_ax = axes[0, 1]
oracle_tc_tuning_ax = axes[1, 0]
oracle_tm_tuning_ax = axes[1, 1]

img1 = oracle_tc_coupling_ax.imshow(
    a_bias_agg_tc,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb1, cax1 = mplego.colorbar.append_colorbar_to_axis(oracle_tc_coupling_ax, img1)
cb1.ax.tick_params(labelsize=15)

img2 = oracle_tm_coupling_ax.imshow(
    a_bias_agg_tm,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb2, cax2 = mplego.colorbar.append_colorbar_to_axis(oracle_tm_coupling_ax, img2)
cb2.ax.tick_params(labelsize=15)

img3 = oracle_tc_tuning_ax.imshow(
    b_bias_agg_tc,
    origin='lower',
    vmin=-3,
    vmax=3,
    cmap='RdGy')
cb3, cax3 = mplego.colorbar.append_colorbar_to_axis(oracle_tc_tuning_ax, img3)
cb3.ax.tick_params(labelsize=15)

img4 = oracle_tm_tuning_ax.imshow(
    b_bias_agg_tm,
    origin='lower',
    vmin=-3,
    vmax=3,
    cmap='RdGy')
cb4, cax4 = mplego.colorbar.append_colorbar_to_axis(oracle_tm_tuning_ax, img4)
cb4.ax.tick_params(labelsize=15)


sel_tc_coupling_ax = axes[0, 3]
sel_tm_coupling_ax = axes[0, 4]
sel_tc_tuning_ax = axes[1, 3]
sel_tm_tuning_ax = axes[1, 4]

img5 = sel_tc_coupling_ax.imshow(
    a_bias_agg_sel_tc,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb5, cax5 = mplego.colorbar.append_colorbar_to_axis(sel_tc_coupling_ax, img5)
cb5.ax.tick_params(labelsize=15)

img6 = sel_tm_coupling_ax.imshow(
    a_bias_agg_sel_tm,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb6, cax6 = mplego.colorbar.append_colorbar_to_axis(sel_tm_coupling_ax, img6)
cb6.ax.tick_params(labelsize=15)

img7 = sel_tc_tuning_ax.imshow(
    b_bias_agg_sel_tc,
    origin='lower',
    vmin=-3,
    vmax=3,
    cmap='RdGy')
cb7, cax7 = mplego.colorbar.append_colorbar_to_axis(sel_tc_tuning_ax, img7)
cb7.ax.tick_params(labelsize=15)

img8 = sel_tm_tuning_ax.imshow(
    b_bias_agg_sel_tm,
    origin='lower',
    vmin=-3,
    vmax=3,
    cmap='RdGy')
cb8, cax8 = mplego.colorbar.append_colorbar_to_axis(sel_tm_tuning_ax, img8)
cb8.ax.tick_params(labelsize=15)

for ax in axes.ravel():
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(['$-1.0$', '$-0.5$', '$0.0$', '$0.5$', '$1.0$'])
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)


for ax in [oracle_tc_coupling_ax, oracle_tc_tuning_ax,
           sel_tc_coupling_ax, sel_tc_tuning_ax]:
    ax.set_title(mplego.labels.bold_text('TC Model'), fontsize=20)

for ax in [oracle_tm_coupling_ax, oracle_tm_tuning_ax,
           sel_tm_coupling_ax, sel_tm_tuning_ax]:
    ax.set_title(mplego.labels.bold_text('Triangular Model'), fontsize=20)


mplego.labels.apply_subplot_labels(
    [oracle_tc_coupling_ax, oracle_tm_coupling_ax,
     oracle_tc_tuning_ax, oracle_tm_tuning_ax,
     sel_tc_coupling_ax, sel_tm_coupling_ax,
     sel_tc_tuning_ax, sel_tm_tuning_ax],
    bold=True, size=19)

axes[0, 0].text(x=-0.4, y=0.5, va='center', ha='center',
                s=mplego.labels.bold_text('Coupling Parameters'),
                rotation=90,
                fontsize=22,
                transform=axes[0, 0].transAxes)
axes[1, 0].text(x=-0.4, y=0.5, va='center', ha='center',
                s=mplego.labels.bold_text('Tuning Parameters'),
                rotation=90,
                fontsize=22,
                transform=axes[1, 0].transAxes)

fig.text(
    x=0.28, y=0.95,
    va='center', ha='center',
    s=mplego.labels.bold_text('Oracle Selection'),
    fontsize=24)
fig.text(
    x=0.72, y=0.95,
    va='center', ha='center',
    s=mplego.labels.bold_text('Inferred Selection'),
    fontsize=24)


plt.savefig('figure3.pdf', bbox_inches='tight')