# Imports etc.

In [None]:
import json
import numpy as np
import os
import pandas as pd
import plotnine as gg
import statsmodels.formula.api as smf
from Functions import *
gg.theme_set(gg.theme_classic)

In [None]:
data_dir = 'C:/Users/maria/MEGAsync/Berkeley/CHaRLy/data/RPP34'
plot_dir = data_dir + 'figures'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

# Get data

In [None]:
all_data = pd.read_csv(os.path.join(data_dir, 'all_data.csv'), index_col=0)
all_data['subtrial'] = all_data['subtrial'].astype(int)
all_data

In [None]:
rule_data = pd.read_csv(os.path.join(data_dir, 'rule_data.csv'), index_col=0)
rule_data.loc[(rule_data.phase == 'low') & (rule_data.sid == 13885)]

In [None]:
rule_data.loc[(rule_data.phase == 'low') & (rule_data.sid == 39508)]

In [None]:
# REMEMBER! RULE_DATA REFERS TO ACTION_ID -> MIDDLE_ITEM_NAME; AND MIDDLE_ITEM_NAME -> UNLOCKED_STAR_NAME
all_data.loc[
    (all_data['sid'] == 39508) & (all_data['phase'] == 'low') & (all_data['trial_type'] == 'learning'),
    ['sid', 'block', 'trial', 'subtrial', 'action_id', 'middle_item_name', 'unlocked_star_name']
][228:260]

# Results

## First discovery of new star: based on known or unknown 2-key sequences?

In [None]:
# REMEMBER! RULE_DATA REFERS TO ACTION_ID -> MIDDLE_ITEM_NAME; AND MIDDLE_ITEM_NAME -> UNLOCKED_STAR_NAME
all_data[['sid', 'phase', 'trial_type', 'trial', 'subtrial', 'acc', 'action_id', 'middle_item_name', 'unlocked_star_name']]

In [None]:
for item in range(4):
    all_data['count_item{}'.format(item)] = all_data['middle_item_name'] == item
    
for star in range(4):
    all_data['count_star{}'.format(star)] = all_data['unlocked_star_name'] == star
    
all_data

In [None]:
id_cols = ['sid', 'phase', 'trial_type', 'block']
count_cols = ['count_item{}'.format(i) for i in range(4)] + ['count_star{}'.format(s) for s in range(4)]

new_columns = all_data[id_cols + count_cols].groupby(id_cols).cumsum()
all_data = all_data.drop(columns=count_cols).join(new_columns)
all_data

In [None]:
def get_sid_phase_star_rules(all_data_row, star_id):

    rule = rule_data.loc[
        (rule_data['phase'] == all_data_row['phase']) & (rule_data['sid'] == all_data_row['sid']) & (rule_data['goal_id'] == star_id),
        ['highRules']
    ]
    rule_string = rule.values[0][0]
    rule_int = [int(n[1]) for n in rule_string.split(',')]
    
    return rule_int

# Example use
get_sid_phase_star_rules(all_data.loc[0], star_id=0), get_sid_phase_star_rules(all_data.loc[0], star_id=1)

test_dat = all_data[:30]
test_dat['star_rule'] = test_dat.apply(get_sid_phase_star_rules, star_id=1, axis=1)
test_dat

In [None]:
def get_star_inplace(all_data_row):
    
    has_item0 = (all_data_row['count_item{}'.format(all_data_row['star_rule'][0])] > 0)
    has_item1 = (all_data_row['count_item{}'.format(all_data_row['star_rule'][1])] > 0)
    
    return has_item0 & has_item1

# Example use
test_dat.apply(get_star_inplace, axis=1)

In [None]:
for star_id in range(4):
    
    print("Star {}".format(star_id))
    
    all_data['star_rule'] = all_data.apply(get_sid_phase_star_rules, star_id=star_id, axis=1)
    all_data['star{}_inplace'.format(star_id)] = all_data.apply(get_star_inplace, axis=1)

all_data

In [None]:
id_cols = ['sid', 'phase', 'trial_type', 'block', 'trial', 'subtrial']
inplace_cols = ['star{}_inplace'.format(s) for s in range(4)]
all_data.sort_values(by=id_cols)[id_cols + ['action_id', 'middle_item_name'] + count_cols + inplace_cols][:30]

In [None]:
id_cols = ['sid', 'phase', 'trial_type', 'block']

new_cols = all_data[id_cols + inplace_cols].groupby(id_cols).cumsum()
new_cols

In [None]:
all_data = all_data.drop(columns=inplace_cols).join(new_cols)
all_data

In [None]:
id_cols = ['sid', 'phase', 'trial_type', 'block', 'trial', 'subtrial']
all_data_m = all_data[id_cols + inplace_cols].melt(id_vars=id_cols, var_name='inplace_star', value_name='trials_inplace')
all_data_m['inplace_star'] = all_data_m['inplace_star'].apply(lambda x : x.split('_')[0][-1])
all_data_m

In [None]:
g = (gg.ggplot(all_data_m, gg.aes('trials_inplace', 'star', color='inplace_star'))
     + gg.stat_summary()
     + gg.stat_summary(geom='line')
     + gg.facet_grid('phase ~ trial_type')
    )
g

## Time to discover each star

In [None]:
id_cols = ['sid', 'trial_type', 'phase', 'block', 'goal_star']
all_data['bool_unlocked_star'] = (np.invert(np.isnan(all_data['unlocked_star'])) & (all_data['unlocked_star'] > -1))
first_dat_high = all_data.loc[all_data['bool_unlocked_star']].groupby(id_cols).aggregate('min').reset_index()[id_cols + ['trial']]
first_dat_high['n_unique_items'] = get_n_unique_items(first_dat_high)
first_dat_high

In [None]:
subj_dat = first_dat_high#diff_dat.loc[diff_dat['exists']].groupby(['sid', 'phase', 'trial_type', 'block', 'middle_item_both']).aggregate('mean').reset_index()

gg.options.figure_size = (5, 5)
g = (gg.ggplot(subj_dat, gg.aes('block', 'trial', color='factor(goal_star)', group='factor(goal_star)'))
     + gg.stat_summary(position=gg.position_dodge(width=0.1))
     + gg.stat_summary(position=gg.position_dodge(width=0.1), geom='line')
     + gg.facet_grid('phase ~ trial_type')
    )
print(g)
g.save(os.path.join('105_trialtofirstOverBlockForStars.png'))

In [None]:
g += gg.aes(color='factor(n_unique_items)')
print(g)
g.save(os.path.join('105_trialtofirstOverBlockForNuniqueitems.png'))