In [None]:
from pathlib import Path
import polars as pl
from cogmood_analysis.load import load_task, boxcoxmask, nanboxcox, load_survey, proc_survey
import cogmood_analysis.survey_helpers as sh
from scipy.stats import boxcox
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
from numpy.typing import ArrayLike, NDArray
from joblib import Parallel, delayed
from datetime import datetime
import json
pl.Config(tbl_rows=300)
pl.Config(tbl_cols=300)

In [None]:
data_dir = Path('../data/')
survey_dir = data_dir / 'survey'
complete_dir = survey_dir / 'complete'
task_data_dir = data_dir / 'task/upload'
task_db_dir = data_dir / 'task/db'
task_tomodel_dir = data_dir / 'task/to_model'
task_tomodel_dir.mkdir(exist_ok=True)
start_time = pl.Series(['2025-07-11 00:00:00']).str.to_datetime()[0]
tasks = ['flkr', 'cab', 'rdm', 'bart']
bad_dirs = []
from_scratch = False

#timing limits
default_high_limit = 2
default_low_limit = 0.350
limits = {
    'cab': (0.2, 2.5),
    'flkr': (0.1, 3),
    'rdm': (0.1, 3),
    'bart': (0, 100)
}

# chance threshold
# have to press pump on first trial of at least 24 of 36 balloons 
# for p = 0.032 < 0.05 responses are not at random
# (23 is 0.066)
bart_thresh = 24
# cab
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
cab_thresh = 57
cab_rt_thresh05 = 4.8
cab_rt_thresh10 = 9.6
#rdm
# correct response on 105 out of 186 nonrandom trials
# for p = 0.045 < 0.05 responses are not at random
# (104 is 0.062)
rdm_thresh = 105
rdm_rt_thresh05 = 222 * 0.05
rdm_rt_thresh10 = 222 * 0.1
# flkr
# correct response on 57 out of 96 trials
# for p = 0.041 < 0.05 responses are not at random
# (56 is 0.0625)
flkr_thresh = 57
flkr_rt_thresh05 = 4.8
flkr_rt_thresh10 = 9.6

corr_threshes = {
    'bart': 24,
    'cab': 57,
    'rdm': 105,
    'flkr': 57
}
# invert thresholds to get mimimum number passing
rt05_threshes = {
    'cab': 96 * 0.95,
    'rdm': 222 * 0.95,
    'flkr': 96 * 0.95
}
rt10_threshes = {
    'cab': 96 * 0.9,
    'rdm': 222 * 0.9,
    'flkr': 96 * 0.9
} 

# Load data

In [None]:

if not from_scratch:
    task_dat = {}
    for task_name in tasks:
        task_dat[task_name] = pl.read_parquet(data_dir/f'{task_name}.parquet')

    complete = []
    for task_name in tasks:
        tdf = task_dat[task_name]
        expected_n = 2
        if task_name== 'rdm':
            expected_n = 3
        tmp = (tdf.group_by('sub_id').n_unique()
            .select(['sub_id', 'zrn']).with_columns(
                (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
                ).select(['sub_id', f'has_all_{task_name}']))
        complete.append(tmp)
    complete = pl.concat(complete, how='align').with_columns(
        has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
    )

    complete_subs = complete.filter('has_all').get_column('sub_id').to_list()
else:
    complete_subs = []

In [None]:
all_subids = sorted([dd.parts[-1] for dd in (data_dir/'task/upload').glob('*') if (dd.parts[-1] not in bad_dirs)])
needed_subids = [sid for sid in all_subids if sid not in complete_subs]
# go ahead and just drop all the needed_subids
if not from_scratch:
    for task_name in tasks:
        tdf = task_dat[task_name]
        task_dat[task_name] = tdf.filter(~pl.col('sub_id').is_in(needed_subids))

In [None]:

task_jobs = {task_name:[] for task_name in tasks}
breakout=False
for subject in needed_subids:
    sub_task_dir = task_data_dir / subject
    for task_name in tasks:
        for runnum in [0,1,2]:
            if runnum == 2 and task_name != 'rdm':
                continue
            zipped_path = sub_task_dir / f'{task_name}_{runnum}.zip'
            if zipped_path.exists():
                file_date = datetime.fromtimestamp(zipped_path.stat().st_mtime)
                if file_date < start_time:
                    continue
                loddf = delayed(load_task)(zipped_path, task_name, subject, runnum)
                task_jobs[task_name].append(loddf)
            else:
                continue

addl_task_dat = {task_name:[] for task_name in tasks}
for task_name in tasks:
    addl_task_dat[task_name] = pl.concat(Parallel(n_jobs=8, verbose=10)(task_jobs[task_name]), how='diagonal_relaxed')

## process exclusion boxcox outlier exclusion criteria

In [None]:


for task_name in tasks:
    low_limit, high_limit = limits[task_name]
    tdf = addl_task_dat[task_name]
    if task_name == 'bart':
        # For bart, just drop trials with weird out of range RTs
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > 0) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('rt_inbounds')
        ).with_columns(
            pl.col('rt_inbounds').alias('keep')
        )
    else:
        # For non-bart, mark nonresponse trials, don't inlcude them in the initial bc_mask
        # also don't include trials with weird out of bounds rts in the initial bc_mask
        # Keep non-response trials along with trials passing box-cox filtering
        tdff = tdf.with_columns(
            pl.when((pl.col('rt') > 0) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('rt_inbounds'),
            pl.when(pl.col('rt').is_null()).then(True).otherwise(False).alias('nonresponse'),
        ).with_columns(
            pl.when((~pl.col('rt_inbounds')) | pl.col('nonresponse')).then(False).otherwise(True).alias('bc_mask')
        ).with_columns(
            bc_mask=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: boxcoxmask(x), return_dtype=pl.Boolean).over(pl.col('sub_id')),
        ).with_columns(
            bc_rt=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: nanboxcox(x), return_dtype=pl.Float64).fill_nan(None).over(pl.col('sub_id')),
        ).with_columns(
            bc_z_rt=pl.when(pl.col('bc_mask')).then((pl.col('rt') - pl.col('rt').mean()) / pl.col('rt').std()).over(pl.col('sub_id'))
        ).with_columns(
            pl.when(pl.col('nonresponse') | pl.col('bc_mask')).then(True).otherwise(False).alias('keep')
        )
    addl_task_dat[task_name] = tdff
    

In [None]:
if not from_scratch:
    for task_name in tasks: 
        tdf = task_dat[task_name]
        atdf = addl_task_dat[task_name]
        ctdf = pl.concat([tdf, atdf], how='diagonal_relaxed')
        task_dat[task_name] = ctdf
else:
    task_dat = {task_name:[] for task_name in tasks}
    for task_name in tasks:
        task_dat[task_name] = addl_task_dat[task_name]

# identify subjects that have the expected number of blocks per task
complete = []
for task_name in tasks:
    tdf = task_dat[task_name]
    expected_n = 2
    if task_name== 'rdm':
        expected_n = 3
    tmp = (tdf.group_by('sub_id').n_unique()
        .select(['sub_id', 'zrn']).with_columns(
            (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
            ).select(['sub_id', f'has_all_{task_name}']))
    complete.append(tmp)
complete = pl.concat(complete, how='align').with_columns(
    has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
)

In [None]:
# double check that exclusion criteria are correct
for task_name in tasks:
    if task_name != 'bart':
        tdff = task_dat[task_name]
        # check that we're correctly picking trials to include
        assert (tdff.filter(pl.col('bc_mask') | pl.col('nonresponse'))).select('keep').to_series().all()


In [None]:
#sub_ids corresponding to test participants
bad_subids = [
    'no0z2yzyloa58hcsb5cyxxwz',
    'jvj53cg6gm44jattfxws849e',
    'fqebjziam9e7e9vnpzqghiv9',
    'lwk7rgfebcajlfttz1f3euzs',
    'b1c6cj5oy3wv9sh4qyj379s9',
    'in6dp60i65swuwbnbyjz8m6i',
    's7qczd3ccbwvvv54rkg2s3xh',
    '3y3tn37wv2libdutqxbcat3d',
    'p1h1eval1q08k2beesprnfwq'
]
for task_name in tasks:
    print(task_name)
    tdff = task_dat[task_name]

    if len(tdff.filter(pl.col('sub_id').is_in(bad_subids))):
        print(len(tdff.filter(pl.col('sub_id').is_in(bad_subids))))

In [None]:

# for task_name in tasks:
#     low_limit, high_limit = limits[task_name]
#     tdf = addl_task_dat[task_name]
#     if task_name == 'bart':
#         tdff = tdf.with_columns(
#             pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('og_mask')
#         )
#     else:
#         tdff = tdf.with_columns(
#             pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('og_mask'),
#             pl.when((pl.col('rt') > low_limit) & (pl.col('rt') < high_limit)).then(True).otherwise(False).alias('bc_mask')
#         ).with_columns(
#             bc_mask=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: boxcoxmask(x), return_dtype=pl.Boolean).over(pl.col('sub_id')),
#         ).with_columns(
#             bc_rt=pl.when(pl.col('bc_mask')).then(pl.col('rt')).map_batches(lambda x: nanboxcox(x), return_dtype=pl.Float64).fill_nan(None).over(pl.col('sub_id')),
#         ).with_columns(
#             bc_z_rt=pl.when(pl.col('bc_mask')).then((pl.col('rt') - pl.col('rt').mean()) / pl.col('rt').std()).over(pl.col('sub_id'))
#         )
#     addl_task_dat[task_name] = tdff
    

In [None]:
# if not from_scratch:
#     for task_name in tasks: 
#         tdf = task_dat[task_name]
#         atdf = addl_task_dat[task_name]
#         ctdf = pl.concat([tdf, atdf], how='diagonal_relaxed')
#         task_dat[task_name] = ctdf
# else:
#     for task_name in tasks:
#         task_dat[task_name] = addl_task_dat[task_name]

# complete = []
# for task_name in tasks:
#     tdf = task_dat[task_name]
#     expected_n = 2
#     if task_name== 'rdm':
#         expected_n = 3
#     tmp = (tdf.group_by('sub_id').n_unique()
#         .select(['sub_id', 'zrn']).with_columns(
#             (pl.col('zrn')==expected_n).alias(f'has_all_{task_name}')
#             ).select(['sub_id', f'has_all_{task_name}']))
#     complete.append(tmp)
# complete = pl.concat(complete, how='align').with_columns(
#     has_all=(pl.col('has_all_flkr') & pl.col('has_all_cab') & pl.col('has_all_rdm') & pl.col('has_all_bart'))
# )

In [None]:
# # double check that exclusion criteria are correct
# for task_name in tasks:
#     if task_name != 'bart':
#         tdff = task_dat[task_name]
#         # check that time limits are respected
#         assert (tdff.filter(~pl.col('og_mask') & pl.col('bc_mask'))).is_empty()
#         assert (tdff.filter(pl.col('bc_mask')).min().select('rt') > limits[task_name][0])[0, 'rt']
#         assert (tdff.filter(pl.col('bc_mask')).max().select('rt') < limits[task_name][1])[0, 'rt']

In [None]:
for task_name in tasks:
    old_task_dat = pl.read_parquet(data_dir/f'{task_name}.parquet')
    tdff = task_dat[task_name]
    if tdff.equals(old_task_dat):
        continue
    else:
        old_task_dat.write_parquet(data_dir/f'{task_name}_old.parquet')
        tdff.write_parquet(data_dir/f'{task_name}.parquet')

In [None]:
# process subject level exclusion
# subjects are excluded if:
# 1) they have below chance performance
# 2) > 5 or 10 % of trials are non-response or exluded by box-cox
# bc_mask is true if the trial is within RT response bounds, has a response, and passed box-cox outlier exclusion
tkeeps = []
for task_name in tasks:
    if task_name == 'cab':
        task_dat[task_name] = task_dat[task_name].with_columns(
            (pl.col('resp_acc') == True).alias('correct')
        )
    tdff = task_dat[task_name]
    
    if task_name == 'rdm':
        tkeep = tdff.group_by(pl.col('sub_id')).agg(
            pl.sum('bc_mask').alias('bc_mask'),
            pl.last('date').alias('date')
        )
        tdff = tdff.with_columns(coh_dif=(pl.col('left_coherence') - pl.col('right_coherence')).abs())
        tcorr = tdff.filter(pl.col('coh_dif') > 0).group_by(pl.col('sub_id')).sum().select(['sub_id','correct'])
        tkeep = tkeep.join(tcorr, how='left', on='sub_id')
    elif task_name == 'bart':
        tkeep = tdff.group_by(['sub_id', 'zrn', 'balloon_id', 'date']).first().with_columns(
            correct=pl.col('key_pressed') == pl.col('pump_button')
        ).group_by(pl.col('sub_id')).agg(
            pl.sum('correct').alias('correct'),
            pl.last('date').alias('date')
        ).with_columns(
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}')
        ).with_columns(
            pl.col(f'corr_ok_{task_name}').alias(f'good_{task_name}')
        ).with_columns(
            (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
        ).rename({'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})
    else:
        tkeep =  tdff.group_by(pl.col('sub_id')).agg(
            pl.sum('bc_mask').alias('bc_mask'),
            pl.sum('correct').alias('correct'),
            pl.last('date').alias('date')
        )
        
    if task_name != 'bart':
        tkeep = tkeep.with_columns(
            (pl.col('bc_mask') >= rt05_threshes[task_name]).alias(f'resp05_ok_{task_name}'),
            (pl.col('bc_mask') >= rt10_threshes[task_name]).alias(f'resp10_ok_{task_name}'),
            (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}'),
        ).with_columns(
            (pl.col(f'resp05_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good05_{task_name}'),
            (pl.col(f'resp10_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good10_{task_name}'),
        ).with_columns(
            (~pl.col(f'good05_{task_name}')).alias(f'bad05_{task_name}'),
            (~pl.col(f'good10_{task_name}')).alias(f'bad10_{task_name}')
        ).rename({
            'bc_mask':f'n_good_resps_{task_name}', 'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})

    tkeeps.append(tkeep)

tkeep = pl.concat(tkeeps, how="align").with_columns(
    good05=(pl.col('good05_flkr') & pl.col('good_bart') & pl.col('good05_rdm') & pl.col('good05_cab')),
    good10=(pl.col('good10_flkr') & pl.col('good_bart') & pl.col('good10_rdm') & pl.col('good10_cab')),
).with_columns(
    bad05=~pl.col('good05'),
    bad10=~pl.col('good10')
)

In [None]:
# tkeeps = []
# for task_name in tasks:
#     if task_name == 'cab':
#         task_dat[task_name] = task_dat[task_name].with_columns(
#             (pl.col('resp_acc') == True).alias('correct')
#         )
#     tdff = task_dat[task_name]
    
#     if task_name == 'rdm':
#         tkeep = tdff.group_by(pl.col('sub_id')).agg(
#             pl.sum('bc_mask').alias('bc_mask'),
#             pl.sum('og_mask').alias('og_mask'),
#             pl.last('date').alias('date')
#         )
#         tdff = tdff.with_columns(coh_dif=(pl.col('left_coherence') - pl.col('right_coherence')).abs())
#         tcorr = tdff.filter(pl.col('coh_dif') > 0).group_by(pl.col('sub_id')).sum().select(['sub_id','correct'])
#         tkeep = tkeep.join(tcorr, how='left', on='sub_id')
#     elif task_name == 'bart':
#         tkeep = tdff.group_by(['sub_id', 'zrn', 'balloon_id', 'date']).first().with_columns(
#             correct=pl.col('key_pressed') == pl.col('pump_button')
#         ).group_by(pl.col('sub_id')).agg(
#             pl.sum('correct').alias('correct'),
#             pl.last('date').alias('date')
#         ).with_columns(
#             (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}')
#         ).with_columns(
#             pl.col(f'corr_ok_{task_name}').alias(f'good_{task_name}')
#         ).with_columns(
#             (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
#         ).rename({'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})
#     else:
#         tkeep =  tdff.group_by(pl.col('sub_id')).agg(
#             pl.sum('bc_mask').alias('bc_mask'),
#             pl.sum('og_mask').alias('og_mask'),
#             pl.sum('correct').alias('correct'),
#             pl.last('date').alias('date')
#         )
        
#     if task_name != 'bart':
#         tkeep = tkeep.with_columns(
#             (pl.col('og_mask') >= rt_threshes[task_name]).alias(f'ogrt_ok_{task_name}'),
#             (pl.col('bc_mask') >= rt_threshes[task_name]).alias(f'rt_ok_{task_name}'),
#             (pl.col('correct') >= corr_threshes[task_name]).alias(f'corr_ok_{task_name}'),
#         ).with_columns(
#             (pl.col(f'rt_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'good_{task_name}'),
#             (pl.col(f'ogrt_ok_{task_name}') & pl.col(f'corr_ok_{task_name}')).alias(f'oggood_{task_name}')
#         ).with_columns(
#             (~pl.col(f'good_{task_name}')).alias(f'bad_{task_name}')
#         ).rename({'og_mask':f'n_good_ogrts_{task_name}',
#                    'bc_mask':f'n_good_rts_{task_name}', 'correct': f'n_correct_{task_name}', 'date':f'date_{task_name}'})

#     tkeeps.append(tkeep)

# tkeep = pl.concat(tkeeps, how="align").with_columns(
#     good=(pl.col('good_flkr') & pl.col('good_bart') & pl.col('good_rdm') & pl.col('good_cab')),
#     oggood=(pl.col('oggood_flkr') & pl.col('good_bart') & pl.col('oggood_rdm') & pl.col('oggood_cab')),
# ).with_columns(
#     bad=~pl.col('good')
# )

In [None]:
tkeep = tkeep.join(complete.select(['sub_id', 'has_all']), on='sub_id')

In [None]:
# Write out a data quality CSV for participant filtering
tkeep.write_csv(task_tomodel_dir / 'data_quality.csv')
# write out data for all subjects
# no need to do this until we've finished preregistration
# for task_name in tasks:
#     tdf = task_dat[task_name]
#     subjects = tdf.select('sub_id').unique().to_numpy().flatten()

#     for sub_id in subjects:
#         sdf = tdf.filter(pl.col('sub_id') == sub_id)
#         sdf.write_csv(out_dir / f'{task_name}-{sub_id}.csv')
    

In [None]:
tkeep = tkeep.with_columns(
    good_any05=pl.col('good05_flkr') | pl.col('good05_cab') | pl.col('good05_rdm') | pl.col('good_bart'),
    good_any10=pl.col('good10_flkr') | pl.col('good10_cab') | pl.col('good10_rdm') | pl.col('good_bart')
)


# Count completions

In [None]:
tkeep.select('date_flkr').min(), tkeep.select('date_flkr').max(),

In [None]:
tkeep.select('good_any05').sum(), tkeep.select('good_any10').sum()

In [None]:
3659 / 4069

In [None]:
tkeep.filter(pl.col('has_all')).sum().select(['good05', 'good10']), tkeep.filter(pl.col('has_all')).mean().select(['good05', 'good10'])

In [None]:
tkeep.filter(pl.col('has_all')).sum().select(['good05', 'good10']), tkeep.filter(pl.col('has_all')).mean().select(['good05', 'good10'])and

In [None]:
tkeep.filter(pl.col('has_all')).select(['good05', 'good05_flkr', 'good05_cab', 'good_bart', 'good05_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good05', 'good05_flkr', 'good05_cab', 'good_bart', 'good05_rdm', 'has_all']).sum(), tkeep.filter(pl.col('has_all')).select(['good10', 'good10_flkr', 'good10_cab', 'good_bart', 'good10_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good10', 'good10_flkr', 'good10_cab', 'good_bart', 'good10_rdm', 'has_all']).sum()

In [None]:
tkeep.filter(pl.col('has_all')).select(['good05', 'good05_flkr', 'good05_cab', 'good_bart', 'good05_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good05', 'good05_flkr', 'good05_cab', 'good_bart', 'good05_rdm', 'has_all']).sum(), tkeep.filter(pl.col('has_all')).select(['good10', 'good10_flkr', 'good10_cab', 'good_bart', 'good10_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good10', 'good10_flkr', 'good10_cab', 'good_bart', 'good10_rdm', 'has_all']).sum()

In [None]:
tkeep.filter(pl.col('has_all')).select(['good10', 'good10_flkr', 'good10_cab', 'good_bart', 'good10_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good10', 'good10_flkr', 'good10_cab', 'good_bart', 'good10_rdm', 'has_all']).sum()

In [None]:
tkeep.filter(pl.col('has_all')).select(['good', 'good_flkr', 'good_cab', 'good_bart', 'good_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good', 'good_flkr', 'good_cab', 'good_bart', 'good_rdm', 'has_all']).sum()

In [None]:
tkeep.filter(pl.col('has_all')).select(['good', 'good_flkr', 'good_cab', 'good_bart', 'good_rdm']).mean(), tkeep.filter(pl.col('has_all')).select(['good', 'good_flkr', 'good_cab', 'good_bart', 'good_rdm', 'has_all']).sum()

In [None]:
tkeep.sum().select('has_all')

# load survey responses

In [None]:
survey_resps = [load_survey(sr_path) for sr_path in complete_dir.glob("*.json")]
srdf = proc_survey(survey_resps)

In [None]:
screen_cols = [
    'experience_depression',
    'experience_anxiety',
    'have_adhd',
    'mentalhealth_daily_impact',
    'screen_group',
    'mood_pro_diagnosis',
    'anxiety_pro_diagnosis',
    'attention_pro_diagnosis'
]

In [None]:
stkeep = tkeep.join(srdf.select(['sub_id', 'survey_date'] + screen_cols), how='inner', on='sub_id')

In [None]:
cstkeep = stkeep.filter('has_all')

In [None]:
cstkeep = cstkeep.with_columns(
        final_block_date=cstkeep.select(['date_flkr', 'date_bart', 'date_rdm', 'date_cab']).max_horizontal()
    ).with_columns(
        st_lag = (pl.col('final_block_date') - pl.col('survey_date')).dt.total_minutes()
    )

In [None]:
cstkeep.select('st_lag').describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999])

In [None]:
group_order = ['hv', 'dep', 'anx', 'atn', 'dep_anx', 'dep_atn', 'anx_atn', 'dep_anx_atn', 'othermh']
order_mapping = {val: i for i, val in enumerate(group_order)}

In [None]:
cstkeep.group_by('screen_group').agg([
    pl.sum('good10').alias('n_good'),
    pl.sum('has_all').alias('n_complete'),
    pl.mean('good10').alias('frac_good').round(2),
    ]).sort(pl.col("screen_group").replace(order_mapping)), 'foo'

In [None]:
cstkeep.group_by('screen_group').agg([
    pl.sum('good').alias('n_good'),
    pl.sum('has_all').alias('n_complete'),
    pl.mean('good').alias('frac_good').round(2),
    ]).sort(pl.col("screen_group").replace(order_mapping)), 'foo'and

In [None]:
import scipy

In [None]:
scipy.stats.chi2_contingency(cstkeep.group_by('screen_group').agg([
    pl.sum('good10').alias('n_good'),
    pl.sum('has_all').alias('n_complete')]).sort(pl.col("screen_group").replace(order_mapping)).select(['n_good', 'n_complete']).to_numpy())

In [None]:
overall_good = cstkeep.sort(pl.col('final_block_date')).group_by_dynamic('final_block_date', every='1d').agg(
    pl.col('good10').sum().alias("daily_n_good"),
    pl.col('has_all').sum().alias("daily_n_complete"),
    pl.col('good10').mean().alias("daily_good_rate"),
).with_columns(
    pl.col('final_block_date').sub(start_time).dt.total_days().alias('study_day'),
    tng:=pl.col('daily_n_good').cum_sum().alias('total_n_good'),
    tnc:=pl.col('daily_n_complete').cum_sum().alias('total_n_complete'),
    (tng/tnc).alias('total_good_rate')
).with_columns(
    pl.when(pl.col('study_day')>=14).then(pl.col('study_day')-11).otherwise('study_day').alias('study_day')
)


In [None]:
c1,c2,c3,c4 = sns.color_palette('Paired', 4)

In [None]:
fig, lax = plt.subplots(1)
lax.plot(overall_good['study_day'], overall_good['total_n_complete'], color=c2, label='n_complete')
lax.plot(overall_good['study_day'], overall_good['total_n_good'], color=c4, label='n_good')

rax = lax.twinx()
sns.barplot(data=overall_good, x='study_day', y='daily_n_complete',  color=c1, ax=rax)
sns.barplot(data=overall_good, x='study_day', y='daily_n_good',color=c3, ax=rax)
rax.set_zorder(0)
lax.set_zorder(100)
lax.patch.set_alpha(0)
lax.set_ylabel('Total Counts (lines)')
rax.set_ylabel('Daily Counts (bars)')
lax.set_xlabel('Study Day')
fig.legend()

In [None]:
cstkeep.sort(pl.col('final_block_date')).group_by_dynamic('final_block_date', every='1d', group_by='screen_group').agg(
    pl.col('good').sum().alias("n_good"),
    pl.col('has_all').sum().alias("n_complete"),
    pl.col('good').mean().alias("good_rate")
)

# add in prolific info

In [None]:
pdf = (
    pl.read_csv(data_dir / 'sid_prolific_demo.tsv', separator='\t')
    .filter(pl.col('status') != 'Status')
    .filter(pl.col('sub_id').is_not_null())
    .group_by('sub_id').last()
    .with_columns(
        pl.when(pl.col('depression') == "Yes").then(True).when(pl.col('depression') == "No").then(False).alias('depression'),
        pl.when(pl.col('anxiety') == "Yes").then(True).when(pl.col('anxiety') == "No").then(False).alias('anxiety'),
        pl.when(pl.col('attention') == "Yes").then(True).when(pl.col('attention') == "No").then(False).alias('attention'),
        pl.when(pl.col('mental_health_ongoing') == "Yes").then(True).when(pl.col('mental_health_ongoing') == "No").then(False).alias('mental_health_ongoing'),

    )
)

In [None]:
pkeep = pdf.join(cstkeep, how='right', on='sub_id')

In [None]:
pkeep['depression'].unique()

In [None]:
pkeep = pkeep.with_columns(
        p_screen_group=pl.when(
            ~pl.col("mental_health_ongoing")
            & ~pl.col("depression")
            & ~pl.col("anxiety")
            & ~pl.col("attention")
        )
        .then(pl.lit("hv"))
        .when(
            pl.col("mental_health_ongoing")
            & ~pl.col("depression")
            & ~pl.col("anxiety")
            & ~pl.col("attention")
        )
        .then(pl.lit("othermh"))
        .when(
            pl.col("depression")
            & ~pl.col("anxiety")
            & ~pl.col("attention")
        )
        .then(pl.lit("dep"))
        .when(
            ~pl.col("depression")
            & pl.col("anxiety")
            & ~pl.col("attention")
        )
        .then(pl.lit("anx"))
        .when(
            ~pl.col("depression")
            & ~pl.col("anxiety")
            & pl.col("attention")
        )
        .then(pl.lit("atn"))
        .when(
            pl.col("depression")
            & pl.col("anxiety")
            & ~pl.col("attention")
        )
        .then(pl.lit("dep_anx"))
        .when(
            pl.col("depression")
            & ~pl.col("anxiety")
            & pl.col("attention")
        )
        .then(pl.lit("dep_atn"))
        .when(
            ~pl.col("depression")
            & pl.col("anxiety")
            & pl.col("attention")
        )
        .then(pl.lit("anx_atn"))
        .when(
            pl.col("depression")
            & pl.col("anxiety")
            & pl.col("attention")
        )
        .then(pl.lit("dep_anx_atn"))
    )

In [None]:
p_atn = pkeep.group_by(['p_screen_group', 'screen_group']).agg([
    pl.sum('good').alias('n_good'),
    pl.sum('has_all').alias('n_complete'),
    pl.mean('good').alias('frac_good').round(2),
    ]).sort(pl.col("p_screen_group").replace(order_mapping)).filter(pl.col("p_screen_group") == 'atn')

In [None]:
p_atn = pkeep.filter((pl.col('p_screen_group')=='atn') & (pl.col('status')=="APPROVED")).with_columns(
    is_atn=pl.col('screen_group') == 'atn',
    is_good_atn = (pl.col('screen_group') == 'atn') & pl.col('good')
)
p_atn.mean()

In [None]:
p_atn.shape

In [None]:
p_atn.group_by('screen_group').agg([
    pl.sum('good').alias('n_good') / 35,
    pl.sum('has_all').alias('n_complete'),
    pl.mean('good').alias('frac_good').round(2),
    ]).sort(pl.col("screen_group").replace(order_mapping))

In [None]:
pkeep.group_by(['screen_group', 'p_screen_group']).agg([
    pl.sum('good').alias('n_good'),
    pl.sum('has_all').alias('n_complete'),
    pl.mean('good').alias('frac_good').round(2),
    ]).sort(pl.col("screen_group").replace(order_mapping)).filter(pl.col("screen_group") == 'atn'), 'foo'