In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy

from importlib import reload
from matplotlib import rc
from neurobiases import (TriangularModel,
                         EMSolver,
                         em_utils,
                         solver_utils)
%matplotlib inline

In [None]:
plt.rcParams.update({"text.usetex": True, "font.family": "serif"})

In [None]:
# Choose latent dimensionality
K = 1

In [None]:
# Create triangular model and draw data
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=10,
    N=10,
    K=K,
    corr_cluster=0.25,
    corr_back=0.1,
    coupling_distribution='gaussian',
    coupling_sparsity=0.5,
    coupling_loc=0,
    coupling_scale=0.25,
    coupling_rng=1307295399,
    tuning_distribution='gaussian',
    tuning_sparsity=0.5,
    tuning_loc=0,
    tuning_scale=0.25,
    tuning_rng=184955555,
    stim_distribution='uniform')
X, Y, y = tm.generate_samples(n_samples=500, rng=2105532715)

In [None]:
# Run sparse solver
solver1 = EMSolver(
    X, Y, y, K=K,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-7,
    c_coupling=100,
    c_tuning=0.7847599703514607,
    penalize_B=False,
    rng=4131998,
    initialization='fits',
    fa_rng=2332)
solver2 = EMSolver(
    X, Y, y, K=K,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-7,
    c_coupling=100,
    c_tuning=0.7847599703514607,
    penalize_B=False,
    rng=4131998,
    initialization='fits',
    fa_rng=2332)

In [None]:
solver1.fit_em(numpy=True, verbose=True, index=False, refit=True)

In [None]:
solver2.fit_em(numpy=False, verbose=True, index=False, refit=True)

In [None]:
print(solver2.a.ravel())
print(solver2.b.ravel())

In [None]:
for rep in range(10):
    print(rep)
    solver.fit_em(index=False, refit=True)
    mu, zz, sigma = solver.e_step()

    f1, grad1 = solver.f_df_em(
        solver.get_params(),
        X, Y, y,
        a_mask=solver.a_mask,
        b_mask=solver.b_mask,
        B_mask=solver.B_mask,
        train_B=solver.train_B,
        train_L_nt=solver.train_L_nt,
        train_L=solver.train_L,
        train_Psi_tr_nt=solver.train_Psi_tr_nt,
        train_Psi_tr=solver.train_Psi_tr,
        Psi_transform=solver.Psi_transform,
        mu=mu, zz=zz, sigma=sigma,
        tuning_to_coupling_ratio=1,
        penalize_B=False,
        wrt_Psi=True)
    f, grad = solver._f_df_em(
        solver.get_params(),
        X, Y, y,
        a_mask=solver.a_mask,
        b_mask=solver.b_mask,
        B_mask=solver.B_mask,
        train_B=solver.train_B,
        train_L_nt=solver.train_L_nt,
        train_L=solver.train_L,
        train_Psi_tr_nt=solver.train_Psi_tr_nt,
        train_Psi_tr=solver.train_Psi_tr,
        Psi_transform=solver.Psi_transform,
        mu=mu, zz=zz, sigma=sigma,
        tuning_to_coupling_ratio=1,
        penalize_B=False,
        wrt_Psi=True)
    np.testing.assert_allclose(f, f1)
    np.testing.assert_allclose(grad, grad1)

In [None]:
solver.fit_em(refit=True)

In [None]:
print(solver.a.ravel())
print(solver.b.ravel())

In [None]:
print(solver.a.ravel())
print(solver.b.ravel())

In [None]:
deltas = np.identity(K) / 100.
dx_params = np.zeros((K, solver.get_params().size))

for idx, delta in enumerate(deltas):
    copy = solver.copy()
    copy.identifiability_transform(delta=delta)
    dx_params[idx] = copy.get_params(return_Psi=True) - solver.get_params(return_Psi=True)
dx_params = scipy.linalg.orth(dx_params.T).T

In [None]:
hessian_pre = solver.marginal_likelihood_hessian(wrt_Psi=True)
u_pre, v_pre = np.linalg.eigh(hessian_pre)


In [None]:
plt.plot(np.abs(u_pre), color='k', marker='o', markersize=4)
plt.axhline(np.abs(np.dot(np.dot(hessian_pre, dx_params[0]), dx_params[0])))
plt.axhline(np.abs(np.dot(np.dot(hessian_pre, dx_params[1]), dx_params[1])))
plt.axhline(np.abs(np.dot(np.dot(hessian_pre, dx_params[2]), dx_params[2])))
plt.axhline(np.abs(np.dot(np.dot(hessian_pre, dx_params[3]), dx_params[3])))

plt.yscale('log')

In [None]:
np.dot(np.dot(hessian_pre, dx_params[1]), dx_params[1])

In [None]:
np.abs(np.dot(np.dot(hessian_pre, dx_params[2]), dx_params[2]))

In [None]:
np.abs(np.dot(np.dot(hessian_pre, dx_params[1]), dx_params[1]))

In [None]:
base_delta = np.random.randn(K) / 100.
# Hessian at initialization
hessian_pre = solver.marginal_likelihood_hessian(wrt_Psi=True)
u_pre, v_pre = np.linalg.eigh(hessian_pre)
# dParams initialization
copy = solver.copy()
copy.identifiability_transform(delta=base_delta)
dx_params_pre = copy.get_params(return_Psi=True) - solver.get_params(return_Psi=True)
dx_params_pre = dx_params_pre / np.linalg.norm(dx_params_pre)
# Fit EM
solver.fit_em()
# Hessian after fitting
hessian_post = solver.marginal_likelihood_hessian(wrt_Psi=True, mask=False)
u_post, v_post = np.linalg.eigh(hessian_post)
# dX params Post
copy = solver.copy()
copy.identifiability_transform(delta=base_delta)
dx_params_post = copy.get_params(return_Psi=True) - solver.get_params(return_Psi=True)
dx_params_post = dx_params_post / np.linalg.norm(dx_params_post)
# Hessian after fitting, without L
hessian_post_L = solver.marginal_likelihood_hessian(wrt_Psi=True, mask='L')
u_post_L, v_post_L = np.linalg.eigh(hessian_post_L)
L_idx = np.arange(solver.N + solver.M + solver.N * solver.M + solver.N + 1, solver.get_params().shape[0])
L_idx = np.delete(L_idx, np.arange(0, L_idx.size, solver.N + 1))
dx_params_post_L = np.delete(dx_params_post, L_idx)
# Hessian after fitting, without L and sparse
hessian_post_all = solver.marginal_likelihood_hessian(wrt_Psi=True, mask=True)
u_post_all, v_post_all = np.linalg.eigh(hessian_post_all)
a_idx = np.argwhere(solver.a_mask.ravel() == 0).ravel()
b_idx = solver.N + np.argwhere(solver.b_mask.ravel() == 0).ravel()
idx = np.concatenate((a_idx, b_idx, L_idx))
dx_params_post_all = np.delete(dx_params_post, idx)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(14, 3), sharey=True)

axes[0].plot(np.abs(u_pre), color='k', marker='o', markersize=4)
axes[0].plot(np.abs(u_pre[u_pre < 0]), color='blue', marker='o', markersize=4)
axes[0].axhline(np.abs(np.dot(np.dot(hessian_pre, dx_params_pre), dx_params_pre)), color='red')

axes[1].plot(np.abs(u_post), color='k', marker='o', markersize=4)
axes[1].plot(np.abs(u_post[u_post < 0]), color='blue', marker='o', markersize=4)
axes[1].axhline(np.abs(np.dot(np.dot(hessian_post, dx_params_post), dx_params_post)), color='red')

axes[2].plot(np.abs(u_post_L), color='k', marker='o', markersize=4)
axes[2].plot(np.abs(u_post_L[u_post_L < 0]), color='blue', marker='o', markersize=4)
axes[2].axhline(np.abs(np.dot(np.dot(hessian_post_L, dx_params_post_L), dx_params_post_L)), color='red')

axes[3].plot(np.abs(u_post_all), color='k', marker='o', markersize=4)
axes[3].axhline(np.abs(np.dot(np.dot(hessian_post_all, dx_params_post_all), dx_params_post_all)), color='red')

axes[0].set_ylabel(fr'$K={K}$' '\n' r'$|$\textbf{Eigenvalue}$|$', fontsize=16)

axes[0].set_title(r'\textbf{Initialization}', fontsize=15)
axes[1].set_title(r'\textbf{Trained}', fontsize=15)
axes[2].set_title(r'\textbf{Trained, no L}', fontsize=15)
axes[3].set_title(r'\textbf{Trained, no identifiability}', fontsize=15)


for ax in axes:
    ax.set_yscale('log')
    ax.set_ylim([1e-2, 1e10])
    ax.tick_params(labelsize=14)
    ax.set_xlabel(r'\textbf{Eigenvalue Index}', fontsize=14)

plt.tight_layout()
# plt.savefig(f'hessian_K={K}_eigs.pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(14, 3), sharey=True)

axes[0].plot(np.dot(v_pre, dx_params_pre)**2, color='k')
axes[1].plot(np.dot(v_post, dx_params_post)**2, color='k')
axes[2].plot(np.dot(v_post_L, dx_params_post_L)**2, color='k')
axes[3].plot(np.dot(v_post_all, dx_params_post_all)**2, color='k')

axes[0].set_ylabel(fr'$K={K}$' '\n' r'\textbf{Eigenvector Overlap}', fontsize=16)

for ax in axes:
    #ax.set_yscale('log')
    ax.set_ylim(bottom=-0.05)
    ax.tick_params(labelsize=14)
    ax.set_xlabel(r'\textbf{Eigenvalue Index}', fontsize=14)

axes[0].set_title(r'\textbf{Initialization}', fontsize=15)
axes[1].set_title(r'\textbf{Trained}', fontsize=15)
axes[2].set_title(r'\textbf{Trained, no L}', fontsize=15)
axes[3].set_title(r'\textbf{Trained, no identifiability}', fontsize=15)

plt.tight_layout()
plt.savefig(f'hessian_K={K}_overlap.pdf', bbox_inches='tight')