In [8]:
import polars as pl
import pandas as pd
import polars.selectors as cs 
import numpy as np
from scipy.stats import sem, t
import seaborn as sns
from inspect import signature
from typing import Callable
import matplotlib.pyplot as plt 

In [163]:
def list_eval_ref(
    listcol: pl.Expr | str,
    op: Callable[..., pl.Expr],
    *ref_cols: str | pl.Expr,
):
    if len(ref_cols)==0:
        ref_cols = tuple([x for x in signature(op).parameters.keys()][1:])
    
    args_to_op = [pl.element().struct[0].explode()] + [
        pl.element().struct[i + 1] for i in range(len(ref_cols))
    ]
    return pl.concat_list(pl.struct(listcol, *ref_cols)).list.eval(op(*args_to_op))

def processing(df: pl.DataFrame):
    df = df.sort(pl.col('subjectID', 'trial'))
    match = cs.matches(r'^valArm(?:[1-9]|1[0-9]|20)$')
    noise_arm_match = cs.matches(r'^noiseArm(?:[1-9]|1[0-9]|20)$')
    rand_arm_match = cs.matches(r'^randArm(?:[1-9]|1[0-9]|20)$')
    train_df = df.filter(pl.col('phase') == 'training')

    train_df = train_df.with_columns(
        all_arm_vals = pl.concat_list(match),
        all_noise_vals = pl.concat_list(noise_arm_match),
        rand_arm_vals = pl.concat_list(rand_arm_match)
    )

    train_df = train_df.with_columns(
        true_arm_vals = pl.struct('all_arm_vals', 'all_noise_vals')
        .map_batches(
            lambda x: np.array(x.struct.field('all_arm_vals')) - np.array(x.struct.field('all_noise_vals'))
        )
    )
    # get rank of the chosen arm 
    # the true rank is defined by the pre noise value ranking 

    train_df = train_df.with_columns(
        true_selected_arm_val = pl.struct('chosenArm', 'rand_arm_vals')
        .map_batches(
            lambda x: x.struct.field('rand_arm_vals').list.get(x.struct.field('chosenArm') - 1)
        )
    )

    return train_df
    print(train_df)
    train_df = train_df.with_columns(   
        all_arm_vals_shifted=list_eval_ref("true_arm_vals", lambda x, all_noise_vals: x-all_noise_vals)
        .list
        .sort(descending=True)
    )
    train_df = train_df.with_columns(pl.col('all_arm_vals_shifted').list.eval((pl.element() == 0).cast(int).arg_max()).list.first().alias('chosenRank'))
    grouped_chosen_ranks = train_df.group_by('trial', 'expCond').agg(pl.col('chosenRank').mean())
    return train_df, grouped_chosen_ranks


In [131]:
df = pl.read_csv('/Users/jeremiahetiosaomeike/research_projects/lbd/data/exp1_banditData.csv', null_values='NA')
df = df.sort(pl.col('subjectID', 'trial'))

In [147]:
df

browser,platform,subjectID,expID,expCond,totalTime,phase,trial,keyPress,chosenArm,choiceRT,rewardObtained,rewardMax,regret,correct,switch,runningTotal,valArm1,randArm1,valArm2,randArm2,valArm3,randArm3,valArm4,randArm4,valArm5,randArm5,valArm6,randArm6,valArm7,randArm7,valArm8,randArm8,valArm9,randArm9,valArm10,randArm10,…,valArm12feat2,noiseArm12,nameArm12,valArm13feat1,valArm13feat2,noiseArm13,nameArm13,valArm14feat1,valArm14feat2,noiseArm14,nameArm14,valArm15feat1,valArm15feat2,noiseArm15,nameArm15,valArm16feat1,valArm16feat2,noiseArm16,nameArm16,valArm17feat1,valArm17feat2,noiseArm17,nameArm17,valArm18feat1,valArm18feat2,noiseArm18,nameArm18,valArm19feat1,valArm19feat2,noiseArm19,nameArm19,valArm20feat1,valArm20feat2,noiseArm20,nameArm20,weight1,weight2
str,str,i64,str,str,f64,str,i64,i64,i64,i64,f64,f64,f64,i64,i64,f64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,…,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,i64,i64
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",1,19,20,9712,2.721505,3.561307,0.839802,0,0,22.721505,1.327617,2,1.718571,9,1.876731,6,2.88524,10,2.065993,15,1.109209,8,3.464238,12,1.913467,11,1.214639,1,2.352129,4,…,0.217513,-0.158846,"""inter""",0.72645,0.867909,1.099039,"""inter""",0.216667,0.821464,-0.494095,"""inter""",0.127428,0.167216,-1.310707,"""inter""",0.283032,0.525127,-0.138439,"""inter""",0.711361,0.112186,-0.495821,"""inter""",0.753527,0.607805,1.323672,"""inter""",0.677082,0.632203,0.998222,"""inter""",0.803621,0.732499,0.452885,"""inter""",1,2
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""test""",1,1,2,5477,1.418656,1.418656,0.0,1,0,182.213264,0.449713,2,1.418656,1,-0.152609,3,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,2
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",2,18,19,5833,1.082311,3.991859,2.909548,0,1,23.803816,1.207088,2,2.114623,9,2.573887,6,2.285075,10,1.24303,15,1.525026,8,1.987652,12,-0.849392,11,2.314265,1,3.315799,4,…,0.217513,0.562236,"""inter""",0.72645,0.867909,-0.09522,"""inter""",0.216667,0.821464,-1.032388,"""inter""",0.127428,0.167216,0.999521,"""inter""",0.283032,0.525127,2.658572,"""inter""",0.711361,0.112186,0.230814,"""inter""",0.753527,0.607805,0.142077,"""inter""",0.677082,0.632203,-0.859178,"""inter""",0.803621,0.732499,-0.466765,"""inter""",1,2
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""test""",2,2,3,1723,2.45246,2.45246,0.0,1,1,184.665724,1.089181,3,0.988944,2,2.45246,1,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,2
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",3,0,1,2766,0.619955,3.496261,2.876306,0,1,24.423771,0.619955,2,-1.048898,9,1.276719,6,2.189622,10,3.052373,15,1.861023,8,1.53738,12,2.245228,11,1.266903,1,2.433701,4,…,0.217513,0.19402,"""inter""",0.72645,0.867909,1.033992,"""inter""",0.216667,0.821464,-1.014614,"""inter""",0.127428,0.167216,-0.023868,"""inter""",0.283032,0.525127,0.898417,"""inter""",0.711361,0.112186,0.085191,"""inter""",0.753527,0.607805,0.779204,"""inter""",0.677082,0.632203,-0.147048,"""inter""",0.803621,0.732499,-0.956391,"""inter""",1,2
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",96,16,17,264,2.934323,3.056532,0.122209,0,1,200.422783,1.992445,4,2.409771,13,1.149279,2,0.934874,6,0.319049,19,-0.370369,12,0.159502,5,0.704094,1,-0.536397,17,0.419553,11,…,0.627812,-0.349723,"""inter""",0.416415,0.165412,-1.028168,"""inter""",0.11768,0.883581,-0.567994,"""inter""",0.451982,0.624603,-1.290522,"""inter""",0.636182,0.801538,0.012748,"""inter""",0.892148,0.568857,0.581169,"""inter""",0.489443,0.636078,0.428402,"""inter""",0.811643,0.426754,-1.049746,"""inter""",0.749234,0.445612,1.112453,"""inter""",2,1
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",97,16,17,326,1.951474,3.095261,1.143787,0,0,202.374257,0.922237,4,0.611632,13,-1.372232,2,0.871147,6,0.255031,19,-0.003087,12,2.084547,5,2.472242,1,1.92802,17,-0.035103,11,…,0.627812,-1.004059,"""inter""",0.416415,0.165412,0.164106,"""inter""",0.11768,0.883581,-0.115076,"""inter""",0.451982,0.624603,0.766353,"""inter""",0.636182,0.801538,1.021359,"""inter""",0.892148,0.568857,-0.401679,"""inter""",0.489443,0.636078,0.065846,"""inter""",0.811643,0.426754,-0.518802,"""inter""",0.749234,0.445612,-0.193789,"""inter""",2,1
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",98,16,17,459,1.939091,4.328614,2.389522,0,0,204.313349,1.062252,4,4.226407,13,-0.095945,2,2.836889,6,0.814876,19,-0.351499,12,1.972048,5,1.017589,1,1.375276,17,0.753573,11,…,0.627812,2.080083,"""inter""",0.416415,0.165412,-0.685837,"""inter""",0.11768,0.883581,1.08956,"""inter""",0.451982,0.624603,1.32574,"""inter""",0.636182,0.801538,1.533621,"""inter""",0.892148,0.568857,-0.414062,"""inter""",0.489443,0.636078,-0.658163,"""inter""",0.811643,0.426754,0.66962,"""inter""",0.749234,0.445612,0.190015,"""inter""",2,1
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",99,18,19,680,1.826546,2.978731,1.152185,0,1,206.139895,1.96475,4,1.037819,13,0.579412,2,2.822571,6,1.758046,19,1.460204,12,1.761822,5,2.552483,1,0.823805,17,0.199735,11,…,0.627812,-0.262451,"""inter""",0.416415,0.165412,-0.088036,"""inter""",0.11768,0.883581,0.738748,"""inter""",0.451982,0.624603,-1.001483,"""inter""",0.636182,0.801538,-0.396266,"""inter""",0.892148,0.568857,0.030941,"""inter""",0.489443,0.636078,0.302163,"""inter""",0.811643,0.426754,-0.223494,"""inter""",0.749234,0.445612,0.343347,"""inter""",2,1


In [164]:
train_df, grouped_chosen_ranks = processing(df)

grouped_chosen_ranks_pd = grouped_chosen_ranks.to_pandas()
g = sns.lmplot(data=grouped_chosen_ranks, x='trial', y='chosenRank', hue='expCond', markers=['o', 's', '^'])
ticks = [0, 25, 50, 75, 100]
g.set(xticks=ticks)
# g.set_ticks([0, 25, 50, 75, 100])
# g.set_xticklabels([0, 25, 50, 75, 100])

shape: (19_300, 146)
┌─────────┬──────────┬───────────┬────────────┬───┬────────────┬───────────┬───────────┬───────────┐
│ browser ┆ platform ┆ subjectID ┆ expID      ┆ … ┆ all_noise_ ┆ rand_arm_ ┆ true_arm_ ┆ true_sele │
│ ---     ┆ ---      ┆ ---       ┆ ---        ┆   ┆ vals       ┆ vals      ┆ vals      ┆ cted_arm_ │
│ str     ┆ str      ┆ i64       ┆ str        ┆   ┆ ---        ┆ ---       ┆ ---       ┆ val       │
│         ┆          ┆           ┆            ┆   ┆ list[f64]  ┆ list[i64] ┆ object    ┆ ---       │
│         ┆          ┆           ┆            ┆   ┆            ┆           ┆           ┆ i64       │
╞═════════╪══════════╪═══════════╪════════════╪═══╪════════════╪═══════════╪═══════════╪═══════════╡
│ chrome  ┆ windows  ┆ 2         ┆ exp1highRe ┆ … ┆ [0.203984, ┆ [2, 9, …  ┆ [1.123633 ┆ 20        │
│         ┆          ┆           ┆ ward       ┆   ┆ 0.557114,  ┆ 5]        ┆ 68 1.1614 ┆           │
│         ┆          ┆           ┆            ┆   ┆ … 0.45288… ┆      


thread '<unnamed>' panicked at crates/polars-core/src/series/ops/null.rs:60:80:
called `Result::unwrap()` on an `Err` value: InvalidOperation(ErrString("nested objects are not allowed"))
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace


PanicException: called `Result::unwrap()` on an `Err` value: InvalidOperation(ErrString("nested objects are not allowed"))

In [106]:
train_df

browser,platform,subjectID,expID,expCond,totalTime,phase,trial,keyPress,chosenArm,choiceRT,rewardObtained,rewardMax,regret,correct,switch,runningTotal,valArm1,randArm1,valArm2,randArm2,valArm3,randArm3,valArm4,randArm4,valArm5,randArm5,valArm6,randArm6,valArm7,randArm7,valArm8,randArm8,valArm9,randArm9,valArm10,randArm10,…,valArm13feat1,valArm13feat2,noiseArm13,nameArm13,valArm14feat1,valArm14feat2,noiseArm14,nameArm14,valArm15feat1,valArm15feat2,noiseArm15,nameArm15,valArm16feat1,valArm16feat2,noiseArm16,nameArm16,valArm17feat1,valArm17feat2,noiseArm17,nameArm17,valArm18feat1,valArm18feat2,noiseArm18,nameArm18,valArm19feat1,valArm19feat2,noiseArm19,nameArm19,valArm20feat1,valArm20feat2,noiseArm20,nameArm20,weight1,weight2,all_arm_vals,all_arm_vals_shifted,chosenRank
str,str,i64,str,str,f64,str,i64,i64,i64,i64,f64,f64,f64,i64,i64,f64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,f64,i64,…,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,f64,f64,f64,str,i64,i64,list[f64],list[f64],u32
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",1,19,20,9712,2.721505,3.561307,0.839802,0,0,22.721505,1.327617,2,1.718571,9,1.876731,6,2.88524,10,2.065993,15,1.109209,8,3.464238,12,1.913467,11,1.214639,1,2.352129,4,…,0.72645,0.867909,1.099039,"""inter""",0.216667,0.821464,-0.494095,"""inter""",0.127428,0.167216,-1.310707,"""inter""",0.283032,0.525127,-0.138439,"""inter""",0.711361,0.112186,-0.495821,"""inter""",0.753527,0.607805,1.323672,"""inter""",0.677082,0.632203,0.998222,"""inter""",0.803621,0.732499,0.452885,"""inter""",1,2,"[1.327617, 1.718571, … 2.721505]","[0.839802, 0.742733, … -3.570352]",5
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",2,18,19,5833,1.082311,3.991859,2.909548,0,1,23.803816,1.207088,2,2.114623,9,2.573887,6,2.285075,10,1.24303,15,1.525026,8,1.987652,12,-0.849392,11,2.314265,1,3.315799,4,…,0.72645,0.867909,-0.09522,"""inter""",0.216667,0.821464,-1.032388,"""inter""",0.127428,0.167216,0.999521,"""inter""",0.283032,0.525127,2.658572,"""inter""",0.711361,0.112186,0.230814,"""inter""",0.753527,0.607805,0.142077,"""inter""",0.677082,0.632203,-0.859178,"""inter""",0.803621,0.732499,-0.466765,"""inter""",1,2,"[1.207088, 2.114623, … 1.801855]","[2.909548, 2.233488, … -1.931703]",16
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",3,0,1,2766,0.619955,3.496261,2.876306,0,1,24.423771,0.619955,2,-1.048898,9,1.276719,6,2.189622,10,3.052373,15,1.861023,8,1.53738,12,2.245228,11,1.266903,1,2.433701,4,…,0.72645,0.867909,1.033992,"""inter""",0.216667,0.821464,-1.014614,"""inter""",0.127428,0.167216,-0.023868,"""inter""",0.283032,0.525127,0.898417,"""inter""",0.711361,0.112186,0.085191,"""inter""",0.753527,0.607805,0.779204,"""inter""",0.677082,0.632203,-0.147048,"""inter""",0.803621,0.732499,-0.956391,"""inter""",1,2,"[0.619955, -1.048898, … 1.312229]","[2.876306, 2.432417, … -1.668853]",17
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",4,17,18,1000,3.218271,3.437353,0.219082,0,1,27.642042,1.744009,2,2.22625,9,0.881657,6,1.730071,10,0.50605,15,3.275484,8,0.37383,12,2.974462,11,1.766405,1,1.816545,4,…,0.72645,0.867909,-0.454141,"""inter""",0.216667,0.821464,1.577759,"""inter""",0.127428,0.167216,-1.915397,"""inter""",0.283032,0.525127,0.372876,"""inter""",0.711361,0.112186,1.924093,"""inter""",0.753527,0.607805,1.249134,"""inter""",0.677082,0.632203,-0.119374,"""inter""",0.803621,0.732499,-0.911789,"""inter""",1,2,"[1.744009, 2.22625, … 1.35683]","[0.219082, 0.057213, … -4.671809]",2
"""chrome""","""windows""",2,"""exp1highReward""","""CMAB_Lin_NoIns""",10.816433,"""training""",5,16,17,1058,0.861198,4.974281,4.113083,0,1,28.50324,4.974281,2,2.21165,9,0.388026,6,1.220695,10,1.035577,15,0.634234,8,2.679628,12,1.841328,11,2.083442,1,4.21385,4,…,0.72645,0.867909,-0.11686,"""inter""",0.216667,0.821464,0.947277,"""inter""",0.127428,0.167216,0.185042,"""inter""",0.283032,0.525127,-0.984962,"""inter""",0.711361,0.112186,-0.074534,"""inter""",0.753527,0.607805,0.526791,"""inter""",0.677082,0.632203,-0.961692,"""inter""",0.803621,0.732499,-0.132406,"""inter""",1,2,"[4.974281, 2.21165, … 2.136214]","[4.113083, 3.352652, … -2.744176]",13
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",96,16,17,264,2.934323,3.056532,0.122209,0,1,200.422783,1.992445,4,2.409771,13,1.149279,2,0.934874,6,0.319049,19,-0.370369,12,0.159502,5,0.704094,1,-0.536397,17,0.419553,11,…,0.416415,0.165412,-1.028168,"""inter""",0.11768,0.883581,-0.567994,"""inter""",0.451982,0.624603,-1.290522,"""inter""",0.636182,0.801538,0.012748,"""inter""",0.892148,0.568857,0.581169,"""inter""",0.489443,0.636078,0.428402,"""inter""",0.811643,0.426754,-1.049746,"""inter""",0.749234,0.445612,1.112453,"""inter""",2,1,"[1.992445, 2.409771, … 3.056532]","[0.122209, 0.0, … -3.47072]",1
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",97,16,17,326,1.951474,3.095261,1.143787,0,0,202.374257,0.922237,4,0.611632,13,-1.372232,2,0.871147,6,0.255031,19,-0.003087,12,2.084547,5,2.472242,1,1.92802,17,-0.035103,11,…,0.416415,0.165412,0.164106,"""inter""",0.11768,0.883581,-0.115076,"""inter""",0.451982,0.624603,0.766353,"""inter""",0.636182,0.801538,1.021359,"""inter""",0.892148,0.568857,-0.401679,"""inter""",0.489443,0.636078,0.065846,"""inter""",0.811643,0.426754,-0.518802,"""inter""",0.749234,0.445612,-0.193789,"""inter""",2,1,"[0.922237, 0.611632, … 1.75029]","[1.143787, 0.520768, … -3.323706]",4
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",98,16,17,459,1.939091,4.328614,2.389522,0,0,204.313349,1.062252,4,4.226407,13,-0.095945,2,2.836889,6,0.814876,19,-0.351499,12,1.972048,5,1.017589,1,1.375276,17,0.753573,11,…,0.416415,0.165412,-0.685837,"""inter""",0.11768,0.883581,1.08956,"""inter""",0.451982,0.624603,1.32574,"""inter""",0.636182,0.801538,1.533621,"""inter""",0.892148,0.568857,-0.414062,"""inter""",0.489443,0.636078,-0.658163,"""inter""",0.811643,0.426754,0.66962,"""inter""",0.749234,0.445612,0.190015,"""inter""",2,1,"[1.062252, 4.226407, … 2.134094]","[2.389522, 2.287316, … -2.290591]",9
"""firefox""","""windows""",357,"""exp1highReward""","""MAB_Lin""",5.661583,"""training""",99,18,19,680,1.826546,2.978731,1.152185,0,1,206.139895,1.96475,4,1.037819,13,0.579412,2,2.822571,6,1.758046,19,1.460204,12,1.761822,5,2.552483,1,0.823805,17,0.199735,11,…,0.416415,0.165412,-0.088036,"""inter""",0.11768,0.883581,0.738748,"""inter""",0.451982,0.624603,-1.001483,"""inter""",0.636182,0.801538,-0.396266,"""inter""",0.892148,0.568857,0.030941,"""inter""",0.489443,0.636078,0.302163,"""inter""",0.811643,0.426754,-0.223494,"""inter""",0.749234,0.445612,0.343347,"""inter""",2,1,"[1.96475, 1.037819, … 2.287426]","[1.152185, 0.996025, … -1.626811]",9
