In [None]:
import matplotlib.pyplot as plt
import mpl_lego as mplego
import numpy as np
import h5py

from mpl_lego import style
from mpl_toolkits.mplot3d import Axes3D
from mpl_lego.labels import bold_text
from neurobiases.utils import rotation
from neurobiases import EMSolver, TriangularModel
from sklearn.decomposition import PCA

In [None]:
style.use_latex_style()

In [None]:
# Create triangular model and draw data
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=10,
    N=10,
    K=1,
    corr_cluster=0.05,
    corr_back=0.,
    coupling_distribution='gaussian',
    coupling_sparsity=0.5,
    coupling_loc=0,
    coupling_scale=0.25,
    coupling_rng=2332,
    tuning_distribution='gaussian',
    tuning_sparsity=0.5,
    tuning_loc=0,
    tuning_scale=0.25,
    tuning_rng=23456542,
    stim_distribution='uniform')
X, Y, y = tm.generate_samples(n_samples=2000, rng=2332)

# Run only once

In [None]:
example = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=10,
    N=10,
    K=1,
    corr_cluster=0.05,
    corr_back=0.,
    coupling_distribution='gaussian',
    coupling_sparsity=0.5,
    coupling_loc=0,
    coupling_scale=0.25,
    coupling_rng=2332,
    tuning_distribution='gaussian',
    tuning_sparsity=0.5,
    tuning_loc=1.5,
    tuning_scale=0.25,
    tuning_rng=23456542,
    stim_distribution='uniform')

b_true = example.b.ravel()
b_mask = b_true != 0
a_true = example.a.ravel()
a_mask = a_true != 0

rng = np.random.default_rng(23332)
n_datasets = 10
n_inits = 30
a_est = np.zeros((n_inits, n_datasets, a_true.size))
b_est = np.zeros((n_inits, n_datasets, b_true.size))

for ii in range(n_inits):
    print(ii)
    seed = rng.integers(100000)
    for jj in range(n_datasets):
        X, Y, y = example.generate_samples(n_samples=2000, rng=rng)
        solver = EMSolver(
            X, Y, y, K=1,
            solver='scipy_lbfgs',
            max_iter=1000,
            tol=1e-7,
            penalize_B=False,
            rng=seed,
            fit_intercept=False,
            Psi_transform=None,
            initialization='random').fit_em()
        a_est[ii, jj] = solver.a.ravel()
        b_est[ii, jj] = solver.b.ravel()

In [None]:
a_bias = np.median(a_est.mean(axis=0)[:, a_mask] - a_true[a_mask], axis=1)
b_bias = np.median(b_est.mean(axis=0)[:, b_mask] - b_true[b_mask], axis=1)

In [None]:
with h5py.File('init_exp.h5', 'w') as file:
    file['a_bias'] = a_bias
    file['b_bias'] = b_bias

# Figure generation

In [None]:
with h5py.File('init_exp.h5', 'r+') as file:
    a_bias = file['a_bias'][:]
    b_bias = file['b_bias'][:]

In [None]:
# Run sparse solver
bound = EMSolver(
    X, Y, y, K=1,
    solver='scipy_lbfgs',
    max_iter=1000,
    tol=1e-7,
    penalize_B=False,
    rng=948512,
    fit_intercept=False,
    Psi_transform=None,
    initialization='random')
softplus = bound.copy()
softplus.Psi_transform = 'softplus'
softplus.Psi_tr = softplus.Psi_to_Psi_tr(bound.Psi_tr)
exp = bound.copy()
exp.Psi_transform = 'exp'
exp.Psi_tr = exp.Psi_to_Psi_tr(bound.Psi_tr)

In [None]:
n_deltas = 20
deltas_init = np.sort(np.insert(np.linspace(-1.1, 1.1, n_deltas), 0, 0))
n_deltas = deltas_init.size
init_params = np.zeros((n_deltas, bound.get_params().size))

for idx, delta in enumerate(deltas_init):
    copy = exp.copy()
    copy.identifiability_transform(delta)
    init_params[idx] = copy.get_params(return_Psi=True)

In [None]:
deltas_fit = np.linspace(-0.18, 0.18, n_deltas)
fit_ident_params = np.zeros((n_deltas, bound.get_params().size))

fit_solver = bound.copy().fit_em()

for idx, delta in enumerate(deltas_fit):
    copy = fit_solver.copy()
    copy.identifiability_transform(delta)
    fit_ident_params[idx] = copy.get_params(return_Psi=True)

In [None]:
pca = PCA(2)
fit_transformed = pca.fit_transform(fit_ident_params)

In [None]:
fit_params_bound = np.zeros((n_deltas, bound.get_params().size))

for idx, delta in enumerate(deltas_init):
    print(idx)
    # Bound optimization
    bound_copy = bound.copy()
    bound_copy.identifiability_transform(delta)
    bound_copy.fit_em()
    fit_params_bound[idx] = bound_copy.get_params(return_Psi=True)

In [None]:
fit_bound_transformed = pca.transform(fit_params_bound)

In [None]:
"""
Figure 8
"""
fig_width = 16
fig_height = 8.5
fig = plt.figure(figsize=(fig_width, fig_height))

bins = 4
tuning_color = 'C0'
coupling_color = 'C4'
parabola_color = 'black'

parabola_linewidth = 4
ident_curve_lw = 3

legend_size = 18

axis_label_size = 25
axis_tick_size = 23
axis_label_pad = 15
subplot_label_size = 30

x_base = 0
y_base = 0
base_offset = 1.25

"""
Figure 8a: 
Identifiability Surface (Toy)
"""
toy_surface_x = x_base
toy_surface_y = y_base
toy_surface_width = 6
toy_surface_height = 6
toy_surface_ax = fig.add_axes([toy_surface_x / fig_width,
                               toy_surface_y / fig_height,
                               toy_surface_width / fig_width,
                               toy_surface_height / fig_height],
                              projection='3d')
# Create surface
span = np.linspace(-1, 1, 1000)
x, y = np.meshgrid(span, span)
z = (x - y)**2
# Plot surface
surf = toy_surface_ax.plot_surface(x, y, z, cmap='viridis')

# Plot line at bottom
toy_surface_ax.plot(
    span,
    span,
    np.zeros(1000),
    lw=ident_curve_lw,
    color='red',
    zorder=1000)
# Plot line slightly above
line1 = np.linspace(-1, -0.25, 1000)
toy_surface_ax.plot(
    line1,
    1.25 + line1,
    1.25**2,
    lw=ident_curve_lw,
    color='red',
    zorder=1000)
# Plot parallel line slightly above
line2 = np.linspace(0.25, 1, 1000)
toy_surface_ax.plot(
    line2,
    line2 - 1.25,
    1.25**2,
    lw=ident_curve_lw,
    color='red',
    zorder=1000)

# Set axis labels
toy_surface_ax.set_xlabel(
    bold_text('Parameter 1'),
    fontsize=axis_label_size,
    labelpad=axis_label_pad + 5)
toy_surface_ax.set_ylabel(
    bold_text('Parameter 2'),
    fontsize=axis_label_size,
    labelpad=axis_label_pad + 5)
toy_surface_ax.zaxis.set_rotate_label(False)
toy_surface_ax.set_zlabel(
    bold_text('Log-likelihood'),
    fontsize=axis_label_size,
    labelpad=axis_label_pad,
    rotation=90)
toy_surface_ax.tick_params(labelsize=axis_tick_size)
toy_surface_ax.set_zticks([0., 1, 2, 3, 4])
toy_surface_ax.view_init(30, 30)


"""
Figure 8b:
"Bias" from lack of selection
"""
ident_bias_gap = 2
ident_bias_y_offset = base_offset
ident_bias_x = x_base + toy_surface_width + ident_bias_gap
ident_bias_y = y_base + ident_bias_y_offset
ident_bias_width = 6
ident_bias_height = 4

ident_bias_ax = fig.add_axes([ident_bias_x / fig_width,
                              ident_bias_y / fig_height,
                              ident_bias_width / fig_width,
                              ident_bias_height / fig_height])
ident_bias_ax.hist(a_bias,
                   bins=bins,
                   color=coupling_color,
                   label=bold_text('Coupling'))
ident_bias_ax.hist(b_bias,
                   bins=bins,
                   color=tuning_color,
                   label=bold_text('Tuning'))
ident_bias_ax.axvline(0, color='gray', linestyle='--')

ident_bias_ax.legend(loc=0,
                     prop={'size': legend_size})

ident_bias_ax.set_xlim([-0.75, 0.75])
ident_bias_ax.set_xticks([-0.5, -0.25, 0, 0.25, 0.5])
ident_bias_ax.set_yticks([0, 1, 2, 3, 4])
ident_bias_ax.tick_params(labelsize=axis_tick_size)
ident_bias_ax.set_xlabel(
    bold_text('Bias'),
    fontsize=axis_label_size)
ident_bias_ax.set_ylabel(
    bold_text('Frequency'),
    fontsize=axis_label_size,
    labelpad=axis_label_pad)


"""
Figure 8c:
CoTuLa Identifiability Curve
"""
ident_cotula_x_offset = -3.5
ident_cotula_y_offset = -6.5
ident_cotula_x = x_base + ident_cotula_x_offset
ident_cotula_y = y_base + ident_cotula_y_offset
ident_cotula_width = 6
ident_cotula_height = 6

ident_cotula_ax = fig.add_axes([ident_cotula_x / fig_width,
                                ident_cotula_y / fig_height,
                                ident_cotula_width / fig_width,
                                ident_cotula_height / fig_height],
                               projection='3d')
# Plot flat surface
rng = np.random.RandomState(23)
span = np.linspace(-5, 5, 100)
x, y = np.meshgrid(span, span)
z = np.zeros_like(x)
surf = ident_cotula_ax.plot_surface(
    x, y, z,
    alpha=0.5,
    label=r'$\Psi_t=0$')
surf._edgecolors2d = surf._edgecolor3d
surf._facecolors2d = surf._facecolor3d

ident_cotula_ax.set_zlim(-0.1, 1.1)
ident_cotula_ax.set_xlabel(r'$P_0[a,b_t,l_t]$',
                           size=axis_label_size,
                           labelpad=axis_label_pad + 10)
ident_cotula_ax.set_ylabel(r'$P_1[a,b_t,l_t]$',
                           size=axis_label_size,
                           labelpad=axis_label_pad + 10)
ident_cotula_ax.zaxis.set_rotate_label(False)
ident_cotula_ax.set_zlabel(r'$\Psi_t$',
                           size=axis_label_size,
                           labelpad=axis_label_pad + 10,
                           rotation=0)

n_points = 21
parabola = np.stack([np.zeros(n_points), np.linspace(-2, 2, n_points)], axis=0)
z = -parabola[1]**2 / 4.0 + 1
parabola = rotation(np.pi / 3.).dot(parabola)
ident_cotula_ax.plot(
    parabola[0], parabola[1], z,
    label=bold_text('Identifiability\nSubspace'),
    color=parabola_color,
    linewidth=parabola_linewidth)
ident_cotula_ax.tick_params(labelsize=axis_tick_size, pad=10)
ident_cotula_ax.legend(loc='center right',
                       bbox_to_anchor=(0.95, 0.82),
                       prop={'size': legend_size})

"""
Figure 8d:
Experiment mock-up
"""
ident_exp_x = ident_cotula_x + ident_cotula_width + base_offset
ident_exp_y = ident_cotula_y
ident_exp_width = 6
ident_exp_height = 6

ident_exp_ax = fig.add_axes([ident_exp_x / fig_width,
                             ident_exp_y / fig_height,
                             ident_exp_width / fig_width,
                             ident_exp_height / fig_height],
                            projection='3d')

rng = np.random.RandomState(23)
span = np.linspace(-5, 5, 100)
x, y = np.meshgrid(span, span)
z = np.zeros_like(x)
surf = ident_exp_ax.plot_surface(
    x, y, z,
    alpha=0.5,
    label=r'$\Psi_t=0$')
surf._edgecolors2d = surf._edgecolor3d
surf._facecolors2d = surf._facecolor3d

ident_exp_ax.set_zlim(-0.1, 1.1)
ident_exp_ax.set_xlabel(r'$P_0[a,b_t,l_t]$',
                        size=axis_label_size,
                        labelpad=axis_label_pad + 10)
ident_exp_ax.set_ylabel(r'$P_1[a,b_t,l_t]$',
                        size=axis_label_size,
                        labelpad=axis_label_pad + 10)
ident_exp_ax.zaxis.set_rotate_label(False)
ident_exp_ax.set_zlabel(r'$\Psi_t$',
                        size=axis_label_size,
                        labelpad=axis_label_pad + 10,
                        rotation=0)

pts = 21
pl = np.stack([np.zeros(pts), np.linspace(-2, 2, pts)], axis=0)
z = -pl[1]**2 / 4. + 1
pl = rotation(np.pi/4.).dot(pl) + np.array([2, 2])[:, np.newaxis]
ident_exp_ax.plot(
    pl[0],
    pl[1],
    z,
    label=bold_text('Initialization'),
    linewidth=parabola_linewidth)
ident_exp_ax.scatter(
    pl[0, ::4],
    pl[1, ::4],
    z[::4],
    marker='x',
    color=parabola_color,
    s=60,
    zorder=1000)

pl = np.stack([np.zeros(pts), np.linspace(-np.sqrt(2), np.sqrt(2), pts)], axis=0)
z = -pl[1]**2 / 4. + .5
pl = rotation(np.pi/4.).dot(pl) + np.array([-2, -2])[:, np.newaxis]
ident_exp_ax.plot(
    pl[0],
    pl[1],
    z,
    label=bold_text('Fitted'),
    linewidth=parabola_linewidth)
pln = pl + rng.randn(*pl.shape) / 20.
zn = z + rng.rand(*z.shape) / 20.
ident_exp_ax.scatter(
    pln[0, 4:-4:2],
    pln[1, 4:-4:2],
    zn[4:-4:2],
    marker='x',
    c=parabola_color,
    s=50)
ident_exp_ax.legend(loc='center right',
                    bbox_to_anchor=(0.95, 0.82),
                    prop={'size': legend_size})
ident_exp_ax.tick_params(labelsize=axis_tick_size, pad=10)

"""
Figure 8e:
Empirical results
"""
pc_x = ident_exp_x + ident_exp_width + base_offset + 1.45
pc_y = ident_exp_y + 1
pc_width = 4.25
pc_height = 4.25
pc_ax = fig.add_axes([pc_x / fig_width,
                      pc_y / fig_height,
                      pc_width / fig_width,
                      pc_height / fig_height])
pc_ax.plot(
    fit_transformed[:, 0],
    -fit_transformed[:, 1],
    color='green',
    lw=3,
    alpha=0.5,
    label='Identifiability Family')
pc_ax.scatter(
    fit_bound_transformed[:, 0],
    -fit_bound_transformed[:, 1],
    marker='x',
    color='black',
    s=100,
    alpha=1,
    label='No Transform')
pc_ax.tick_params(labelsize=15)
pc_ax.set_xlabel(bold_text('Principal Component 1'), fontsize=axis_label_size)
pc_ax.set_ylabel(bold_text('Principal Component 2'), fontsize=axis_label_size)
pc_ax.tick_params(labelsize=axis_tick_size)


axes_list = [ident_bias_ax,
             toy_surface_ax,
             ident_cotula_ax,
             ident_exp_ax,
             pc_ax]

fig.text(
    (ident_bias_x - 0.50) / fig_width,
    (ident_bias_y + ident_bias_height + 0.25) / fig_height,
    s=bold_text('b'),
    fontsize=subplot_label_size)

fig.text(
    (toy_surface_x - 0.50) / fig_width,
    (ident_bias_y + ident_bias_height + 0.25) / fig_height,
    s=bold_text('a'),
    fontsize=subplot_label_size)

fig.text(
    (ident_cotula_x + 0.25) / fig_width,
    (ident_cotula_y + ident_cotula_height - 0.50) / fig_height,
    s=bold_text('c'),
    fontsize=subplot_label_size)

fig.text(
    (ident_exp_x + 0.25) / fig_width,
    (ident_cotula_y + ident_cotula_height - 0.50) / fig_height,
    s=bold_text('d'),
    fontsize=subplot_label_size)

fig.text(
    (pc_x - 0.50) / fig_width,
    (ident_cotula_y + ident_cotula_height - 0.50) / fig_height,
    s=bold_text('e'),
    fontsize=subplot_label_size)

plt.savefig('figure8.pdf', bbox_inches='tight')