In [1]:
import h5py
import matplotlib.pyplot as plt
import numpy as np

from neurobiases import (TriangularModel,
                         EMSolver,
                         TCSolver,
                         solver_utils,
                         plot)
from sklearn.model_selection import check_cv

%matplotlib inline

In [2]:
results = h5py.File('../em_oracle.h5', 'r')

In [3]:
list(results)

['B',
 'B_est',
 'L',
 'L_est',
 'Psi',
 'Psi_est',
 'a',
 'a_est',
 'b',
 'b_est',
 'bics',
 'coupling_lambdas',
 'coupling_locs',
 'coupling_rngs',
 'dataset_rngs',
 'mlls',
 'shape_key',
 'tuning_lambdas',
 'tuning_locs',
 'tuning_rngs']

In [4]:
print(list(results))
print(results['shape_key'][:])

['B', 'B_est', 'L', 'L_est', 'Psi', 'Psi_est', 'a', 'a_est', 'b', 'b_est', 'bics', 'coupling_lambdas', 'coupling_locs', 'coupling_rngs', 'dataset_rngs', 'mlls', 'shape_key', 'tuning_lambdas', 'tuning_locs', 'tuning_rngs']
['tuning_loc' 'coupling_loc' 'model_idx' 'dataset_idx' 'split_idx'
 'coupling_lambda' 'tuning_lambda']


In [5]:
print(results['tuning_rngs'][:])
print(results['coupling_rngs'][:])
print(results['dataset_rngs'][:])

[1949932628 3437779209]
[  66316748 2930678936]
[1274196501 2786552199]


In [30]:
tm = TriangularModel(
    model='linear',
    parameter_design='direct_response',
    M=results.attrs['M'],
    N=results.attrs['N'],
    K=results.attrs['K'],
    corr_cluster=results.attrs['corr_cluster'],
    corr_back=results.attrs['corr_back'],
    coupling_distribution=results.attrs['coupling_distribution'],
    coupling_sparsity=results.attrs['coupling_sparsity'],
    coupling_loc=results['coupling_locs'][1],
    coupling_scale=results.attrs['coupling_scale'],
    coupling_rng=results['coupling_rngs'][1],
    tuning_distribution=results.attrs['tuning_distribution'],
    tuning_loc=results['tuning_locs'][0],
    tuning_scale=results.attrs['tuning_scale'],
    tuning_sparsity=results.attrs['tuning_sparsity'],
    tuning_rng=results['tuning_rngs'][1]
)

In [34]:
results['shape_key'][:]

array(['tuning_loc', 'coupling_loc', 'model_idx', 'dataset_idx',
       'split_idx', 'coupling_lambda', 'tuning_lambda'], dtype=object)

In [31]:
print(tm.a.ravel())
print(tm.b.ravel())

[0.         0.         0.         3.25159523 4.13021651 2.14776041]
[ 0.          0.          1.2064731  -1.05602344 -0.44346509 -0.92659889
  0.          0.        ]


In [36]:
print(results['a'][0, 1, 1, 0, 1])
print(results['b'][0, 1, 1, 0, 1])

[0.         0.         0.         3.25159523 4.13021651 2.14776041]
[ 0.          0.          1.2064731  -1.05602344 -0.44346509 -0.92659889
  0.          0.        ]


In [37]:
X, Y, y = tm.generate_samples(n_samples=results.attrs['D'], rng=results['dataset_rngs'][0])

In [38]:
cv = check_cv(2)

In [39]:
train_idx, test_idx = list(cv.split(X))[1]
X_train = X[train_idx]
Y_train = Y[train_idx]
y_train = y[train_idx]
X_test = X[test_idx]
Y_test = Y[test_idx]
y_test = y[test_idx]

In [40]:
solver = TCSolver(
    X=X_train,
    Y=Y_train,
    y=y_train,
    c_coupling=0,
    c_tuning=0,
    tol=results.attrs['tol'],
    max_iter=results.attrs['max_iter'],
    initialization='random',
    solver='cd').fit_lasso(refit=True)

In [42]:
solver = EMSolver(
    X=X_train,
    Y=Y_train,
    y=y_train,
    K=1,
    a_mask=results['a'][0, 1, 1, 0, 1]!= 0,
    b_mask=results['b'][0, 1, 1, 0, 1]!= 0,
    B_mask=results['B'][0, 1, 1, 0, 1] != 0,
    c_coupling=0,
    c_tuning=0,
    tol=results.attrs['tol'],
    max_iter=results.attrs['max_iter'],
    initialization='fits',
    solver='ow_lbfgs',
    rng=456).fit_em(refit=False)

In [43]:
print(solver.a.ravel())
print(solver.b.ravel())
print(solver.B.ravel())

[ 0.          0.          0.          4.98010851 10.52384809  2.50885544]
[ 0.          0.         -1.10369432 -2.05116481  0.86454204 -0.5368986
  0.          0.        ]
[ 1.46389437  0.71576336  0.          0.          0.          0.
 -0.13652049 -1.21438368 -0.14971124  0.          0.          0.
 -0.17992822  0.55721223  1.80435746  0.69166882  0.          0.
  0.          0.78413344  0.42563803  1.64540932 -0.40375277  0.
  0.          0.         -0.48606762 -0.09546449 -0.419451    0.
  0.          0.          0.         -0.42182721  0.71817558 -1.30579091
  0.          0.          0.          0.         -0.14552053  0.49797361
  0.          0.          0.          0.          0.         -3.88688013]


In [44]:
print(results['a_est'][0, 1, 1, 0, 1])
print(results['b_est'][0, 1, 1, 0, 1])
print(results['B_est'][0, 1, 1, 0, 1].ravel())

[ 0.          0.          0.          4.98010851 10.52384809  2.50885544]
[ 0.          0.         -1.10369432 -2.05116481  0.86454204 -0.5368986
  0.          0.        ]
[ 1.46389437  0.71576336  0.          0.          0.          0.
 -0.13652049 -1.21438368 -0.14971124  0.          0.          0.
 -0.17992822  0.55721223  1.80435746  0.69166882  0.          0.
  0.          0.78413344  0.42563803  1.64540932 -0.40375277  0.
  0.          0.         -0.48606762 -0.09546449 -0.419451    0.
  0.          0.          0.         -0.42182721  0.71817558 -1.30579091
  0.          0.          0.          0.         -0.14552053  0.49797361
  0.          0.          0.          0.          0.         -3.88688013]


In [22]:
results['shape_key'][:]

array(['tuning_loc', 'coupling_loc', 'model_idx', 'dataset_idx',
       'split_idx', 'coupling_lambda', 'tuning_lambda'], dtype=object)