In [1]:
import numpy as np
import pandas as pd
import altair as alt
from bokeh.palettes import Category20c, Category20, Category20b, Pastel2, Purples
from sklearn.feature_extraction.text import CountVectorizer

## helper 

In [2]:
def bootstrap_ci(array, ci=95, n_iterations = 10000):
    np.random.seed(12)
    samples = np.random.choice(array, (n_iterations, len(array)), replace=True)
    stats = np.mean(samples, axis=1)
    lower, upper = np.percentile(stats, [(100-ci)/2, (100+ci)/2])
    return lower, upper

def bootstrap_ci_lower(array):
    return bootstrap_ci(array)[0]

def bootstrap_ci_upper(array):
    return bootstrap_ci(array)[1]

def num2seg(num):
    # from num:953 to seg:9.5
    
    num = int(num)
    if num//10%10 in [0,1,2,3,4]:
        return num//100
    elif num//10%10 in [5,6,7,8,9]:
        return num//100+0.5

def set_color_old(df, step_col):
    df.loc[(df['condition'].isin(['r0','r','r_p','r_recall'])) & (df[step_col]%1==0.5), 'color'] = cmap['r_gap']
    df.loc[(df['condition'].isin(['r0','r','r_p','r_recall'])) & (df[step_col]%1==0), 'color'] = cmap['r_main']
#     df.loc[(df['condition'].isin(['r0','r','r_p'])) & (df[step_col]==-1), 'color'] = Category20c[8][4]
    df.loc[(df['condition'].isin(['r0','r','r_p','r_recall'])) & (df[step_col]==-1), 'color'] = cmap['r_3']

    df.loc[(df['condition'].isin(['p0','p','p_r','p_recall'])) & (df[step_col]%1==0.5), 'color'] = cmap['p_gap']
    df.loc[(df['condition'].isin(['p0','p','p_r','p_recall'])) & (df[step_col]%1==0), 'color'] = cmap['p_main']
#     df.loc[(df['condition'].isin(['p0','p','p_r'])) & (df[step_col]==1), 'color'] = Category20c[8][0]
    df.loc[(df['condition'].isin(['p0','p','p_r','p_recall'])) & (df[step_col]==1), 'color'] = cmap['p_3']

def set_color(df, step_col):
    df.loc[(df[step_col]%1==0.5) & (df[step_col]<0), 'color'] = cmap['r_gap']
    df.loc[(df[step_col]%1==0) & (df[step_col]<0), 'color'] = cmap['r_main']
    df.loc[(df[step_col]==-1), 'color'] = cmap['r_3']

    df.loc[(df[step_col]%1==0.5) & (df[step_col]>0), 'color'] = cmap['p_gap']
    df.loc[(df[step_col]%1==0) & (df[step_col]>0), 'color'] = cmap['p_main']
    df.loc[(df[step_col]==1), 'color'] = cmap['p_3']

    df.loc[(df[step_col]==0), 'color'] = cmap['recall']

    

## constants

In [6]:
cond_pr = ['r0','p0','r','p'] 
cond_3 = ['r0','p0','r','p','r_recall','p_recall'] 
cond_pr_recall = ['r0','p0','r','p','r_p','p_r','r_recall','p_recall'] 
cond_full = ['r0','p0','r','p','r_p','p_r','r_recall','p_recall','r_p_recall','p_r_recall','recall_f','recall_b']
e_cond = ['e_target','e_gap','e_ahead','e_unmatch']
e_cond_bin = ['e_match','e_unmatch']

cmap_old = {'p_1':Category20[8][1], 'r_1':Category20[8][3],
        'p_2':Category20[8][0], 'r_2':Category20[8][2],
        'p_3':Category20b[8][1], 'r_3':Category20[8][6],        # recall
        'p_main':Category20c[8][1], 'r_main':Category20c[8][5],   
        'p_gap':Category20c[8][3], 'r_gap':Category20c[8][7],   
        }

# Tableau Color Blind
cmap = {'p_1':'#a3cce9', 'r_1':'#ffbc79',
        'p_2':'#5fa2ce', 'r_2':'#fc7d0b',
        'p_3':'#1170aa', 'r_3':'#c85200',        # recall
        'p_main':'#5fa2ce', 'r_main':'#fc7d0b',   
        'p_gap':'#c7e0f1', 'r_gap':'#ffd6ae',      # a3cce9 ffbc79  
        'recall': '#59a14f'
        }

e_cond_cmap_old = {
               'e_unmatch': Category20c[20][18],
               'e_ahead': Category20b[20][5],
               'e_gap': Category20b[20][6],
               'e_target': Category20b[20][4],
              }

e_cond_cmap = {
               'e_unmatch': Category20c[20][18],
               'e_ahead': '#C69C6D',
               'e_target': '#8C6238',
               'e_gap': '#C7B39A',
              }

e_cond_long_map = {   
                   'e_unmatch': 'Unmatched events',
                   'e_ahead': 'Events with lag < -1 or lag > 1',
                   'e_target': 'Target events with lag = -1 or 1',
                   'e_gap': 'Events with lag = -0.5 or 0.5',
                  }

cond_short_map = {'r0':'u-R','p0':'u-P','r':'c-R','p':'c-P','r_p':'c-RP','p_r':'c-PR',
                  'r_recall':'re(R)','p_recall':'re(P)'
                 }

cond_long_map = {'r0':'u-R: uncued retrodiction','p0':'u-P: uncued prediction',
                 'r':'c-R: character-cued retrodiction','p':'c-P: character-cued prediction',
                 'r_p':'c-RP: updated retrodiction (after watching one segment earlier)',
                 'p_r':'c-PR: updated prediction (after watching one segment later)',
                 'r_recall':'re(R): retrodiction-matched recall',
                 'p_recall':'re(P): prediction-mached recall'
                 }

cond_long_3_map = {'r0':'u-R: uncued retrodiction','p0':'u-P: uncued prediction',
                 'r':'c-R: character-cued retrodiction','p':'c-P: character-cued prediction',
                 'r_recall':'re(R): retrodiction-matched recall',
                 'p_recall':'re(P): prediction-mached recall'
                 }

cond_color_map = {'r0':cmap['r_1'], 'p0':cmap['p_1'],
                  'r':cmap['r_2'], 'p':cmap['p_2'],
                  'r_p':cmap['r_2'],'p_r':cmap['p_2'],
                  'r_recall':cmap['r_3'],'p_recall':cmap['p_3']}

cond_color_3_map = {'r0':cmap['r_1'], 'p0':cmap['p_1'],
                  'r':cmap['r_2'], 'p':cmap['p_2'],
                  'r_recall':cmap['r_3'],'p_recall':cmap['p_3']}

color_stroke_color_map = {'r0':cond_color_map['r0'],'p0':cond_color_map['p0'],
                          'r':cond_color_map['r'],'p':cond_color_map['p'],
                          'r_p':cond_color_map['p'],'p_r':cond_color_map['r'],
                          'r_recall':cond_color_map['r_recall'],'p_recall':cond_color_map['p_recall']}

color_stroke_color_3_map = {'r0':cond_color_map['r0'],'p0':cond_color_map['p0'],
                          'r':cond_color_map['r'],'p':cond_color_map['p'],
                          'r_recall':cond_color_map['r_recall'],'p_recall':cond_color_map['p_recall']}

cond_pr_short = [cond_short_map[cond] for cond in cond_pr]
cond_3_short = [cond_short_map[cond] for cond in cond_3]
cond_pr_recall_short = [cond_short_map[cond] for cond in cond_pr_recall]

## annotation file

In [36]:
df_annot = pd.read_excel('../data/main/WhyWomenKill.xlsx', sheet_name='annotation')
df_annot['half_point_exclude'] = df_annot['half_point_exclude'].fillna('')
df_annot['half_point_exclude'] = df_annot['half_point_exclude'].astype(str)
df_annot['half_point_add'] = df_annot['half_point_add'].fillna('')
df_annot['half_point_add'] = df_annot['half_point_add'].astype(str)
df_annot.shape

(217, 14)

In [37]:
# df_annot['main_or_gap'] = np.where(df_annot['segment']%1 == 0, 'main', 'gap')
# df_annot['full_or_half'] = np.where((df_annot['num']//10%10).isin([0,5]), 'full', 'half')
df_annot[['partial','on_off']].value_counts()

partial  on_off 
ori      on         117
         off         65
sum      on          16
ori      off_add      9
partial  on           7
ori      on_add       1
partial  off          1
sum      off          1
dtype: int64

## data files

In [5]:
df = pd.read_csv('../data/main/exp_embed_use_large.csv', index_col='no')

## add events columns
'-1' - '1201'

In [15]:
vectorizer = CountVectorizer(binary=False, min_df=1, token_pattern='[^ ]+')
counts = vectorizer.fit_transform(df['scenes'].fillna('-1')) # assign "-1" for recall data
df_features = pd.DataFrame(counts.toarray(), columns=vectorizer.get_feature_names(), index=df.index)

cols_addi = [str(num) for num in df_annot['num'] if str(num) not in df_features.columns]  # for some elems no one hit 
for col in cols_addi:
    df_features[col] = 0

cols_unsorted = df_features.columns.to_list()
cols_int = [int(float(e)) for e in cols_unsorted]
cols_int.sort()
cols_sorted = [str(e) for e in cols_int]
df_features = df_features[cols_sorted]

# correct multiple matches
print(df_features[df_features>1].count().sum())
df_features[df_features.loc[:,'51':]>1] = 1
print(df_features[df_features>1].count().sum())

df = pd.concat([df, df_features], axis=1)



In [16]:
# correct half points (1 to 0.5)

def get_half_elems(story, mode='add'):
    # mode: add or exclude
    
    half_elems_num = [num for i,num in enumerate(df_annot['num']) 
        if num//10%10 in [3,4,8,9] and df_annot.loc[i,'story']==story]
    half_elems_str = [str(num) for num in half_elems_num]

    rule_elems = [l.split(' ') for l in df_annot.loc[
        df_annot['num'].isin(half_elems_num) & (df_annot['story']==story), 
                        f'half_point_{mode}'].values]
    
    return half_elems_str, rule_elems

def correct_half_point_e(row, mode='add'):
    # mode: add or exclude
    '''
    mode:
        add: set addi elements as 0.5 point, the total points are fixed
        exclude: total points are not fixed (deprecated)
    '''
    for half_e, rule_e in zip(*get_half_elems(row['story'], mode)):
        
        if mode == 'exclude':
            if row[half_e] == 1:
                if rule_e == ['']:        
                    row[half_e] = 0.5
                elif row[rule_e].sum() == 0:
                    row[half_e] = 0.5
                else:
                    row[half_e] = 0
                    
        elif mode == 'add':
            if rule_e == ['']:        
                pass
            elif row[rule_e].sum() > 0:
                row[half_e] = 1
            if row[half_e] == 1:
                row[half_e] = 0.5
    return row

df = df.apply(correct_half_point_e, axis=1)


In [17]:
first_event = '51'
first_event_loc = cols_sorted.index(first_event)

# get_elem_list_alt slow
def get_elem_list_alt(seg_num, mode=None, cond_direction=None, full_list=cols_sorted[first_event_loc:]):
    if mode == None:
        return [e for e in full_list if num2seg(e)==seg_num]
    
    else:
        if mode == 'gap':
            if cond_direction == 'f':
                return get_elem_list(seg_num-0.5)
            elif cond_direction == 'b':
                return get_elem_list(seg_num+0.5)
            
        elif mode == 'ahead':
            if cond_direction == 'f':            
                return [e for e in full_list if num2seg(e)>seg_num]
            elif cond_direction == 'b':
                return [e for e in full_list if num2seg(e)<seg_num]

        elif mode == 'ahead_main':
            full_list = [e for e in full_list if 100<int(e)<1150]  # exclude 51.. and 1201
            if cond_direction == 'f':            
                return [e for e in full_list if num2seg(e)>seg_num and num2seg(e)%1==0]
            elif cond_direction == 'b':
                return [e for e in full_list if num2seg(e)<seg_num and num2seg(e)%1==0]
            
        elif mode == 'ahead_gap':
            full_list = [e for e in full_list if 100<int(e)<1150]
            if cond_direction == 'f':            
                return [e for e in full_list if num2seg(e)>seg_num and num2seg(e)%1==0.5]
            elif cond_direction == 'b':
                return [e for e in full_list if num2seg(e)<seg_num and num2seg(e)%1==0.5]

def get_elem_list(seg_num, mode=None, cond_direction=None, full_list=cols_sorted[first_event_loc:]):
    if mode == None:
        min_num = seg_num*100
        max_num = min_num+49
        return [e for e in full_list if min_num<int(e)<=max_num]
    
    else:
        if mode == 'gap':
            if cond_direction == 'f':
                return get_elem_list(seg_num-0.5)
            elif cond_direction == 'b':
                return get_elem_list(seg_num+0.5)
            
        elif mode == 'ahead':
            if cond_direction == 'f':
                min_num = seg_num*100+50
                return [e for e in full_list if min_num<int(e)]
            elif cond_direction == 'b':
                max_num = seg_num*100 - 1
                return [e for e in full_list if int(e)<=max_num]

        elif mode == 'ahead_main':
            full_list = [e for e in full_list if 100<int(e)<1150]  # exclude 51.. and 1201
            if cond_direction == 'f':
                min_num = seg_num*100+50
                return [e for e in full_list if min_num<int(e) and int(e)//10%10 in [0,1,2,3,4]]
            elif cond_direction == 'b':
                max_num = seg_num*100 - 1
                return [e for e in full_list if int(e)<=max_num and int(e)//10%10 in [0,1,2,3,4]]
            
        elif mode == 'ahead_gap':
            full_list = [e for e in full_list if 100<int(e)<1150]
            if cond_direction == 'f':
                min_num = seg_num*100+50
                return [e for e in full_list if min_num<int(e) and int(e)//10%10 in [5,6,7,8,9]]
            elif cond_direction == 'b':
                max_num = seg_num*100 - 1
                return [e for e in full_list if int(e)<=max_num and int(e)//10%10 in [5,6,7,8,9]]
                
def get_seg_elems(story, seg_num, version):
    if version=='full':
        return [num for num in df_annot.query(
            f"story=={story} and segment=={seg_num}")['num']]
    if version=='main':
        return [num for num in df_annot.query(
            f"story=={story} and segment=={seg_num}")['num'] if num//10%10 in [0,5]]
    if version=='add':
        return [num for num in df_annot.query(
            f"story=={story} and segment=={seg_num}")['num'] if num//10%10 in [3,4,8,9]]
    
def get_seg_points(story, seg_num):
    return len(get_seg_elems(story, seg_num, 'main'))+0.5*len(get_seg_elems(story, seg_num, 'add'))

In [None]:
# calculate e_total e_match e_unmatch e_target e_gap e_ahead e_ahead_main e_head_gap
# e_0.5 e_1 e_1.5 ... e_11 e_11.5 e_12

for i, col in enumerate(['e_total','e_match','e_unmatch','e_target','e_gap','e_ahead','e_ahead_main','e_ahead_gap'] 
                         + ['e_'+str(0.5*n) if 0.5*n%1!=0 else 'e_'+str(int(0.5*n)) for n in range(1,25)]):
    df[col] = np.nan

def add_e_count_columns(row):
#     print(row.name)
    seg_num = row['segment_num']
    cond_direction = row['cond_direction']
    
    row['e_total'] = row[ cols_sorted[1:]].sum()
    row['e_match'] = row[ cols_sorted[first_event_loc:]].sum()
    row['e_unmatch'] = row[ cols_sorted[1:first_event_loc]].sum()
    row['e_target'] = row[ get_elem_list(seg_num) ].sum()
    row['e_gap'] = row[ get_elem_list(seg_num, 'gap', cond_direction) ].sum()
    
    for col in ['e_ahead', 'e_ahead_main', 'e_ahead_gap']:
#         print(col)
        match_cols = get_elem_list(seg_num, col[2:], cond_direction)
        if match_cols == []:
            row[col] = 0
        else:   
            row[col] = row[ match_cols ].sum()
    
    for col in ['e_'+str(0.5*n) if 0.5*n%1!=0 else 'e_'+str(int(0.5*n)) for n in range(1,25)]:
        seg_num = float(col.split('_')[1])
        row[col] = row[ get_elem_list(seg_num) ].sum()
        
    return row

df = df.apply(add_e_count_columns, axis=1)

In [None]:
# df_seg_info

df_seg_info = pd.DataFrame(columns=['story','segment_num','target_points'])

for story in [1,2]:
    for seg_num in np.arange(0.5,12.5,0.5):
        target_points =  get_seg_points(story, seg_num)
        df_seg_info = df_seg_info.append({'story':story, 'segment_num':seg_num,
                            'target_points':target_points}, ignore_index=True)

df_seg_info['story'] = df_seg_info['story'].astype('int')
df_seg_info['other_refer_seg_num'] = df_seg_info['segment_num']
df_seg_info['other_refer_seg_points'] = df_seg_info['target_points']

# add total_points
df = pd.merge(df, df_seg_info, on=['story','segment_num'], how='left')
df['target_hit_rate'] = df['e_target'] / df['target_points']

df_seg_info.to_csv("../data/main/df_seg_info.csv", index=False)

In [22]:
df.to_csv("../data/main/exp_with_scenes_use_large.csv", index=False) 

## generating dfs

In [4]:
df = pd.read_csv("../data/main/exp_with_scenes_use_large.csv")
df_seg_info = pd.read_csv("../data/main/df_seg_info.csv")

# now subset
print(df.shape)

# discard truc conditions
df = df.query(f"condition=={cond_full}")
print(df.shape, 'after discard truc conditions')

df['base_segment'] = df['base_segment'].astype(str) # prevent excluded from groupyby
df['segment_pair'] = df['segment_pair'].astype(str)
df['segment_count_demean'] = df['segment_count']-5.5
df['segment_num_demean'] = df['segment_num']-6

# seg_mean
seg_mean = df.groupby(['condition','cond_direction','cond_amount','story','segment_num','target_points',
                       'segment','segment_count','base_segment','base_seg_num','segment_pair'])[['res_1_simi_other_z','res_1_simi_self_z',
                       'res_1_MAD_z','res_1_MD_z','res_1_MAD_oppo_sub_z',
                                         'res_1_MD_oppo_sub_z'] + 
                        df.loc[:,'res_1_simi_otherseg_01':'res_1_simi_otherseg_11'].columns.tolist() +
                                        ['hit'] + ['target_hit_rate'] +
                        df.loc[:,'-1':'e_12'].columns.tolist()].mean().reset_index()

# cond_mean
cond_mean = seg_mean.groupby(['condition']).mean().reset_index()

(2152, 245)
(2084, 245) after discard truc conditions


In [23]:
cond_mean[['condition','target_hit_rate']]

Unnamed: 0,condition,target_hit_rate
0,p,0.15939
1,p0,0.104483
2,p_r,0.264368
3,p_r_recall,0.707585
4,p_recall,0.717539
5,r,0.198212
6,r0,0.117832
7,r_p,0.271185
8,r_p_recall,0.735689
9,r_recall,0.733117


## for R

In [24]:
# df_pr_R and df_pr_recall_R 
cols = ['sub','story','segment','segment_count','segment_num','base_segment',
        'segment_pair','condition','cond_direction','cond_amount','target_points',
        'e_total','e_target','res_1_simi_self_z','res_1_simi_other_z',
        'res_1_MAD_z','res_1_MAD_sub_z','res_1_MD_z','res_1_MD_sub_z','segment_count_demean','segment_num_demean']



df_pr_R = df.query("condition==['p0','r0','p','r','p_r','r_p']")[cols]
df_pr_recall_R = df.query("condition==['p0','r0','p','r','p_r','r_p','r_recall','p_recall']")[cols]

df_pr_R.to_csv("../data/main/R/df_pr_for_R.csv", index=False)
print(df_pr_R.shape)

df_pr_recall_R.to_csv("../data/main/R/df_pr_recall_for_R.csv", index=False)
df_pr_recall_R.shape

(981, 21)


(1360, 21)

In [7]:
# df_pr_econd_bin_R (for figure 2 stats)
cols = ['sub','story','segment','segment_count','segment_num','base_segment',
        'segment_pair','condition','cond_direction','cond_amount']

df_pr_econd_R_bin = pd.melt(df, id_vars=cols, value_vars=e_cond_bin, 
        var_name='e_cond_bin', value_name='n_hits')

df_pr_econd_R_bin = df_pr_econd_R_bin.query("condition==['p0','r0','p','r','p_r','r_p']")

df_pr_econd_R_bin['segment_count_demean'] = df_pr_econd_R_bin['segment_count']-5.5
df_pr_econd_R_bin['segment_num_demean'] = df_pr_econd_R_bin['segment_num']-6

df_pr_econd_R_bin.to_csv("../data/main/R/df_pr_econd_bin_for_R.csv", index=False)
df_pr_econd_R_bin.shape

(1962, 14)

In [25]:
# df_pr_econd_R
cols = ['sub','story','segment','segment_count','segment_num','base_segment',
        'segment_pair','condition','cond_direction','cond_amount']

df_pr_econd_R = pd.melt(df, id_vars=cols, value_vars=e_cond, 
        var_name='e_cond', value_name='n_hits')

df_pr_econd_R = df_pr_econd_R.query("condition==['p0','r0','p','r','p_r','r_p']")

df_pr_econd_R['segment_count_demean'] = df_pr_econd_R['segment_count']-5.5
df_pr_econd_R['segment_num_demean'] = df_pr_econd_R['segment_num']-6

df_pr_econd_R.to_csv("../data/main/R/df_pr_econd_for_R.csv", index=False)
df_pr_econd_R.shape

(3924, 14)

# segment count plot: figure S2

In [27]:
# map cond names
seg_mean['cond_short'] = seg_mean['condition'].map(cond_short_map)
seg_mean['cond_long'] = seg_mean['condition'].map(cond_long_map)

In [28]:
def reg_plot(index, scale_max, title, label):
    scatter = alt.Chart().mark_circle(opacity=1).encode(
        alt.X('segment_count:O', axis=alt.Axis(title='Segment count', labelAngle=0)),
        alt.Y(index, axis=alt.Axis(title=label), scale=alt.Scale(domain=(0, scale_max))
        # scale=alt.Scale(domain=(0, scale_max))
        ),
    #     alt.Column('condition', sort=cond_pr),
        color=alt.Color(field='story', type='nominal', scale=alt.Scale(range=['grey','#ff7f0e','#2ca02c']),
        legend=alt.Legend(title='Storyline', values=[1,2])),
        tooltip='segment'
        ).properties(width=200, height=200)

    reg_line = scatter.transform_regression('segment_count', index).mark_line()

    faceted = alt.layer(scatter, reg_line, data=seg_mean.query(f"condition=={cond_pr}")).facet(title=title, 
            column=alt.Column('cond_short:N', sort=cond_pr_short, title='Condition', 
            header=alt.Header(labelOrient='top', titleFontSize=14, labelFontSize=14, titleFontWeight='normal', titlePadding=0)))
    return faceted

reg_hit_rate = reg_plot('target_hit_rate',0.6,'A','Target events hit rate')
reg_pre = reg_plot('res_1_simi_other_z',1.4,'B','Precision')
reg_conv = reg_plot('res_1_MD_z',1.4,'C','Convergence')

(reg_hit_rate & reg_pre & reg_conv).configure_axis(
    titleFontWeight='normal',
    titleFontSize=14,
    labelFontSize=14,
).configure_concat(
    spacing=40
).configure_legend( 
    symbolSize=100,  
    labelFontSize=14, 
    titleFontWeight='normal',
    titleFontSize=16,
).configure_title(
    fontSize=20,
)

# figure 4B

In [29]:
# cond_mean_e_long <- cond_mean

cond_mean_e_long = pd.melt(cond_mean, id_vars=['condition'], value_vars=e_cond, 
        var_name='e_cond', value_name='scenes')

cond_mean_e_long = cond_mean_e_long.query(f"condition=={cond_pr}")
cond_mean_e_long['sort'] = cond_mean_e_long['e_cond'].\
    map({'e_target':2, 'e_gap':1, 'e_ahead':3, 'e_unmatch':4})

cond_mean_e_long['cond_2'] = cond_mean_e_long['condition'].\
    map({'r0':1, 'p0':1, 'p':2, 'r':2, 'r_p':3, 'p_r':3})

cond_mean_e_long['cond_short'] = cond_mean_e_long['condition'].map(cond_short_map)
cond_mean_e_long['e_cond_long'] = cond_mean_e_long['e_cond'].map(e_cond_long_map)
cond_mean_e_long.head(1)

Unnamed: 0,condition,e_cond,scenes,sort,cond_2,cond_short,e_cond_long
0,p,e_target,0.908869,2,2,c-P,Target events with lag = -1 or 1


In [30]:
def make_stack_bars(n_conds, legendX, width, title):

    stack_bars = alt.Chart(cond_mean_e_long).mark_bar(size=42, strokeWidth=2, stroke='white').encode(
        y=alt.Y('scenes:Q', axis=alt.Axis(title='Number of events hit', titleFontWeight='normal')),
        x=alt.X('cond_short:N', sort=cond_pr_short, axis=alt.Axis(title='Condition', 
                                                                titleFontWeight='normal',
                            labelFontWeight='normal', 
                            labelAngle=0)),
        color=alt.Color('e_cond_long', 
                        scale=alt.Scale(domain=list(e_cond_long_map.values()), 
                        range=list(e_cond_cmap.values())),
                        legend=alt.Legend(title='Event type', orient='none', 
                                        legendY=200, legendX=legendX)   #  legendX=360 for 6 conds
                    ),
        order=alt.Order("sort")
    ).properties(width=width, height=350, title=title)  # width=350 for 6 conds
    text = alt.Chart(cond_mean_e_long).mark_text(fontSize=12, dx=0, dy=8, color='white').encode(
        y=alt.Y('scenes:Q', stack='zero'),
        x=alt.X('cond_short:N', sort=cond_pr_short),
        detail='e_cond_long:N',
        order=alt.Order("sort"),
        text=alt.Text('scenes:Q', format='.2f')
    )

    if n_conds==2:
        stack_bars = stack_bars.transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=['r0','p0']))
        text = text.transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=['r0','p0']))

    return stack_bars+text

stack_bars2 = make_stack_bars(n_conds=2, legendX=140, width=120, title='B')
stack_bars4 = make_stack_bars(n_conds=6, legendX=250, width=240, title='')

p = (stack_bars2).configure_legend(
    titleFontWeight='normal',
    symbolSize=400,  
    titleFontSize=20,
    labelFontSize=16, 
    labelLimit=1000,
    titleLimit=1000,
    titlePadding=10,
    rowPadding=5
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    grid=False
).configure_view(
    strokeWidth=0
).configure_title(
    fontSize=28,
    anchor='start',
    offset=10
)
# p.save('../figs/core/events_stacked_bar.svg')
p

# figure 4C

In [31]:
# df_refer_long <- df (constrained to seg 1-11)

df_refer_long = pd.melt(df, id_vars=df.loc[:,:'scenes'], 
        value_vars=df.loc[:,'e_1':'e_11'], var_name='refer_seg_num', value_name='refer_seg_e')

df_refer_long['refer_seg_num'] = df_refer_long['refer_seg_num'].str[2:].astype(float)
# df_refer_long['refer_step'] = df_refer_long['refer_seg_num'] - df_refer_long['segment_num']
df_refer_long['refer_step_base'] = df_refer_long['refer_seg_num'] - df_refer_long['base_seg_num']

In [32]:
# add fields to df_refer_long

df_seg_info['refer_seg_num'] = df_seg_info['other_refer_seg_num']
df_seg_info['refer_seg_points'] = df_seg_info['other_refer_seg_points']

# add total_points
df_refer_long = pd.merge(df_refer_long, df_seg_info[['story','refer_seg_num','refer_seg_points']], on=['story','refer_seg_num'], how='left')

# add main or gap
df_refer_long['mod1'] = df_refer_long['refer_seg_num'] % 1
df_refer_long['main_or_gap'] = df_refer_long['mod1'].map({0:'main', 0.5:'gap'})

# add refer_seg
df_refer_long['refer_seg'] = df_refer_long['story'].astype(str) + '_' + df_refer_long['refer_seg_num'].astype(str)

In [33]:
# df_refer_long_p0r0 <- df_refer_long

# select segments in right direction
df_refer_long_p0r0 = df_refer_long.query(
    "condition=='p0' & refer_seg_num>=base_seg_num | condition=='r0' & refer_seg_num<=base_seg_num").copy()

df_refer_long_p0r0_for_R = df_refer_long.query(
    "condition=='p0' & refer_seg_num>base_seg_num | condition=='r0' & refer_seg_num<base_seg_num").copy()

df_refer_long_p0r0_for_R['segment_count_demean'] = df_refer_long_p0r0_for_R['segment_count']-5.5
df_refer_long_p0r0_for_R['segment_num_demean'] = df_refer_long_p0r0_for_R['segment_num']-6
df_refer_long_p0r0_for_R['step_abs'] = abs(df_refer_long_p0r0_for_R['refer_step_base'])
df_refer_long_p0r0_for_R['step_abs_demean'] = df_refer_long_p0r0_for_R['step_abs']-3.25
df_refer_long_p0r0_for_R['base_refer_seg_pair'] = df_refer_long_p0r0_for_R['base_segment'] + '_' + df_refer_long_p0r0_for_R['refer_seg']

refer_step_bound = abs(df_refer_long_p0r0.query("refer_seg_e>=1")['refer_step_base'].min()) #8

In [34]:
cols = ['condition','story','sub','segment_count','segment_count_demean','main_or_gap','base_segment','step_abs','step_abs_demean',
        'refer_seg','base_refer_seg_pair','refer_seg_e','refer_seg_points']

df_refer_long_p0r0_for_R = df_refer_long_p0r0_for_R[cols]

df_refer_long_p0r0_for_R.to_csv('../data/main/R/df_refer_long_p0r0_for_R.csv', index=False)

In [35]:
# df_refer <- seg_mean (equivelant to seg_seg_mean)

df_refer = pd.melt(seg_mean, id_vars=seg_mean.loc[:,:'base_seg_num'], 
        value_vars=seg_mean.loc[:,'e_1':'e_11'], 
                   var_name='other_refer_seg_num', value_name='refer_seg_num_e')

df_refer['other_refer_seg_num'] = df_refer['other_refer_seg_num'].str[2:].astype(float)
df_refer['other_refer_step'] = df_refer['other_refer_seg_num'] - df_refer['segment_num']
df_refer['other_refer_step_base'] = df_refer['other_refer_seg_num'] - df_refer['base_seg_num']

# select only res in correct direction (to avoid altair plotting problem)
# df_refer = df_refer.query(
#     "cond_direction=='f' & other_refer_seg_num>=base_seg_num | cond_direction=='b' & other_refer_seg_num<=base_seg_num")

df_refer = df_refer.drop(columns=['target_points'])
df_refer = pd.merge(df_refer, df_seg_info[['story','other_refer_seg_num','other_refer_seg_points']],
                                           on=['story','other_refer_seg_num'], how='left')
df_refer['other_refer_seg_hit_rate'] = df_refer['refer_seg_num_e'] / df_refer['other_refer_seg_points']

# add main or gap
df_refer['mod1'] = df_refer['other_refer_seg_num'] % 1
df_refer['main_or_gap'] = df_refer['mod1'].map({0:'main', 0.5:'gap'})

set_color(df_refer, 'other_refer_step_base')

In [37]:
# fig p0r0_step_hit_rate

base = alt.Chart(df_refer.query(
    f"condition==['p0','r0'] & (cond_direction=='f' & other_refer_seg_num>=base_seg_num | cond_direction=='b' & other_refer_seg_num<=base_seg_num)")).encode(
    alt.X('other_refer_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11))))
).properties(width=550, height=350, title={
      "text": ["C"], 
    #   "subtitle": ['Uncued retrodiction (u-R) and uncued prediction (u-P)']
    })

bar = base.mark_bar(size=10).encode(
    alt.Y('mean(refer_seg_num_e):Q', axis=alt.Axis(title=['Number of events hit']), 
        #   scale=alt.Scale(domain=(0, 0.7))
         ), 
    alt.Color('color', scale=None),
)

# error_bar = base.mark_errorbar(extent='ci', color='black', opacity=0.4).encode(
#     alt.Y('refer_seg_num_hit:Q', axis=alt.Axis(title='Hit rate'))
# )

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(refer_seg_num_e)',
    y2='ci1(refer_seg_num_e)',
)

# y_pos = 0.75
# text = alt.Chart(pd.DataFrame({'x': [-4, 4], 'y':[y_pos, y_pos], 'text':['R0', 'P0']})).mark_text(
#     fontSize=20
# ).encode(
#     x='x:O', y='y:Q', text='text'
# )

step_scenes_hits = (bar + error_bar
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    titleFontWeight='normal'
).configure_title(
    fontSize=28,
    subtitleFontSize=18,
    anchor='start',
    offset=10
)

step_scenes_hits

In [38]:
((stack_bars2) | (bar + error_bar)
).configure_legend(
    titleFontWeight='normal',
    symbolSize=300,  
    titleFontSize=20,
    labelFontSize=16, 
    labelLimit=1000,
    titleLimit=1000,
    titlePadding=10,
    rowPadding=2
).configure_axis(
    titleFontWeight='normal',
    titleFontSize=16,
    labelFontSize=16,
).configure_view(
    strokeWidth=0
).configure_concat(
    spacing=350
).configure_title(
    fontSize=28,
    subtitleFontSize=20,
    anchor='start',
    offset=10
)

# figure 3B, C, D

In [39]:
# cond_mean_ci (for plotting)
cond_mean_ci = seg_mean.groupby(['condition']).agg(
    hit=('hit', 'mean'),
    hit_ci_lower=('hit', bootstrap_ci_lower),
    hit_ci_upper=('hit', bootstrap_ci_upper), 
    e_target=('e_target', 'mean'),
    e_target_ci_lower=('e_target', bootstrap_ci_lower),
    e_target_ci_upper=('e_target', bootstrap_ci_upper), 
    target_hit_rate=('target_hit_rate', 'mean'),
    target_hit_rate_ci_lower=('target_hit_rate', bootstrap_ci_lower),
    target_hit_rate_ci_upper=('target_hit_rate', bootstrap_ci_upper),
    
    res_1_simi_other_z=('res_1_simi_other_z', 'mean'),
    res_1_simi_other_z_ci_lower=('res_1_simi_other_z', bootstrap_ci_lower),
    res_1_simi_other_z_ci_upper=('res_1_simi_other_z', bootstrap_ci_upper),
    res_1_simi_self_z=('res_1_simi_self_z', 'mean'),
    res_1_simi_self_z_ci_lower=('res_1_simi_self_z', bootstrap_ci_lower),
    res_1_simi_self_z_ci_upper=('res_1_simi_self_z', bootstrap_ci_upper),
    res_1_MD_z=('res_1_MD_z', 'mean'),
    res_1_MD_z_ci_lower=('res_1_MD_z', bootstrap_ci_lower),
    res_1_MD_z_ci_upper=('res_1_MD_z', bootstrap_ci_upper)
).reset_index()

# replace e_target in recall to np.nan for now
cond_mean_ci['e_target'] = cond_mean_ci['e_target'].replace(0, np.nan)

# transform back to r
for col in cond_mean_ci.columns[7:]:
    cond_mean_ci[col+'_r'] = np.tanh(cond_mean_ci[col])

# add recall as upper limit
cond_mean_ci['r_recall_res_1_simi_other_z_r'] = cond_mean_ci.query("condition=='r_recall'")['res_1_simi_other_z_r'].values[0]
cond_mean_ci['p_recall_res_1_simi_other_z_r'] = cond_mean_ci.query("condition=='p_recall'")['res_1_simi_other_z_r'].values[0]

# map cond names
cond_mean_ci['cond_short'] = cond_mean_ci['condition'].map(cond_short_map)
cond_mean_ci['cond_long'] = cond_mean_ci['condition'].map(cond_long_map)
cond_mean_ci['res_1_simi_self_z_r_range'] = cond_mean_ci['res_1_simi_self_z_ci_upper_r'] - cond_mean_ci['res_1_simi_self_z_ci_lower_r'] 

In [40]:
data = cond_mean_ci.query(f"condition=={cond_3}")

base = alt.Chart(data).encode(
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short, 
          axis=alt.Axis(title='Condition', labelAngle=0)),
)

width_pre = 200
width_tar = width_pre*3/2

# # hit rate (binary)
# hit_point = prec_point.encode(
#     alt.Y('hit:Q', axis=alt.Axis(title='Hit rate'), scale=alt.Scale(domain=(0, 1))), 
#     tooltip='hit',
# )

# hit_error = prec_error_bar.encode(
#     y='ci0(hit_ci_lower)',
#     y2='ci1(hit_ci_upper)',  
# )

# hit_plot = (hit_error + hit_point).properties(width=250, height=350, title='A')

# # scenes hit
# scenes_point = prec_point.encode(
#     alt.Y('e_target:Q', axis=alt.Axis(title='Number of target events hit'), 
# #           scale=alt.Scale(domain=(0, 1))
#          ), 
#     tooltip='e_target',
# )

# scenes_error = prec_error_bar.encode(
#     y='ci0(e_target_ci_lower)',
#     y2='ci1(e_target_ci_upper)',  
# )
# scenes_plot = (scenes_error + scenes_point).properties(width=250, height=350, title='A')



# precision
prec_point = base.mark_circle(opacity=1, size=160, filled=True, strokeWidth=3.5).encode(
    alt.Y('res_1_simi_other_z_r:Q', axis=alt.Axis(title='Precision'), scale=alt.Scale(domain=(0, 1))), 
#     strokeWidth = alt.StrokeWidth('strokeWidth', legend=None),
    stroke = alt.Color('cond_long', scale=alt.Scale(domain=list(cond_long_3_map.values()), 
                                                    range=list(color_stroke_color_3_map.values()))),
    tooltip='res_1_simi_other_z_r',
    fill=alt.Color('cond_long', scale=alt.Scale(domain=list(cond_long_3_map.values()), 
                                                range=list(cond_color_3_map.values())),
                   legend=alt.Legend(direction='vertical', title="Condition"))
)

prec_error_bar = base.mark_rule(size=2.6, color='black', opacity=0.4).encode(
    y='ci0(res_1_simi_other_z_ci_lower_r)',
    y2='ci1(res_1_simi_other_z_ci_upper_r)',  
)

rule_r = alt.Chart(data).mark_rule(color=cond_color_map['r_recall'], size=4, opacity=0.3).encode(
    y='r_recall_res_1_simi_other_z_r'
)

rule_p = alt.Chart(data).mark_rule(color=cond_color_map['p_recall'], size=4, opacity=0.3).encode(
    y='p_recall_res_1_simi_other_z_r'
)

prec_plot = (prec_error_bar + prec_point + rule_r + rule_p).properties(width=width_pre, height=350, title='C').\
    transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=cond_pr)
)

# target hit rate
target_hit_rate = prec_point.encode(
    alt.Y('target_hit_rate:Q', axis=alt.Axis(title='Target events hit rate'), 
          scale=alt.Scale(domain=(0, 1))
         ), 
    tooltip='target_hit_rate',
)

target_hit_rate_error = prec_error_bar.encode(
    y='ci0(target_hit_rate_ci_lower)',
    y2='ci1(target_hit_rate_ci_upper)',  
)
target_plot = (target_hit_rate_error + target_hit_rate).properties(width=width_tar, height=350, title='B') # .\
    # transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=cond_pr)
#)

# convergence
conv_point = prec_point.encode(
    alt.Y('res_1_MD_z_r:Q', axis=alt.Axis(title='Convergence'), 
          scale=alt.Scale(domain=(0, 1))), 
    tooltip='res_1_MD_z_r',
)

conv_error = prec_error_bar.encode(
    y='ci0(res_1_MD_z_ci_lower_r)',
    y2='ci1(res_1_MD_z_ci_upper_r)',  
)

conv_plot = (conv_error + conv_point).properties(width=width_tar, height=350, title='D')

# sig
h1 = 0.35  # 0.28
h2 = 0.85  # 0.76
target_sig = pd.DataFrame({'cond_short': ['c-R'], 'x2': ['c-P'], 'y':[h1], 'text':['*']})
target_line = pd.DataFrame({'cond_short': ['c-R','c-P'], 'y':[h1]*2})
prec_sig = pd.DataFrame({'cond_short': ['c-R'], 'x2': ['c-P'], 'y':[h2], 'text':['**']})
prec_line = pd.DataFrame({'cond_short': ['c-R','c-P'], 'y':[h2]*2})

target_sig_line = alt.Chart(target_line).mark_line(strokeWidth=2, color='black').encode(
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short),
    y='y')

target_sig_text = alt.Chart(target_sig).mark_text(dx=24, fontSize=20).encode(   # dx=21
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short),
    y='y', text='text')

prec_sig_line = alt.Chart(prec_line).mark_line(strokeWidth=2, color='black').encode(
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short),
    y='y')

prec_sig_text = alt.Chart(prec_sig).mark_text(dx=25, fontSize=20).encode(       #dx=21
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short),
    y='y', text='text')



# # for denoting sig ***
# overlay = pd.DataFrame({'cond_short': ['R','P'], 'sig':[0.2,0.4], 'text':['    p','dfsdfsdfsdfsdfr']})
# line_test = alt.Chart(overlay).mark_line(color='red', strokeWidth=3).encode(x='cond_short:O', y='sig')
# text_test = alt.Chart(overlay).mark_text(color='blue', dx=4).encode(x='cond_short:O', y='sig', text='text')

p = (target_plot+target_sig_line+target_sig_text | prec_plot+prec_sig_line+prec_sig_text | conv_plot

).configure_legend(
    columns=3,
    columnPadding=40,
    symbolSize=100, 
    orient='none',
    legendX=0,
    legendY=450,
    symbolStrokeWidth=3,
    rowPadding=8,
    labelFontSize=20, 
    titleFontWeight='normal',
    titleFontSize=20,
    titlePadding=12,
    labelLimit=1000
).configure_axis(
    titleFontSize=20,
    labelFontSize=20,
    titleFontWeight='normal',
    labelFontWeight='normal',
).configure_concat(
    spacing=80
).configure_title(
    fontSize=28,
    anchor='start',
#     offset=20
)

p

# references

## set refer type using seg_sc_pr

In [39]:
# seg_mean_pr, seg_sc_pr <- seg_mean
seg_mean_pr = seg_mean.query(f"condition in {cond_pr}")

seg_sc_pr = pd.melt(seg_mean_pr, 
        id_vars=['story','segment_num','segment','cond_direction','condition','base_seg_num'], 
        value_vars=df.loc[:,'51':'1201'], var_name='ref_sc_num', 
                        value_name='ref_sc_hit')

# correct mismatched story-sc pairs
seg_sc_pr['ref_sc_num'] = seg_sc_pr['ref_sc_num'].astype('int')
story1_sc_nums = df_annot.query("story==1")['num'].to_list()
story2_sc_nums = df_annot.query("story==2")['num'].to_list()
seg_sc_pr = seg_sc_pr.query(f"story==1 and ref_sc_num=={story1_sc_nums} or story==2 and ref_sc_num=={story2_sc_nums}")

seg_sc_pr['ref_seg_num'] = seg_sc_pr['ref_sc_num'].apply(num2seg)
seg_sc_pr['ref_step_base'] = seg_sc_pr['ref_seg_num'] - seg_sc_pr['base_seg_num']

seg_sc_pr = seg_sc_pr.query(
    "cond_direction=='f' & ref_seg_num>=base_seg_num | cond_direction=='b' & ref_seg_num<=base_seg_num").\
    reset_index(drop=True)

In [40]:
# add ‘refer_seg_type’ in seg_sc_pr

def add_ref_seg_type(row):
    step = row['ref_step_base']
    if abs(step) == 0.5:
        row['ref_seg_type'] = 'gap'
    elif abs(step) == 1:
        row['ref_seg_type'] = 'target'
    elif abs(step) > 1 and step%1 == 0.5:
        row['ref_seg_type'] = 'ahead_gap'
    elif abs(step) > 1 and step%1 == 0:
        row['ref_seg_type'] = 'ahead_main'
    elif step == 0:
        row['ref_seg_type'] = 'self'
    return row

seg_sc_pr['ref_seg_type'] = np.nan
seg_sc_pr = seg_sc_pr.apply(add_ref_seg_type, axis=1)

In [8]:
# df_refer_raw_bounded <- df_refer_raw

df_refer_raw = pd.read_excel('../data/main/WhyWomenKill.xlsx', sheet_name='references')
# df_refer_raw_bounded = df_refer_raw.query("seg_2 != 0.5")
df_refer_raw_bounded = df_refer_raw.query("is_included==1")
df_refer_raw_bounded.query("amended != 1").groupby(['direction'])['seg_1'].count()

direction
b    52
f    30
Name: seg_1, dtype: int64

## figure S6B

In [43]:
df_refer_dist = pd.DataFrame(df_refer_raw_bounded['step'].value_counts().rename_axis('step').reset_index(name='counts'))
df_refer_dist['condition'] = np.where(df_refer_dist['step']>0, 'p','r')
set_color(df_refer_dist, 'step')
w=635; h=180

p = alt.Chart(df_refer_dist).mark_bar(size=12).encode(
    alt.X('step:Q', axis=alt.Axis(title='Lag', labelAngle=0, grid=False, tickMinStep=1), scale=alt.Scale(domain=(-9, 9))),
    alt.Y('counts:Q', axis=alt.Axis(title='Count'), 
          scale=alt.Scale(domain=(0, 30))
        ), 
    alt.Color('color', scale=None),
).properties(width=w, height=h, 
# title={
#       "text": [""], 
      # "subtitle": ["Reference rate"]}
).configure_axis(
    titleFontSize=14,
    labelFontSize=14,
    titleFontWeight='normal'
)
p

In [44]:
# add 'refer_type', 'refer_boost_type', 'refer_boost_type_tri' to seg_sc_pr

def add_refer_type(row, df=df_refer_raw_bounded):
    
    base_seg_num = row['base_seg_num']
    segment_num = row['segment_num']
    ref_seg_num = row['ref_seg_num']
    ref_sc_num = row['ref_sc_num']
    
#     print(row.name, row['condition'], base_seg_num, ref_seg_num)
    
    # main or gap
    if ref_seg_num % 1 == 0:
        row['main_or_gap'] = 'main'
    elif ref_seg_num % 1 == 0.5:
        row['main_or_gap'] = 'gap'
    
    # get refer segments
    if row['condition'] in ['r0','r']:
        refer_segments = df.\
            query(f"story=={row['story']} & seg_1>={base_seg_num} & seg_2<{base_seg_num}")
        refer_segments_base_only = df.\
            query(f"story=={row['story']} & seg_1=={base_seg_num} & seg_2<{base_seg_num}")  # only from base segment
        referring_segments = df.\
            query(f"story=={row['story']} & main_or_gap=='main' & seg_2>={base_seg_num} & seg_1<{base_seg_num}")  # swap seg_1 and seg_2, referenced main events only

    elif row['condition'] in ['r_recall', 'p_recall']:
        refer_segments = df.\
            query(f"story=={row['story']} & seg_1=={segment_num}")
        refer_segments_base_only = df.\
            query(f"story=={row['story']} & seg_1=={segment_num}")  # same as refer_segments  
        referring_segments = df.\
            query(f"story=={row['story']} & seg_2=={segment_num}")  

    elif row['condition'] == 'r_p': # only for target segments
        refer_segments = df.\
            query(f"story=={row['story']} & (seg_1>={base_seg_num} & seg_2<{base_seg_num} | seg_1=={base_seg_num-2} & seg_2=={base_seg_num-1})")
        referring_segments = df.\
            query(f"story=={row['story']} & main_or_gap=='main' & (seg_2>={base_seg_num} & seg_1<{base_seg_num} | seg_2=={base_seg_num-2} & seg_1=={base_seg_num-1})")
        
    elif row['condition'] in ['p0','p']:
        refer_segments = df.\
            query(f"story=={row['story']} & seg_1<={base_seg_num} & seg_2>{base_seg_num}")
        refer_segments_base_only = df.\
            query(f"story=={row['story']} & seg_1=={base_seg_num} & seg_2>{base_seg_num}")
        referring_segments = df.\
            query(f"story=={row['story']} & main_or_gap=='main' & seg_2<={base_seg_num} & seg_1>{base_seg_num}")

    elif row['condition'] == 'p_r':
        refer_segments = df.\
            query(f"story=={row['story']} & (seg_1<={base_seg_num} & seg_2>{base_seg_num} | seg_1=={base_seg_num+2} & seg_2=={base_seg_num+1})")
        referring_segments = df.\
            query(f"story=={row['story']} & main_or_gap=='main' & (seg_2<={base_seg_num} & seg_1>{base_seg_num} | seg_2=={base_seg_num+2} & seg_1=={base_seg_num+1})")

    refer_list = refer_segments['refer_num'].to_numpy()
    refer_list_base_only = refer_segments_base_only['refer_num'].to_numpy()
    refer_list_d = refer_segments.query("type=='d'")['refer_num'].to_numpy()
    refer_list_i = refer_segments.query("type=='i'")['refer_num'].to_numpy()      

    referring_list = referring_segments['source_num'].to_numpy()
    
    # refer_type
    if ref_sc_num in refer_list_d:
        row['refer_type'] = 'direct'
    elif ref_sc_num in refer_list_i:
        row['refer_type'] = 'indirect'
    else:
        row['refer_type'] = 'none'
    
    # refer_type_base_only
    if ref_sc_num in refer_list_base_only:
        row['refer_type_base_only'] = 'referred'
    else:
        row['refer_type_base_only'] = 'none'

    # referring_type
    if ref_sc_num in referring_list:
        row['referring_type'] = 'referring' 
    else:
        row['referring_type'] = 'none'

    # refer_referring_type (referred has priority)
    if ref_sc_num in refer_list:
        row['rr_type'] = 'referred' 
    elif ref_sc_num in referring_list:
        row['rr_type'] = 'referring'
    else:
        row['rr_type'] = 'none'
 
    row['refer_list'] = refer_list
    row['referring_list'] = referring_list
    
    refer_seg_list = refer_segments['seg_2'].to_numpy() # both d and i
    
    refer_seg_list_left = refer_seg_list[refer_seg_list<=ref_seg_num]
    refer_seg_list_right = refer_seg_list[refer_seg_list>=ref_seg_num]
    
    # refer_boost_step
    if len(refer_seg_list_left) == 0:
        row['refer_boost_step_left'] = 99
    else:
        row['refer_boost_step_left'] = ref_seg_num - refer_seg_list_left.max()
    
    if len(refer_seg_list_right) == 0:
        row['refer_boost_step_right'] = 99
    else:
        row['refer_boost_step_right'] = refer_seg_list_right.min() - ref_seg_num 
        
    row['refer_boost_step'] = np.nanmin([row['refer_boost_step_left'], row['refer_boost_step_right']])
    
    
    # refer_boost_type 
    if row['main_or_gap'] == 'main':
        if ref_sc_num in refer_list:
            row['refer_boost_type'] = 'referred'
            row['refer_boost_type_tri'] = 'referred'
        elif row['refer_boost_step'] == 0:
            row['refer_boost_type'] = 'same_seg'
            row['refer_boost_type_tri'] = 'neighbor'
        elif row['refer_boost_step'] <= 1:
            row['refer_boost_type'] = 'neighbor'
            row['refer_boost_type_tri'] = 'neighbor'
        else:
            row['refer_boost_type'] = 'none'      
            row['refer_boost_type_tri'] = 'none'
    
    if row['main_or_gap'] == 'gap':
        if ref_sc_num in refer_list:
            row['refer_boost_type'] = 'referred'
            row['refer_boost_type_tri'] = 'referred'
        elif row['refer_boost_step'] == 0:
            row['refer_boost_type'] = 'same_seg'
            row['refer_boost_type_tri'] = 'neighbor'
        elif row['refer_boost_step'] <= 0.5:
            row['refer_boost_type'] = 'neighbor'
            row['refer_boost_type_tri'] = 'neighbor'
        else:
            row['refer_boost_type'] = 'none' 
            row['refer_boost_type_tri'] = 'none'
    
    # correct NA refer type
    if (row['condition'] in ['p0','p','p_r'] and ref_seg_num <= base_seg_num) or \
       (row['condition'] in ['r0','r','r_p'] and ref_seg_num >= base_seg_num) or \
       (row['condition'] in ['p_recall','r_recall'] and ref_seg_num == segment_num):

        row['refer_type'] = 'NA'
        row['refer_type_base_only'] = 'NA'
        row['referring_type'] = 'NA'
        row['rr_type'] = 'NA'
        row['refer_boost_type'] = 'NA'
        row['refer_boost_type_tri'] = 'NA'

    return row

In [45]:
seg_sc_pr = seg_sc_pr.apply(add_refer_type, axis=1)

In [47]:
# add more fields to seg_sc_pr
seg_sc_pr['refer_type_merge'] = seg_sc_pr['refer_type'].\
    map({'direct':'referred','indirect':'referred','none':'none'})

seg_sc_pr['refer_type_d'] = seg_sc_pr['refer_type'].\
    map({'direct':'referred','indirect':'none','none':'none'})

seg_sc_pr['refer_type_point'] = seg_sc_pr['refer_type'].\
    map({'direct':1, 'indirect':0.5, 'none':0})  # deprecated

seg_sc_pr['refer_type_point_merge'] = seg_sc_pr['refer_type'].\
    map({'direct':1, 'indirect':1, 'none':0})

seg_sc_pr['referring_type_point'] = seg_sc_pr['referring_type'].\
    map({'referring':1, 'none':0})

# full point or half point scene
seg_sc_pr['sc_point'] = np.where(seg_sc_pr['ref_sc_num']//10%10%5 == 0, 1, 0.5)

# refer_point_c: 0 0.25 0.5 1 (deprecated)
seg_sc_pr['refer_point_c'] = seg_sc_pr['refer_type_point']*seg_sc_pr['sc_point']  

# refer_point_c_merge: 0 0.5 1
seg_sc_pr['refer_point_c_merge'] = seg_sc_pr['refer_type_point_merge']*seg_sc_pr['sc_point']

# referring_point_c
seg_sc_pr['referring_point_c'] = seg_sc_pr['referring_type_point']*seg_sc_pr['sc_point']


seg_sc_pr['ref_step_base_abs'] = abs(seg_sc_pr['ref_step_base'])

# set color
set_color(seg_sc_pr, 'ref_step_base')

# seg_sc_pr_bounded
seg_sc_pr_bounded = seg_sc_pr.query("100<ref_sc_num<1199")
seg_sc_pr_bounded.to_csv("../data/main/seg_sc_pr_bounded.csv", index=False)
seg_sc_pr_bounded.shape

(4782, 34)

## melt df and merge cols from seg_sc_pr for single trial data

In [48]:
# df_pr, df_seg_sc_pr <- df, seg_sc_pr (single sub data)
df_pr = df.query(f"condition in {cond_pr}")

df_seg_sc_pr = pd.melt(df_pr, id_vars=['sub','story','segment_num','segment','segment_count',
                        'cond_direction','cond_amount','condition','base_seg_num','base_segment'], 
        value_vars=df.loc[:,'51':'1201'], var_name='ref_sc_num', 
                        value_name='ref_sc_hit')

# correct mismatched story-sc pairs
df_seg_sc_pr['ref_sc_num'] = df_seg_sc_pr['ref_sc_num'].astype('int')
story1_sc_nums = df_annot.query("story==1")['num'].to_list()
story2_sc_nums = df_annot.query("story==2")['num'].to_list()
df_seg_sc_pr = df_seg_sc_pr.query(f"story==1 and ref_sc_num=={story1_sc_nums} or story==2 and ref_sc_num=={story2_sc_nums}")

df_seg_sc_pr['ref_seg_num'] = df_seg_sc_pr['ref_sc_num'].apply(num2seg)
df_seg_sc_pr['ref_step'] = df_seg_sc_pr['ref_seg_num'] - df_seg_sc_pr['segment_num']
df_seg_sc_pr['ref_step_base'] = df_seg_sc_pr['ref_seg_num'] - df_seg_sc_pr['base_seg_num']
df_seg_sc_pr['ref_sc_num_full'] = df_seg_sc_pr['story'].astype(str)+'_'+df_seg_sc_pr['ref_sc_num'].astype(str)

df_seg_sc_pr = df_seg_sc_pr.query(
    "cond_direction=='f' & ref_seg_num>base_seg_num | cond_direction=='b' & ref_seg_num<base_seg_num").\
    reset_index(drop=True)

# merge cols from seg_sc_pr
cols_for_merge = ['condition','segment','ref_sc_num','ref_seg_type', 'main_or_gap', 
       'refer_type','referring_type','rr_type','refer_list','referring_list','refer_boost_step_left', 'refer_boost_step_right',
       'refer_boost_step', 'refer_boost_type','refer_boost_type_tri', 'refer_type_merge',
       'refer_type_d', 'refer_type_point', 'refer_type_point_merge',
       'sc_point', 'refer_point_c', 'refer_point_c_merge', 'ref_step_base_abs',
       'color']

df_seg_sc_pr = pd.merge(df_seg_sc_pr, seg_sc_pr[cols_for_merge],
                on=['condition','segment','ref_sc_num'], how='left')

df_seg_sc_pr_target = df_seg_sc_pr.query("ref_seg_type=='target'")
# df_seg_sc_pr_target.to_csv("../data/main/R/df_seg_sc_pr_target.csv", index=False)
df_seg_sc_pr_target.shape

(4289, 37)

In [49]:
df_seg_sc_p0r0 = df_seg_sc_pr.loc[df_seg_sc_pr['condition'].isin(['p0','r0'])]
df_seg_sc_p0r0_bounded = df_seg_sc_p0r0.query("100<ref_sc_num<1199")
df_seg_sc_p0r0_bounded.shape

(18432, 37)

In [50]:
# df_seg_seg_type_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_type_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_seg_num','base_segment','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','refer_type_merge']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_type_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_type_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_type_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_type_p0r0_for_R['segment_count_demean'] = df_seg_seg_type_p0r0_for_R['segment_count']-5.5
df_seg_seg_type_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_type_p0r0_for_R['ref_step_base_abs']-3.25

df_seg_seg_type_p0r0_for_R.to_csv('../data/main/R/df_seg_seg_type_p0r0_for_R.csv', index=False)

In [52]:
# df_seg_seg_rr_type_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_rr_type_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_seg_num','base_segment','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','rr_type']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_rr_type_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_rr_type_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_rr_type_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_rr_type_p0r0_for_R['segment_count_demean'] = df_seg_seg_rr_type_p0r0_for_R['segment_count']-5.5
df_seg_seg_rr_type_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_rr_type_p0r0_for_R['ref_step_base_abs']-3.25

df_seg_seg_rr_type_p0r0_main_for_R = df_seg_seg_rr_type_p0r0_for_R.query("main_or_gap=='main'")

df_seg_seg_rr_type_p0r0_main_for_R.to_csv('../data/main/R/df_seg_seg_rr_type_p0r0_main_for_R.csv', index=False)
df_seg_seg_rr_type_p0r0_main_for_R.shape

(2496, 19)

In [53]:
# df_seg_seg_boost_type_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_boost_type_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_segment','base_seg_num','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type','refer_boost_type_tri']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_boost_type_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_boost_type_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_boost_type_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_boost_type_p0r0_for_R['segment_count_demean'] = df_seg_seg_boost_type_p0r0_for_R['segment_count']-5.5
df_seg_seg_boost_type_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_boost_type_p0r0_for_R['ref_step_base_abs']-3.25

df_seg_seg_boost_type_p0r0_for_R.to_csv('../data/main/R/df_seg_seg_boost_type_p0r0_for_R.csv', index=False)

## more dfs

In [54]:
# seg_seg_pr (average scenes in same segments in same trials) <- seg_sc_pr_bounded
seg_seg_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
                            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    refer_points=('refer_point_c', 'sum'),
    refer_points_merge=('refer_point_c_merge', 'sum'),
    referring_points=('referring_point_c', 'sum'),
    ).reset_index()

# seg_seg_type_pr
seg_seg_type_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','refer_type_merge','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
seg_seg_type_pr['sc_hit_rate_c'] = seg_seg_type_pr['total_hits'] / seg_seg_type_pr['total_points']

# seg_seg_boost_type_pr
seg_seg_boost_type_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_tri','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
seg_seg_boost_type_pr['sc_hit_rate_c'] = seg_seg_boost_type_pr['total_hits'] / seg_seg_boost_type_pr['total_points']


def add_addi_fields(row, df1=seg_seg_type_pr, df2=seg_seg_boost_type_pr):
    none_sc_hits = df1.loc[(df1['condition']==row['condition']) & 
                           (df1['segment']==row['segment']) &
                           (df1['ref_step_base']==row['ref_step_base']) &
                           (df1['refer_type_merge']=='none'), 'total_hits'].values
    
    if len(none_sc_hits)>0:
        row['none_hits'] = none_sc_hits[0]
    else:
        row['none_hits'] = np.nan
        
#     row['refer_bonus_hit_rate'] = row['sc_hit_rate_c'] - row['none_hit_rate_c']
#     row['refer_bonus_hit_rate_c'] = row['refer_bonus_hit_rate'] / (1-row['none_hit_rate'])
#     row['direct_refer_pure_hit_rate'] = (row['direct_hit_rate'] - row['none_hit_rate']) / (1-row['none_hit_rate'])

    remain_sc_hit_rate = df2.loc[(df2['condition']==row['condition']) & 
                                 (df2['segment']==row['segment']) &
                                 (df2['ref_step_base']==row['ref_step_base']) &
                                 (df2['refer_boost_type_tri']=='none'), 'sc_hit_rate_c'].values
    
    if len(remain_sc_hit_rate)>0:
        row['remain_hit_rate'] = remain_sc_hit_rate[0]
    else:
        row['remain_hit_rate'] = np.nan
        
    # for neighboring rate    
    neighbor_points = df2.loc[(df2['condition']==row['condition']) & 
                                 (df2['segment']==row['segment']) &
                                 (df2['ref_step_base']==row['ref_step_base']) &
                                 (df2['refer_boost_type_tri']=='neighbor'), 'total_points'].values
    if len(neighbor_points)>0:
        row['neighbor_points'] = neighbor_points[0]
    else:
        row['neighbor_points'] = 0
        
    return row

seg_seg_pr = seg_seg_pr.apply(add_addi_fields, axis=1)

seg_seg_pr['refer_rate'] = seg_seg_pr['refer_points_merge'] / seg_seg_pr['total_points']
seg_seg_pr['referring_rate'] = seg_seg_pr['referring_points'] / seg_seg_pr['total_points']

seg_seg_pr['none_points'] = seg_seg_pr['total_points'] - seg_seg_pr['refer_points_merge']
seg_seg_pr['neighbor_rate'] = seg_seg_pr['neighbor_points'] / seg_seg_pr['none_points']
seg_seg_pr.loc[seg_seg_pr['ref_step_base']==0, 'neighbor_rate'] = 0
seg_seg_pr['hit_rate_c'] = seg_seg_pr['total_hits'] / seg_seg_pr['total_points']
seg_seg_pr['none_hit_rate'] = seg_seg_pr['none_hits'] / (seg_seg_pr['total_points'] - seg_seg_pr['refer_points_merge'])  # same from seg_seg_type_pr 
# seg_seg_pr['refer_bonus_hits'] = seg_seg_pr['total_hits'] - seg_seg_pr['none_hit_rate']*seg_seg_pr['total_points']
# seg_seg_pr['refer_bonus_hit_rate'] = seg_seg_pr['hit_rate_c'] - seg_seg_pr['none_hit_rate']
# seg_seg_pr['neighbor_bonus_hit_rate'] = seg_seg_pr['none_hit_rate'] - seg_seg_pr['remain_hit_rate']
# seg_seg_pr['refer_combo_bonus_hit_rate'] = seg_seg_pr['hit_rate_c'] - seg_seg_pr['remain_hit_rate']


# seg_seg_step_pr (average over step)
seg_seg_step_pr = seg_seg_pr.groupby(['cond_direction','condition','ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','color']).agg(
    remain_hit_rate=('remain_hit_rate', 'mean'),
    none_hit_rate=('none_hit_rate', 'mean'),
    hit_rate_c=('hit_rate_c', 'mean'),
    none_points=('none_points', 'mean'),
    neighbor_points=('neighbor_points', 'mean'),
    ).reset_index()

seg_seg_step_pr['neighbor_rate'] = seg_seg_step_pr['neighbor_points'] / seg_seg_step_pr['none_points']
seg_seg_step_pr['refer_bonus_hit_rate'] = seg_seg_step_pr['hit_rate_c'] - seg_seg_step_pr['none_hit_rate']
seg_seg_step_pr['neighbor_bonus_hit_rate'] = seg_seg_step_pr['none_hit_rate'] - seg_seg_step_pr['remain_hit_rate']
seg_seg_step_pr['refer_combo_bonus_hit_rate'] = seg_seg_step_pr['hit_rate_c'] - seg_seg_step_pr['remain_hit_rate']
seg_seg_step_pr.loc[seg_seg_step_pr['ref_step_base']==0, 'neighbor_rate'] = 0

seg_seg_pr['points'] = seg_seg_pr['refer_points_merge']  # for plotting points in step 0
seg_seg_pr.loc[seg_seg_pr['ref_step_base']==0, 'points'] = seg_seg_pr.loc[seg_seg_pr['ref_step_base']==0, 'total_points']

## figure 7C

In [55]:
# referring rate for main segments (seg_seg_pr: mean(refer_rate))

w=400 ; h=180

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0'] and main_or_gap=='main'")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0))
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Referring rate"]
    })  # cannnot be interative

bar = base.mark_bar(size=12).encode(
    alt.Y('mean(referring_rate):Q', axis=alt.Axis(title='Proportion'), 
#           scale=alt.Scale(domain=(0, 0.7))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(refer_rate)',
    y2='ci1(refer_rate)',
)

ref_rate_plot = (bar
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
)
ref_rate_plot

## figure 5B

In [56]:
# 1-1 refer rate (seg_seg_pr: mean(refer_rate))

w=400 ; h=180
# w=600 ; h=180

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11))))
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Reference rate"]
    })  # cannnot be interative

bar = base.mark_bar(size=8).encode(
    alt.Y('mean(refer_rate):Q', axis=alt.Axis(title='Proportion'), 
#           scale=alt.Scale(domain=(0, 0.7))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(refer_rate)',
    y2='ci1(refer_rate)',
)

ref_rate_plot = (bar)
ref_rate_plot

## figure 5C

In [60]:
# 1-2 refer bonus hit rate (seg_seg_step_pr: refer_bonus_hit_rate)

seg_seg_step_pr = seg_seg_step_pr.fillna(0)
refer_bonus_hit_rate_plot = alt.Chart(seg_seg_step_pr.query("condition==['r0','p0']")).mark_bar(size=8).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11)))),
    alt.Y('refer_bonus_hit_rate:Q', axis=alt.Axis(title=['Hit rate']),
          scale=alt.Scale(domain=(-0.01, 0.06))
         ),  
    alt.Color('color', scale=None),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Hit rate difference:", "all events - unreferenced events"]
    }).interactive()

refer_bonus_hit_rate_plot

## figure 5D

In [62]:
# 1-3 none hit rate (seg_seg_pr: mean(none_hit_rate))

seg_seg_pr = seg_seg_pr.fillna(0)
base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Unreferenced events hit rate"]
    })

bar = base.mark_bar(size=8).encode(
    alt.Y('mean(none_hit_rate):Q', axis=alt.Axis(title='Hit rate'), 
        #   scale=alt.Scale(domain=(0, 0.2))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(none_hit_rate)',
    y2='ci1(none_hit_rate)',
)

none_hit_rate_plot = (bar + error_bar)
none_hit_rate_plot

## figure 6B

In [63]:
# 2-1 neighbor rate (seg_seg_pr: mean(neighbor_rate))

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Adjacent-reference rate"]
    })

bar = base.mark_bar(size=8).encode(
    alt.Y('mean(neighbor_rate):Q', axis=alt.Axis(title='Proportion'), 
          scale=alt.Scale(domain=(0, 1))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
#     y='ci0(neighbor_rate)',
    y2='ci1(neighbor_rate)',
)

neighbor_rate_plot = (bar)
neighbor_rate_plot

## figure 6C

In [64]:
# 2-2 neighbor bonus hit rate (seg_seg_step_pr: neighbor_bonus_hit_rate)

neighbor_bonus_hit_rate_plot = alt.Chart(seg_seg_step_pr.query("condition==['r0','p0']")).mark_bar(size=8).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11)))),
    alt.Y('neighbor_bonus_hit_rate:Q', axis=alt.Axis(title='Hit rate'),
          scale=alt.Scale(domain=(-0.10, 0.10))
         ),  
    alt.Color('color', scale=None),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Hit rate difference:", "Unreferenced events - remaining events"]
    })

neighbor_bonus_hit_rate_plot

## figure 6D

In [66]:
# 1-3 none hit rate (seg_seg_pr: mean(remain_hit_rate))
 
base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-10, 11)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Remaining events hit rate"]
    })

bar = base.mark_bar(size=8).encode(
    alt.Y('mean(remain_hit_rate):Q', axis=alt.Axis(title='Hit rate'), 
          scale=alt.Scale(domain=(0, 0.1))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(remain_hit_rate)',
    y2='ci1(remain_hit_rate)',
)

remain_hit_rate_plot = (bar + error_bar)
remain_hit_rate_plot

In [67]:
hit_rate_concat_plot = (ref_rate_plot & refer_bonus_hit_rate_plot & none_hit_rate_plot
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
).configure_concat(
    spacing=15
)

# hit_rate_concat_plot.save('../figs/core/hit_rate_concat_plot.svg')
hit_rate_concat_plot

In [68]:
hit_rate_concat_plot = (neighbor_rate_plot & neighbor_bonus_hit_rate_plot & remain_hit_rate_plot
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
).configure_concat(
    spacing=15
)

# hit_rate_concat_plot.save('../figs/core/hit_rate_concat_plot.svg')
hit_rate_concat_plot

## more dfs

In [69]:
# step_pr <- seg_sc_pr_bounded
step_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','color']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    total_refer_points=('refer_point_c', 'sum'),
    total_refer_points_merge=('refer_point_c_merge', 'sum'),
    total_referring_points=('referring_point_c', 'sum'),
    ).reset_index()

step_pr['refer_rate_c'] = step_pr['total_refer_points'] / step_pr['points_count']
step_pr['refer_rate_c_merge'] = step_pr['total_refer_points_merge'] / step_pr['points_count']
step_pr['sc_hit_rate_c'] = step_pr['total_points'] / step_pr['points_count']

step_pr.head()

Unnamed: 0,cond_direction,condition,ref_step_base,ref_step_base_abs,ref_seg_type,main_or_gap,color,total_count,points_count,total_points,total_refer_points,total_refer_points_merge,total_referring_points,refer_rate_c,refer_rate_c_merge,sc_hit_rate_c
0,b,r,-10.0,10.0,ahead_main,main,#fc7d0b,8,7.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,b,r,-9.5,9.5,ahead_gap,gap,#ffd6ae,13,13.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,b,r,-9.0,9.0,ahead_main,main,#fc7d0b,21,19.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,b,r,-8.5,8.5,ahead_gap,gap,#ffd6ae,18,18.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,b,r,-8.0,8.0,ahead_main,main,#fc7d0b,31,29.5,0.125,0.5,1.0,0.0,0.016949,0.033898,0.004237


In [71]:
# step_type_pr step_type3_pr step_boost_type_pr

# step_type_pr
step_type_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_type_merge']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
step_type_pr['sc_hit_rate_c'] = step_type_pr['total_points']/step_type_pr['points_count']

# step_rr_type_pr
step_rr_type_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','rr_type']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
step_rr_type_pr['sc_hit_rate_c'] = step_rr_type_pr['total_points']/step_rr_type_pr['points_count']

step_rr_type_pr['rr_type'] = step_rr_type_pr['rr_type'].astype("category")
step_rr_type_pr['rr_type'] = step_rr_type_pr['rr_type'].cat.set_categories(['referred','referring','none'])

# step_type3_pr (direct, indirect, none)
step_type3_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_type']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
step_type3_pr['sc_hit_rate_c'] = step_type3_pr['total_points']/step_type3_pr['points_count']

# step_boost_type_pr
step_boost_type_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()

step_boost_type_pr['sc_hit_rate_c'] = step_boost_type_pr['total_points']/step_boost_type_pr['points_count']
step_boost_type_pr['refer_boost_type'] = step_boost_type_pr['refer_boost_type'].astype('category')
step_boost_type_pr['refer_boost_type'].cat.reorder_categories(['referred','same_seg','neighbor','none'], inplace=True)
step_boost_type_pr = step_boost_type_pr.sort_values(by=['condition', 'ref_step_base_abs', 'refer_boost_type'])

# step_boost_type_tri_pr
step_boost_type_tri_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_tri']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()

step_boost_type_tri_pr['sc_hit_rate_c'] = step_boost_type_tri_pr['total_points']/step_boost_type_tri_pr['points_count']
step_boost_type_tri_pr['refer_boost_type_tri'] = step_boost_type_tri_pr['refer_boost_type_tri'].astype('category')
step_boost_type_tri_pr['refer_boost_type_tri'].cat.reorder_categories(['referred','neighbor','none'], inplace=True)
step_boost_type_tri_pr = step_boost_type_tri_pr.sort_values(by=['condition', 'ref_step_base_abs', 'refer_boost_type_tri'])

  step_boost_type_pr['refer_boost_type'].cat.reorder_categories(['referred','same_seg','neighbor','none'], inplace=True)
  step_boost_type_tri_pr['refer_boost_type_tri'].cat.reorder_categories(['referred','neighbor','none'], inplace=True)


In [73]:
# add more fields for plotting

def add_cols_for_plotting(df, type_col, value_col, count_col, doubled=True):
    # discard step=0
    df = df.query("ref_step_base != 0")

    df['x0'] = np.nan
    df['x1'] = np.nan
    
    if type_col == 'refer_type_merge':  
        ascending=[True, True, False]  # referred and none
    elif type_col == 'refer_type':  
        ascending=[True, True, True]  # direct, indirect and none
    else:
        ascending=[True, True, True]

    df = df.sort_values(by=['condition', 'ref_step_base_abs', type_col], ascending=ascending)
   
    df['x1'] = df[count_col].cumsum()
    df['x0'] = df['x1'] - df[count_col]

    # aligh x cood, start from 0 for each step
    df['x0_c'] = df['x0']
    df['x1_c'] = df['x1']
    
    def aligh_x_cood(row):        
        step_base = df.loc[(df['ref_step_base']==row['ref_step_base']) 
                           & (df['condition']==row['condition']), 'x0'].values[0]
        row['x0_c'] = row['x0'] - step_base
        row['x1_c'] = row['x1'] - step_base

        return row

    df = df.apply(aligh_x_cood, axis=1)
    df['y0'] = 0
    df['y1'] = df[value_col]
    
    if doubled:
        df_copy = df.copy()
        df[f'{type_col}_plus'] = df[type_col] + '_hit'
        df_copy[f'{type_col}_plus'] = df_copy[type_col] + '_miss'

        df_copy['y0'] = df_copy[value_col]
        df_copy['y1'] = 1

        df = pd.concat([df, df_copy], ignore_index=True)
    return df   

In [74]:
step_type_pr_doubled = add_cols_for_plotting(step_type_pr, 'refer_type_merge',
                                             'sc_hit_rate_c', 'points_count')

step_type_pr = add_cols_for_plotting(step_type_pr, 'refer_type_merge',
                                             'sc_hit_rate_c', 'points_count', False)

step_rr_type_pr_doubled = add_cols_for_plotting(step_rr_type_pr, 'rr_type',
                                             'sc_hit_rate_c', 'points_count')

step_rr_type_pr = add_cols_for_plotting(step_rr_type_pr, 'rr_type',
                                             'sc_hit_rate_c', 'points_count', False)

step_boost_type_pr_doubled = add_cols_for_plotting(step_boost_type_pr, 'refer_boost_type',
                                             'sc_hit_rate_c', 'points_count')

step_boost_type_pr = add_cols_for_plotting(step_boost_type_pr, 'refer_boost_type',
                                             'sc_hit_rate_c', 'points_count', False)

step_boost_type_tri_pr_doubled = add_cols_for_plotting(step_boost_type_tri_pr, 'refer_boost_type_tri',
                                             'sc_hit_rate_c', 'points_count')

In [75]:
# refer type
type_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                 'none_miss':'Unreferenced: Miss', 'none_hit':'Unreferenced: Hit'}

type_color_map_double = { 'referred_miss':Category20b[20][15], 'referred_hit':'#CC0033', 
#                'referred_miss':Category20b[20][15], 'referred_hit':Category20b[20][12],
#                'referred_miss':Category20c[20][7], 'referred_hit':Category20c[20][4],
                  'none_miss':Category20c[20][18], 'none_hit':Category20c[20][16]}

step_type_pr_doubled['refer_type_merge_plus_long'] = step_type_pr_doubled['refer_type_merge_plus'].map(type_long_map)

# rr type
rr_type_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                 'referring_miss':'Referring: Miss', 'referring_hit':'Referring: Hit',
                 'none_miss':'Other: Miss', 'none_hit':'Other: Hit'}

rr_type_color_map_double = { 'referred_miss':Category20b[20][15], 'referred_hit':'#CC0033', 
                #   'referring_miss':Category20c[20][3], 'referring_hit':Category20c[20][0],
                  # 'referring_miss':Category20[20][19], 'referring_hit':Category20[20][18],
                  'referring_miss':Category20c[20][11], 'referring_hit':Category20c[20][8],
                  'none_miss':Category20c[20][18], 'none_hit':Category20c[20][16]}

step_rr_type_pr_doubled['rr_type_plus_long'] = step_rr_type_pr_doubled['rr_type_plus'].map(rr_type_long_map)

# refer boost type
boost_type_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                       'same_seg_miss':'Same-segment event referenced: Miss', 'same_seg_hit':'Same-segment event referenced: Hit',
                       'neighbor_miss':'Neighbor-segment event referenced: Miss', 'neighbor_hit':'Neighbor-segment event referenced: Hit',
                       'none_miss':'Other: Miss', 'none_hit':'Other: Hit'}

boost_type_tri_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                       'neighbor_miss':'Reference-adjacent: Miss', 'neighbor_hit':'Reference-adjacent: Hit',
                       'none_miss':'Remaining: Miss', 'none_hit':'Remaining: Hit'}

boost_type_color_map_double = {'referred_miss':Category20b[20][15], 'referred_hit':Category20b[20][12],
                        'same_seg_miss':Category20c[20][7], 'same_seg_hit':Category20c[20][4],
                        'neighbor_miss':Category20c[20][3], 'neighbor_hit':Category20c[20][0],
                        'none_miss':Category20c[20][19], 'none_hit':Category20c[20][16]}

boost_type_tri_color_map_double = {'referred_miss':Category20b[20][15], 'referred_hit':'#CC0033', # 'referred_hit':Category20b[20][12],
#                             'referred_miss':Pastel2[8][1], 'referred_hit':Category20c[20][4],
#                             'neighbor_miss':Pastel2[8][6], 'neighbor_hit':Category20[20][2], 
                            'neighbor_miss':Purples[9][5], 'neighbor_hit':Category20[20][8],  # Category20[20][9] Category20[20][8]
                            'none_miss':Category20c[20][19], 'none_hit':Category20c[20][17]}

boost_type_tri_color_map = {'referred':Category20c[20][4],
                            'neighbor':Category20c[20][0], 
                            'none':Category20c[20][16]}

boost_type_tri_color_map_stroke = {'referred':Pastel2[8][1],
                                   'neighbor':Category20c[20][3], 
                                   'none':Category20c[20][19]}

step_boost_type_pr_doubled['refer_boost_type_plus_long'] = step_boost_type_pr_doubled['refer_boost_type_plus'].map(boost_type_long_map)
step_boost_type_tri_pr_doubled['refer_boost_type_tri_plus_long'] = step_boost_type_tri_pr_doubled['refer_boost_type_tri_plus'].map(boost_type_tri_long_map)

## figure 5E

In [76]:
# function
def area_plot(df, color, scale, value, width=40.1):
    legend_cols = len(scale)/2
    width = width
    height = 230
    
    p0 = alt.Chart(df.query("condition=='p0' and ref_step_base <= 8")).mark_rect().encode(
        x=alt.X('x0_c:Q', 
    #             scale=alt.Scale(domain=(0, 140)), 
                axis=alt.Axis(title='', titleFontSize=7)),
        x2='x1_c',
        y=alt.Y('y0:Q', axis=alt.Axis(title='Hit rate')),
        y2='y1',
        color=alt.Color(color, legend=alt.Legend(direction='vertical', orient='none', title="Event type: Response type"), scale=alt.Scale(domain=list(scale.values()), 
                                                           range=list(value.values()))),
        column=alt.Column('ref_step_base:N', title='Lag', header=alt.Header(labelOrient='top', titleFontSize=16, titleFontWeight='normal', titlePadding=0, labelFontSize=16)),
    ).properties(
        width=width, height=height, title={
      "text": [""], 
      "subtitle": ['Uncued prediction (u-P)']
    })

    r0 = alt.Chart(df.query("condition=='r0' and ref_step_base >= -8")).mark_rect().encode(
        x=alt.X('x0_c:Q', 
#                 scale=alt.Scale(domain=(0, 160)), 
                axis=alt.Axis(title='')),
        x2='x1_c',
        y=alt.Y('y0:Q', axis=alt.Axis(title='Hit rate')),
        y2='y1',
        color=alt.Color(color, scale=alt.Scale(domain=list(scale.values()), range=list(value.values()))),
        column=alt.Column('ref_step_base:N', title='Lag', sort='descending', header=alt.Header(labelOrient='top', titleFontSize=16, titleFontWeight='normal',  titlePadding=0, labelFontSize=16)),
    ).properties(
        width=width, height=height, title={
      "text": [""], 
      "subtitle": ['Uncued retrodiction (u-R)']
    })

    type_hit_plot = alt.vconcat(r0, p0, title=''
    ).configure_axis(
        grid=False, 
        titleFontWeight='normal',
        titleFontSize=16,
        labelFontSize=16,
    ).configure_view(
        strokeWidth=0
    ).configure_concat(
        spacing=50
    ).configure_legend( 
        legendX=0,
        legendY=695,
        rowPadding=4,
        columnPadding=20,
        titlePadding=10,
        symbolSize=200,  
        labelFontSize=16, 
        titleFontWeight='normal',
        titleFontSize=20,
        labelLimit=1000,
        titleLimit=1000,
        columns=legend_cols
    ).configure_title(
        subtitleFontSize=20,
        anchor='start',
        offset=0
    )

#     boost_type_hit_plot.save('../figs/boost_type_hit_plot.svg')
    return type_hit_plot

# boost 4
# area_plot(step_boost_type_pr_doubled, 'refer_boost_type_plus_long', boost_type_long_map, boost_type_color_map_double)

# boost 3
# area_plot(step_boost_type_tri_pr_doubled, 'refer_boost_type_tri_plus_long', boost_type_tri_long_map, boost_type_tri_color_map_double)

# boost 2
a = area_plot(step_type_pr_doubled, 'refer_type_merge_plus_long', type_long_map, type_color_map_double)
a
# a.save('../figs/core/a1.svg')


## figure 6E

In [77]:
area_plot(step_boost_type_tri_pr_doubled, 'refer_boost_type_tri_plus_long', boost_type_tri_long_map, boost_type_tri_color_map_double)

## figure 7D

In [78]:
area_plot(step_rr_type_pr_doubled.query("main_or_gap=='main'"), 'rr_type_plus_long', rr_type_long_map, rr_type_color_map_double, width=100)