In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy

from importlib import reload
from matplotlib import rc
from neurobiases import EMSolver, TriangularModel
%matplotlib inline

In [None]:
K = 1

In [None]:
# Create triangular model and draw data
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=10,
    N=10,
    K=K,
    corr_cluster=0.25,
    corr_back=0.1,
    coupling_distribution='gaussian',
    coupling_sparsity=0.5,
    coupling_loc=0,
    coupling_scale=0.25,
    coupling_rng=2332,
    tuning_distribution='gaussian',
    tuning_sparsity=0.5,
    tuning_loc=0,
    tuning_scale=0.25,
    tuning_rng=23456542,
    stim_distribution='uniform')
X, Y, y = tm.generate_samples(n_samples=2000, rng=2332)

## Curvature along identifiability family (sparse)

In [None]:
# Run sparse solver
solver = EMSolver(
    X, Y, y, K=K,
    solver='ow_lbfgs',
    max_iter=1000,
    a_mask=tm.a.ravel() != 0,
    b_mask=tm.b.ravel() != 0,
    tol=1e-7,
    penalize_B=False,
    rng=948512,
    fit_intercept=False,
    initialization='random').fit_em()

In [None]:
hessian = solver.marginal_likelihood_hessian(mask='L', wrt_Psi=True)
eigs = np.linalg.eigvalsh(hessian)

In [None]:
eigs_tr = []
deltas = np.linspace(-0.31, 0.20, 20)

for delta in deltas:
    print(delta)
    copy = solver.copy()
    copy.identifiability_transform(delta)
    hessian_copy = copy.marginal_likelihood_hessian(mask='L')
    eigs_tr.append(np.linalg.eigvalsh(hessian_copy))

In [None]:
plt.hist([np.median(eig) for eig in eigs_tr])
plt.axvline(np.median(eigs), color='k')

In [None]:
[plt.plot(eig) for eig in eigs_tr]
plt.yscale('log')
plt.plot(eigs, color='k', linewidth=3)

## Curvature along identifiability family (sparse)

In [None]:
# Run sparse solver
solver = EMSolver(
    X, Y, y, K=K,
    solver='ow_lbfgs',
    max_iter=1000,
    tol=1e-7,
    penalize_B=False,
    rng=94822,
    fit_intercept=False,
    initialization='random').fit_em()

In [None]:
hessian = solver.marginal_likelihood_hessian(mask=True, wrt_Psi=True)
eig_vecs = np.zeros((20, hessian.shape[0], hessian.shape[1]))
eigs = np.linalg.eigvalsh(hessian)

In [None]:
eigs_tr = []
deltas = np.linspace(-0.25, 0.25, 20)

for idx, delta in enumerate(deltas):
    print(delta)
    copy = solver.copy()
    copy.identifiability_transform(delta)
    hessian_copy = copy.marginal_likelihood_hessian(mask=True)
    eig_vecs[idx] = np.linalg.eigh(hessian_copy)[1]
    eigs_tr.append(np.linalg.eigvalsh(hessian_copy))

In [None]:
plt.hist([np.median(eig) for eig in eigs_tr])
plt.axvline(np.median(eigs), color='k')

In [None]:
[plt.plot(eig) for eig in eigs_tr]
plt.yscale('log')
plt.plot(eigs, color='k', linewidth=3)