In [None]:
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
from matplotlib import rc
from neurobiases import (TriangularModel,
                         EMSolver,
                         TCSolver,
                         solver_utils,
                         plot)

%matplotlib inline

In [None]:
rc('text', usetex=True)
rc('font', **{'family': 'serif'})

In [None]:
tuning_kwargs, coupling_kwargs, noise_kwargs, stim_kwargs = \
    TriangularModel.generate_kwargs(
        parameter_design='direct_response',
        M=10, N=15, K=1, corr_cluster=0.4, corr_back=0.1,
        tuning_sparsity=0.5, coupling_sparsity=0.5,
        tuning_random_state=2332, coupling_random_state=2332)

In [None]:
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    tuning_kwargs=tuning_kwargs,
    coupling_kwargs=coupling_kwargs,
    noise_kwargs=noise_kwargs,
    stim_kwargs=stim_kwargs)
tm.plot_tuning_curves(neuron='all')
plt.show()

In [None]:
X, Y, y = tm.generate_samples(n_samples=1000, random_state=2332)

# Examine sparsity after short optimization under a variety of configurations

In [None]:
n_lambdas = 30
tuning_lambdas = np.insert(np.logspace(-3, -1, n_lambdas), 0, 0)
coupling_lambdas = np.insert(np.logspace(-2, 2, n_lambdas), 0, 0)
n_lambdas += 1

In [None]:
a_srs = np.zeros((n_lambdas, n_lambdas))
b_srs = np.zeros_like(a_srs)
B_srs = np.zeros_like(a_srs)

In [None]:
solver = EMSolver(
    X, Y, y, K=1,
    solver='fista',
    max_iter=20,
    tol=0,
    c_coupling=0,
    c_tuning=10
).fit_em()

In [None]:
%%time
for idx1, tuning_lambda in enumerate(tuning_lambdas):
    for idx2, coupling_lambda in enumerate(coupling_lambdas):
        print(tuning_lambda, coupling_lambda)
        solver = EMSolver(
            X, Y, y, K=1,
            solver='fista',
            max_iter=20,
            tol=0,
            c_coupling=coupling_lambda,
            c_tuning=tuning_lambda
        ).fit_em()
        a_srs[idx1, idx2] = np.count_nonzero(solver.a) / tm.N
        b_srs[idx1, idx2] = np.count_nonzero(solver.b) / tm.M
        B_srs[idx1, idx2] = np.count_nonzero(solver.B) / (tm.N * tm.M)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
plt.subplots_adjust(wspace=0.4)

axes[0].imshow(a_srs, origin='lower', vmin=0, vmax=1)
axes[1].imshow(b_srs, origin='lower', vmin=0, vmax=1)
img = axes[2].imshow(B_srs, origin='lower', vmin=0, vmax=1)

cax = fig.add_axes([axes[2].get_position().x1 + 0.01,
                    axes[2].get_position().y0,
                    0.015,
                    axes[0].get_position().height])
cb = fig.colorbar(img, cax=cax)
cb.ax.tick_params(labelsize=18)
cb.set_label(r'\textbf{Selection Ratio}', fontsize=18,
             labelpad=22,
             rotation=270)
for ax in axes:
    ax.set_xticks([0, 9, 17, 25])
    ax.set_yticks([0, 9, 17, 25])
    ax.set_xticklabels([r'$0$', r'$10^{-2}$', r'$10^{-1}$', r'$10^{0}$'])
    ax.set_yticklabels([r'$0$', r'$10^{-3}$', r'$10^{-2}$', r'$10^{-1}$'])
    ax.set_xlabel(r'\textbf{Coupling Penalty}', fontsize=20)
    ax.set_ylabel(r'\textbf{Tuning Penalty}', fontsize=20)
    ax.tick_params(labelsize=18)
    
axes[0].set_title(r'\textbf{Coupling}', fontsize=22)
axes[1].set_title(r'\textbf{Target Tuning}', fontsize=22)
axes[2].set_title(r'\textbf{Non-target Tuning}', fontsize=22)

#plt.savefig('em_sparsity_on_coupling2.pdf', bbox_inches='tight')