In [92]:
%matplotlib notebook
%load_ext snakeviz
%load_ext autoreload
%autoreload 2
from supplementary.simple_choice_model import hits_gen as hits
from supplementary.simple_choice_model import sim_tools
import ipywidgets as wid
import loc_utils as lut

import vis_utils as vut
import numpy as np
import pandas as pd
import scipy as sp
import contextlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from IPython.display import display
from itertools import combinations
from tqdm import tqdm_notebook

colors = ['#43799d', '#cc5b46', '#ffbb00', '#71bc78', '#43799d', '#cc5b46', '#ffbb00', '#71bc78']

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

The snakeviz extension is already loaded. To reload it, use:
  %reload_ext snakeviz
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Estimating population parameters from individual-level data (via negative log of trial-wise composite likelihood)

In [116]:
np.random.seed(1)

def neg_log_likelihood(params, *args):
    coeffs = np.array(params[:-1])
    inps = np.stack(args[:-1], axis=0)
    U = (coeffs[:, None, None] * inps).sum(axis=0)
    exponent = np.exp(U * params[-1])
    P = (exponent.T / np.sum(exponent, axis=1)).T
    logP = np.log(P[args[-1].astype(bool)])
    logL = np.sum(logP, axis=0)
    return -logL

# Estimate the params
df = lut.unpickle('supplementary/simple_choice_model/data/fit_data.pkl')
df = df.loc[df.ntm != 0, :]
data_dict = {'sid': [], 'grp': [], 'ntm': [], 'loss': [], 
             'alpha': [], 'beta': [], 'gamma': [], 'theta': [], 'tau': []}
alpha_bounds = [-1, 1]
beta_bounds = [-1, 1]
gamma_bounds = [0, 1]
theta_bounds = [-1, 0]
tau_bounds = [1, 10]

bounds = (alpha_bounds, beta_bounds, gamma_bounds, theta_bounds, tau_bounds)
init_guess = sim_tools.rand_params(bounds)

for i, sdf in df.groupby('sid'):
    sid, grp, ntm = sdf.sid.values[0], sdf.grp.values[0], sdf.ntm.values[0]
    lps = sdf.loc[:, 'lp1':'lp4'].values[1:, :]
    pcs = sdf.loc[:, 'pc1':'pc4'].values[1:, :]
    ins = sdf.loc[:, 'in1':'in4'].values[1:, :]
    chs = sdf.loc[:, 'ch1':'ch4'].values[1:, :]
    time_alloc = (sdf.loc[:, 'ch1':'ch4'].values[1:, :].cumsum(axis=0) + 15)
    trs = (time_alloc.T / time_alloc.sum(axis=1)).T

    data = (lps, pcs, ins, trs, chs)
    x, f, d = sp.optimize.fmin_l_bfgs_b(func=neg_log_likelihood, x0=init_guess, args=data,
                                    approx_grad=True, disp=False, bounds=bounds)

    # Store params   
    data_dict['sid'].append(sid)
    data_dict['grp'].append(grp)
    data_dict['ntm'].append(ntm)
    data_dict['loss'].append(f)
    data_dict['alpha'].append(x[0])
    data_dict['beta'].append(x[1])
    data_dict['gamma'].append(x[2])
    data_dict['theta'].append(x[3])
    data_dict['tau'].append(x[4])

# Calculate parameter stats
fdf = pd.DataFrame(data_dict)
gfdf = fdf.groupby(['grp','ntm']).mean().drop(columns=['sid'])
display(gfdf)

Unnamed: 0_level_0,Unnamed: 1_level_0,loss,alpha,beta,gamma,theta,tau
grp,ntm,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,1,26.95532,0.1588,0.29659,0.69673,-0.4115,7.94423
0,2,30.33708,-0.02072,0.09978,0.72372,-0.58023,7.94853
0,3,35.11903,0.13616,-0.1894,0.74029,-0.54902,7.44947
1,1,24.75883,0.11222,-0.2473,0.78917,-0.29117,7.49193
1,2,35.51072,0.06113,-0.30924,0.74885,-0.51468,7.24357
1,3,33.6737,0.12082,-0.41099,0.77443,-0.57813,6.99315


## Model comparison

In [106]:
def neg_log_likelihood(params, *args):
    coeffs = np.array(params[:-1])
    inps = np.stack(args[:-1], axis=0)
    U = (coeffs[:, None, None] * inps).sum(axis=0)
    exponent = np.exp(U * params[-1])
    P = (exponent.T / np.sum(exponent, axis=1)).T
    logP = np.log(P[args[-1].astype(bool)])
    logL = np.sum(logP, axis=0)
    return -logL

# Estimate the params
df = lut.unpickle('supplementary/simple_choice_model/data/fit_data.pkl')
df = df.loc[df.ntm != 0, :]
data_dict = {'form': [], 'sid': [], 'grp': [], 'ntm': [], 'loss': [], 'aic': [], 'bic': [],
             'alpha': [], 'beta': [], 'gamma': [], 'theta': [], 'tau': []}
varnames = np.array(('LP', 'PC', 'I', 'TR'))
alpha_bounds = [-1, 1]
beta_bounds = [-1, 1]
gamma_bounds = [0, 1]
theta_bounds = [-1, 0]
tau_bounds = [1, 50]

bounds = (alpha_bounds, beta_bounds, gamma_bounds, theta_bounds, tau_bounds)
models, inds = [], [0,1,2,3]
for s in range(1,len(inds)+1):
    models += combinations(inds, s)

grouped = df.groupby('sid')
with tqdm_notebook(total=len(grouped)) as progbar:
    for i, sdf in grouped:
        sid, grp, ntm = sdf.sid.values[0], sdf.grp.values[0], sdf.ntm.values[0]
        lps = sdf.loc[:, 'lp1':'lp4'].values[1:, :]
        pcs = sdf.loc[:, 'pc1':'pc4'].values[1:, :]
        ins = sdf.loc[:, 'in1':'in4'].values[1:, :]
        chs = sdf.loc[:, 'ch1':'ch4'].values[1:, :]
        time_alloc = (sdf.loc[:, 'ch1':'ch4'].values[1:, :].cumsum(axis=0) + 15)
        trs = (time_alloc.T / time_alloc.sum(axis=1)).T

        data = (lps, pcs, ins, trs, chs)
        init_guess = sim_tools.rand_params(bounds).tolist()

        for model in models:
            subdata = [data[mi] for mi in model] + [data[-1]]
            subguess = [init_guess[mi] for mi in model] + [init_guess[-1]]
            subbounds = [bounds[mi] for mi in model] + [bounds[-1]]

            x, f, d = sp.optimize.fmin_l_bfgs_b(func=neg_log_likelihood, x0=subguess, args=tuple(subdata),
                                            approx_grad=True, disp=False, bounds=subbounds)

            # Store params
            vec = np.full(5, np.nan)
            vec[tuple([model])] = x[:-1]
            vec[-1] = x[-1]
            data_dict['form'].append(' + '.join(varnames[tuple([model])]))
            data_dict['sid'].append(sid)
            data_dict['grp'].append(grp)
            data_dict['ntm'].append(ntm)
            data_dict['loss'].append(f)
            data_dict['aic'].append(2*f + 2*len(models))
            data_dict['bic'].append(2*f + np.log(chs.shape[0])*len(models))
            data_dict['alpha'].append(vec[0])
            data_dict['beta'].append(vec[1])
            data_dict['gamma'].append(vec[2])
            data_dict['theta'].append(vec[3])
            data_dict['tau'].append(vec[4])
        progbar.update()

# Calculate parameter stats
fdf = pd.DataFrame(data_dict)
gfdf = fdf.groupby(['form']).mean()[['aic', 'bic']]
display(gfdf.sort_values(by='bic'))

HBox(children=(IntProgress(value=0, max=320), HTML(value='')))




Unnamed: 0_level_0,aic,bic
form,Unnamed: 1_level_1,Unnamed: 2_level_1
LP + PC + I + TR,93.87194,146.57337
PC + I + TR,95.75397,148.4554
LP + PC + I,96.69262,149.39405
LP + I + TR,97.78605,150.48748
PC + I,98.62696,151.32839
LP + I,99.75707,152.4585
I + TR,100.39001,153.09144
I,101.76421,154.46565
LP + PC + TR,576.22747,628.9289
LP + PC,576.87434,629.57578


## Visualize

In [None]:
# Visualize empirical data
plt.close()
fig = plt.figure('Fitting individuals', figsize=[9,4])
ax_ = vut.pretty(fig.add_subplot(2,3,3))
group_choices = df.groupby(['grp', 'trial']).mean()
for grp in [0, 1]:
    ax = vut.pretty(fig.add_subplot(2,3,grp+1))
    ax.set_title('% selection across time ({})'.format('FS'[grp]))
    ax.set_ylim(0,.7)
    if not grp: ax.set_ylabel('Empirical data')
        
    for tid in [1, 2, 3, 4]:   
        ax.plot(group_choices.loc[(grp, slice(None)), 'ch{}'.format(tid)].values, color = colors[tid-1])
    
    ax_.set_title('Total time allocation')
    ax_.plot(group_choices.loc[(grp, slice(None)), 'ch1':'ch4'].values.mean(axis=0), 
            color=['#008fd5', '#fc4f30'][grp], label='FS'[grp], lw=2)
    ax_.set_xticks([0,1,2,3])
    ax_.set_xticklabels(['1D', 'I1D', '2D', 'R'])
    ax_.set_ylim(0,.7)
ax_.legend()
    
# Visualize data generated by simulation with estimated parameters
df = lut.unpickle('supplementary/simple_choice_model/data/fit_data.pkl')
df = df.loc[df.ntm != 0, :]
N_trials = 250
group_sizes = (df.groupby(['grp','ntm','sid']).count()/249).reset_index().groupby(['grp','ntm']).count()
N_runs = 5

runs_data = {0: [], 1: []}
for run in range(N_runs):
    for grp in [0, 1]:
        grp_simdata = []
        for ntm in [1, 2, 3]:
            sids = df.loc[(df.grp==grp) & (df.ntm==ntm), 'sid'].unique()
            N_sim = group_sizes.loc[(grp, ntm), 'sid']
            sids = np.random.choice(sids, size=N_sim)
            hits_params = hits.get_parametric(grp=grp, ntm=ntm)
            trials = np.arange(N_trials) + 1
            probs = np.stack([1 / (1 + np.exp(-(hits_params[tid][0] + hits_params[tid][1]*trials))) for tid in [1,2,3,4]], axis=1)
            simhits = (np.random.rand(N_sim, N_trials, 4) <= probs).astype(int)
            init_data = sim_tools.get_multiple_sids(sids)

            simdata = []
            for i, sid in enumerate(sids):
                alpha, beta, gamma, theta, tau = fdf.set_index('sid').loc[sid, 'alpha':'tau'].values
                choices = sim_tools.simple_simulation(init_state=init_data[i, :, :], 
                                                      win1=10, win2=9, N=N_trials, 
                                                      hits = simhits[i, :, :], 
                                                      alpha=alpha, beta=beta, 
                                                      gamma=gamma, theta=theta, tau=tau,
                                                      inverse_temp=True)
                simdata.append(np.eye(4)[choices.astype(int)])
            grp_simdata.append(np.stack(simdata, axis=0).mean(axis=0))
        runs_data[grp].append(np.stack(grp_simdata, axis=0).mean(axis=0))

ax_ = vut.pretty(fig.add_subplot(2,3,6))
for grp in [0, 1]:
    mean_runs_data = np.stack(runs_data[grp], axis=0).mean(axis=0)
    se_runs_data = sp.stats.sem(np.stack(runs_data[grp], axis=0), axis=0)
    
    # Plot simulated percent selection across time
    ax = vut.pretty(fig.add_subplot(2,3,grp+4))
    ax.set_ylim(0,.7)
    if not grp: ax.set_ylabel('Simulated data')
    for tid in [1, 2, 3, 4]:   
        ax.plot(mean_runs_data[:, tid-1], color = colors[tid-1])
        ax.fill_between(np.arange(mean_runs_data.shape[0]), 
                        mean_runs_data[:, tid-1]+se_runs_data[:, tid-1], 
                        mean_runs_data[:, tid-1]-se_runs_data[:, tid-1], 
                        color = colors[tid-1], alpha=.3)
        
    # Plot simulated TIME ALLOCATION
    ax_.plot(mean_runs_data.mean(axis=0), color=['#008fd5', '#fc4f30'][grp], lw=2)
    ax_.fill_between([0,1,2,3], 
                        mean_runs_data.mean(axis=0)+se_runs_data.mean(axis=0), 
                        mean_runs_data.mean(axis=0)-se_runs_data.mean(axis=0), 
                        color = ['#008fd5', '#fc4f30'][grp], alpha=.3)
    ax_.set_xticks([0,1,2,3])
    ax_.set_xticklabels(['1D', 'I1D', '2D', 'R'])
    ax_.set_ylim(0,.7)
    
# fig.savefig('fitted_runs_LP_PC_I_TR_temp.png')

In [None]:
### Qualitative evaluation of individual fits

In [42]:
def neg_log_likelihood(params, *args):
    a, b, c, d, t = params
    LP, PC, I, TR, choices = args
    U = a*LP + b*PC + c*I + d*TR
    P = (np.exp(U * t).T / np.sum(np.exp(U * t), axis=1)).T
    logP = np.log(P[choices.astype(bool)])
    logL = np.sum(logP, axis=0)
    return -logL

# Estimate the params
df = lut.unpickle('supplementary/simple_choice_model/data/fit_data.pkl')
df = df.loc[df.ntm != 0, :]
data_dict = {'sid': [], 'grp': [], 'ntm': [], 'loss': [], 
             'alpha': [], 'beta': [], 'gamma': [], 'theta': [], 'tau': []}
alpha_bounds = [-1, 1]
beta_bounds = [-1, 1]
gamma_bounds = [0, 1]
theta_bounds = [-1, 0]
tau_bounds = [1, 100]

bounds = (alpha_bounds, beta_bounds, gamma_bounds, theta_bounds, tau_bounds)

gfdfs = []
for run in range(50):
    for i, sdf in df.groupby('sid'):
        loss_hist = []
        sid, grp, ntm = sdf.sid.values[0], sdf.grp.values[0], sdf.ntm.values[0]
        lps = sdf.loc[:, 'lp1':'lp4'].values[1:, :]
        pcs = sdf.loc[:, 'pc1':'pc4'].values[1:, :]
        ins = sdf.loc[:, 'in1':'in4'].values[1:, :]
        chs = sdf.loc[:, 'ch1':'ch4'].values[1:, :]
        time_alloc = (sdf.loc[:, 'ch1':'ch4'].values[1:, :].cumsum(axis=0) + 15)
        trs = (time_alloc.T / time_alloc.sum(axis=1)).T

        data = (lps, pcs, ins, trs, chs)
        init_guess = sim_tools.rand_params(bounds)

        x, f, d = sp.optimize.fmin_l_bfgs_b(func=neg_log_likelihood, x0=init_guess, args=data,
                                        approx_grad=True, disp=False, bounds=bounds,
                                        callback=lambda xk: loss_hist.append([xk, neg_log_likelihood(xk, *data)]))
        # Store params   
        data_dict['sid'].append(sid)
        data_dict['grp'].append(grp)
        data_dict['ntm'].append(ntm)
        data_dict['loss'].append(f)
        data_dict['alpha'].append(x[0])
        data_dict['beta'].append(x[1])
        data_dict['gamma'].append(x[2])
        data_dict['theta'].append(x[3])
        data_dict['tau'].append(x[4])

    # Calculate parameter stats
    fdf = pd.DataFrame(data_dict)
    gfdf = fdf.groupby(['grp','ntm']).mean().drop(columns=['sid'])
    gfdfs.append(gfdf.reset_index().values)

Unnamed: 0_level_0,Unnamed: 1_level_0,loss,alpha,beta,gamma,theta,tau
grp,ntm,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0.0,1.0,26.41085,0.09644,0.1423,0.26314,-0.22289,47.64495
0.0,2.0,29.97449,0.00084,0.05753,0.28611,-0.27455,46.03194
0.0,3.0,35.03755,0.07498,-0.05362,0.2668,-0.22383,43.17138
1.0,1.0,24.6904,0.03072,-0.07529,0.29928,-0.15761,47.36352
1.0,2.0,35.31411,0.05033,-0.10388,0.29073,-0.20764,43.97381
1.0,3.0,33.67059,0.03429,-0.13606,0.26228,-0.22331,42.98296


In [43]:
mean = np.stack(gfdfs, axis=0).mean(axis=0)
std = np.stack(gfdfs, axis=0).std(axis=0)
mean_gfdf = gfdf.reset_index()
mean_gfdf.loc[:, :] = mean

std_gfdf = gfdf.reset_index()
std_gfdf.loc[:, :] = std

display(mean_gfdf.set_index(['grp','ntm']))


Unnamed: 0_level_0,Unnamed: 1_level_0,loss,alpha,beta,gamma,theta,tau
grp,ntm,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0.0,1.0,26.41085,0.09644,0.1423,0.26314,-0.22289,47.64495
0.0,2.0,29.97449,0.00084,0.05753,0.28611,-0.27455,46.03194
0.0,3.0,35.03755,0.07498,-0.05362,0.2668,-0.22383,43.17138
1.0,1.0,24.6904,0.03072,-0.07529,0.29928,-0.15761,47.36352
1.0,2.0,35.31411,0.05033,-0.10388,0.29073,-0.20764,43.97381
1.0,3.0,33.67059,0.03429,-0.13606,0.26228,-0.22331,42.98296


Unnamed: 0_level_0,Unnamed: 1_level_0,loss,alpha,beta,gamma,theta,tau
grp,ntm,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0.0,0.0,0.00541,0.00483,0.00712,0.00786,0.0046,0.9574
0.0,0.0,0.00206,0.00214,0.00676,0.00419,0.0035,0.39201
0.0,0.0,0.00556,0.0092,0.00283,0.00635,0.00494,0.33874
0.0,0.0,0.04239,0.0046,0.00974,0.01614,0.0041,1.45029
0.0,0.0,0.02544,0.00383,0.00549,0.01364,0.01095,0.93639
0.0,0.0,0.07887,0.00141,0.00155,0.00304,0.00281,0.46543


## Estimating parameters from group-level data (via KL divergence)

### TODO:
- Make sure the D_KL loss is computed appropriately (test on various vector pairs)
- Try an exhaustive grid search of the sample space to see which regions correspond to minimal losses
    - See if these "good" parameter space regions make sense
- Build an interactive visualization (alpha and beta on the x, y axis, with grids colored according to loss value). TAU should be represented as a slider that should change the appearence of the 2d-alpha-beta grid.
    - Try to visualize the data-simulation comparison histograms by selecting data by click_on event on the imshow object.

In [None]:
# %%snakeviz
def DKL_loss(params, *args):
    a, b, t = params
    sids, data_beg, data_mid, data_end, init_data, simhits = args
    
    choices_beg = []
    choices_mid = []
    choices_end = []
    for i, sid in enumerate(sids):
        choices = sim_tools.simple_simulation(init_state=init_data[i, :, :], 
                                              win1=10, win2=9, N=250, 
                                              hits = simhits[i, :, :], 
                                              alpha=a, beta=b, 
                                              gamma=0, tau=t)
        choices_beg.append(np.eye(4)[choices[:10].astype(int)])
        choices_mid.append(np.eye(4)[choices[:250//2].astype(int)])
        choices_end.append(np.eye(4)[choices.astype(int)])
    choices_beg = np.stack(choices_beg).mean(axis=0).mean(axis=0)
    choices_mid = np.stack(choices_mid).mean(axis=0).mean(axis=0)
    choices_end = np.stack(choices_end).mean(axis=0).mean(axis=0)
    DKL_beg = sp.special.kl_div(data_beg, choices_beg).sum()
    DKL_mid = sp.special.kl_div(data_mid, choices_mid).sum()
    DKL_end = sp.special.kl_div(data_end, choices_end).sum()
    
    loss = DKL_beg + DKL_mid + DKL_end
    return loss


df = lut.unpickle('supplementary/simple_choice_model/data/fit_data.pkl')
df = df.loc[df.ntm != 0, :]
data_dict = {'grp': [], 'ntm': [], 'loss': [], 'alpha': [], 'beta': [], 'gamma': [], 'tau': []}
N_sim = 50
N_trials = 250
bounds = ([-1,1],[-1,1],[0,50])

grp, ntm = 0, 3
sids = df.loc[(df.grp==grp), 'sid'].unique()
print('Estimating params for GRP-{}, NTM-{} ({}/{}) ...'.format(grp, ntm, N_sim, sids.size))
sids = np.random.choice(sids, size=N_sim)

data_beg = df.loc[(df.grp==grp) & (df.ntm==ntm) & (df.trial<=10), 'ch1':'ch4'].values.mean(axis=0)
data_mid = df.loc[(df.grp==grp) & (df.ntm==ntm) & (df.trial<=249//2), 'ch1':'ch4'].values.mean(axis=0)
data_end = df.loc[(df.grp==grp) & (df.ntm==ntm) & (df.trial>=1), 'ch1':'ch4'].values.mean(axis=0)

hits_params = hits.get_parametric(grp=grp)
trials = np.arange(N_trials) + 1
probs = np.stack([1 / (1 + np.exp(-(hits_params[tid][0] + hits_params[tid][1]*trials))) for tid in [1,2,3,4]], axis=1)
simhits = (np.random.rand(N_sim, N_trials, 4) <= probs).astype(int)
init_data = sim_tools.get_multiple_sids(sids)
data = (sids, data_beg, data_mid, data_end, init_data, simhits)

# params = (0.14954,-0.11199,0.47016,11.34070)
# params = (0.10445,-0.21977,0.47538,10.04687)
init_guess = sim_tools.rand_params(bounds)
# init_guess = gfdf.loc[(grp, ntm),'alpha':'tau'].values
# DKL_loss(params, *data)

opt_res = sp.optimize.minimize(fun=DKL_loss, 
                               x0=init_guess,
                               args=data,
                               bounds=bounds)
print(init_guess)
print(opt_res.x)
# for grp in [0, 1]:
#     for ntm in [1, 2, 3]: 
#         data_dict['grp'].append(grp)
#         data_dict['ntm'].append(ntm)
#         data_dict['loss'].append(opt_res.fun)
#         data_dict['alpha'].append(opt_res.x[0])
#         data_dict['beta'].append(opt_res.x[1])
#         data_dict['gamma'].append(opt_res.x[2])
#         data_dict['tau'].append(opt_res.x[3])

In [None]:
def DKL_loss(params, *args):
    a, b, t = params
    sids, data_beg, data_mid, data_end, init_data, simhits = args
    
    choices_beg = []
    choices_mid = []
    choices_end = []
    for i, sid in enumerate(sids):
        choices = sim_tools.simple_simulation(init_state=init_data[i, :, :], 
                                              win1=10, win2=9, N=250, 
                                              hits = simhits[i, :, :], 
                                              alpha=a, beta=b, 
                                              gamma=0, tau=t)
        choices_beg.append(np.eye(4)[choices[:10].astype(int)])
        choices_mid.append(np.eye(4)[choices[:250//2].astype(int)])
        choices_end.append(np.eye(4)[choices.astype(int)])
    choices_beg = np.stack(choices_beg).mean(axis=0).mean(axis=0)
    choices_mid = np.stack(choices_mid).mean(axis=0).mean(axis=0)
    choices_end = np.stack(choices_end).mean(axis=0).mean(axis=0)
    DKL_beg = sp.special.kl_div(data_beg, choices_beg).sum()
    DKL_mid = sp.special.kl_div(data_mid, choices_mid).sum()
    DKL_end = sp.special.kl_div(data_end, choices_end).sum()
#     print(data_beg)
#     print(choices_beg)
#     print(data_mid)
#     print(choices_mid)
#     print(data_end)
#     print(choices_end)
    loss = DKL_beg + DKL_mid + DKL_end
    return loss, [data_beg, data_mid, data_end], [choices_beg, choices_mid, choices_end]


loss, _data, preds = DKL_loss(opt_res.x, *data)

plt.figure('new', figsize=[9, 3])
for i in range(3):
    plt.subplot(131+i)
    plt.plot([1,2,3,4], _data[i], c='k', ls='-')
    plt.plot([1,2,3,4], preds[i], c='k', ls='--')
    plt.ylim(0,.6)