In [1]:
from lca_conf import SpotLCA
from RunDEMC.io import load_results
from RunDEMC import Model, Param, dists
from pathlib2 import Path
import pandas as pd
import numpy as np

In [2]:
SCORE_MIN_RT = 0.35
SCORE_MAX_RT = 1.35
def flkr_score(correct, rts):
    accuracy = (correct.astype(bool).mean() - 0.5) / 0.5
    rts = (rts - rts.min()) / rts.max() + 0.35
    speeds = (np.log(SCORE_MAX_RT + 1) - np.log(rts + 1)) / (np.log(SCORE_MAX_RT + 1) - np.log(SCORE_MIN_RT + 1))
    speed = speeds.mean()
    score = speed * accuracy * 100
    return score

In [3]:
subid = '11nuj5ty67ojohm39cmzbt23'
csv_dir = Path('/cogmood/data/pilot/to_model')
res_dir = Path('/cogmood/data/derivatives/model_res/')

In [4]:
sub_ids = [fp.parts[-1].split('.')[0].split('-')[-1] for fp in sorted(csv_dir.glob('flkr-*.csv'))]

In [5]:
log_shift = .05
max_rt = 2
burnin = 400
conditions = ['+', '=', '~']
nsims = 10000
p_post_points = 1000

def_params = dict(r=.1,
                  p=1.0,
                  sd0=2.0,
                  bin_lim=.5,
                  in_bin=1.5,
                  out_bin=2.5,
                  sd_min=.01,
                  K=.1,
                  L=.5,
                  U=0.0,
                  eta=1.,
                  t0=.25,
                  thresh=1.,
                  alpha=0.,
                  max_time=5.0,
                  truncate=True,
                  dt=.01, tau=.1)


In [None]:
all_dat_scores = []
all_map_scores = []
ppres = []
for subid in sub_ids:
    print subid
    s = csv_dir / ('flkr-' + subid + '.csv')

    dat = pd.read_csv(s)
    ddat = {}
    for c in conditions:
        ind = (dat.condition == c) & (dat.rt < max_rt)
        d = {'rt':np.log(np.array(dat[ind].rt)+log_shift),
             'resp': np.array(~dat[ind]['correct'], dtype=np.int)}
        ddat[c] = d

    mdl_tgz = res_dir / ('flanker_flkr-' + subid + '.tgz')
    res = load_results(mdl_tgz.as_posix())
    
    # calculate scores on real data
    dat_scores = {'sub_id':subid}
    dat_scores['total'] = flkr_score(dat.correct.astype(bool).values, dat.rt.values)
    for c in ['+', '=', '~']:
        ind = (dat.condition == c) & (dat.rt < max_rt)
        d = {'rt':np.log(np.array(dat[ind].rt)+log_shift),
             'resp': np.array(~dat[ind]['correct'], dtype=np.int)}
        ddat[c] = d
        correct = ~(d['resp'].astype(bool))
        rts = dat[ind].rt.values
        dat_scores[c] = flkr_score(correct, rts)
    all_dat_scores.append(dat_scores)
    
    # calculate scores on map parameters
    best_ind = res['weights'][burnin:].argmax()
    indiv = [res['particles'][burnin:, :, i].ravel()[best_ind]
            for i in range(res['particles'].shape[-1])]
    best_ps = {p:v for p, v in zip(res['param_names'], indiv)}
    params = best_ps
    
    mod_params = def_params.copy()
    mod_params.update(params)
    
    
    dbin = {}
    dbin['+'] = {
             'bins': np.array([[-(mod_params['out_bin']), -(mod_params['in_bin'])],
                          [-(mod_params['in_bin']), -(mod_params['bin_lim'])],
                          [-(mod_params['bin_lim']), mod_params['bin_lim']],
                          [mod_params['bin_lim'], mod_params['in_bin']],
                          [mod_params['in_bin'], mod_params['out_bin']]], dtype=np.float32),
             'bin_ind': np.array([0,0,0,0,0], dtype=np.int32),
             'nbins': 5}
    dbin['='] = {
             'bins': np.array([[-(mod_params['out_bin']), -(mod_params['in_bin'])],
                          [-(mod_params['in_bin']), -(mod_params['bin_lim'])],
                          [-(mod_params['bin_lim']), mod_params['bin_lim']],
                          [mod_params['bin_lim'], mod_params['in_bin']],
                          [mod_params['in_bin'], mod_params['out_bin']]], dtype=np.float32),
             'bin_ind': np.array([1,0,0,0,1], dtype=np.int32),
             'nbins': 5}
    dbin['~'] = {
             'bins': np.array([[-(mod_params['out_bin']), -(mod_params['in_bin'])],
                          [-(mod_params['in_bin']), -(mod_params['bin_lim'])],
                          [-(mod_params['bin_lim']), mod_params['bin_lim']],
                          [mod_params['bin_lim'], mod_params['in_bin']],
                          [mod_params['in_bin'], mod_params['out_bin']]], dtype=np.float32),
             'bin_ind': np.array([0,1,0,1,0], dtype=np.int32),
             'nbins': 5}

    mod_params['x_init'] = np.ones(len(dbin), dtype=np.float32)*(mod_params['thresh']*float(1/3.))
    
    out_times = {'total':[]}
    corrects = {'total': []}
    map_scores = {'sub_id': subid}
    for c in conditions:
        lca = SpotLCA(nitems=2, nbins=dbin[c]['nbins'],
                      nsims=nsims, log_shift=log_shift, nreps=1)
        mod_params['bins'] = dbin[c]['bins']
        mod_params['bin_ind'] = dbin[c]['bin_ind']
        out_time, x_ind, x_out, conf = lca.simulate(**mod_params)
        x_correct = ~((x_ind == 1).astype(bool))
        out_times[c] = out_time
        out_times['total'].extend(list(out_time))
        corrects[c] = x_correct
        corrects['total'].extend(list(x_correct))
        map_scores[c] = flkr_score(x_correct, out_time)
    map_scores['total'] = flkr_score(np.array(corrects['total']), np.array(out_times['total']))
    all_map_scores.append(map_scores)
    
    for ppix in range(p_post_points):
        if ppix %100 == 0:
            print ppix
        particles = res['particles'][burnin:]
        n_chains = particles.shape[1]
        n_draws = particles.shape[0]
        part_id = np.random.choice(range(n_chains * n_draws))
        params = {p:particles[:, :, i].ravel()[part_id] for i,p in enumerate(res['param_names'])}

        mod_params = def_params.copy()
        mod_params.update(params)
        mod_params['x_init'] = np.ones(len(dbin), dtype=np.float32)*(mod_params['thresh']*float(1/3.))

        out_times = {'total':[]}
        corrects = {'total': []}
        sim_scores = {'ppix':ppix, 'part_id':part_id, 'sub_id':subid}
        for c in conditions:
            lca = SpotLCA(nitems=2, nbins=dbin[c]['nbins'],
                          nsims=nsims, log_shift=log_shift, nreps=1)
            mod_params['bins'] = dbin[c]['bins']
            mod_params['bin_ind'] = dbin[c]['bin_ind']
            out_time, x_ind, x_out, conf = lca.simulate(**mod_params)
            x_correct = ~((x_ind == 1).astype(bool))
            out_times[c] = out_time
            out_times['total'].extend(list(out_time))
            corrects[c] = x_correct
            corrects['total'].extend(list(x_correct))
            sim_scores[c] = flkr_score(x_correct, out_time)
        sim_scores['total'] = flkr_score(np.array(corrects['total']), np.array(out_times['total']))
        ppres.append(sim_scores)

11nuj5ty67ojohm39cmzbt23
0
100
200
300


In [8]:
ppresdf = pd.DataFrame(ppres)

In [10]:
ppresdf.to_csv('/cogmood/data/derivatives/model_res/ppres.tsv', index=None, sep='\t')

In [11]:
all_dat_scores_df = pd.DataFrame(all_dat_scores)
all_dat_scores_df.to_csv('/cogmood/data/derivatives/model_res/scores.tsv', index=None, sep='\t')

In [12]:
all_map_scores_df = pd.DataFrame(all_map_scores)
all_map_scores_df.to_csv('/cogmood/data/derivatives/model_res/map_scores.tsv', index=None, sep='\t')

In [14]:
all_dat_scores_df

Unnamed: 0,+,=,sub_id,total,~
0,86.029404,81.053698,11nuj5ty67ojohm39cmzbt23,84.33439,74.895387
1,83.062016,80.189032,2upuqdbw3wdpk3q43x89zysp,71.486539,62.244573
2,82.563732,74.995706,48juqsgxp4m2o7797zvjxln9,77.491206,79.238249
3,83.555272,76.793314,60pixcark57tgonq4abwctvs,80.363024,77.28334
4,83.134938,73.925223,81987885tpc29718g2d8evdm,82.173027,81.046399
5,71.722008,84.069111,h3q7g3g6za07rl9qnhd87hoq,72.359139,62.07574
6,70.997722,81.16719,hvann18ezp9i2kq8bvqivehs,71.937518,69.008666
7,68.031666,69.369719,l8eyqget2wsecwew6bwabn1h,70.934597,69.85937
8,83.226578,84.668529,mglomvxjfi6gya3jmrt7o09w,80.497133,75.764692
9,80.627383,75.901583,mjff7puqxr95bh6d945ru7z2,75.912687,78.177303


In [15]:
all_map_scores_df

Unnamed: 0,+,=,sub_id,total,~
0,-80.13382,-80.111053,11nuj5ty67ojohm39cmzbt23,-80.341945,-81.276336
1,-80.048184,-67.982495,2upuqdbw3wdpk3q43x89zysp,-65.234791,-48.421591
2,-83.812538,-71.091837,48juqsgxp4m2o7797zvjxln9,-68.352506,-49.685225
3,-84.303333,-71.625541,60pixcark57tgonq4abwctvs,-68.131536,-47.371972
4,-85.850124,-81.890894,81987885tpc29718g2d8evdm,-84.879379,-80.678356
5,-81.344937,-79.094544,h3q7g3g6za07rl9qnhd87hoq,-80.008351,-76.636216
6,-79.572386,-82.171287,hvann18ezp9i2kq8bvqivehs,-81.737007,-80.269166
7,-78.323314,-78.069698,l8eyqget2wsecwew6bwabn1h,-78.279231,-76.785248
8,-82.141417,-80.029386,mglomvxjfi6gya3jmrt7o09w,-80.07138,-79.783984
9,-75.644237,-80.877626,mjff7puqxr95bh6d945ru7z2,-73.413163,-79.712075


In [16]:
ppresdf

Unnamed: 0,+,=,part_id,ppix,sub_id,total,~
0,-80.932950,-80.719811,88209,0,11nuj5ty67ojohm39cmzbt23,-80.889766,-80.703717
1,-80.810694,-81.235921,79257,1,11nuj5ty67ojohm39cmzbt23,-81.560254,-81.932507
2,-79.914865,-80.471182,4920,2,11nuj5ty67ojohm39cmzbt23,-80.678684,-79.482715
3,-82.191165,-78.793125,3081,3,11nuj5ty67ojohm39cmzbt23,-83.242672,-83.187998
4,-79.287178,-79.871562,124421,4,11nuj5ty67ojohm39cmzbt23,-80.437124,-78.454108
5,-62.470392,-61.099875,89679,5,11nuj5ty67ojohm39cmzbt23,-62.036642,-61.177383
6,-76.255470,-73.924159,52729,6,11nuj5ty67ojohm39cmzbt23,-75.244459,-74.450493
7,-79.239783,-80.138670,3570,7,11nuj5ty67ojohm39cmzbt23,-79.833658,-78.016226
8,-77.612960,-76.237626,4702,8,11nuj5ty67ojohm39cmzbt23,-76.872459,-76.089213
9,-79.321030,-82.002370,52001,9,11nuj5ty67ojohm39cmzbt23,-84.779300,-84.708190
