In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os

from neurobiases import TriangularModel, EMSolver
from matplotlib.ticker import FormatStrFormatter
from mpl_lego import colorbar
from sklearn.model_selection import check_cv

%matplotlib inline

In [None]:
plt.rcParams.update({'text.usetex': True, 'font.family': 'serif'})

In [None]:
path = os.path.join(os.environ['HOME'], "fits/neurobiases/exp18/exp18_em_aic_large_0_0_0_0.h5")

In [None]:
results = h5py.File(path, "r")

In [None]:
# Open up experiment settings
with h5py.File(path, 'r') as params:
    # Triangular model hyperparameters
    N = params.attrs['N']
    M = params.attrs['M']
    K = params.attrs['K']
    D = params.attrs['D']
    corr_cluster = params.attrs['corr_cluster']
    corr_back = params.attrs['corr_back']
    coupling_distribution = params.attrs['coupling_distribution']
    coupling_sparsity = params.attrs['coupling_sparsity']
    coupling_loc = params.attrs['coupling_loc']
    coupling_scale = params.attrs['coupling_scale']
    coupling_rng = params.attrs['coupling_rng']
    tuning_distribution = params.attrs['tuning_distribution']
    tuning_sparsity = params.attrs['tuning_sparsity']
    tuning_loc = params.attrs['tuning_loc']
    tuning_scale = params.attrs['tuning_scale']
    tuning_rng = params.attrs['tuning_rng']
    # Random seeds
    coupling_rng = params.attrs['coupling_rng']
    tuning_rng = params.attrs['tuning_rng']
    dataset_rng = params.attrs['dataset_rng']
    fitter_rng = params.attrs['fitter_rng']
    # Training hyperparameters
    Ks = params['Ks'][:]
    coupling_lambdas = params['coupling_lambdas'][:]
    n_coupling_lambdas = coupling_lambdas.size
    tuning_lambdas = params['tuning_lambdas'][:]
    n_tuning_lambdas = tuning_lambdas.size
    # Training settings
    criterion = params.attrs['criterion']
    cv = params.attrs['cv']
    fine_sweep_frac = params.attrs['fine_sweep_frac']
    solver = params.attrs['solver']
    initialization = params.attrs['initialization']
    max_iter = params.attrs['max_iter']
    tol = params.attrs['tol']

In [None]:
results['b_true'][:]

In [None]:
results['b_est'][:]

In [None]:
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=M,
    N=N,
    K=K,
    corr_cluster=corr_cluster,
    corr_back=corr_back,
    coupling_distribution=coupling_distribution,
    coupling_sparsity=coupling_sparsity,
    coupling_loc=coupling_loc,
    coupling_scale=coupling_scale,
    coupling_rng=coupling_rng,
    tuning_distribution=tuning_distribution,
    tuning_sparsity=tuning_sparsity,
    tuning_loc=tuning_loc,
    tuning_scale=tuning_scale,
    tuning_rng=tuning_rng,
    stim_distribution='uniform')
# Generate data using seed
X, Y, y = tm.generate_samples(n_samples=D, rng=int(dataset_rng))

In [None]:
# Pull out the indices for the current fold
cv = check_cv(cv)
train_idx, test_idx = list(cv.split(X))[int(0)]
X_train = X[train_idx]
Y_train = Y[train_idx]
y_train = y[train_idx]
X_test = X[test_idx]
Y_test = Y[test_idx]
y_test = y[test_idx]

In [None]:
list(results)

In [None]:
a_est_coarse = np.squeeze(results["a_est_coarse"])
a_est_fine = np.squeeze(results["a_est_fine"])
a_est = np.squeeze(results["a_est"])
a_true = results["a_true"][:]

b_est_coarse = np.squeeze(results["b_est_coarse"])
b_est_fine = np.squeeze(results["b_est_fine"])
b_est = np.squeeze(results["b_est"])
b_true = results["b_true"][:]

aics_coarse = np.squeeze(results["aics_coarse"][:])
aics_fine = np.squeeze(results["aics_fine"][:])
aics_coarse_med = np.median(aics_coarse, axis=-1)
aics_fine_med = np.median(aics_fine, axis=-1)
best_aic_coarse = np.unravel_index(np.argmin(aics_coarse_med), aics_coarse_med.shape)
best_aic_fine = np.unravel_index(np.argmin(aics_fine_med), aics_fine_med.shape)

bics_coarse = np.squeeze(results["bics_coarse"][:])
bics_fine = np.squeeze(results["bics_fine"][:])
bics_coarse_med = np.median(bics_coarse, axis=-1)
bics_fine_med = np.median(bics_fine, axis=-1)
best_bic_coarse = np.unravel_index(np.argmin(bics_coarse_med), bics_coarse_med.shape)
best_bic_fine = np.unravel_index(np.argmin(bics_fine_med), bics_fine_med.shape)

mlls_coarse = np.squeeze(results["scores_coarse"][:])
mlls_fine = np.squeeze(results["scores_fine"][:])
mlls_coarse_med = np.median(mlls_coarse, axis=-1)
mlls_fine_med = np.median(mlls_fine, axis=-1)
best_mll_coarse = np.unravel_index(np.argmax(mlls_coarse_med), mlls_coarse_med.shape)
best_mll_fine = np.unravel_index(np.argmax(mlls_fine_med), mlls_fine_med.shape)

coupling_lambdas_coarse = results['coupling_lambdas'][:]
coupling_lambdas_fine = results['coupling_lambdas_fine'][:]
tuning_lambdas_coarse = results['tuning_lambdas'][:]
tuning_lambdas_fine = results['tuning_lambdas_fine'][:]

In [None]:
a_sr_coarse = np.median(np.count_nonzero(a_est_coarse, axis=-1) / a_true.size, axis=-1)
b_sr_coarse = np.median(np.count_nonzero(b_est_coarse, axis=-1) / b_true.size, axis=-1)
a_sr_fine = np.median(np.count_nonzero(a_est_fine, axis=-1) / a_true.size, axis=-1)
b_sr_fine = np.median(np.count_nonzero(b_est_fine, axis=-1) / b_true.size, axis=-1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

plt.subplots_adjust(wspace=0.3)
img1 = axes[0].imshow(
    a_sr_coarse,
    origin='lower',
    vmin=0,
    vmax=1,
    cmap='Greys_r')
img2 = axes[1].imshow(
    b_sr_coarse,
    origin='lower',
    vmin=0,
    vmax=1,
    cmap='Greys_r')

for ax in axes:
    ax.scatter(best_mll_coarse[1], best_mll_coarse[0], color='red', marker='x', s=50)
    ax.scatter(best_aic_coarse[1], best_aic_coarse[0], color='green', marker='o', s=60)
    ax.scatter(best_bic_coarse[1], best_bic_coarse[0], color='blue', marker='^', s=50)
colorbar.append_colorbar_to_axis(axes[1], img2)

axes[0].set_title(r'\textbf{Coupling}', fontsize=20)
axes[1].set_title(r'\textbf{Tuning}', fontsize=20)
for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.append(np.arange(0, 20, 4), 19))
    ax.set_xticklabels(np.round(tuning_lambdas_coarse[ax.get_xticks()], decimals=3))
    ax.set_yticks(np.append(np.arange(0, 20, 4), 19))
    ax.set_yticklabels(np.round(coupling_lambdas_coarse[ax.get_yticks()], decimals=3))
    ax.set_xlabel(r'\textbf{Tuning Lambda}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Lambda}', fontsize=15)

plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

plt.subplots_adjust(wspace=0.3)

img1 = axes[0].imshow(
    a_sr_fine,
    origin='lower',
    vmin=0,
    vmax=1,
    cmap='Greys_r')
img2 = axes[1].imshow(
    b_sr_fine,
    origin='lower',
    vmin=0,
    vmax=1,
    cmap='Greys_r')

axes[0].set_title(r'\textbf{Coupling}', fontsize=20)
axes[1].set_title(r'\textbf{Tuning}', fontsize=20)

for ax in axes:
    ax.scatter(best_mll_fine[1], best_mll_fine[0], color='red', marker='x', s=50)
    ax.scatter(best_aic_fine[1], best_aic_fine[0], color='green', marker='o', s=60)
    ax.scatter(best_bic_fine[1], best_bic_fine[0], color='blue', marker='^', s=50)

    ax.tick_params(labelsize=15)
    ax.set_xticks(np.append(np.arange(0, 20, 4), 19))
    ax.set_xticklabels(np.round(tuning_lambdas_fine[ax.get_xticks()], decimals=3))
    ax.set_yticks(np.append(np.arange(0, 20, 4), 19))
    ax.set_yticklabels(np.round(coupling_lambdas_fine[ax.get_yticks()], decimals=3))
    ax.set_xlabel(r'\textbf{Tuning Lambda}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Lambda}', fontsize=15)

colorbar.append_colorbar_to_axis(axes[1], img2)

plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

plt.subplots_adjust(wspace=0.5)
img1 = axes[0].imshow(
    aics_coarse_med,
    origin='lower',
    vmax=np.percentile(aics_coarse_med.ravel(), q=25),
    cmap='Greys_r')
img2 = axes[1].imshow(
    aics_fine_med,
    origin='lower',
    vmax=np.percentile(aics_fine_med.ravel(), q=25),
    cmap='Greys_r')

axes[0].scatter(best_aic_coarse[1], best_aic_coarse[0], color='cyan', marker='x')
axes[1].scatter(best_aic_fine[1], best_aic_fine[0], color='cyan', marker='x')
colorbar.append_colorbar_to_axis(axes[0], img1)
colorbar.append_colorbar_to_axis(axes[1], img2)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(0, 20, 4))
    ax.set_yticks(np.arange(0, 20, 4))
    ax.set_xlabel(r'\textbf{Tuning Lambda}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Lambda}', fontsize=15)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

plt.subplots_adjust(wspace=0.5)
img1 = axes[0].imshow(
    bics_coarse_med,
    origin='lower',
    vmax=np.percentile(bics_coarse_med.ravel(), q=25),
    cmap='Greys_r')
img2 = axes[1].imshow(
    bics_fine_med,
    origin='lower',
    vmax=np.percentile(bics_fine_med.ravel(), q=25),
    cmap='Greys_r')

axes[0].scatter(best_bic_coarse[1], best_bic_coarse[0], color='cyan', marker='x')
axes[1].scatter(best_bic_fine[1], best_bic_fine[0], color='cyan', marker='x')
colorbar.append_colorbar_to_axis(axes[0], img1)
colorbar.append_colorbar_to_axis(axes[1], img2)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(0, 20, 4))
    ax.set_yticks(np.arange(0, 20, 4))
    ax.set_xlabel(r'\textbf{Tuning Lambda}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Lambda}', fontsize=15)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

plt.subplots_adjust(wspace=0.5)
img1 = axes[0].imshow(
    mlls_coarse_med,
    origin='lower',
    vmin=np.percentile(mlls_coarse_med.ravel(), q=50),
    cmap='Greys_r')
img2 = axes[1].imshow(
    mlls_fine_med,
    origin='lower',
    cmap='Greys_r')

axes[0].scatter(best_mll_coarse[1], best_mll_coarse[0], color='blue', marker='x')
axes[1].scatter(best_mll_fine[1], best_mll_fine[0], color='blue', marker='x')
colorbar.append_colorbar_to_axis(axes[0], img1)
colorbar.append_colorbar_to_axis(axes[1], img2)

for ax in axes:
    ax.tick_params(labelsize=15)
    ax.set_xticks(np.arange(0, 20, 4))
    ax.set_yticks(np.arange(0, 20, 4))
    ax.set_xlabel(r'\textbf{Tuning Lambda}', fontsize=15)
    ax.set_ylabel(r'\textbf{Coupling Lambda}', fontsize=15)
plt.show()