In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks
import numpy as np
import os

from neurobiases.utils import selection_accuracy

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
exp = 22
base_path = f'/storage/fits/neurobiases/exp{exp}'

In [None]:
# Hyperparameter 1
n_coupling_locs = 5
coupling_loc_min = -1
coupling_loc_max = 1
coupling_locs = np.linspace(coupling_loc_min, coupling_loc_max, n_coupling_locs)
# Hyperparameter 2
n_tuning_locs = 5
tuning_loc_min = -1
tuning_loc_max = 1
tuning_locs = np.linspace(tuning_loc_min, tuning_loc_max, n_tuning_locs)
# Hyperparameter 3
n_models = 3
# Hyperparameter 4
n_datasets = 10

In [None]:
N = 10
M = 10
K = 1
D = 2000
n_folds = 3

In [None]:
f = h5py.File(os.path.join(base_path, 'exp22_0_0_0_0.h5'), 'r')

# Oracle Selection

In [None]:
shape_tuple = (n_coupling_locs, n_tuning_locs, n_models, n_datasets)
n_cvs = 3
a_true = np.zeros(shape_tuple + (n_cvs, N))
a_est_tm = np.zeros(shape_tuple + (n_cvs, N))
a_est_tc = np.zeros(shape_tuple + (n_cvs, N,))
b_true = np.zeros(shape_tuple + (n_cvs, M))
b_est_tm = np.zeros(shape_tuple + (n_cvs, M,))
b_est_tc = np.zeros(shape_tuple + (n_cvs, M))

for ii, coupling_loc in enumerate(coupling_locs):
    for jj, tuning_loc in enumerate(tuning_locs):
        for kk in range(n_models):
            for ll in range(n_datasets):
                file = f"exp{exp}_{ii}_{jj}_{kk}_{ll}.h5"
                path = os.path.join(base_path, file)
                
                with h5py.File(path, 'r') as results:
                    # Coupling parameters
                    a_true[ii, jj, kk, ll] = results['a_true'][:]
                    a_est_tm[ii, jj, kk, ll] = results['tm_oracle/a_est'][:]
                    a_est_tc[ii, jj, kk, ll] = results['tc_ols_oracle/a_est'][:]
                    # Tuning parameters
                    b_true[ii, jj, kk, ll] = results['b_true'][:]
                    b_est_tm[ii, jj, kk, ll] = results['tm_oracle/b_est'][:]
                    b_est_tc[ii, jj, kk, ll] = results['tc_ols_oracle/b_est'][:]

In [None]:
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_tm_nz = a_est_tm[a_true != 0].reshape(a_est_tm.shape[:-1] + (-1,))
a_est_tc_nz = a_est_tc[a_true != 0].reshape(a_est_tc.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_tm_nz = b_est_tm[b_true != 0].reshape(b_est_tm.shape[:-1] + (-1,))
b_est_tc_nz = b_est_tc[b_true != 0].reshape(b_est_tc.shape[:-1] + (-1,))

In [None]:
a_bias_tm_all = np.mean(a_est_tm_nz - a_true_nz, axis=3)
a_bias_tc_all = np.mean(a_est_tc_nz - a_true_nz, axis=3)
b_bias_tm_all = np.mean(b_est_tm_nz - b_true_nz, axis=3)
b_bias_tc_all = np.mean(b_est_tc_nz - b_true_nz, axis=3)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].scatter(
    a_bias_tc_all.ravel(),
    a_bias_tm_all.ravel(),
    color='black',
    alpha=0.1)
axes[1].scatter(
    b_bias_tc_all.ravel(),
    b_bias_tm_all.ravel(),
    color='black',
    alpha=0.1)

for ax in axes:
    ax.axhline(0, color='gray', linestyle='--')
    ax.axvline(0, color='gray', linestyle='--')
    ax.set_xlabel(r'\textbf{Tuning + Coupling Model}', fontsize=15)
    ax.set_ylabel(r'\textbf{Triangular Model}', fontsize=15)
mplego.scatter.tighten_scatter_plot(axes[0], lim=[-1, 1], color='gray')
mplego.scatter.tighten_scatter_plot(axes[1], lim=[-3, 3], color='gray')

axes[0].set_xticks([-1, -0.5, 0, 0.5, 1.])
axes[0].set_yticks(axes[0].get_xticks())
axes[1].set_xticks([-3, -2, -1, 0, 1, 2, 3.])
axes[1].set_yticks(axes[1].get_xticks())

axes[0].set_title(r'\textbf{Coupling Parameters}', fontsize=18)
axes[1].set_title(r'\textbf{Tuning Parameters}', fontsize=18)

plt.show()

In [None]:
a_bias_agg_tm = np.median(np.mean(np.median(np.median(a_est_tm_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
a_bias_agg_tc = np.median(np.mean(np.median(np.median(a_est_tc_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_tm = np.median(np.mean(np.median(np.median(b_est_tm_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_tc = np.median(np.mean(np.median(np.median(b_est_tc_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.7)

img1 = axes[0].imshow(
    a_bias_agg_tm,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb1, cax1 = mplego.colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    a_bias_agg_tc,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb2, cax2 = mplego.colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Triangular Model}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning + Coupling Model}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

    
cb1.set_ticks([-0.5, -0.25, 0., 0.25, 0.5])
cb2.set_ticks([-1, -0.5, 0., 0.5, 1])

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Bias}', fontsize=15, rotation=270, labelpad=15)

plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.7)

img1 = axes[0].imshow(
    b_bias_agg_tm,
    origin='lower',
    vmin=-2.5,
    vmax=2.5,
    cmap='RdGy')
cb1, cax1 = mplego.colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    b_bias_agg_tc,
    origin='lower',
    vmin=-2.5,
    vmax=2.5,
    cmap='RdGy')
cb2, cax2 = mplego.colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Triangular Model}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning + Coupling Model}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

    
cb1.set_ticks([-0.5, -0.25, 0., 0.25, 0.5])
cb2.set_ticks([-1, -0.5, 0., 0.5, 1])

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Bias}', fontsize=15, rotation=270, labelpad=15)

plt.show()

# UoI T/C Selection

In [None]:
shape_tuple = (n_coupling_locs, n_tuning_locs, n_models, n_datasets)
n_cvs = 3
a_true = np.zeros(shape_tuple + (n_cvs, N))
a_est_tm = np.zeros(shape_tuple + (n_cvs, N))
a_est_tc = np.zeros(shape_tuple + (n_cvs, N,))
b_true = np.zeros(shape_tuple + (n_cvs, M))
b_est_tm = np.zeros(shape_tuple + (n_cvs, M,))
b_est_tc = np.zeros(shape_tuple + (n_cvs, M))

for ii, coupling_loc in enumerate(coupling_locs):
    for jj, tuning_loc in enumerate(tuning_locs):
        for kk in range(n_models):
            for ll in range(n_datasets):
                file = f"exp{exp}_{ii}_{jj}_{kk}_{ll}.h5"
                path = os.path.join(base_path, file)
                
                with h5py.File(path, 'r') as results:
                    # Coupling parameters
                    a_true[ii, jj, kk, ll] = results['a_true'][:]
                    a_est_tm[ii, jj, kk, ll] = results['tm_t_c_uoi/a_est'][:]
                    a_est_tc[ii, jj, kk, ll] = results['tc_ols_t_c_uoi/a_est'][:]
                    # Tuning parameters
                    b_true[ii, jj, kk, ll] = results['b_true'][:]
                    b_est_tm[ii, jj, kk, ll] = results['tm_t_c_uoi/b_est'][:]
                    b_est_tc[ii, jj, kk, ll] = results['tc_ols_t_c_uoi/b_est'][:]

In [None]:
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_tm_nz = a_est_tm[a_true != 0].reshape(a_est_tm.shape[:-1] + (-1,))
a_est_tc_nz = a_est_tc[a_true != 0].reshape(a_est_tc.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_tm_nz = b_est_tm[b_true != 0].reshape(b_est_tm.shape[:-1] + (-1,))
b_est_tc_nz = b_est_tc[b_true != 0].reshape(b_est_tc.shape[:-1] + (-1,))

In [None]:
a_bias_tm_all = np.mean(a_est_tm_nz - a_true_nz, axis=3)
a_bias_tc_all = np.mean(a_est_tc_nz - a_true_nz, axis=3)
b_bias_tm_all = np.mean(b_est_tm_nz - b_true_nz, axis=3)
b_bias_tc_all = np.mean(b_est_tc_nz - b_true_nz, axis=3)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].scatter(
    a_bias_tc_all.ravel(),
    a_bias_tm_all.ravel(),
    color='black',
    alpha=0.1)
axes[1].scatter(
    b_bias_tc_all.ravel(),
    b_bias_tm_all.ravel(),
    color='black',
    alpha=0.1)

for ax in axes:
    ax.axhline(0, color='gray', linestyle='--')
    ax.axvline(0, color='gray', linestyle='--')
    ax.set_xlabel(r'\textbf{Tuning + Coupling Model}', fontsize=15)
    ax.set_ylabel(r'\textbf{Triangular Model}', fontsize=15)
mplego.scatter.tighten_scatter_plot(axes[0], lim=[-1, 1], color='gray')
mplego.scatter.tighten_scatter_plot(axes[1], lim=[-3, 3], color='gray')

axes[0].set_xticks([-1, -0.5, 0, 0.5, 1.])
axes[0].set_yticks(axes[0].get_xticks())
axes[1].set_xticks([-3, -2, -1, 0, 1, 2, 3.])
axes[1].set_yticks(axes[1].get_xticks())

axes[0].set_title(r'\textbf{Coupling Parameters}', fontsize=18)
axes[1].set_title(r'\textbf{Tuning Parameters}', fontsize=18)

plt.show()

In [None]:
a_bias_agg_tm = np.median(np.mean(np.median(np.median(a_est_tm_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
a_bias_agg_tc = np.median(np.mean(np.median(np.median(a_est_tc_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_tm = np.median(np.mean(np.median(np.median(b_est_tm_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_tc = np.median(np.mean(np.median(np.median(b_est_tc_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.7)

img1 = axes[0].imshow(
    a_bias_agg_tm,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb1, cax1 = mplego.colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    a_bias_agg_tc,
    origin='lower',
    vmin=0,
    vmax=0.75,
    cmap='Greys')
cb2, cax2 = mplego.colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Triangular Model}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning + Coupling Model}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

    
cb1.set_ticks([-0.5, -0.25, 0., 0.25, 0.5])
cb2.set_ticks([-1, -0.5, 0., 0.5, 1])

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Bias}', fontsize=15, rotation=270, labelpad=15)

plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.7)

img1 = axes[0].imshow(
    b_bias_agg_tm,
    origin='lower',
    vmin=-2.5,
    vmax=2.5,
    cmap='RdGy')
cb1, cax1 = mplego.colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    b_bias_agg_tc,
    origin='lower',
    vmin=-2.5,
    vmax=2.5,
    cmap='RdGy')
cb2, cax2 = mplego.colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Triangular Model}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning + Coupling Model}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

    
cb1.set_ticks([-0.5, -0.25, 0., 0.25, 0.5])
cb2.set_ticks([-1, -0.5, 0., 0.5, 1])

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Bias}', fontsize=15, rotation=270, labelpad=15)

plt.show()