In [None]:
import h5py
import matplotlib.pyplot as plt
import neuropacks
import numpy as np
import os
import mpl_lego.colorbar as colorbar

from neurobiases.utils import selection_accuracy, selection_accuracy_single
%matplotlib inline

In [None]:
plt.rcParams.update({'text.usetex': True, 'font.family': 'serif'})

In [None]:
base_path = '/storage/fits/neurobiases/exp15'
exp15_path = os.path.join(base_path, "exp15.h5")

In [None]:
with h5py.File(exp15_path, "r") as results:
    a_true = np.expand_dims(results["a_true"][:], axis=4)
    a_est = results["a_est"][:]
    b_true = np.expand_dims(results["b_true"][:], axis=4)
    b_est = results["b_est"][:]

In [None]:
a_true_bcast = np.broadcast_to(a_true, a_est.shape)
b_true_bcast = np.broadcast_to(b_true, b_est.shape)

In [None]:
a_est_med = np.median(a_est, axis=4, keepdims=True)
b_est_med = np.median(b_est, axis=4, keepdims=True)

In [None]:
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_nz = a_est[b_true_bcast != 0].reshape(a_est.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_nz = b_est[b_true_bcast != 0].reshape(b_est.shape[:-1] + (-1,))

# Selection Accuracy

In [None]:
a_sel_acc = selection_accuracy(a_true, a_est)
b_sel_acc = selection_accuracy(b_true, b_est)

In [None]:
a_sel_acc_agg = np.median(np.median(np.median(a_sel_acc, axis=4), axis=3), axis=2)
b_sel_acc_agg = np.median(np.median(np.median(b_sel_acc, axis=4), axis=3), axis=2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.55)
img1 = axes[0].imshow(
    a_sel_acc_agg,
    origin='lower',
    vmin=0,
    vmax=1,
    cmap='Greys')
cb1, cax1 = colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    b_sel_acc_agg,
    origin='lower',
    vmin=0,
    vmax=1,
    cmap='Greys')
cb2, cax2 = colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Coupling Parameters}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning Parameters}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Selection Accuracy}', fontsize=15, rotation=270, labelpad=15)

plt.show()

# Bias

In [None]:
a_bias = a_est - a_true
a_bias_nz = a_est_nz - a_true_nz
b_bias = b_est - b_true
b_bias_nz = b_est_nz - b_true_nz

In [None]:
a_bias_agg = np.median(np.mean(np.median(a_bias, axis=4), axis=3), axis=2)
b_bias_agg = np.median(np.mean(np.median(b_bias, axis=4), axis=3), axis=2)

In [None]:
a_bias_nz_agg = np.median(np.median(np.mean(np.median(a_bias_nz, axis=4), axis=3), axis=2), axis=-1)
b_bias_nz_agg = np.median(np.median(np.mean(np.median(b_bias_nz, axis=4), axis=3), axis=2), axis=-1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.7)
img1 = axes[0].imshow(
    a_bias_nz_agg,
    origin='lower',
    vmin=-0.5,
    vmax=0.5,
    cmap='Greys')
cb1, cax1 = colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    b_bias_nz_agg,
    origin='lower',
    vmin=-1,
    vmax=1,
    cmap='Greys')
cb2, cax2 = colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Coupling Parameters}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning Parameters}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

    
cb1.set_ticks([-0.5, -0.25, 0., 0.25, 0.5])
cb2.set_ticks([-1, -0.5, 0., 0.5, 1])

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Bias}', fontsize=15, rotation=270, labelpad=15)

plt.show()

# Variance

In [None]:
a_var = np.var(a_est, axis=3)
b_var = np.var(b_est, axis=3)

In [None]:
a_var_agg = np.median(np.median(np.median(a_var, axis=3), axis=2), axis=2)
b_var_agg = np.median(np.median(np.median(b_var, axis=3), axis=2), axis=2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plt.subplots_adjust(wspace=0.7)
img1 = axes[0].imshow(
    a_var_agg,
    origin='lower',
    vmin=0,
    vmax=0.05,
    cmap='Greys')
cb1, cax1 = colorbar.append_colorbar_to_axis(axes[0], img1)

img2 = axes[1].imshow(
    b_bias_nz_agg,
    origin='lower',
    vmin=0,
    vmax=0.5,
    cmap='Greys')
cb2, cax2 = colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Coupling Parameters}', fontsize=16)
axes[1].set_title(r'\textbf{Tuning Parameters}', fontsize=16)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))
    ax.set_xticklabels(np.linspace(-1, 1, 5))
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(r'\textbf{Tuning Mean}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Mean}', fontsize=15)

for cax in [cax1, cax2]:
    cax.tick_params(labelsize=12)
    cax.set_ylabel(r'\textbf{Variance}', fontsize=15, rotation=270, labelpad=15)

plt.show()