In [1]:
from pathlib import Path
import polars as pl
from cogmood_analysis.load import load_task, boxcoxmask, nanboxcox, load_survey, proc_survey
import cogmood_analysis.survey_helpers as sh
from scipy.stats import boxcox
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, NDArray
from joblib import Parallel, delayed
from datetime import datetime
import json
pl.Config(tbl_rows=300)

<polars.config.Config at 0x129ad56a0>

In [2]:
!ls ~/code/cogmood/data/20250611_pilot/good_task

[34m11nuj5ty67ojohm39cmzbt23[m[m [34mhvann18ezp9i2kq8bvqivehs[m[m [34mreqevxyh9eqa3jyc8wvucdi2[m[m
[34m1ts935dhccck7dgtssvxv9nd[m[m [34ml8eyqget2wsecwew6bwabn1h[m[m [34ms7szmm610ygsyon54c67mlvj[m[m
[34m2upuqdbw3wdpk3q43x89zysp[m[m [34mlpbs1m834j0r6sbezpcotnei[m[m [34msz1p1qr5v5saov60ia90oqlh[m[m
[34m48juqsgxp4m2o7797zvjxln9[m[m [34mmglomvxjfi6gya3jmrt7o09w[m[m [34mu5sdtc6ljckmewbo4acifldp[m[m
[34m60pixcark57tgonq4abwctvs[m[m [34mmjff7puqxr95bh6d945ru7z2[m[m [34mv0pbbk0rbplhgr4tsqqrlvzm[m[m
[34m81987885tpc29718g2d8evdm[m[m [34mol8u2qd7k4zi7cd0p7idogr4[m[m [34mww2y0qu0joyyxy7lsxirxvm4[m[m
bart.parquet             [34mp5t6r1pmhqfffx93vv5h8vxc[m[m [34mwxabwlnirq69pf95h6573i6y[m[m
cab.parquet              [34mq86m1zrqk9q16o5e3dfvb8yx[m[m [34my6d1crfpg0rh1wpohfgi959w[m[m
[34md4hsof73ftqmz1sbm3vc82f6[m[m [34mqjvaxvijfmumaq0czvg4m55x[m[m [34mycbg09io1hcmzyiosg50hgl8[m[m
flkr.parquet             [34mqx2vl559ytpgxjwh26fr

In [3]:
data_dir = Path('/Users/nielsond/code/cogmood/data/20250611_pilot/good_task')
out_dir = Path('/Users/nielsond/code/cogmood/data/20250611_pilot/to_model')
out_dir.mkdir(exist_ok=True)
task_data_dir = data_dir
start_time = pl.Series(['2025-06-09 00:00:00']).str.to_datetime()[0]
tasks = ['flkr', 'cab', 'rdm', 'bart']
bad_dirs = ['.DS_Store']
from_scratch = True

#timing limits
default_high_limit = 2
default_low_limit = 0.350
limits = {
    'cab': (0.2, 2.5),
    'flkr': (default_low_limit, default_high_limit),
    'rdm': (0.2, 3),
    'bart': (0, 3)
}

# chance threshold
# have to press pump on first trial of at least 24 of 36 balloons 
# for p = 0.032 < 0.05 responses are not at random
# (23 is 0.066)
bart_thresh = 24
# cab
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
cab_thresh = 57
cab_rt_thresh05 = 4.8
cab_rt_thresh10 = 9.6
#rdm
# correct response on 105 out of 186 nonrandom trials
# for p = 0.045 < 0.05 responses are not at random
# (104 is 0.062)
rdm_thresh = 105
rdm_rt_thresh05 = 222 * 0.05
rdm_rt_thresh10 = 222 * 0.1
# flkr
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
flkr_thresh = 57
flkr_rt_thresh05 = 4.8
flkr_rt_thresh10 = 9.6

corr_threshes = {
    'bart': 24,
    'cab': 57,
    'rdm': 105,
    'flkr': 57
}
# invert thresholds to get mimimum number passing
rt_threshes = {
    'cab': 96 * 0.95,
    'rdm': 222 * 0.95,
    'flkr': 96 * 0.95
}    


# Load data

In [4]:

if not from_scratch:
    task_dat = {}
    for task_name in tasks:
        task_dat[task_name] = pl.read_parquet(data_dir/f'{task_name}.parquet')

    complete = []
    for task_name in tasks:
        tdf = task_dat[task_name]
        expected_n = 2
        if task_name== 'rdm':
            expected_n = 3
        tmp = (tdf.group_by('sub_id').n_unique()
            .select(['sub_id', 'zrn']).with_columns(
                (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
                ).select(['sub_id', f'has_all_{task_name}']))
        complete.append(tmp)
    complete = pl.concat(complete, how='align').with_columns(
        has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
    )

    complete_subs = complete.filter('has_all').get_column('sub_id').to_list()
else:
    complete_subs = []

In [5]:
all_subids = sorted([dd.parts[-1] for dd in (data_dir).glob('*') if (dd.parts[-1] not in bad_dirs)])
needed_subids = [sid for sid in all_subids if sid not in complete_subs]
# go ahead and just drop all the needed_subids
if not from_scratch:
    for task_name in tasks:
        tdf = task_dat[task_name]
        task_dat[task_name] = tdf.filter(~pl.col('sub_id').is_in(needed_subids))

In [6]:

task_jobs = {task_name:[] for task_name in tasks}
breakout=False
for subject in needed_subids:
    sub_task_dir = task_data_dir / subject
    for task_name in tasks:
        for runnum in [0,1,2]:
            if runnum == 2 and task_name != 'rdm':
                continue
            zipped_path = sub_task_dir / f'{task_name}_{runnum}.zip'
            if zipped_path.exists():
                file_date = datetime.fromtimestamp(zipped_path.stat().st_mtime)
                if file_date < start_time:
                    continue
                loddf = delayed(load_task)(zipped_path, task_name, subject, runnum)
                task_jobs[task_name].append(loddf)
            else:
                print(zipped_path)
                continue

addl_task_dat = {task_name:[] for task_name in tasks}
for task_name in tasks:
    addl_task_dat[task_name] = pl.concat(Parallel(n_jobs=8, verbose=10)(task_jobs[task_name]), how='diagonal_relaxed')

/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/flkr_0.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/flkr_1.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/cab_0.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/cab_1.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/rdm_0.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/rdm_1.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/rdm_2.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/bart_0.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/bart.parquet/bart_1.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/cab.parquet/flkr_0.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/cab.parquet/flkr_1.zip
/Users/nielsond/code/cogmood/data/20250611_pilot/good_task/cab.parquet/cab_0.zip
/Users/nielso

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  2 out of 56 | elapsed:    3.5s
[Parallel(n_jobs=8)]: Done  9 out of 56 | elapsed:    3.6s
[Parallel(n_jobs=8)]: Done 16 out of 56 | elapsed:    3.6s
[Parallel(n_jobs=8)]: Done 25 out of 56 | elapsed:    3.6s
[Parallel(n_jobs=8)]: Batch computation too fast (0.17433088683246437s.) Setting batch_size=2.
[Parallel(n_jobs=8)]: Done 34 out of 56 | elapsed:    3.7s
[Parallel(n_jobs=8)]: Done 47 out of 56 | elapsed:    3.7s remaining:    0.7s
[Parallel(n_jobs=8)]: Done 53 out of 56 | elapsed:    3.7s remaining:    0.2s
[Parallel(n_jobs=8)]: Done 56 out of 56 | elapsed:    3.7s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Batch computation too fast (0.03125619888305664s.) Setting batch_size=2.
[Parallel(n_jobs=8)]: Done  2 out of 56 | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  9 out of 56 | elapsed:    0.0s
[Parallel(n_jobs=8)]: D

## process exclusion boxcox outlier exclusion criteria

In [7]:

for task_name in tasks:
    low_limit, high_limit = limits[task_name]
    tdf = addl_task_dat[task_name]
    if task_name == 'bart':
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('og_mask')
        )
    else:
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('og_mask'),
            pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('bc_mask')
        ).with_columns(
            bc_mask=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: boxcoxmask(x), return_dtype=pl.Boolean).over(pl.col('sub_id')),
        ).with_columns(
            bc_rt=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: nanboxcox(x), return_dtype=pl.Float64).fill_nan(None).over(pl.col('sub_id')),
        ).with_columns(
            bc_z_rt=pl.when(pl.col('bc_mask')).then((pl.col('rt') - pl.col('rt').mean()) / pl.col('rt').std()).over(pl.col('sub_id'))
        )
    addl_task_dat[task_name] = tdff
    

In [8]:
if not from_scratch:
    for task_name in tasks: 
        tdf = task_dat[task_name]
        atdf = addl_task_dat[task_name]
        ctdf = pl.concat([tdf, atdf], how='diagonal_relaxed')
        task_dat[task_name] = ctdf
else:
    task_dat = {task_name:[] for task_name in tasks}
    for task_name in tasks:
        task_dat[task_name] = addl_task_dat[task_name]

complete = []
for task_name in tasks:
    tdf = task_dat[task_name]
    expected_n = 2
    if task_name== 'rdm':
        expected_n = 3
    tmp = (tdf.group_by('sub_id').n_unique()
        .select(['sub_id', 'zrn']).with_columns(
            (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
            ).select(['sub_id', f'has_all_{task_name}']))
    complete.append(tmp)
complete = pl.concat(complete, how='align').with_columns(
    has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
)

In [9]:
# double check that exclusion criteria are correct
for task_name in tasks:
    if task_name != 'bart':
        tdff = task_dat[task_name]
        # check that time limits are respected
        assert (tdff.filter(~pl.col('og_mask') & pl.col('bc_mask'))).is_empty()
        assert (tdff.filter(pl.col('bc_mask')).min().select('rt') > limits[task_name][0])[0, 'rt']
        assert (tdff.filter(pl.col('bc_mask')).max().select('rt') < limits[task_name][1])[0, 'rt']

In [10]:
data_dir

PosixPath('/Users/nielsond/code/cogmood/data/20250611_pilot/good_task')

In [11]:
for task_name in tasks:
    tdff = task_dat[task_name]
    tdff.write_parquet(data_dir/f'{task_name}.parquet')

In [53]:
tdf = task_dat['flkr']

In [12]:
tkeeps = []
for task_name in tasks:
    if task_name == 'cab':
        task_dat[task_name] = task_dat[task_name].with_columns(
            (pl.col('resp_acc') == True).alias('correct')
        )
    tdff = task_dat[task_name]
    
    if task_name == 'rdm':
        tkeep = tdff.group_by(pl.col('sub_id')).agg(
            pl.sum('bc_mask').alias('bc_mask'),
            pl.sum('og_mask').alias('og_mask'),
            pl.last('date').alias('date')
        )
        tdff = tdff.with_columns(coh_dif=(pl.col('left_coherence') - pl.col('right_coherence')).abs())
        tcorr = tdff.filter(pl.col('coh_dif') > 0).group_by(pl.col('sub_id')).sum().select(['sub_id','correct'])
        tkeep = tkeep.join(tcorr, how='left', on='sub_id')
    elif task_name == 'bart':
        tkeep = tdff.group_by(['sub_id', 'zrn', 'balloon_id', 'date']).first().with_columns(
            correct=pl.col('key_pressed') == pl.col('pump_button')
        ).group_by(pl.col('sub_id')).agg(
            pl.sum('correct').alias('correct'),
            pl.last('date').alias('date')
        ).with_columns(
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}')
        ).with_columns(
            pl.col(f'corr_ok_{task_name}').alias(f'good_{task_name}')
        ).with_columns(
            (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
        ).rename({'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})
    else:
        tkeep =  tdff.group_by(pl.col('sub_id')).agg(
            pl.sum('bc_mask').alias('bc_mask'),
            pl.sum('og_mask').alias('og_mask'),
            pl.sum('correct').alias('correct'),
            pl.last('date').alias('date')
        )
        
    if task_name != 'bart':
        tkeep = tkeep.with_columns(
            (pl.col('og_mask') >= rt_threshes[task_name]).alias(f'ogrt_ok_{task_name}'),
            (pl.col('bc_mask') >= rt_threshes[task_name]).alias(f'rt_ok_{task_name}'),
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}'),
        ).with_columns(
            (pl.col(f'rt_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good_{task_name}'),
            (pl.col(f'ogrt_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'oggood_{task_name}')
        ).with_columns(
            (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
        ).rename({'og_mask':f'n_good_ogrts_{task_name}',
                   'bc_mask':f'n_good_rts_{task_name}', 'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})

    tkeeps.append(tkeep)

tkeep = pl.concat(tkeeps, how="align").with_columns(
    good=(pl.col('good_flkr') & pl.col('good_bart') & pl.col('good_rdm') & pl.col('good_cab')),
    oggood=(pl.col('oggood_flkr') & pl.col('good_bart') & pl.col('oggood_rdm') & pl.col('oggood_cab')),
).with_columns(
    bad=~pl.col('good')
)

In [13]:
tkeep = tkeep.join(complete.select(['sub_id', 'has_all']), on='sub_id')

In [14]:
keep_sub_ids = tkeep.filter(pl.col('good_flkr')).select('sub_id').to_numpy().flatten()

In [16]:
for task_name in tasks:
    if task_name == 'bart':
        continue
    tdf = task_dat[task_name]
    keep_df = tdf.filter(pl.col('sub_id').is_in(keep_sub_ids)).filter(pl.col('bc_mask'))
    for sub_id in keep_sub_ids:
        sdf = keep_df.filter(pl.col('sub_id') == sub_id)
        sdf.write_csv(out_dir / f'{task_name}-{sub_id}.csv')