# Imports etc.

In [None]:
import json
import numpy as np
import os
import pandas as pd
import plotnine as gg
gg.theme_set(gg.theme_classic)
default_figure_size = (6.4, 4.8)

In [None]:
# data_dir = 'C:/Users/maria/MEGAsync/Berkeley/CHaRLy/data/mTurk1'
data_dir = 'C:/Users/maria/MEGAsync/Berkeley/CHaRLy/data/RPP34'
plot_dir = data_dir + 'figures'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

## Get all_data

In [None]:
rule_data = pd.read_csv(os.path.join(data_dir, 'rule_data.csv'), index_col=0)
rule_data.loc[rule_data.phase == 'high'][:30]

In [None]:
all_data = pd.read_csv(os.path.join(data_dir, 'all_data.csv'), index_col=0)
all_data = all_data.loc[np.invert(all_data['inattentives'])]
all_data

# Results

## Data overview

In [None]:
interesting_cols = [
#     'sid',
    'trial_type', 'phase', 'block', 'trial', 'trial_', 'subtrial', 'points', 'rt', #'key_press',
    'action_id', 'action_name', 'middle_item', 'middle_item_name', 'goal_star', 'goal_star_name', 'correct', 'unlocked_star',
    'unlocked_star_name', 'timeout', 'star_iteration',
#     'middle_item_lowTransferRules', 'middle_item_lowRules', 'bool_middle_item_lowRules',
#     'unlocked_star_highTransferRules', 'unlocked_star_highRules', 'bool_unlocked_star_highRules',
    'chance_performer',
]

In [None]:
id_cols = ['sid', 'trial_type', 'phase']
all_data[interesting_cols + id_cols].groupby(id_cols[0]).mean()

In [None]:
all_data[interesting_cols].describe()

In [None]:
all_data.loc[:30, interesting_cols]

## Task Duration

In [None]:
dur_dat = all_data.groupby(['sid', 'chance_performer']).aggregate('mean').reset_index().reset_index()
print("Number of participants: {}".format(dur_dat.shape[0]))
print("Mean duration: {} minutes (min: {}; max: {}; sd: {})".format(
    np.mean(dur_dat['duration']).round(), np.min(dur_dat['duration']).round(), np.max(dur_dat['duration']).round(), np.std(dur_dat['duration']).round(1))
     )

g = (gg.ggplot(dur_dat, gg.aes('index', 'duration', fill='factor(sid)', color='chance_performer'))
     + gg.geom_hline(yintercept=np.mean(dur_dat['duration']), linetype='dotted')
     + gg.scale_color_manual(values=('white', 'red'))
     + gg.geom_bar(stat='identity')
#      + gg.theme(legend_position='none')
    )
g.save(os.path.join(plot_dir, '0_TaskDuration_all.png'))
print(g)

## Raw button presses

In [None]:
# Get data
sub_dat = all_data.loc[
    (all_data.phase == 'high') &
    (all_data.trial_ < 40) #& np.invert(np.isnan(all_data.key_press))
]
sub_dat['shape'] = sub_dat['acc'].apply(lambda x: 0 if np.isnan(x) or x == 0 else 1)

# Plot
gg.options.figure_size = (20, 10)
g = (gg.ggplot(sub_dat, gg.aes('subtrial', 'trial_', color='factor(key_press)', shape='factor(shape)'))
     + gg.geom_point()
     + gg.facet_grid('trial_type ~ sid', scales='free_x')
    )
g.save(os.path.join(plot_dir, '0_RawKeyPresses.png'))
print(g)
gg.options.figure_size = default_figure_size

## Points won over time

In [None]:
def plot_PointsOverTrials(dat, suf=''):
    
    gg.options.figure_size = (8, 4)
    g = (gg.ggplot(dat, gg.aes('trial_', 'points', color='factor(sid)', linetype='chance_performer'))
         + gg.geom_line()
         + gg.facet_grid('phase ~ phaseNum + trial_type')
        )
    g.save(os.path.join(plot_dir, '0_PointsOverTrials{}.png'.format(suf)))
    gg.options.figure_size = default_figure_size

    return g

# Use
# plot_PointsOverTrials(incl_data)
plot_PointsOverTrials(all_data, '_all')

## Performance for each star

In [None]:
def plot_PerformanceByStar(dat, suf=''):
    
    id_cols = ['sid', 'phase', 'phaseNum', 'trial_type', 'goal_star']
    sum_dat = dat.groupby(id_cols).mean().reset_index()[id_cols + ['acc']]
    
    chance_perf = 1 / (4 * 3 * 2)
    g = (gg.ggplot(sum_dat, gg.aes('goal_star', 'acc'))
         + gg.stat_summary(geom='bar')
         + gg.stat_summary()
         + gg.geom_hline(yintercept=chance_perf, linetype='dotted')
         + gg.geom_point(gg.aes(color='factor(sid)'), position='jitter')
         + gg.facet_grid('phase + phaseNum ~ trial_type')
        )
    g.save(os.path.join(plot_dir, '3_PerformanceByStar{}.png'.format(suf)))
    
    return g

# Use
plot_PerformanceByStar(all_data, '_all')