In [36]:
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
from matplotlib import rc
from neurobiases import (TriangularModel,
                         EMSolver,
                         TCSolver,
                         solver_utils,
                         plot)
from sklearn.model_selection import check_cv

%matplotlib inline

In [2]:
results = np.load('../test.npz')

In [3]:
list(results)

['coupling_random_states',
 'tuning_random_states',
 'dataset_random_states',
 'mses',
 'bics',
 'a_true',
 'a_est',
 'b_true',
 'b_est',
 'B',
 'Psi',
 'L',
 'coupling_locs',
 'tuning_locs',
 'coupling_lambdas',
 'tuning_lambdas',
 'N',
 'M',
 'K',
 'D',
 'coupling_distribution',
 'coupling_sparsity',
 'coupling_scale',
 'tuning_distribution',
 'tuning_sparsity',
 'tuning_scale',
 'corr_cluster',
 'corr_back',
 'n_datasets',
 'n_models',
 'shape_key']

In [4]:
results['shape_key']

array(['tuning_loc', 'coupling_loc', 'model_idx', 'dataset_idx',
       'split_idx', 'coupling_lambda', 'tuning_lambda'], dtype='<U15')

In [5]:
results['a_true'].shape

(2, 2, 2, 2, 2, 2, 2, 10)

In [6]:
coupling_rng = np.random.default_rng(12092020)
tuning_rng = np.random.default_rng(1231993)
dataset_rng = np.random.default_rng(23332)

In [7]:
print(coupling_rng.integers(low=0, high=2**32 - 1, size=2))
print(results['coupling_random_states'])

[3551129612 3893514943]
[3551129612 3893514943]


In [8]:
print(tuning_rng.integers(low=0, high=2**32 - 1, size=2))
print(results['tuning_random_states'])

[2576946247 1582465175]
[2576946247 1582465175]


In [14]:
results['corr_cluster'].item()

0.25

In [32]:
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=results['M'],
    N=results['N'],
    K=results['K'],
    corr_cluster=results['corr_cluster'].item(),
    corr_back=results['corr_back'].item(),
    coupling_distribution=str(results['coupling_distribution']),
    coupling_sparsity=results['coupling_sparsity'].item(),
    coupling_loc=results['coupling_locs'][0],
    coupling_scale=results['coupling_scale'].item(),
    coupling_sum=None,
    coupling_random_state=results['coupling_random_states'][0],
    tuning_distribution=str(results['tuning_distribution']),
    tuning_loc=results['tuning_locs'][1],
    tuning_scale=results['tuning_scale'].item(),
    tuning_sparsity=results['tuning_sparsity'].item(),
    tuning_random_state=results['tuning_random_states'][0]
)

In [33]:
print(tm.a.ravel())
print(tm.b.ravel())

[ 0.         -1.80070927 -1.4239516  -2.25000536 -2.91930531 -2.86834318
  0.          0.          0.          0.        ]
[0.         0.         2.6186867  2.57596026 4.64939162 2.07196026
 2.42118227 0.         0.         0.        ]


In [34]:
print(results['a_true'][1, 0, 0, 0, 0, 0, 0])
print(results['b_true'][1, 0, 0, 0, 0, 0, 0])

[ 0.         -1.80070927 -1.4239516  -2.25000536 -2.91930531 -2.86834318
  0.          0.          0.          0.        ]
[0.         0.         2.6186867  2.57596026 4.64939162 2.07196026
 2.42118227 0.         0.         0.        ]


In [35]:
X, Y, y = tm.generate_samples(n_samples=results['D'].item(),
                              random_state=results['dataset_random_states'][0])

In [37]:
cv = check_cv(2)

In [98]:
train_idx, test_idx = list(cv.split(X))[0]
X_train = X[train_idx]
Y_train = Y[train_idx]
y_train = y[train_idx]
X_test = X[test_idx]
Y_test = Y[test_idx]
y_test = y[test_idx]

In [116]:
solver = TCSolver(
    X=X_train,
    Y=Y_train,
    y=y_train,
    c_coupling=0.1,
    c_tuning=results['tuning_lambdas'][0],
    tol=1e-2,
    max_iter=10000,
    initialization='random',
    solver='cd').fit_lasso(refit=True)

In [117]:
print(solver.a.ravel())
print(solver.b.ravel())

[ 1.13085924 -0.58069764 -1.39592524 -1.86943724 -2.5578471  -2.37845829
  0.64239794  1.57579009  0.58078552  0.31223051]
[ -7.16097661  -4.67339156 -10.56641833  -7.25169031  -5.16373909
  -9.33391295 -12.13216904  -7.17084384  -8.43993944  -7.60431761]


In [45]:
print(results['a_est'][1, 0, 0, 0, 1, 0, 0])
print(results['b_est'][1, 0, 0, 0, 1, 0, 0])

[ 0.51461096 -0.93911426 -1.79102391 -1.97371346 -1.76061942 -1.90530908
  1.60415288  0.49371094  1.1979827   0.89633542]
[ -2.0272286   -2.79858611  -6.35695577  -7.65301251  -3.59144671
 -12.65703412 -18.17689371 -14.18729573  -8.88213244  -8.32358116]
