In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os

from pratplot import colorbar

%matplotlib inline

In [None]:
base_path = '/Volumes/pss/fits/neurobiases'
em_path = os.path.join(base_path, 'exp9_em_oracle.h5')
tc_path = os.path.join(base_path, 'exp9_tc_oracle.h5')

In [None]:
em_results = h5py.File(em_path, 'r')
tc_results = h5py.File(tc_path, 'r')

In [None]:
# Get true parameters
a_true = tc_results['a_true']
b_true = tc_results['b_true']

In [None]:
# Calculate bias on each dataset
a_bias_tc = tc_results['a_est'][:] - tc_results['a_true'][:]
b_bias_tc = tc_results['b_est'][:] - tc_results['b_true'][:]
# Replace zeros (not estimated values) with NaNs
a_bias_tc[a_bias_tc == 0] = np.nan
b_bias_tc[b_bias_tc == 0] = np.nan
# Aggregate biases across datasets, folds, parameters, and models
a_bias_tc_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.mean(a_bias_tc, axis=3) # Mean across datasets
                , axis=3)
            , axis=3)
    , axis=2)
b_bias_tc_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.mean(b_bias_tc, axis=3) # Mean across datasets
                , axis=3)
            , axis=3)
    , axis=2)

In [None]:
# Calculate bias on each dataset
a_bias_em = em_results['a_est'][:] - em_results['a_true'][:]
b_bias_em = em_results['b_est'][:] - em_results['b_true'][:]
# Replace zeros (not estimated values) with NaNs
a_bias_em[a_bias_em == 0] = np.nan
b_bias_em[b_bias_em == 0] = np.nan
# Aggregate biases across datasets, folds, parameters, and models
a_bias_em_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.mean(a_bias_em, axis=3) # Mean across datasets
                , axis=3)
            , axis=3)
    , axis=2)
b_bias_em_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.mean(b_bias_em, axis=3) # Mean across datasets
                , axis=3)
            , axis=3)
    , axis=2)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

plt.subplots_adjust(wspace=0.75)

img1 = axes[0].imshow(
    a_bias_tc_agg,
    vmin=0,
    vmax=0.75,
    cmap='Greys',
    interpolation=None,
    origin='lower')
img2 = axes[1].imshow(
    b_bias_tc_agg,
    vmin=-9,
    vmax=9,
    cmap='RdGy',
    interpolation=None,
    origin='lower')

cb1, _ = colorbar.append_colorbar_to_axis(axes[0], img1)
cb2, _ = colorbar.append_colorbar_to_axis(axes[1], img2)

for ax in axes:
    ticks = np.linspace(0, 9, 4, endpoint=True, dtype='int')
    ax.tick_params(labelsize=15)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_yticklabels(np.round(em_results['tuning_locs'][::3], decimals=1))
    ax.set_xticklabels(np.round(em_results['coupling_locs'][::3], decimals=1))
    ax.set_ylabel(r'\textbf{Tuning Mean}', fontsize=17)
    ax.set_xlabel(r'\textbf{Coupling Mean}', fontsize=17)
    
cb1.ax.tick_params(labelsize=15)
cb1.set_label(r'\textbf{Coupling Bias}', fontsize=15, rotation=270, labelpad=20)
cb2.ax.tick_params(labelsize=15)
cb2.set_label(r'\textbf{Tuning Bias}', fontsize=15, rotation=270, labelpad=20)
plt.savefig('exp9_oracle_tc_bias.pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

plt.subplots_adjust(wspace=0.75)

img1 = axes[0].imshow(
    a_bias_em_agg,
    vmin=0,
    vmax=0.05,
    cmap='Greys',
    interpolation=None,
    origin='lower')
img2 = axes[1].imshow(
    b_bias_em_agg,
    vmin=-0.2,
    vmax=0.2,
    cmap='RdGy',
    interpolation=None,
    origin='lower')

cb1, _ = colorbar.append_colorbar_to_axis(axes[0], img1)
cb2, _ = colorbar.append_colorbar_to_axis(axes[1], img2)


for ax in axes:
    ticks = np.linspace(0, 9, 4, endpoint=True, dtype='int')
    ax.tick_params(labelsize=15)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_yticklabels(np.round(em_results['tuning_locs'][::3], decimals=1))
    ax.set_xticklabels(np.round(em_results['coupling_locs'][::3], decimals=1))
    ax.set_ylabel(r'\textbf{Tuning Mean}', fontsize=17)
    ax.set_xlabel(r'\textbf{Coupling Mean}', fontsize=17)
    
cb1.ax.tick_params(labelsize=15)
cb1.set_label(r'\textbf{Coupling Bias}', fontsize=15, rotation=270, labelpad=20)
cb2.ax.tick_params(labelsize=15)
cb2.set_label(r'\textbf{Tuning Bias}', fontsize=15, rotation=270, labelpad=20)
plt.savefig('exp9_oracle_em_bias.pdf', bbox_inches='tight')

# Variance

In [None]:
a_est_tc_nan = tc_results['a_est'][:]
a_est_tc_nan[a_est_tc_nan == 0] = np.nan
b_est_tc_nan = tc_results['b_est'][:]
b_est_tc_nan[b_est_tc_nan == 0] = np.nan

In [None]:
a_est_em_nan = em_results['a_est'][:]
a_est_em_nan[a_est_em_nan == 0] = np.nan
b_est_em_nan = em_results['b_est'][:]
b_est_em_nan[b_est_em_nan == 0] = np.nan

In [None]:
# Aggregate biases across datasets, folds, parameters, and models
a_var_tc_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.var(a_est_tc_nan, axis=3) # Variance across datasets
                , axis=3)
            , axis=3)
    , axis=2)
b_var_tc_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.var(b_est_tc_nan, axis=3) # Mean across datasets
                , axis=3)
            , axis=3)
    , axis=2)

In [None]:
# Aggregate biases across datasets, folds, parameters, and models
a_var_em_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.var(a_est_em_nan, axis=3) # Variance across datasets
                , axis=3)
            , axis=3)
    , axis=2)
b_var_em_agg = \
    np.median( # Median across models
        np.nanmedian( # NaN Median across parameters
            np.mean( # Mean across folds
                np.var(b_est_em_nan, axis=3) # Mean across datasets
                , axis=3)
            , axis=3)
    , axis=2)

In [None]:
b_var_em_agg.max()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

plt.subplots_adjust(wspace=0.75)

img1 = axes[0].imshow(
    a_var_tc_agg,
    vmin=0,
    vmax=0.12,
    cmap='Greys',
    interpolation=None,
    origin='lower')
img2 = axes[1].imshow(
    b_var_tc_agg,
    vmin=0,
    vmax=2.5,
    cmap='Greys',
    interpolation=None,
    origin='upper')

cb1, _ = colorbar.append_colorbar_to_axis(axes[0], img1)
cb2, _ = colorbar.append_colorbar_to_axis(axes[1], img2)

for ax in axes:
    ticks = np.linspace(0, 9, 4, endpoint=True, dtype='int')
    ax.tick_params(labelsize=15)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_yticklabels(np.round(em_results['tuning_locs'][::3], decimals=1))
    ax.set_xticklabels(np.round(em_results['coupling_locs'][::3], decimals=1))
    ax.set_ylabel(r'\textbf{Tuning Mean}', fontsize=17)
    ax.set_xlabel(r'\textbf{Coupling Mean}', fontsize=17)
    
cb1.ax.tick_params(labelsize=15)
cb1.set_label(r'\textbf{Coupling Variance}', fontsize=15, rotation=270, labelpad=20)
cb2.ax.tick_params(labelsize=15)
cb2.set_label(r'\textbf{Tuning Variance}', fontsize=15, rotation=270, labelpad=20)
plt.savefig('exp9_oracle_tc_var.pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

plt.subplots_adjust(wspace=0.75)

img1 = axes[0].imshow(
    a_var_em_agg,
    vmin=0,
    vmax=0.12,
    cmap='Greys',
    interpolation=None,
    origin='lower')
img2 = axes[1].imshow(
    b_var_em_agg,
    vmin=0,
    vmax=2.5,
    cmap='Greys',
    interpolation=None,
    origin='upper')

cb1, _ = colorbar.append_colorbar_to_axis(axes[0], img1)
cb2, _ = colorbar.append_colorbar_to_axis(axes[1], img2)

for ax in axes:
    ticks = np.linspace(0, 9, 4, endpoint=True, dtype='int')
    ax.tick_params(labelsize=15)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_yticklabels(np.round(em_results['tuning_locs'][::3], decimals=1))
    ax.set_xticklabels(np.round(em_results['coupling_locs'][::3], decimals=1))
    ax.set_ylabel(r'\textbf{Tuning Mean}', fontsize=17)
    ax.set_xlabel(r'\textbf{Coupling Mean}', fontsize=17)
    
cb1.ax.tick_params(labelsize=15)
cb1.set_label(r'\textbf{Coupling Variance}', fontsize=15, rotation=270, labelpad=20)
cb2.ax.tick_params(labelsize=15)
cb2.set_label(r'\textbf{Tuning Variance}', fontsize=15, rotation=270, labelpad=20)
plt.savefig('exp9_oracle_em_var.pdf', bbox_inches='tight')