In [None]:
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
from matplotlib import rc
from neurobiases import (TriangularModel,
                         EMSolver,
                         em_utils,
                         solver_utils)
%matplotlib inline

In [None]:
# Create triangular model and draw data
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=10,
    N=10,
    K=1,
    corr_cluster=0.25,
    corr_back=0.1,
    coupling_distribution='gaussian',
    coupling_sparsity=0.5,
    coupling_loc=0,
    coupling_scale=0.25,
    coupling_rng=2332,
    tuning_distribution='gaussian',
    tuning_sparsity=0.5,
    tuning_loc=0,
    tuning_scale=0.25,
    tuning_rng=23456542,
    stim_distribution='uniform')
X, Y, y = tm.generate_samples(n_samples=1000, rng=2332)

In [None]:
# Run sparse solver
solver_sparse = EMSolver(
    X, Y, y, K=1,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-8,
    c_coupling=8e-2,
    c_tuning=1.3e-1,
    penalize_B=False,
    rng=948512
)
# Get initializations
params_init_sparse = solver_sparse.get_params()
# Fit EM using sparse optimizer
solver_sparse.fit_em(verbose=True, store_parameters=True, refit=False)
params_final_sparse = solver_sparse.get_params()
solver_sparse.set_masks(a_mask=solver_sparse.a.ravel() != 0, b_mask=solver_sparse.b.ravel() != 0)

In [None]:
# Run non-sparse solver (penalties = 0), suffers from identifiability
solver_nonsparse = EMSolver(
    X, Y, y, K=1,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-8,
    c_tuning=0,
    c_coupling=0,
    penalize_B=False,
    rng=948512
)
# Get initializations
params_init_nonsparse = solver_nonsparse.get_params()
# Fit EM using sparse optimizer
solver_nonsparse.fit_em(verbose=True, store_parameters=True, refit=False)
params_final_nonsparse = solver_nonsparse.get_params()

In [None]:
# Run non-sparse solver (penalties = 0), suffers from identifiability
solver_oracle = EMSolver(
    X, Y, y, K=1,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-8,
    a_mask=tm.a.ravel() != 0,
    b_mask=tm.b.ravel() != 0,
    c_tuning=0,
    c_coupling=0,
    penalize_B=False,
    rng=948512
)
# Get initializations
params_init_oracle = solver_oracle.get_params()
# Fit EM using sparse optimizer
solver_oracle.fit_em(verbose=True, store_parameters=True, refit=False)
params_final_oracle = solver_oracle.get_params()

In [None]:
n_points_sparse = solver_sparse.ll_path.size
interp_sparse = np.linspace(params_init_sparse, params_final_sparse, n_points_sparse)
mlls_interp_sparse = np.zeros(n_points_sparse)

for idx, interp in enumerate(interp_sparse):
    a, b, B, Psi_tr, L = solver_sparse.split_params(interp)
    mlls_interp_sparse[idx] = em_utils.marginal_log_likelihood_linear_tm(
        X=X, Y=Y, y=y, a=a, b=b, B=B, L=L,
        Psi=solver_sparse.Psi_tr_to_Psi(Psi_tr)
    )

In [None]:
n_points_nonsparse = solver_nonsparse.ll_path.size
interp_nonsparse = np.linspace(params_init_nonsparse, params_final_nonsparse, n_points_nonsparse)
mlls_interp_nonsparse = np.zeros(n_points_nonsparse)

for idx, interp in enumerate(interp_nonsparse):
    a, b, B, Psi_tr, L = solver_nonsparse.split_params(interp)
    mlls_interp_nonsparse[idx] = em_utils.marginal_log_likelihood_linear_tm(
        X=X, Y=Y, y=y, a=a, b=b, B=B, L=L,
        Psi=solver_nonsparse.Psi_tr_to_Psi(Psi_tr)
    )

In [None]:
n_points_oracle = solver_oracle.ll_path.size
interp_oracle = np.linspace(params_init_oracle, params_final_oracle, n_points_oracle)
mlls_interp_oracle = np.zeros(n_points_oracle)

for idx, interp in enumerate(interp_oracle):
    a, b, B, Psi_tr, L = solver_oracle.split_params(interp)
    mlls_interp_oracle[idx] = em_utils.marginal_log_likelihood_linear_tm(
        X=X, Y=Y, y=y, a=a, b=b, B=B, L=L,
        Psi=solver_oracle.Psi_tr_to_Psi(Psi_tr)
    )

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3))

for ax, solver, mlls in zip(axes,
                      [solver_sparse, solver_nonsparse, solver_oracle],
                      [mlls_interp_sparse, mlls_interp_nonsparse, mlls_interp_oracle]):
    ax.plot(solver.ll_path, color='k', linewidth=2)
    ax.plot(mlls, color='red', linewidth=2)
    ax.vlines(np.cumsum(solver.steps), 0, 20000, color='gray', linewidth=0.5)
    ax.set_xscale('log')
    
    ax.tick_params(labelsize=13)
    ax.set_xlabel(r'\textbf{Iterations}', fontsize=13)
    ax.set_ylabel(r'\textbf{Marginal log-likelihood}', fontsize=13)
    
axes[0].set_title(r'\textbf{Sparse}', fontsize=15)
axes[1].set_title(r'\textbf{Non-Sparse}', fontsize=15)
axes[2].set_title(r'\textbf{Oracle}', fontsize=15)

plt.tight_layout()
plt.savefig('ll_path.pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3))

for ax, solver, mlls in zip(axes,
                      [solver_sparse, solver_nonsparse, solver_oracle],
                      [mlls_interp_sparse, mlls_interp_nonsparse, mlls_interp_oracle]):
    ax.plot(solver.a_path, linewidth=2)
    ax.set_xscale('log')
    
    ax.tick_params(labelsize=13)
    ax.set_xlabel(r'\textbf{Iterations}', fontsize=13)
    ax.set_ylabel(r'\textbf{Coupling Parameters}', fontsize=13)
    ax.set_ylim([0, 0.75])
    
axes[0].set_title(r'\textbf{Sparse}', fontsize=15)
axes[1].set_title(r'\textbf{Non-Sparse}', fontsize=15)
axes[2].set_title(r'\textbf{Oracle}', fontsize=15)

plt.tight_layout()
plt.savefig('coupling.pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3))

for ax, solver, mlls in zip(axes,
                      [solver_sparse, solver_nonsparse, solver_oracle],
                      [mlls_interp_sparse, mlls_interp_nonsparse, mlls_interp_oracle]):
    ax.plot(solver.b_path, linewidth=2)
    ax.set_xscale('log')
    
    ax.tick_params(labelsize=13)
    ax.set_xlabel(r'\textbf{Iterations}', fontsize=13)
    ax.set_ylabel(r'\textbf{Tuning Parameters}', fontsize=13)
    ax.set_ylim([0, 1.5])
    
axes[0].set_title(r'\textbf{Sparse}', fontsize=15)
axes[1].set_title(r'\textbf{Non-Sparse}', fontsize=15)
axes[2].set_title(r'\textbf{Oracle}', fontsize=15)

plt.tight_layout()
plt.savefig('tuning.pdf', bbox_inches='tight')

In [None]:
hessian1 = solver_sparse.marginal_likelihood_hessian()
hessian2 = solver_sparse.marginal_likelihood_hessian(mask=True)
hessian3 = solver_nonsparse.marginal_likelihood_hessian()
hessian4 = solver_oracle.marginal_likelihood_hessian()
hessian5 = solver_oracle.marginal_likelihood_hessian(mask=True)

u1 = np.linalg.eigvalsh(hessian1)
u2 = np.linalg.eigvalsh(hessian2)
u3 = np.linalg.eigvalsh(hessian3)
u4 = np.linalg.eigvalsh(hessian4)
u5 = np.linalg.eigvalsh(hessian5)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

ax.plot(np.abs(u1), marker='o')
ax.plot(np.abs(u2), marker='o')
ax.plot(np.abs(u1[u1 < 0]), marker='o', color='black', markersize=7)
ax.set_yscale('log')

ax.tick_params(labelsize=14)
ax.set_xlabel(r'\textbf{Index}', fontsize=20)
ax.set_ylabel(r'$|$\textbf{Eigenvalue}$|$', fontsize=20)
ax.set_title(r'$K=1$', fontsize=20)
plt.savefig('k=1_eigs.pdf', bbox_inches='tight')

In [None]:
# Run sparse solver
test = EMSolver(
    X, Y, y, K=3,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-8,
    c_coupling=8e-2,
    c_tuning=1.3e-1,
    penalize_B=False,
    rng=948512,
    initialization='random'
)

In [None]:
hessian1 = test.marginal_likelihood_hessian()
u1 = np.linalg.eigvalsh(hessian1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].plot(np.abs(u1), marker='o')
axes[0].plot(np.abs(u1[u1 < 0]), marker='o', color='black', markersize=7)
axes[0].set_yscale('log')

axes[1].plot(np.abs(u1), marker='o')
axes[1].plot(np.abs(u1[u1 < 0]), marker='o', color='black', markersize=7)
axes[1].set_yscale('log')


axes[0].tick_params(labelsize=14)
axes[0].set_xlabel(r'\textbf{Index}', fontsize=20)
axes[0].set_ylabel(r'$|$\textbf{Eigenvalue}$|$', fontsize=20)
axes[0].set_title(r'$K=3$', fontsize=20)

axes[1].tick_params(labelsize=14)
axes[1].set_xlabel(r'\textbf{Index}', fontsize=20)
axes[1].set_ylabel(r'$|$\textbf{Eigenvalue}$|$', fontsize=20)
axes[1].set_title(r'$K=3$ (Zoomed in)', fontsize=20)
axes[1].set_xlim([20, 40])
plt.tight_layout()
plt.savefig('k=3_eigs.pdf', bbox_inches='tight')