In [9]:
import scipy.stats

import polars as pl


df = (
    pl.scan_parquet("//allen/programs/mindscope/workgroups/dynamicrouting/ben/unit_drift.parquet")
    .select('unit_id', 'drift_rating')
    .filter(pl.col('unit_id').str.ends_with('_ks4').not_())
    .with_columns(
        session_id=pl.col('unit_id').str.split('_').list.slice(0, 2).list.join('_'),
    )
    .drop_nulls('drift_rating')
    # get spike-counts ------------------------------------------------- #
    .join(
        other=(
            pl.scan_parquet("//allen/programs/mindscope/workgroups/dynamicrouting/ben/spike_counts.parquet")
        ),
        on='unit_id',
        how='left',
    )
    .join(
        other=(
            pl.scan_parquet("s3://aind-scratch-data/dynamic-routing/cache/nwb_components/v0.0.261/consolidated/trials.parquet")
            .select('session_id', 'trial_index', 'block_index', 'context_name', 'start_time', 'stim_name')
        ),
        on=['session_id', 'trial_index',],
        how='inner',
    )
).collect()

In [3]:
df

unit_id,drift_rating,session_id,trial_index,baseline,response,block_index,context_name,start_time
str,i32,str,i64,i64,i64,i64,str,f64
"""620263_2022-07-26_F-225""",0,"""620263_2022-07-26""",0,6,46,0,"""vis""",111.21154
"""620263_2022-07-26_F-225""",0,"""620263_2022-07-26""",1,7,48,0,"""vis""",116.09859
"""620263_2022-07-26_F-225""",0,"""620263_2022-07-26""",2,10,40,0,"""vis""",120.78591
"""620263_2022-07-26_F-225""",0,"""620263_2022-07-26""",3,0,43,0,"""vis""",126.04027
"""620263_2022-07-26_F-225""",0,"""620263_2022-07-26""",4,0,79,0,"""vis""",130.66081
…,…,…,…,…,…,…,…,…
"""742903_2024-10-24_C-153""",1,"""742903_2024-10-24""",528,1,0,5,"""vis""",6019.56103
"""742903_2024-10-24_C-153""",1,"""742903_2024-10-24""",529,0,1,5,"""vis""",6025.69953
"""742903_2024-10-24_C-153""",1,"""742903_2024-10-24""",530,0,0,5,"""vis""",6033.33927
"""742903_2024-10-24_C-153""",1,"""742903_2024-10-24""",531,0,0,5,"""vis""",6039.0774


In [None]:
import tqdm
import scipy.stats

class NullResult:
    statistic = None
    pvalue = None

null_result = NullResult()

# pvalue_method = scipy.stats.PermutationMethod(n_resamples=1_000)
pvalue_method = None
results = []
iterable = tuple(df.drop_nulls(['baseline', 'response']).group_by('unit_id', 'context_name', 'stim_name'))
for (unit_id, context_name, stim_name, *_), unit_df in tqdm.tqdm(iterable):
    if unit_df.n_unique('block_index') < 2:
        # for Templeton we need spike counts for segments of time
        continue
    if unit_df['baseline'].sum() == 0:
        baseline = null_result
    else:
        baseline = scipy.stats.anderson_ksamp(unit_df.group_by('block_index').agg(pl.col('baseline')).get_column('baseline'), method=pvalue_method)
    if unit_df['response'].sum() == 0:
        response = null_result
    else:
        response = scipy.stats.anderson_ksamp(unit_df.group_by('block_index').agg(pl.col('response')).get_column('response'), method=pvalue_method)
    if response == null_result and baseline == null_result:
        trial = null_result
    else:
        trial = scipy.stats.anderson_ksamp(unit_df.group_by('block_index').agg((pl.col('baseline') + pl.col('response')).alias('trial')).get_column('trial'), method=pvalue_method)
    result = dict(
        unit_id=unit_id,
        context_name=context_name,
        stim_name=stim_name,
        ad_stat_baseline=baseline.statistic,
        ad_stat_response=response.statistic,
        ad_stat_trial=trial.statistic,
        ad_p_baseline=baseline.pvalue,
        ad_p_response=response.pvalue,
        ad_p_trial=trial.pvalue,
    )
    results.append(result)
    # break
result_df = pl.DataFrame(results)
max_min_df = (
    result_df
    .select(
        'unit_id',
        ad_stat_max_baseline=pl.col('ad_stat_baseline').max().over('unit_id'),
        ad_stat_max_response=pl.col('ad_stat_response').max().over('unit_id'),
        ad_stat_max_trial=pl.col('ad_stat_trial').max().over('unit_id'),
        ad_p_min_baseline=pl.col('ad_p_baseline').min().over('unit_id'),
        ad_p_min_response=pl.col('ad_p_response').min().over('unit_id'),
        ad_p_min_trial=pl.col('ad_p_trial').min().over('unit_id'),
    )
    .unique('unit_id')
)
max_min_df.write_parquet('//allen/programs/mindscope/workgroups/dynamicrouting/ben/ad_test.parquet')
max_min_df

In [None]:
len(result_df), len(max_min_df)

In [17]:
len(result_df), len(max_min_df)

(14547, 7290)

In [11]:
result_df.filter(pl.col('unit_id') == '741137_2024-10-09_E-223')

unit_id,context_name,stim_name,ad_stat_baseline,ad_stat_response,ad_stat_trial,ad_p_baseline,ad_p_response,ad_p_trial
str,str,str,f64,f64,f64,f64,f64,f64
"""741137_2024-10-09_E-223""","""aud""","""sound2""",0.639378,3.154024,0.906961,0.204525,0.013418,0.152958
"""741137_2024-10-09_E-223""","""aud""","""catch""",1.087416,2.064874,1.093305,0.125753,0.043587,0.124952
"""741137_2024-10-09_E-223""","""aud""","""vis2""",-0.9903,1.889781,-0.535186,0.25,0.052689,0.25
"""741137_2024-10-09_E-223""","""vis""","""catch""",1.814996,5.037134,4.209285,0.057136,0.001761,0.004296
"""741137_2024-10-09_E-223""","""vis""","""vis2""",8.953638,8.717206,14.320727,0.001,0.001,0.001
"""741137_2024-10-09_E-223""","""vis""","""sound1""",10.020753,7.594551,15.535587,0.001,0.001,0.001
"""741137_2024-10-09_E-223""","""vis""","""vis1""",2.599881,7.867481,5.783227,0.024428,0.001,0.001
"""741137_2024-10-09_E-223""","""aud""","""vis1""",1.599188,6.490163,7.933298,0.07219,0.001,0.001
"""741137_2024-10-09_E-223""","""aud""","""sound1""",0.836414,25.777405,24.543309,0.165132,0.001,0.001
"""741137_2024-10-09_E-223""","""vis""","""sound2""",7.481738,5.377495,10.243103,0.001,0.001221,0.001


In [13]:
max_min_df.filter(pl.col('unit_id') == '741137_2024-10-09_E-223')

unit_id,ad_stat_max_baseline,ad_stat_max_response,ad_stat_max_trial,ad_p_min_baseline,ad_p_min_response,ad_p_min_trial
str,f64,f64,f64,f64,f64,f64
"""741137_2024-10-09_E-223""",10.020753,25.777405,24.543309,0.001,0.001,0.001
