In [1]:
from pathlib import Path
import polars as pl
from cogmood_analysis.load import load_task, boxcoxmask, nanboxcox
from scipy.stats import boxcox
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, NDArray
from joblib import Parallel, delayed
from datetime import datetime

In [2]:
data_dir = Path('../data/')
survey_dir = data_dir / 'survey'
survey_dir = data_dir / 'survey'
task_data_dir = data_dir / 'task/upload'
task_db_dir = data_dir / 'task/db'
start_time = pl.Series(['2025-07-14 00:00:00']).str.to_datetime()[0]
tasks = ['flkr', 'cab', 'rdm', 'bart']
bad_dirs = []


#timing limits
default_high_limit = 2
default_low_limit = 0.350
limits = {
    'cab': (default_low_limit, 2.5),
    'flkr': (default_low_limit, default_high_limit),
    'rdm': (0.2, 3),
    'bart': (0, 3)
}

# chance threshold
# have to press pump on first trial of at least 24 of 36 balloons 
# for p = 0.032 < 0.05 responses are not at random
# (23 is 0.066)
bart_thresh = 24
# cab
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
cab_thresh = 57
cab_rt_thresh05 = 4.8
cab_rt_thresh10 = 9.6
#rdm
# correct response on 105 out of 186 nonrandom trials
# for p = 0.045 < 0.05 responses are not at random
# (104 is 0.062)
rdm_thresh = 105
rdm_rt_thresh05 = 222 * 0.05
rdm_rt_thresh10 = 222 * 0.1
# flkr
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
flkr_thresh = 57
flkr_rt_thresh05 = 4.8
flkr_rt_thresh10 = 9.6

corr_threshes = {
    'bart': 24,
    'cab': 57,
    'rdm': 105,
    'flkr': 57
}
# invert thresholds to get mimimum number passing
rt_threshes = {
    'cab': 96 * 0.95,
    'rdm': 222 * 0.95,
    'flkr': 96 * 0.95
}    


In [3]:
all_subids = sorted([dd.parts[-1] for dd in (data_dir/'task/upload').glob('*') if dd.parts[-1] not in bad_dirs])

In [4]:
task_jobs = {task_name:[] for task_name in tasks}
breakout=False
for subject in all_subids:
    sub_task_dir = task_data_dir / subject
    for task_name in tasks:
        for runnum in [0,1,2]:
            if runnum == 2 and task_name != 'rdm':
                continue
            zipped_path = sub_task_dir / f'{task_name}_{runnum}.zip'
            if zipped_path.exists():
                file_date = datetime.fromtimestamp(zipped_path.stat().st_mtime)
                if file_date < start_time:
                    continue
                loddf = delayed(load_task)(zipped_path, task_name, subject, runnum)
                task_jobs[task_name].append(loddf)
            else:
                continue

task_dat = {task_name:[] for task_name in tasks}
for task_name in tasks:
    task_dat[task_name] = pl.concat(Parallel(n_jobs=8, verbose=10)(task_jobs[task_name]), how='diagonal_relaxed')

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done    2 out of 5469 | elapsed:    1.4s
[Parallel(n_jobs=8)]: Done    9 out of 5469 | elapsed:    1.5s
[Parallel(n_jobs=8)]: Done   16 out of 5469 | elapsed:    1.5s
[Parallel(n_jobs=8)]: Done   25 out of 5469 | elapsed:    1.5s
[Parallel(n_jobs=8)]: Batch computation too fast (0.1719489157851148s.) Setting batch_size=2.
[Parallel(n_jobs=8)]: Done   34 out of 5469 | elapsed:    1.5s
[Parallel(n_jobs=8)]: Done   45 out of 5469 | elapsed:    1.6s
[Parallel(n_jobs=8)]: Batch computation too fast (0.029504776000976562s.) Setting batch_size=4.
[Parallel(n_jobs=8)]: Done   64 out of 5469 | elapsed:    1.6s
[Parallel(n_jobs=8)]: Batch computation too fast (0.06101489067077637s.) Setting batch_size=8.
[Parallel(n_jobs=8)]: Done  100 out of 5469 | elapsed:    1.6s
[Parallel(n_jobs=8)]: Batch computation too fast (0.12724804878234863s.) Setting batch_size=16.
[Parallel(n_jobs=8)]: Done  160 out of 5

In [5]:

for task_name in tasks:
    low_limit, high_limit = limits[task_name]
    tdf = task_dat[task_name]
    if task_name == 'bart':
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('og_mask')
        )
    else:
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('og_mask'),
            pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('bc_mask')
        ).with_columns(
            bc_mask=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: boxcoxmask(x), return_dtype=pl.Boolean).over(pl.col('sub_id')),
        ).with_columns(
            bc_rt=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: nanboxcox(x), return_dtype=pl.Float64).fill_nan(None).over(pl.col('sub_id')),
        ).with_columns(
            bc_z_rt=pl.when(pl.col('bc_mask')).then((pl.col('rt') - pl.col('rt').mean()) / pl.col('rt').std()).over(pl.col('sub_id'))
        )
    task_dat[task_name] = tdff
    

In [6]:
for task_name in tasks:
    if task_name != 'bart':
        tdff = task_dat[task_name]
        # check that time limits are respected
        assert (tdff.filter(~pl.col('og_mask') & pl.col('bc_mask'))).is_empty()
        assert (tdff.filter(pl.col('bc_mask')).min().select('rt') > limits[task_name][0])[0, 'rt']
        assert (tdff.filter(pl.col('bc_mask')).max().select('rt') < limits[task_name][1])[0, 'rt']

In [60]:
tkeeps = []
for task_name in tasks:
    if task_name == 'cab':
        task_dat[task_name] = task_dat[task_name].with_columns(
            (pl.col('resp_acc') == True).alias('correct')
        )
    tdff = task_dat[task_name]
    
    if task_name == 'rdm':
        tkeep = tdff.group_by(pl.col('sub_id')).sum().select(['sub_id', 'bc_mask', 'og_mask'])
        tdff = tdff.with_columns(coh_dif=(pl.col('left_coherence') - pl.col('right_coherence')).abs())
        tcorr = tdff.filter(pl.col('coh_dif') > 0).group_by(pl.col('sub_id')).sum().select(['sub_id','correct'])
        tkeep = tkeep.join(tcorr, how='left', on='sub_id')
    elif task_name == 'bart':
        tkeep = tdff.group_by(['sub_id', 'zrn', 'balloon_id']).first().with_columns(
            correct=pl.col('key_pressed') == pl.col('pump_button')
        ).group_by(pl.col('sub_id')).sum().select(['sub_id', 'correct']).with_columns(
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}')
        ).with_columns(
            pl.col(f'corr_ok_{task_name}').alias(f'good_{task_name}')
        ).with_columns(
            (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
        ).rename({'correct': f'n_correct_{task_name}'})
    else:
        tkeep =  tdff.group_by(pl.col('sub_id')).sum().select(['sub_id', 'bc_mask', 'og_mask', 'correct'])
        
    if task_name != 'bart':
        tkeep = tkeep.with_columns(
            (pl.col('og_mask') >= rt_threshes[task_name]).alias(f'ogrt_ok_{task_name}'),
            (pl.col('bc_mask') >= rt_threshes[task_name]).alias(f'rt_ok_{task_name}'),
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}'),
        ).with_columns(
            (pl.col(f'rt_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good_{task_name}'),
            (pl.col(f'ogrt_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'oggood_{task_name}')
        ).with_columns(
            (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
        ).rename({'og_mask':f'n_good_ogrts_{task_name}', 'bc_mask':f'n_good_rts_{task_name}', 'correct': f'n_correct_{task_name}'})
    tkeeps.append(tkeep)

tkeep = pl.concat(tkeeps, how="align").with_columns(
    good=(pl.col('good_flkr') & pl.col('good_bart') & pl.col('good_rdm') & pl.col('good_cab')),
    oggood=(pl.col('oggood_flkr') & pl.col('good_bart') & pl.col('oggood_rdm') & pl.col('oggood_cab')),
).with_columns(
    bad=~pl.col('good')
)

In [93]:
complete = []
for task_name in tasks:
    tdf = task_dat[task_name]
    expected_n = 2
    if task_name== 'rdm':
        expected_n = 3
    tmp = (tdf.group_by('sub_id').n_unique()
           .select(['sub_id', 'zrn']).with_columns(
               (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
               ).select(['sub_id', f'has_all_{task_name}']))
    complete.append(tmp)
complete = pl.concat(complete, how='align').with_columns(
    has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
)

In [100]:
tkeep = tkeep.join(complete.select(['sub_id', 'has_all']), on='sub_id')

In [103]:
tkeep.filter(pl.col('has_all')).sum().select(['oggood', 'good']), tkeep.filter(pl.col('has_all')).mean().select(['oggood', 'good'])

(shape: (1, 2)
 ┌────────┬──────┐
 │ oggood ┆ good │
 │ ---    ┆ ---  │
 │ u32    ┆ u32  │
 ╞════════╪══════╡
 │ 1875   ┆ 1854 │
 └────────┴──────┘,
 shape: (1, 2)
 ┌──────────┬──────────┐
 │ oggood   ┆ good     │
 │ ---      ┆ ---      │
 │ f64      ┆ f64      │
 ╞══════════╪══════════╡
 │ 0.699627 ┆ 0.691791 │
 └──────────┴──────────┘)

In [110]:
tkeep.filter(pl.col('has_all')).select(['good', 'good_flkr', 'good_cab', 'good_bart', 'good_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good', 'good_flkr', 'good_cab', 'good_bart', 'good_rdm']).sum()

(shape: (1, 5)
 ┌──────────┬───────────┬──────────┬───────────┬──────────┐
 │ good     ┆ good_flkr ┆ good_cab ┆ good_bart ┆ good_rdm │
 │ ---      ┆ ---       ┆ ---      ┆ ---       ┆ ---      │
 │ f64      ┆ f64       ┆ f64      ┆ f64       ┆ f64      │
 ╞══════════╪═══════════╪══════════╪═══════════╪══════════╡
 │ 0.691791 ┆ 0.905597  ┆ 0.786194 ┆ 0.991791  ┆ 0.854851 │
 └──────────┴───────────┴──────────┴───────────┴──────────┘,
 shape: (1, 5)
 ┌──────┬───────────┬──────────┬───────────┬──────────┐
 │ good ┆ good_flkr ┆ good_cab ┆ good_bart ┆ good_rdm │
 │ ---  ┆ ---       ┆ ---      ┆ ---       ┆ ---      │
 │ u32  ┆ u32       ┆ u32      ┆ u32       ┆ u32      │
 ╞══════╪═══════════╪══════════╪═══════════╪══════════╡
 │ 1854 ┆ 2427      ┆ 2107     ┆ 2658      ┆ 2291     │
 └──────┴───────────┴──────────┴───────────┴──────────┘)