In [1]:
from pathlib import Path
import polars as pl
from cogmood_analysis.load import load_task, boxcoxmask, nanboxcox, load_survey, proc_survey
import cogmood_analysis.survey_helpers as sh
from scipy.stats import boxcox
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, NDArray
from joblib import Parallel, delayed
from datetime import datetime
import json

import tempfile
from zipfile import ZipFile
from cogmood_analysis import log
import pickle
pl.Config(tbl_rows=300)

<polars.config.Config at 0x124e556a0>

In [2]:
data_dir = Path('/Users/nielsond/code/cogmood/data/20250611_pilot/good_task')
out_dir = Path('/Users/nielsond/code/cogmood/data/20250611_pilot/to_model')
out_dir.mkdir(exist_ok=True)
task_data_dir = data_dir
start_time = pl.Series(['2025-06-09 00:00:00']).str.to_datetime()[0]
tasks = ['flkr', 'cab', 'rdm', 'bart']
bad_dirs = ['.DS_Store']
from_scratch = True

#timing limits
limits = {
    'cab': (0.2, 2.5),
    'flkr': (0.1, 3),
    'rdm': (0.1, 3),
    'bart': (0, 100)
}

# chance threshold
# have to press pump on first trial of at least 24 of 36 balloons 
# for p = 0.032 < 0.05 responses are not at random
# (23 is 0.066)
bart_thresh = 24
# cab
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
cab_thresh = 57
# cab_rt_thresh05 = 4.8
# cab_rt_thresh10 = 9.6
#rdm
# correct response on 105 out of 186 nonrandom trials
# for p = 0.045 < 0.05 responses are not at random
# (104 is 0.062)
rdm_thresh = 105
# rdm_rt_thresh05 = 222 * 0.05
# rdm_rt_thresh10 = 222 * 0.1
# flkr
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
flkr_thresh = 57
# flkr_rt_thresh05 = 4.8
# flkr_rt_thresh10 = 9.6

corr_threshes = {
    'bart': bart_thresh,
    'cab': cab_thresh,
    'rdm': rdm_thresh,
    'flkr': flkr_thresh
}
# invert thresholds to get mimimum number passing
rt05_threshes = {
    'cab': 96 * 0.95,
    'rdm': 222 * 0.95,
    'flkr': 96 * 0.95
}

rt10_threshes = {
    'cab': 96 * 0.9,
    'rdm': 222 * 0.9,
    'flkr': 96 * 0.9
} 


# Load data

In [3]:

if not from_scratch:
    task_dat = {}
    for task_name in tasks:
        task_dat[task_name] = pl.read_parquet(data_dir/f'{task_name}.parquet')

    complete = []
    for task_name in tasks:
        tdf = task_dat[task_name]
        expected_n = 2
        if task_name== 'rdm':
            expected_n = 3
        tmp = (tdf.group_by('sub_id').n_unique()
            .select(['sub_id', 'zrn']).with_columns(
                (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
                ).select(['sub_id', f'has_all_{task_name}']))
        complete.append(tmp)
    complete = pl.concat(complete, how='align').with_columns(
        has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
    )

    complete_subs = complete.filter('has_all').get_column('sub_id').to_list()
else:
    complete_subs = []

In [4]:
all_subids = sorted([dd.parts[-1] for dd in (data_dir).glob('*') if (dd.parts[-1] not in bad_dirs) and ('parquet' not in dd.parts[-1])])
needed_subids = [sid for sid in all_subids if sid not in complete_subs]
# go ahead and just drop all the needed_subids
if not from_scratch:
    for task_name in tasks:
        tdf = task_dat[task_name]
        task_dat[task_name] = tdf.filter(~pl.col('sub_id').is_in(needed_subids))

In [5]:

task_jobs = {task_name:[] for task_name in tasks}
breakout=False
for subject in needed_subids:
    sub_task_dir = task_data_dir / subject
    for task_name in tasks:
        for runnum in [0,1,2]:
            if runnum == 2 and task_name != 'rdm':
                continue
            zipped_path = sub_task_dir / f'{task_name}_{runnum}.zip'
            if zipped_path.exists():
                file_date = datetime.fromtimestamp(zipped_path.stat().st_mtime)
                if file_date < start_time:
                    continue
                loddf = delayed(load_task)(zipped_path, task_name, subject, runnum)
                task_jobs[task_name].append(loddf)
            else:
                continue

addl_task_dat = {task_name:[] for task_name in tasks}
for task_name in tasks:
    addl_task_dat[task_name] = pl.concat(Parallel(n_jobs=8, verbose=10)(task_jobs[task_name]), how='diagonal_relaxed')

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  2 out of 56 | elapsed:    2.4s
[Parallel(n_jobs=8)]: Done  9 out of 56 | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 16 out of 56 | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 25 out of 56 | elapsed:    2.5s
[Parallel(n_jobs=8)]: Batch computation too fast (0.18306021607825182s.) Setting batch_size=2.
[Parallel(n_jobs=8)]: Done 34 out of 56 | elapsed:    2.5s
[Parallel(n_jobs=8)]: Done 47 out of 56 | elapsed:    2.5s remaining:    0.5s
[Parallel(n_jobs=8)]: Done 53 out of 56 | elapsed:    2.5s remaining:    0.1s
[Parallel(n_jobs=8)]: Done 56 out of 56 | elapsed:    2.5s finished
[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Batch computation too fast (0.009424924850463867s.) Setting batch_size=2.
[Parallel(n_jobs=8)]: Done  2 out of 56 | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  9 out of 56 | elapsed:    0.0s
[Parallel(n_jobs=8)]: 

## process exclusion boxcox outlier exclusion criteria

In [6]:


for task_name in tasks:
    low_limit, high_limit = limits[task_name]
    tdf = addl_task_dat[task_name]
    if task_name == 'bart':
        # For bart, just drop trials with weird out of range RTs
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > 0) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('rt_inbounds')
        ).with_columns(
            pl.col('rt_inbounds').alias('keep')
        )
    else:
        # For non-bart, mark nonresponse trials, don't inlcude them in the initial bc_mask
        # also don't include trials with weird out of bounds rts in the initial bc_mask
        # Keep non-response trials along with trials passing box-cox filtering
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > 0) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('rt_inbounds'),
            pl.when(pl.col('rt').is_null()).then(True).otherwise(False).alias('nonresponse'),
        ).with_columns(
            pl.when((~pl.col('rt_inbounds')) | pl.col('nonresponse')).then(False).otherwise(True).alias('bc_mask')
        ).with_columns(
            pl.col('bc_mask').alias('bc_mask25')
        ).with_columns(
            bc_mask=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: boxcoxmask(x), return_dtype=pl.Boolean).over(pl.col('sub_id')),
            bc_mask25=pl.when(pl.col('bc_mask25')).then(pl.col('rt')).map_batches(lambda x: boxcoxmask(x, thresh=2.5), return_dtype=pl.Boolean).over(pl.col('sub_id')),
        ).with_columns(
            bc_rt=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: nanboxcox(x), return_dtype=pl.Float64).fill_nan(None).over(pl.col('sub_id')),
        ).with_columns(
            bc_z_rt=pl.when(pl.col('bc_mask')).then((pl.col('rt') - pl.col('rt').mean()) / pl.col('rt').std()).over(pl.col('sub_id'))
        ).with_columns(
            pl.when(pl.col('nonresponse') | pl.col('bc_mask')).then(True).otherwise(False).alias('keep')
        )
    addl_task_dat[task_name] = tdff
    

In [7]:
if not from_scratch:
    for task_name in tasks: 
        tdf = task_dat[task_name]
        atdf = addl_task_dat[task_name]
        ctdf = pl.concat([tdf, atdf], how='diagonal_relaxed')
        task_dat[task_name] = ctdf
else:
    task_dat = {task_name:[] for task_name in tasks}
    for task_name in tasks:
        task_dat[task_name] = addl_task_dat[task_name]

# identify subjects that have the expected number of blocks per task
complete = []
for task_name in tasks:
    tdf = task_dat[task_name]
    expected_n = 2
    if task_name== 'rdm':
        expected_n = 3
    tmp = (tdf.group_by('sub_id').n_unique()
        .select(['sub_id', 'zrn']).with_columns(
            (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
            ).select(['sub_id', f'has_all_{task_name}']))
    complete.append(tmp)
complete = pl.concat(complete, how='align').with_columns(
    has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
)

In [8]:
# double check that exclusion criteria are correct
for task_name in tasks:
    if task_name != 'bart':
        tdff = task_dat[task_name]
        # check that we're correctly picking trials to include
        assert (tdff.filter(pl.col('bc_mask') | pl.col('nonresponse'))).select('keep').to_series().all()


In [9]:
for task_name in tasks:
    old_task_dat = pl.read_parquet(data_dir/f'{task_name}.parquet')
    tdff = task_dat[task_name]
    if tdff.equals(old_task_dat):
        continue
    else:
        old_task_dat.write_parquet(data_dir/f'{task_name}_old.parquet')
        tdff.write_parquet(data_dir/f'{task_name}.parquet')

In [10]:
# process subject level exclusion
# subjects are excluded if:
# 1) they have below chance performance
# 2) > 5 or 10 % of trials are non-response or exluded by box-cox
# bc_mask is true if the trial is within RT response bounds, has a response, and passed box-cox outlier exclusion
tkeeps = []
for task_name in tasks:
    if task_name == 'cab':
        task_dat[task_name] = task_dat[task_name].with_columns(
            (pl.col('resp_acc') == True).alias('correct')
        )
    tdff = task_dat[task_name]
    
    if task_name == 'rdm':
        tkeep = tdff.group_by(pl.col('sub_id')).agg(
            pl.sum('bc_mask').alias('bc_mask'),
            pl.sum('bc_mask25').alias('bc_mask25'),
            pl.last('date').alias('date')
        )
        tdff = tdff.with_columns(coh_dif=(pl.col('left_coherence') - pl.col('right_coherence')).abs())
        tcorr = tdff.filter(pl.col('coh_dif') > 0).group_by(pl.col('sub_id')).sum().select(['sub_id','correct'])
        tkeep = tkeep.join(tcorr, how='left', on='sub_id')
    elif task_name == 'bart':
        tkeep = tdff.group_by(['sub_id', 'zrn', 'balloon_id', 'date']).first().with_columns(
            correct=pl.col('key_pressed') == pl.col('pump_button')
        ).group_by(pl.col('sub_id')).agg(
            pl.sum('correct').alias('correct'),
            pl.last('date').alias('date')
        ).with_columns(
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}')
        ).with_columns(
            pl.col(f'corr_ok_{task_name}').alias(f'good_{task_name}')
        ).with_columns(
            (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
        ).rename({'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})
    else:
        tkeep =  tdff.group_by(pl.col('sub_id')).agg(
            pl.sum('bc_mask').alias('bc_mask'),
            pl.sum('bc_mask25').alias('bc_mask25'),
            pl.sum('correct').alias('correct'),
            pl.last('date').alias('date')
        )
        
    if task_name != 'bart':
        tkeep = tkeep.with_columns(
            (pl.col('bc_mask') >= rt05_threshes[task_name]).alias(f'resp05_ok_{task_name}'),
            (pl.col('bc_mask') >= rt10_threshes[task_name]).alias(f'resp10_ok_{task_name}'),
            (pl.col('bc_mask25') >= rt05_threshes[task_name]).alias(f'bc25_resp05_ok_{task_name}'),
            (pl.col('bc_mask25') >= rt10_threshes[task_name]).alias(f'bc25_resp10_ok_{task_name}'),
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}'),
        ).with_columns(
            (pl.col(f'resp05_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good05_{task_name}'),
            (pl.col(f'resp10_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good10_{task_name}'),
            (pl.col(f'bc25_resp05_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'bc25_good05_{task_name}'),
            (pl.col(f'bc25_resp10_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'bc25_good10_{task_name}'),
        ).with_columns(
            (~pl.col(f'good05_{task_name}')).alias(f'bad05_{task_name}'),
            (~pl.col(f'good10_{task_name}')).alias(f'bad10_{task_name}'),
            (~pl.col(f'bc25_good05_{task_name}')).alias(f'bc25_bad05_{task_name}'),
            (~pl.col(f'bc25_good10_{task_name}')).alias(f'bc25_bad10_{task_name}')
        ).rename({
            'bc_mask':f'n_good_resps_{task_name}', 'bc_mask25':f'bc25_n_good_resps_{task_name}', 'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})

    tkeeps.append(tkeep)

tkeep = pl.concat(tkeeps, how="align").with_columns(
    good05=(pl.col('good05_flkr') & pl.col('good_bart') & pl.col('good05_rdm') & pl.col('good05_cab')),
    good10=(pl.col('good10_flkr') & pl.col('good_bart') & pl.col('good10_rdm') & pl.col('good10_cab')),
    bc25_good05=(pl.col('bc25_good05_flkr') & pl.col('good_bart') & pl.col('bc25_good05_rdm') & pl.col('bc25_good05_cab')),
    bc25_good10=(pl.col('bc25_good10_flkr') & pl.col('good_bart') & pl.col('bc25_good10_rdm') & pl.col('bc25_good10_cab')),
).with_columns(
    bad05=~pl.col('good05'),
    bad10=~pl.col('good10')
)

In [11]:
tkeep = tkeep.join(complete.select(['sub_id', 'has_all']), on='sub_id')

In [12]:
cab_misfits = [
    'd4hsof73ftqmz1sbm3vc82f6',
    '60pixcark57tgonq4abwctvs',
    'sz1p1qr5v5saov60ia90oqlh',
]
rdm_misfits = [
    'h3q7g3g6za07rl9qnhd87hoq',
    'ycbg09io1hcmzyiosg50hgl8',
    '1ts935dhccck7dgtssvxv9nd',
    'reqevxyh9eqa3jyc8wvucdi2'
]

In [13]:
tkeep.filter(pl.col('sub_id').is_in(cab_misfits)).select(['sub_id', 'n_good_resps_cab', 'bc25_n_good_resps_cab', 'resp10_ok_cab', 'bc25_resp10_ok_cab']), tkeep.filter(pl.col('sub_id').is_in(rdm_misfits)).select(['sub_id', 'n_good_resps_rdm', 'bc25_n_good_resps_rdm', 'resp10_ok_rdm', 'bc25_resp10_ok_rdm'])

(shape: (3, 5)
 ┌─────────────────────┬──────────────────┬────────────────────┬───────────────┬────────────────────┐
 │ sub_id              ┆ n_good_resps_cab ┆ bc25_n_good_resps_ ┆ resp10_ok_cab ┆ bc25_resp10_ok_cab │
 │ ---                 ┆ ---              ┆ cab                ┆ ---           ┆ ---                │
 │ str                 ┆ u32              ┆ ---                ┆ bool          ┆ bool               │
 │                     ┆                  ┆ u32                ┆               ┆                    │
 ╞═════════════════════╪══════════════════╪════════════════════╪═══════════════╪════════════════════╡
 │ 60pixcark57tgonq4ab ┆ 95               ┆ 94                 ┆ true          ┆ true               │
 │ wctvs               ┆                  ┆                    ┆               ┆                    │
 │ d4hsof73ftqmz1sbm3v ┆ 93               ┆ 92                 ┆ true          ┆ true               │
 │ c82f6               ┆                  ┆                    ┆   

In [12]:
# Write out a data quality CSV for participant filtering
tkeep.write_csv(out_dir / 'data_quality.csv')
# write out data for all subjects
for task_name in tasks:
    tdf = task_dat[task_name]
    subjects = tdf.select('sub_id').unique().to_numpy().flatten()

    for sub_id in subjects:
        sdf = tdf.filter(pl.col('sub_id') == sub_id)
        sdf.write_csv(out_dir / f'{task_name}-{sub_id}.csv')
    

In [15]:
tkeep.sum()

sub_id,n_good_resps_flkr,bc25_n_good_resps_flkr,n_correct_flkr,date_flkr,resp05_ok_flkr,resp10_ok_flkr,bc25_resp05_ok_flkr,bc25_resp10_ok_flkr,corr_ok_flkr,good05_flkr,good10_flkr,bc25_good05_flkr,bc25_good10_flkr,bad05_flkr,bad10_flkr,bc25_bad05_flkr,bc25_bad10_flkr,n_good_resps_cab,bc25_n_good_resps_cab,n_correct_cab,date_cab,resp05_ok_cab,resp10_ok_cab,bc25_resp05_ok_cab,bc25_resp10_ok_cab,corr_ok_cab,good05_cab,good10_cab,bc25_good05_cab,bc25_good10_cab,bad05_cab,bad10_cab,bc25_bad05_cab,bc25_bad10_cab,n_good_resps_rdm,bc25_n_good_resps_rdm,date_rdm,n_correct_rdm,resp05_ok_rdm,resp10_ok_rdm,bc25_resp05_ok_rdm,bc25_resp10_ok_rdm,corr_ok_rdm,good05_rdm,good10_rdm,bc25_good05_rdm,bc25_good10_rdm,bad05_rdm,bad10_rdm,bc25_bad05_rdm,bc25_bad10_rdm,n_correct_bart,date_bart,corr_ok_bart,good_bart,bad_bart,good05,good10,bc25_good05,bc25_good10,bad05,bad10,has_all
str,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
,2652,2624,2494,,26,27,24,27,26,25,25,23,25,3,3,5,3,2602,2569,2077,,24,25,21,25,28,24,25,21,25,4,3,7,3,6105,6034,,4066,25,28,23,28,25,22,25,21,25,6,3,7,3,992,,28,28,0,19,22,16,22,9,6,28


In [16]:
tkeep.sum()

sub_id,n_good_resps_flkr,bc25_n_good_resps_flkr,n_correct_flkr,date_flkr,resp05_ok_flkr,resp10_ok_flkr,bc25_resp05_ok_flkr,bc25_resp10_ok_flkr,corr_ok_flkr,good05_flkr,good10_flkr,bc25_good05_flkr,bc25_good10_flkr,bad05_flkr,bad10_flkr,bc25_bad05_flkr,bc25_bad10_flkr,n_good_resps_cab,bc25_n_good_resps_cab,n_correct_cab,date_cab,resp05_ok_cab,resp10_ok_cab,bc25_resp05_ok_cab,bc25_resp10_ok_cab,corr_ok_cab,good05_cab,good10_cab,bc25_good05_cab,bc25_good10_cab,bad05_cab,bad10_cab,bc25_bad05_cab,bc25_bad10_cab,n_good_resps_rdm,bc25_n_good_resps_rdm,date_rdm,n_correct_rdm,resp05_ok_rdm,resp10_ok_rdm,bc25_resp05_ok_rdm,bc25_resp10_ok_rdm,corr_ok_rdm,good05_rdm,good10_rdm,bc25_good05_rdm,bc25_good10_rdm,bad05_rdm,bad10_rdm,bc25_bad05_rdm,bc25_bad10_rdm,n_correct_bart,date_bart,corr_ok_bart,good_bart,bad_bart,good05,good10,bc25_good05,bc25_good10,bad05,bad10,has_all
str,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,datetime[μs],u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
,2652,2624,2494,,26,27,24,27,26,25,25,23,25,3,3,5,3,2602,2569,2077,,24,25,21,25,28,24,25,21,25,4,3,7,3,6105,6034,,4066,25,28,23,28,25,22,25,21,25,6,3,7,3,992,,28,28,0,19,22,16,22,9,6,28


In [17]:
tkeep.shape

(28, 64)

# Extract bart pickles

In [13]:
missing_bart_pickles = []
for subject in needed_subids:
    sub_task_dir = task_data_dir / subject
    for task_name in ['bart']:
        for runnum in [0,1,2]:
            if runnum == 2 and task_name != 'rdm':
                continue
            zipped_path = sub_task_dir / f'{task_name}_{runnum}.zip'
            try:
                with tempfile.TemporaryDirectory() as tmpdir:
                    bart_pickle = Path(ZipFile(zipped_path).extract(f"obart_pickles/bags_session_{runnum}.p", path=tmpdir))
                    bart_pd = pickle.loads(bart_pickle.read_bytes())
            except KeyError:
                missing_bart_pickles.append(subject)
                continue
            bart_pickle_dir = out_dir / f"bart-{subject}_pickles"
            bart_pickle_dir.mkdir(exist_ok=True)
            bart_pickle_file = bart_pickle_dir / f"bags_session_{runnum}.pickle"
            bart_pickle_file.write_bytes(pickle.dumps(bart_pd))

