In [None]:
import h5py
import matplotlib.pyplot as plt
import mpl_lego as mplego
import neuropacks
import numpy as np
import os

from mpl_lego.colorbar import append_colorbar_to_axis
from mpl_lego.labels import bold_text
from pyprojroot import here

%matplotlib inline

In [None]:
mplego.style.use_latex_style()

In [None]:
exp = 22
base_path = here('fits/exp22')

In [None]:
# Hyperparameter 1
n_coupling_locs = 5
coupling_loc_min = -1
coupling_loc_max = 1
coupling_locs = np.linspace(coupling_loc_min, coupling_loc_max, n_coupling_locs)
# Hyperparameter 2
n_tuning_locs = 5
tuning_loc_min = -1
tuning_loc_max = 1
tuning_locs = np.linspace(tuning_loc_min, tuning_loc_max, n_tuning_locs)
# Hyperparameter 3
n_models = 3
# Hyperparameter 4
n_datasets = 10

In [None]:
# Model settings
N = 10
M = 10
K = 1
D = 2000
n_folds = 3

In [None]:
# Consolidate parameter estimates across fits, for oracle selection
shape_tuple = (n_coupling_locs, n_tuning_locs, n_models, n_datasets)
n_cvs = 3
a_true = np.zeros(shape_tuple + (n_cvs, N))
a_est_cotula = np.zeros(shape_tuple + (n_cvs, N))
a_est_cotu = np.zeros(shape_tuple + (n_cvs, N,))
b_true = np.zeros(shape_tuple + (n_cvs, M))
b_est_cotula = np.zeros(shape_tuple + (n_cvs, M,))
b_est_cotu = np.zeros(shape_tuple + (n_cvs, M))

for ii, coupling_loc in enumerate(coupling_locs):
    for jj, tuning_loc in enumerate(tuning_locs):
        for kk in range(n_models):
            for ll in range(n_datasets):
                file = f"exp{exp}_{ii}_{jj}_{kk}_{ll}.h5"
                path = os.path.join(base_path, file)
                
                with h5py.File(path, 'r') as results:
                    # Coupling parameters
                    a_true[ii, jj, kk, ll] = results['a_true'][:]
                    a_est_cotula[ii, jj, kk, ll] = results['tm_oracle/a_est'][:]
                    a_est_cotu[ii, jj, kk, ll] = results['tc_ols_oracle/a_est'][:]
                    # Tuning parameters
                    b_true[ii, jj, kk, ll] = results['b_true'][:]
                    b_est_cotula[ii, jj, kk, ll] = results['tm_oracle/b_est'][:]
                    b_est_cotu[ii, jj, kk, ll] = results['tc_ols_oracle/b_est'][:]

In [None]:
# Get selection profiles, and non-zero parameter estimates
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_cotula_nz = a_est_cotula[a_true != 0].reshape(a_est_cotula.shape[:-1] + (-1,))
a_est_cotu_nz = a_est_cotu[a_true != 0].reshape(a_est_cotu.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_cotula_nz = b_est_cotula[b_true != 0].reshape(b_est_cotula.shape[:-1] + (-1,))
b_est_cotu_nz = b_est_cotu[b_true != 0].reshape(b_est_cotu.shape[:-1] + (-1,))

In [None]:
# Calculate bias for all non-zero ground truth parameters
a_bias_cotula_all = np.mean(a_est_cotula_nz - a_true_nz, axis=3)
a_bias_cotu_all = np.mean(a_est_cotu_nz - a_true_nz, axis=3)
b_bias_cotula_all = np.mean(b_est_cotula_nz - b_true_nz, axis=3)
b_bias_cotu_all = np.mean(b_est_cotu_nz - b_true_nz, axis=3)

In [None]:
# Aggregate bias across hyperparameter instantiations and parameters
a_bias_agg_cotula = np.median(np.mean(np.median(np.median(a_est_cotula_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
a_bias_agg_cotu = np.median(np.mean(np.median(np.median(a_est_cotu_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_cotula = np.median(np.mean(np.median(np.median(b_est_cotula_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_cotu = np.median(np.mean(np.median(np.median(b_est_cotu_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)

In [None]:
# Consolidate parameter estimates across fits, for inferred selection
shape_tuple = (n_coupling_locs, n_tuning_locs, n_models, n_datasets)
n_cvs = 3
a_true = np.zeros(shape_tuple + (n_cvs, N))
a_est_cotula = np.zeros(shape_tuple + (n_cvs, N))
a_est_cotu = np.zeros(shape_tuple + (n_cvs, N,))
b_true = np.zeros(shape_tuple + (n_cvs, M))
b_est_cotula = np.zeros(shape_tuple + (n_cvs, M,))
b_est_cotu = np.zeros(shape_tuple + (n_cvs, M))

for ii, coupling_loc in enumerate(coupling_locs):
    for jj, tuning_loc in enumerate(tuning_locs):
        for kk in range(n_models):
            for ll in range(n_datasets):
                file = f"exp{exp}_{ii}_{jj}_{kk}_{ll}.h5"
                path = os.path.join(base_path, file)
                
                with h5py.File(path, 'r') as results:
                    # Coupling parameters
                    a_true[ii, jj, kk, ll] = results['a_true'][:]
                    a_est_cotula[ii, jj, kk, ll] = results['tm_t_c_uoi/a_est'][:]
                    a_est_cotu[ii, jj, kk, ll] = results['tc_ols_t_c_uoi/a_est'][:]
                    # Tuning parameters
                    b_true[ii, jj, kk, ll] = results['b_true'][:]
                    b_est_cotula[ii, jj, kk, ll] = results['tm_t_c_uoi/b_est'][:]
                    b_est_cotu[ii, jj, kk, ll] = results['tc_ols_t_c_uoi/b_est'][:]

In [None]:
# Get selection profiles, and non-zero parameter estimates
a_true_nz = a_true[a_true != 0].reshape(a_true.shape[:-1] + (-1,))
a_est_cotula_nz = a_est_cotula[a_true != 0].reshape(a_est_cotula.shape[:-1] + (-1,))
a_est_cotu_nz = a_est_cotu[a_true != 0].reshape(a_est_cotu.shape[:-1] + (-1,))
b_true_nz = b_true[b_true != 0].reshape(b_true.shape[:-1] + (-1,))
b_est_cotula_nz = b_est_cotula[b_true != 0].reshape(b_est_cotula.shape[:-1] + (-1,))
b_est_cotu_nz = b_est_cotu[b_true != 0].reshape(b_est_cotu.shape[:-1] + (-1,))

In [None]:
# Calculate bias for all non-zero ground truth parameters
a_bias_cotula_all_sel = np.mean(a_est_cotula_nz - a_true_nz, axis=3)
a_bias_cotu_all_sel = np.mean(a_est_cotu_nz - a_true_nz, axis=3)
b_bias_cotula_all_sel = np.mean(b_est_cotula_nz - b_true_nz, axis=3)
b_bias_cotu_all_sel = np.mean(b_est_cotu_nz - b_true_nz, axis=3)

In [None]:
# Aggregate bias across hyperparameter instantiations and parameters
a_bias_agg_sel_cotula = np.median(np.mean(np.median(np.median(a_est_cotula_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
a_bias_agg_sel_cotu = np.median(np.mean(np.median(np.median(a_est_cotu_nz - a_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_sel_cotula = np.median(np.mean(np.median(np.median(b_est_cotula_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)
b_bias_agg_sel_cotu = np.median(np.mean(np.median(np.median(b_est_cotu_nz - b_true_nz, axis=4), axis=4), axis=3), axis=2)

In [None]:
"""
Figure 3
"""
fig, axes = plt.subplots(
    nrows=2,
    ncols=5,
    figsize=(24, 10),
    gridspec_kw={'width_ratios': [1, 1, 0.05, 1, 1]})
plt.subplots_adjust(hspace=0.5, wspace=0.3)
# Figure settings
imshow_origin = 'lower'
coupling_vmin = 0
coupling_vmax = 0.75
coupling_cmap = 'Greys'
tuning_vmin = -3
tuning_vmax = 3
tuning_cmap = 'RdGy'
cb_tick_label_size = 15
# Axis settings
ax_label_size = 17
ax_tick_label_size = 15
subplot_label_bold = True
subplot_label_size = 19
title_size = 20
ax_sup_label_size = 22
sup_title_size = 24

# Turn off axes for space
axes[0, 2].axis('off')
axes[1, 2].axis('off')
"""
Figure 3, left side:
Oracle fits
"""
oracle_cotu_coupling_ax = axes[0, 0]
oracle_cotula_coupling_ax = axes[0, 1]
oracle_cotu_tuning_ax = axes[1, 0]
oracle_cotula_tuning_ax = axes[1, 1]

"""
Figure 3a:
Oracle, CoTu Model, Coupling
"""
img1 = oracle_cotu_coupling_ax.imshow(
    a_bias_agg_cotu,
    origin=imshow_origin,
    vmin=coupling_vmin,
    vmax=coupling_vmax,
    cmap=coupling_cmap)
cb1, cax1 = append_colorbar_to_axis(oracle_cotu_coupling_ax, img1)
cb1.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3b:
Oracle, CoTuLa Model, Coupling
"""
img2 = oracle_cotula_coupling_ax.imshow(
    a_bias_agg_cotula,
    origin=imshow_origin,
    vmin=coupling_vmin,
    vmax=coupling_vmax,
    cmap=coupling_cmap)
cb2, cax2 = append_colorbar_to_axis(oracle_cotula_coupling_ax, img2)
cb2.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3c:
Oracle, CoTu Model, Tuning
"""
img3 = oracle_cotu_tuning_ax.imshow(
    b_bias_agg_cotu,
    origin=imshow_origin,
    vmin=tuning_vmin,
    vmax=tuning_vmax,
    cmap=tuning_cmap)
cb3, cax3 = append_colorbar_to_axis(oracle_cotu_tuning_ax, img3)
cb3.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3d:
Oracle, CoTuLa Model, Tuning
"""
img4 = oracle_cotula_tuning_ax.imshow(
    b_bias_agg_cotula,
    origin=imshow_origin,
    vmin=tuning_vmin,
    vmax=tuning_vmax,
    cmap=tuning_cmap)
cb4, cax4 = append_colorbar_to_axis(oracle_cotula_tuning_ax, img4)
cb4.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3, right side:
Inferred fits
"""
sel_cotu_coupling_ax = axes[0, 3]
sel_cotula_coupling_ax = axes[0, 4]
sel_cotu_tuning_ax = axes[1, 3]
sel_cotula_tuning_ax = axes[1, 4]

"""
Figure 3e:
Inferred, CoTu Model, Coupling
"""
img5 = sel_cotu_coupling_ax.imshow(
    a_bias_agg_sel_cotu,
    origin=imshow_origin,
    vmin=coupling_vmin,
    vmax=coupling_vmax,
    cmap=coupling_cmap)
cb5, cax5 = append_colorbar_to_axis(sel_cotu_coupling_ax, img5)
cb5.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3f:
Inferred, CoTuLa Model, Coupling
"""
img6 = sel_cotula_coupling_ax.imshow(
    a_bias_agg_sel_cotula,
    origin=imshow_origin,
    vmin=coupling_vmin,
    vmax=coupling_vmax,
    cmap=coupling_cmap)
cb6, cax6 = append_colorbar_to_axis(sel_cotula_coupling_ax, img6)
cb6.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3g:
Inferred, CoTu Model, Tuning
"""
img7 = sel_cotu_tuning_ax.imshow(
    b_bias_agg_sel_cotu,
    origin=imshow_origin,
    vmin=tuning_vmin,
    vmax=tuning_vmax,
    cmap=tuning_cmap)
cb7, cax7 = append_colorbar_to_axis(sel_cotu_tuning_ax, img7)
cb7.ax.tick_params(labelsize=cb_tick_label_size)

"""
Figure 3h:
Inferred, CoTuLa Model, Tuning
"""
img8 = sel_cotula_tuning_ax.imshow(
    b_bias_agg_sel_cotula,
    origin=imshow_origin,
    vmin=tuning_vmin,
    vmax=tuning_vmax,
    cmap=tuning_cmap)
cb8, cax8 = append_colorbar_to_axis(sel_cotula_tuning_ax, img8)
cb8.ax.tick_params(labelsize=cb_tick_label_size)

# Set ticks, tick labels, tick label sizes, and labels
for ax in axes.ravel():
    ax.set_xticks(np.arange(5))
    ax.set_yticks(ax.get_xticks())
    ax.set_xticklabels(['$-1.0$', '$-0.5$', '$0.0$', '$0.5$', '$1.0$'])
    ax.set_yticklabels(ax.get_xticklabels())
    ax.set_xlabel(bold_text('Tuning Mean'), fontsize=ax_label_size)
    ax.set_ylabel(bold_text('Coupling Mean'), fontsize=ax_label_size)
    ax.tick_params(labelsize=ax_tick_label_size)

# Set axis titles
for ax in [oracle_cotu_coupling_ax,
           oracle_cotu_tuning_ax,
           sel_cotu_coupling_ax,
           sel_cotu_tuning_ax]:
    ax.set_title(bold_text('CoTu Model'), fontsize=title_size)

for ax in [oracle_cotula_coupling_ax, oracle_cotula_tuning_ax,
           sel_cotula_coupling_ax, sel_cotula_tuning_ax]:
    ax.set_title(bold_text('CoTuLa Model'), fontsize=title_size)

# Set subplot labels
mplego.labels.apply_subplot_labels(
    axes=[oracle_cotu_coupling_ax,
          oracle_cotula_coupling_ax,
          oracle_cotu_tuning_ax,
          oracle_cotula_tuning_ax,
          sel_cotu_coupling_ax,
          sel_cotula_coupling_ax,
          sel_cotu_tuning_ax,
          sel_cotula_tuning_ax],
    bold=subplot_label_bold,
    size=subplot_label_size)

# Set axis sup-labels
axes[0, 0].text(
    x=-0.4,
    y=0.5,
    va='center',
    ha='center',
    s=bold_text('Coupling Parameters'),
    rotation=90,
    fontsize=ax_sup_label_size,
    transform=axes[0, 0].transAxes)
axes[1, 0].text(
    x=-0.4,
    y=0.5,
    va='center',
    ha='center',
    s=bold_text('Tuning Parameters'),
    rotation=90,
    fontsize=ax_sup_label_size,
    transform=axes[1, 0].transAxes)

# Set suptitles
fig.text(
    x=0.28,
    y=0.95,
    va='center',
    ha='center',
    s=bold_text('Oracle Selection'),
    fontsize=sup_title_size)
fig.text(
    x=0.72,
    y=0.95,
    va='center',
    ha='center',
    s=bold_text('Inferred Selection'),
    fontsize=sup_title_size)

# Save figure
plt.savefig('figure3.pdf', bbox_inches='tight')