In [None]:
import matplotlib.pyplot as plt
import mpl_lego as mplego
import numpy as np

from mpl_lego import style
from mpl_toolkits.mplot3d import Axes3D
from mpl_lego.labels import bold_text
from neurobiases.utils import rotation
from neurobiases import EMSolver, TriangularModel
from sklearn.decomposition import PCA

In [None]:
style.use_latex_style()

In [None]:
# Create triangular model and draw data
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=10,
    N=10,
    K=1,
    corr_cluster=0.05,
    corr_back=0.,
    coupling_distribution='gaussian',
    coupling_sparsity=0.5,
    coupling_loc=0,
    coupling_scale=0.25,
    coupling_rng=2332,
    tuning_distribution='gaussian',
    tuning_sparsity=0.5,
    tuning_loc=0,
    tuning_scale=0.25,
    tuning_rng=23456542,
    stim_distribution='uniform')
X, Y, y = tm.generate_samples(n_samples=2000, rng=2332)

In [None]:
# Run sparse solver
bound = EMSolver(
    X, Y, y, K=1,
    solver='scipy_lbfgs',
    max_iter=1000,
    tol=1e-7,
    penalize_B=False,
    rng=948512,
    fit_intercept=False,
    Psi_transform=None,
    initialization='random')
softplus = bound.copy()
softplus.Psi_transform = 'softplus'
softplus.Psi_tr = softplus.Psi_to_Psi_tr(bound.Psi_tr)
exp = bound.copy()
exp.Psi_transform = 'exp'
exp.Psi_tr = exp.Psi_to_Psi_tr(bound.Psi_tr)

In [None]:
n_deltas = 20
deltas_init = np.sort(np.insert(np.linspace(-1.1, 1.1, n_deltas), 0, 0))
n_deltas = deltas_init.size
init_params = np.zeros((n_deltas, bound.get_params().size))

for idx, delta in enumerate(deltas_init):
    copy = exp.copy()
    copy.identifiability_transform(delta)
    init_params[idx] = copy.get_params(return_Psi=True)

In [None]:
deltas_fit = np.linspace(-0.18, 0.18, n_deltas)
fit_ident_params = np.zeros((n_deltas, bound.get_params().size))

fit_solver = bound.copy().fit_em()

for idx, delta in enumerate(deltas_fit):
    copy = fit_solver.copy()
    copy.identifiability_transform(delta)
    print(copy.Psi_tr_to_Psi()[0])
    fit_ident_params[idx] = copy.get_params(return_Psi=True)

In [None]:
pca = PCA(2)
fit_transformed = pca.fit_transform(fit_ident_params)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))

ax.plot(
    fit_transformed[:, 0],
    fit_transformed[:, 1],
    marker='o',
    color='black',
    lw=3,
    alpha=0.2,
    label='Identifiability Family')

In [None]:
fit_params_bound = np.zeros((n_deltas, bound.get_params().size))

for idx, delta in enumerate(deltas_init):
    print(idx)
    # Bound optimization
    bound_copy = bound.copy()
    bound_copy.identifiability_transform(delta)
    bound_copy.fit_em()
    fit_params_bound[idx] = bound_copy.get_params(return_Psi=True)

In [None]:
fit_bound_transformed = pca.transform(fit_params_bound)

In [None]:
fig_width = 16
fig_height = 8.5

fig = plt.figure(figsize=(fig_width, fig_height))


ident_width = 6
ident_height = 6
ident = fig.add_axes([0, 0.75, ident_width / fig_width, ident_height / fig_height], projection='3d')

span = np.linspace(-1, 1, 1000)
x, y = np.meshgrid(span, span)
z = (x-y)**2
surf = ident.plot_surface(x, y, z, color='gray')
ident.plot(span, span, np.zeros(1000), lw=3, color='red', zorder=1000)

line1 = np.linspace(-1, 0.5, 1000)
ident.plot(line1, 0.5 + line1, 0.5**2, lw=3, color='red', zorder=1000)
line2 = np.linspace(-0.5, 1, 1000)
ident.plot(line2, line2 - 0.5, 0.5**2, lw=3, color='red', zorder=1000)

ident.set_xlabel(r'\textbf{Parameter 1}', fontsize=18, labelpad=10)
ident.set_ylabel(r'\textbf{Parameter 2}', fontsize=18, labelpad=10)
ident.zaxis.set_rotate_label(False)
ident.set_zlabel(r'\textbf{Log-likelihood}', fontsize=18, labelpad=10, rotation=90)
ident.set_zticks([0., 1, 2, 3, 4])
ident.view_init(30, 30)


triang = fig.add_axes([ident_width / fig_width + 0.05, 0.75, ident_width / fig_width, ident_height / fig_height], projection='3d')
rng = np.random.RandomState(23)
span = np.linspace(-5, 5, 100)
x, y = np.meshgrid(span, span)
z = np.zeros_like(x)

surf = triang.plot_surface(x, y, z, alpha=0.5, label=r'$\Psi_t=0$')
surf._edgecolors2d = surf._edgecolor3d
surf._facecolors2d = surf._facecolor3d
triang.set_zlim(-.1, 1.1)
triang.set_xlabel(r'$P_0[a,b_t,l_t]$', size=20, labelpad=10)
triang.set_ylabel(r'$P_1[a,b_t,l_t]$', size=20, labelpad=10)
triang.zaxis.set_rotate_label(False)
triang.set_zlabel(r'$\Psi_t$', size=20, rotation=0)

pts = 21
pl = np.stack([np.zeros(pts), np.linspace(-2, 2, pts)], axis=0)
z = -pl[1]**2 / 4. + 1
pl = rotation(np.pi/3.).dot(pl)
triang.plot(pl[0], pl[1], z, label=r'\textbf{Identifiability}''\n'r'\textbf{subspace}', c='black', lw=4)
triang.legend(loc='best', prop={'size': 16})


rng = np.random.RandomState(23)
span = np.linspace(-5, 5, 100)
x, y = np.meshgrid(span, span)
z = np.zeros_like(x)

exp = fig.add_axes([0, 0, ident_width / fig_width, ident_height / fig_height], projection='3d')
surf = exp.plot_surface(x, y, z, alpha=0.5, zorder=-1)
surf._edgecolors2d = surf._edgecolor3d
surf._facecolors2d = surf._facecolor3d
exp.set_zlim(-.1, 1.1)
exp.set_xlabel(r'$P_0[a,b_t,l_t]$', size=20, labelpad=10)
exp.set_ylabel(r'$P_1[a,b_t,l_t]$', size=20, labelpad=10)
exp.zaxis.set_rotate_label(False)
exp.set_zlabel(r'$\Psi_t$', size=20, rotation=0)

pts = 21
pl = np.stack([np.zeros(pts), np.linspace(-2, 2, pts)], axis=0)
z = -pl[1]**2 / 4. + 1
pl = rotation(np.pi/4.).dot(pl) + np.array([2, 2])[:, np.newaxis]
exp.plot(pl[0], pl[1], z, label=r'\textbf{Initialization}', lw=4)
exp.scatter(pl[0, ::4], pl[1, ::4], z[::4], marker='x', c='k', s=60, zorder=1000)
pl1 = np.copy(pl)

pl = np.stack([np.zeros(pts), np.linspace(-np.sqrt(2), np.sqrt(2), pts)], axis=0)
z = -pl[1]**2 / 4. + .5
pl = rotation(np.pi/4.).dot(pl) + np.array([-2, -2])[:, np.newaxis]
exp.plot(pl[0], pl[1], z, label=r'\textbf{Fitted}', lw=4)
pln = pl + rng.randn(*pl.shape) / 20.
zn = z + rng.rand(*z.shape) / 20.
exp.scatter(pln[0, 4:-4:2], pln[1, 4:-4:2], zn[4:-4:2], marker='x', c='k', s=50)
exp.legend(loc='best', prop={'size': 15})

pcs = fig.add_axes([ident_width / fig_width + 0.12, 0.05, 0.80 * ident_width / fig_width, 0.80 * ident_height / fig_height])
pcs.plot(
    fit_transformed[:, 0],
    -fit_transformed[:, 1],
    color='green',
    lw=3,
    alpha=0.5,
    label='Identifiability Family')
pcs.scatter(
    fit_bound_transformed[:, 0],
    -fit_bound_transformed[:, 1],
    marker='x',
    color='black',
    s=100,
    alpha=1,
    label='No Transform')
pcs.tick_params(labelsize=15)
pcs.set_xlabel(r'\textbf{Principal Component 1}', fontsize=15)
pcs.set_ylabel(r'\textbf{Principal Component 2}', fontsize=15)

fig.text(-0.01, 1.4, s=bold_text('a'), fontsize=27)
fig.text(0.43, 1.4, s=bold_text('b'), fontsize=27)
fig.text(-0.01, 0.7, s=bold_text('c'), fontsize=27)
fig.text(0.43, 0.7, s=bold_text('d'), fontsize=27)

plt.savefig('figure7.pdf', bbox_inches='tight')