# Prep

## Imports & paths

In [None]:
import json
import numpy as np
import os
import pandas as pd
import plotnine as gg
gg.theme_set(gg.theme_classic)
default_figure_size = (6.4, 4.8)

In [None]:
# data_dir = 'C:/Users/maria/MEGAsync/Berkeley/CHaRLy/data/mTurk1'
data_dir = 'C:/Users/maria/MEGAsync/Berkeley/CHaRLy/data/RPP34'
plot_dir = data_dir + 'figures'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

## Get all_data

In [None]:
rule_data = pd.read_csv(os.path.join(data_dir, 'rule_data.csv'), index_col=0)
rule_data.loc[rule_data.phase == 'high'][:30]

In [None]:
all_data = pd.read_csv(os.path.join(data_dir, 'all_data.csv'), index_col=0)
all_data = all_data.loc[np.invert(all_data['inattentives'] | all_data['psych_disorder'] | all_data['head_trauma'])]
all_data

# Results

In [None]:
subj_dat = all_data.groupby(['sid', 'phase', 'trial_type', 'subtrial']).aggregate('mean').reset_index()

g = (gg.ggplot(subj_dat, gg.aes('subtrial', 'z_rt', color='factor(sid)'))
     + gg.geom_point(position=gg.position_dodge(width=0.2))
     + gg.geom_line(linetype='dotted', position=gg.position_dodge(width=0.2))
     + gg.stat_summary(gg.aes(group=1), color='black')
     + gg.stat_summary(gg.aes(group=1), color='black', geom='line')
     + gg.facet_grid('phase ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '101_zrtOverSubtrialWithIndividuals.png'))

In [None]:
subj_dat = all_data.groupby(['sid', 'subtrial', 'trial_type', 'phase']).aggregate('mean').reset_index()
g = (gg.ggplot(subj_dat, gg.aes('subtrial', 'rt', color='phase'))
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.facet_grid(' ~ trial_type')
    )
g.save(os.path.join(plot_dir, '101_RtOverSubtrial.png'))
print(g)

In [None]:
g += gg.aes(y='z_rt')
g

In [None]:
subj_dat = all_data.groupby(['sid', 'trial_type', 'phase', 'block']).aggregate('mean').reset_index()
g = (gg.ggplot(subj_dat, gg.aes('block', 'rt_zz_low', color='phase'))
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.facet_grid(' ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '101_RtzzlowOverBlock.png'))

In [None]:
g += gg.aes(y='rt_zz_high')
print(g)
g.save(os.path.join(plot_dir, '101_RtzzhighOverBlock.png'))

In [None]:
learn_dat = all_data.loc[(all_data['block'] > 9) & (all_data['trial_type'] == 'learning')]
trans_dat = all_data.loc[(all_data['block'] < 2) & (all_data['trial_type'] == 'transfer')]

subj_learn_dat = learn_dat.groupby(['sid', 'trial_type', 'phase']).aggregate('mean').reset_index()
subj_trans_dat = trans_dat.groupby(['sid', 'trial_type', 'phase']).aggregate('mean').reset_index()

assert(subj_learn_dat.shape==subj_trans_dat.shape)

id_cols = ['sid', 'phase']
subj_dat = pd.merge(
    subj_learn_dat[id_cols + ['rt_zz_low']],
    subj_trans_dat[id_cols + ['rt_zz_low']],
    on=id_cols, suffixes=['_learn', '_trans']
)
subj_dat['rt_zz_low_trans_minus_learn'] = subj_dat['rt_zz_low_trans'] - subj_dat['rt_zz_low_learn']
subj_dat

In [None]:
gg.options.figure_size = (2, 2)
g = (gg.ggplot(subj_dat, gg.aes('phase', 'rt_zz_low_trans_minus_learn'))
     + gg.stat_summary(geom='bar')
     + gg.stat_summary()
     + gg.labs(x='')
    )
print(g)
g.save(os.path.join(plot_dir, '101_RtzigzagTransMinusLearn.png'))
gg.options.figure_size = (5, 5)

In [None]:
id_cols = ['sid', 'trial_type', 'block', 'phase']
interest_cols = ['rt_zz_low', 'rt_zz_high', 'acc']
dat = all_data.groupby(id_cols).aggregate('mean').reset_index()[id_cols + interest_cols]
dat2 = pd.merge(
    dat.loc[dat['phase'] == 'low'].drop(columns=['phase']),
    dat.loc[dat['phase'] == 'high'].drop(columns=['phase']),
    on=id_cols[:-1],
    suffixes=['_low', '_high']
)
dat2['rt_zzlow_low_minus_high'] = dat2['rt_zz_low_low'] - dat2['rt_zz_low_high']
dat2['rt_zzhigh_low_minus_high'] = dat2['rt_zz_high_low'] - dat2['rt_zz_high_high']
dat2

In [None]:
g = (gg.ggplot(dat2, gg.aes('block', 'rt_zzlow_low_minus_high', color='trial_type'))
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.geom_hline(yintercept=0, linetype='dotted')
    )
g

In [None]:
g += gg.aes(y='rt_zzhigh_low_minus_high')
g

In [None]:
id_cols = ['sid', 'trial_type', 'phase']
sum_dat = all_data.groupby(id_cols).aggregate('mean').reset_index()[id_cols + ['rt_zz_low']]
subj_dat = pd.merge(
    sum_dat.loc[sum_dat['phase'] == 'low'].drop(columns=['phase']),
    sum_dat.loc[sum_dat['phase'] == 'high'].drop(columns=['phase']),
    on=id_cols[:-1],
    suffixes=['_low', '_high']
)
subj_dat['rt_zz_low_low_minus_high'] = dat2['rt_zz_low_low'] - dat2['rt_zz_low_high']
subj_dat

In [None]:
gg.options.figure_size = (2, 2)
g = (gg.ggplot(subj_dat, gg.aes('trial_type', 'rt_zz_low_low_minus_high'))
     + gg.stat_summary(geom='bar')
     + gg.stat_summary()
     + gg.labs(x='')
    )
print(g)
g.save(os.path.join(plot_dir, '101_RtzigzagLowMinusHigh.png'))
gg.options.figure_size = (5, 5)

In [None]:
subj_dat = all_data.groupby(['sid', 'subtrial', 'trial_type', 'phase', 'star_iteration']).aggregate('mean').reset_index()
g = (gg.ggplot(subj_dat, gg.aes('star_iteration', 'rt_zz_low', color='phase'))
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.facet_grid(' ~ trial_type')
    )
g

In [None]:
g += gg.aes(y='rt_zz_high')
g

In [None]:
subj_dat = all_data.groupby(['sid', 'subtrial', 'trial_type', 'phase', 'block', 'goal_star']).aggregate('mean').reset_index()
gg.options.figure_size = (7, 4)
g = (gg.ggplot(subj_dat, gg.aes('acc', 'rt_zz_low', color='factor(sid)', shape='factor(goal_star)'))
     + gg.geom_point()
     + gg.geom_line(alpha=0.5)
#      + gg.geom_smooth(gg.aes(group=1))
     + gg.theme(legend_position='none')
     + gg.facet_grid('phase ~ trial_type + goal_star')
    )
g.save(os.path.join(plot_dir, '101_RtzigzagOverAcc.png'))
print(g)

# When do 2-key sequences get activated?

In [None]:
all_data['middle_item_bool'] = all_data['middle_item'] != -1
all_data['star_bool'] = all_data['unlocked_star'] != -1

id_cols = ['sid', 'phase', 'trial_type', 'block', 'trial']
interest_cols = ['middle_item_bool', 'rt_zz_low', 'rt_zz_high', 'rt', 'log_rt']
sum_dat = all_data.loc[all_data['subtrial'].isin([1, 3])].groupby(id_cols).aggregate('sum').reset_index()[id_cols + interest_cols]
sum_dat

In [None]:
subj_dat = sum_dat.groupby(['sid', 'phase', 'trial_type', 'middle_item_bool']).aggregate('mean').reset_index()

gg.options.figure_size = (4, 3.5)
g = (gg.ggplot(subj_dat, gg.aes('factor(middle_item_bool)', 'rt_zz_low', group=1))
#      + gg.geom_point(position=gg.position_dodge(width=0.3))
#      + gg.geom_line(position=gg.position_dodge(width=0.3), linetype='dotted')
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.labs(x='N middle items obtained')
     + gg.facet_grid('phase ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '101_RtzigzagOverItemSuccess.png'))

In [None]:
g += gg.aes(y='rt_zz_high')
g

In [None]:
g += gg.aes(y='rt')
print(g)
g.save(os.path.join(plot_dir, '101_RtsOverItemSuccess.png'))
gg.options.figure_size = (5, 5)

## Effect of intermediate item on RTs

In [None]:
all_data['bool_middle_item'] =  np.invert(np.isnan(all_data['middle_item'])) * (all_data['middle_item'] != -1)

got_middle_sub_dat = all_data.loc[(all_data['bool_middle_item']) & (all_data['subtrial'] == 1)]
got_middle_idxs = got_middle_sub_dat[['sid', 'phase', 'trial_type', 'block', 'trial']]

no_middle_sub_dat = all_data.loc[np.invert(all_data['bool_middle_item']) & (all_data['subtrial'] == 1)]
no_middle_idxs = no_middle_sub_dat[['sid', 'phase', 'trial_type', 'block', 'trial']]

In [None]:
all_data[['sid', 'trial', 'subtrial', 'middle_item', 'bool_middle_item']]

In [None]:
got_middle = pd.merge(got_middle_idxs, all_data, how='left')
no_middle = pd.merge(no_middle_idxs, all_data, how='left')

assert np.round(got_middle.shape[0] / got_middle_idxs.shape[0]) == 4
assert np.round(no_middle.shape[0] / no_middle_idxs.shape[0]) == 4

got_middle['got_middle'] = True
no_middle['got_middle'] = False

In [None]:
mid_data = pd.concat([got_middle, no_middle])
mid_data

In [None]:
subj_dat = mid_data.loc[mid_data['subtrial'] == 2].groupby(['sid', 'trial_type', 'phase', 'trial', 'got_middle']).aggregate('mean').reset_index()

g = (gg.ggplot(subj_dat, gg.aes('trial', 'rt', color='got_middle'))
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.labs(y='RT on item 2 (after potential item)')
     + gg.facet_grid('phase ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '101_item3RtsOverTrial.png'))

In [None]:
subj_dat = mid_data.loc[mid_data['subtrial'] == 2].groupby(['sid', 'trial_type', 'phase', 'block', 'got_middle']).aggregate('mean').reset_index()

g += gg.aes(x='block')
g.data = subj_dat
print(g)
g.save(os.path.join(plot_dir, '101_item3RtsOverBlock.png'))

## RT zigzag 4-item stars vs 3-item stars

In [None]:
all_data['4_item_star'] = all_data['goal_star'].isin([0, 1])
subj_dat = all_data.groupby(['sid', 'goal_star', 'phase', 'trial_type']).aggregate('mean').reset_index()

g = (gg.ggplot(subj_dat, gg.aes('factor(goal_star)', 'rt_zz_low', color='4_item_star'))
     + gg.stat_summary()
     + gg.facet_grid('phase ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '101_rtzzlowOverStar.png'))

In [None]:
all_data[['sid', 'phase', 'trial_type', 'middle_item_bool']]

In [None]:
(all_data[['middle_item_lowRules']])

In [None]:
all_data.loc[(all_data['subtrial'] == 3) & (all_data['middle_item'] > -1)]

In [None]:
subj_dat[['subtrial', 'middle_item']]

In [None]:
subj_dat = all_data.loc[(all_data['subtrial'].isin([1, 3]))]
subj_dat['bool_middle_item'] = subj_dat['middle_item'] != -1

g = (gg.ggplot(subj_dat, gg.aes('bool_middle_item', 'rt_zz_low', color='phase'))
     + gg.stat_summary(position=gg.position_dodge(width=0.5))
     + gg.facet_grid('subtrial ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '101_RtzigzagOverItemAchieved.png'))

## RT pattern || behavior on a single-block level

In [None]:
gg.options.figure_size = (10, 7)
g = (gg.ggplot(rt_wide_block, gg.aes('block', 'rt_zigzag', color='factor(sid)', shape='factor(goal_star)', linetype='chance_performer'))
     + gg.geom_point(gg.aes(size='acc'), position=gg.position_dodge(width=0.5))
     + gg.geom_line(gg.aes(group='factor(sid)'), position=gg.position_dodge(width=0.5))
     + gg.facet_grid('phase ~ trial_type')
    )
g.save(os.path.join(plot_dir, '101_RtzigzagOverBlocks.png'))
print(g)
gg.options.figure_size = default_figure_size

In [None]:
g = (gg.ggplot(rt_wide_block, gg.aes('acc', 'rt_zigzag', color='factor(sid)', shape='factor(goal_star)'))
     + gg.geom_point()
     + gg.geom_line(gg.aes(group='factor(sid)'), alpha=0.2)
     + gg.geom_smooth(group=1, color='black')
     + gg.facet_grid('phase ~ trial_type')
    )
print(g)
g.save(os.path.join(plot_dir, '01_RtzigzagAcc_0.png'))

g2 = g
g2.data = rt_wide_block.loc[rt_wide_block.acc > 0]
print(g2)
g2.save(os.path.join(plot_dir, '01_RtzigzagAcc.png'))

In [None]:
# RT zigzag supports learning middle-layer items: no zigzag = no middle-layer items
# lots of zigzag = intermediate performance (learning); no zigzag = perfect performance (no distrimination)
gll = g + gg.aes(x='bool_middle_item_lowRulesLearnOnly')
gll.save(os.path.join(plot_dir, '01_RtzigzagLoWRulesLearn.png'))
gll

In [None]:
# ???
glt = g + gg.aes(x='bool_middle_item_lowRulesTransferOnly')
glt.save(os.path.join(plot_dir, '01_RtzigzagLoWRulesTransfer.png'))
glt

In [None]:
# Good star performance comes AFTER learning middle-layer items => when RT zigzag is gone completely.
ghl = g + gg.aes(x='bool_unlocked_star_highRulesLearnOnly')
ghl.save(os.path.join(plot_dir, '01_RtzigzagHighRulesLearn.png'))
ghl

In [None]:
# High-level transfer does NOT require relearning middle-layer sequences
# => RT zigzag occurs in the beginning, when trying out different middle-layer sequences, and then dies out as we learn them
ght = g + gg.aes(x='bool_unlocked_star_highRulesTransferOnly')
ght.save(os.path.join(plot_dir, '01_RtzigzagHighRulesTransfer.png'))
ght