In [2]:
import numpy as np
import pandas as pd
import altair as alt
from bokeh.palettes import Category20c, Category20, Category20b, Pastel2, Purples
from sklearn.feature_extraction.text import CountVectorizer

## helper 

In [3]:
def bootstrap_ci(array, ci=95, n_iterations = 10000):
    np.random.seed(12)
    samples = np.random.choice(array, (n_iterations, len(array)), replace=True)
    stats = np.mean(samples, axis=1)
    lower, upper = np.percentile(stats, [(100-ci)/2, (100+ci)/2])
    return lower, upper

def bootstrap_ci_lower(array):
    return bootstrap_ci(array)[0]

def bootstrap_ci_upper(array):
    return bootstrap_ci(array)[1]

def num2seg(num):
    # from num:953 to seg:9.5
    
    num = int(num)
    if num//10%10 in [0,1,2,3,4]:
        return num//100
    elif num//10%10 in [5,6,7,8,9]:
        return num//100+0.5

def set_color_old(df, step_col):
    df.loc[(df['condition'].isin(['r0','r','r_p','r_recall'])) & (df[step_col]%1==0.5), 'color'] = cmap['r_gap']
    df.loc[(df['condition'].isin(['r0','r','r_p','r_recall'])) & (df[step_col]%1==0), 'color'] = cmap['r_main']
#     df.loc[(df['condition'].isin(['r0','r','r_p'])) & (df[step_col]==-1), 'color'] = Category20c[8][4]
    df.loc[(df['condition'].isin(['r0','r','r_p','r_recall'])) & (df[step_col]==-1), 'color'] = cmap['r_3']

    df.loc[(df['condition'].isin(['p0','p','p_r','p_recall'])) & (df[step_col]%1==0.5), 'color'] = cmap['p_gap']
    df.loc[(df['condition'].isin(['p0','p','p_r','p_recall'])) & (df[step_col]%1==0), 'color'] = cmap['p_main']
#     df.loc[(df['condition'].isin(['p0','p','p_r'])) & (df[step_col]==1), 'color'] = Category20c[8][0]
    df.loc[(df['condition'].isin(['p0','p','p_r','p_recall'])) & (df[step_col]==1), 'color'] = cmap['p_3']

def set_color(df, step_col):
    df.loc[(df[step_col]%1==0.5) & (df[step_col]<0), 'color'] = cmap['r_gap']
    df.loc[(df[step_col]%1==0) & (df[step_col]<0), 'color'] = cmap['r_main']
    df.loc[(df[step_col]==-1), 'color'] = cmap['r_3']

    df.loc[(df[step_col]%1==0.5) & (df[step_col]>0), 'color'] = cmap['p_gap']
    df.loc[(df[step_col]%1==0) & (df[step_col]>0), 'color'] = cmap['p_main']
    df.loc[(df[step_col]==1), 'color'] = cmap['p_3']

    df.loc[(df[step_col]==0), 'color'] = cmap['recall']

    

## constants

In [4]:
cond_pr = ['r0','p0','r','p'] 
cond_pr_recall = ['r0','p0','r','p','r_p','p_r','r_recall','p_recall'] 
cond_full = ['r0','p0','r','p','r_p','p_r','r_recall','p_recall','r_p_recall','p_r_recall','recall_f','recall_b']
cond_re = ['r_recall','p_recall','r_p_recall','p_r_recall','recall_f','recall_b']
e_cond = ['e_target','e_gap','e_ahead','e_unmatch']
e_cond_sl = ['e_target_sl','e_gap_sl','e_ahead_sl','e_across_sl','e_unmatch']

cmap_old = {'p_1':Category20[8][1], 'r_1':Category20[8][3],
        'p_2':Category20[8][0], 'r_2':Category20[8][2],
        'p_3':Category20b[8][1], 'r_3':Category20[8][6],        # recall
        'p_main':Category20c[8][1], 'r_main':Category20c[8][5],   
        'p_gap':Category20c[8][3], 'r_gap':Category20c[8][7],   
        }

# Tableau Color Blind
cmap = {'p_1':'#a3cce9', 'r_1':'#ffbc79',
        'p_2':'#5fa2ce', 'r_2':'#fc7d0b',
        'p_3':'#1170aa', 'r_3':'#c85200',        # recall
        'p_main':'#5fa2ce', 'r_main':'#fc7d0b',   
        'p_gap':'#c7e0f1', 'r_gap':'#ffd6ae',      # a3cce9 ffbc79  
        'recall': '#59a14f'
        }

e_cond_cmap_old = {
               'e_unmatch': Category20c[20][18],
               'e_ahead': Category20b[20][5],
               'e_gap': Category20b[20][6],
               'e_target': Category20b[20][4],
              }

e_cond_cmap = {
               'e_unmatch': Category20c[20][18],
               'e_ahead': '#C69C6D',
               'e_target': '#8C6238',
               'e_gap': '#C7B39A',
              }

e_cond_sl_cmap = {
                'e_unmatch': Category20c[20][18],
                'e_across_sl': Category20c[20][3],
                'e_ahead_sl': '#C69C6D',
                'e_target_sl': '#8C6238',
                'e_gap_sl': '#C7B39A',
                }

e_cond_long_map = {   
                   'e_unmatch': 'Unmatched events',
                   'e_ahead': 'Events with lag < -1 or lag > 1',
                   'e_target': 'Target events with lag = -1 or 1',
                   'e_gap': 'Events with lag = -0.5 or 0.5',
                  }

e_cond_sl_long_map = {   
                     'e_unmatch': 'Unmatched events',
                     'e_across_sl': 'Across-storyline events',
                     'e_ahead_sl': 'Within-storyline events with lag < -1 or lag > 1',
                     'e_target_sl': 'Within-storyline events with lag = -1 or 1',
                     'e_gap_sl': 'Within-storyline events with lag = -0.5 or 0.5',
                     }

cond_short_map = {'r0':'u-R','p0':'u-P','r':'c-R','p':'c-P','r_p':'c-RP','p_r':'c-PR',
                  'r_recall':'re(R)','p_recall':'re(P)'
                 }

cond_long_map = {'r0':'u-R: uncued retrodiction','p0':'u-P: uncued prediction',
                 'r':'c-R: character-cued retrodiction','p':'c-P: character-cued prediction',
                 'r_p':'c-RP: updated retrodiction (after watching one segment earlier)',
                 'p_r':'c-PR: updated prediction (after watching one segment later)',
                 'r_recall':'re(R): retrodiction-matched recall',
                 'p_recall':'re(P): prediction-mached recall'
                 }

cond_long_3_map = {'r0':'u-R: uncued retrodiction','p0':'u-P: uncued prediction',
                 'r':'c-R: character-cued retrodiction','p':'c-P: character-cued prediction',
                 'r_recall':'re(R): retrodiction-matched recall',
                 'p_recall':'re(P): prediction-mached recall'
                 }

cond_long_pr_map = {'r0':'u-R: uncued retrodiction','p0':'u-P: uncued prediction',
                 'r':'c-R: character-cued retrodiction','p':'c-P: character-cued prediction'
                 }

cond_color_map = {'r0':cmap['r_1'], 'p0':cmap['p_1'],
                  'r':cmap['r_2'], 'p':cmap['p_2'],
                  'r_p':cmap['r_2'],'p_r':cmap['p_2'],
                  'r_recall':cmap['r_3'],'p_recall':cmap['p_3']}

cond_color_3_map = {'r0':cmap['r_1'], 'p0':cmap['p_1'],
                  'r':cmap['r_2'], 'p':cmap['p_2'],
                  'r_recall':cmap['r_3'],'p_recall':cmap['p_3']}

cond_color_pr_map = {'r0':cmap['r_1'], 'p0':cmap['p_1'],
                  'r':cmap['r_2'], 'p':cmap['p_2']}

color_stroke_color_map = {'r0':cond_color_map['r0'],'p0':cond_color_map['p0'],
                          'r':cond_color_map['r'],'p':cond_color_map['p'],
                          'r_p':cond_color_map['p'],'p_r':cond_color_map['r'],
                          'r_recall':cond_color_map['r_recall'],'p_recall':cond_color_map['p_recall']}

color_stroke_color_3_map = {'r0':cond_color_map['r0'],'p0':cond_color_map['p0'],
                          'r':cond_color_map['r'],'p':cond_color_map['p'],
                          'r_recall':cond_color_map['r_recall'],'p_recall':cond_color_map['p_recall']}

color_stroke_color_pr_map = {'r0':cond_color_map['r0'],'p0':cond_color_map['p0'],
                          'r':cond_color_map['r'],'p':cond_color_map['p']}

cond_pr_short = [cond_short_map[cond] for cond in cond_pr]
cond_pr_recall_short = [cond_short_map[cond] for cond in cond_pr_recall]

## read annotation file

In [4]:
df_annot = pd.read_excel('../data/rep/TheChair.xlsx', sheet_name='annotation')
df_annot['half_point_add'] = df_annot['half_point_add'].fillna('')
df_annot['half_point_add'] = df_annot['half_point_add'].astype(str)
df_annot.shape

(160, 11)

In [5]:
# df_annot['main_or_gap'] = np.where(df_annot['segment']%1 == 0, 'main', 'gap')
# df_annot['full_or_half'] = np.where((df_annot['num']//10%10).isin([0,5]), 'full', 'half')
df_annot[['partial','on_off']].value_counts()

partial  on_off 
ori      on         71
         off        51
         off_add    15
sum      on          9
partial  on          8
ori      on_add      3
partial  off         2
sum      off         1
dtype: int64

## read data files

In [5]:
df = pd.read_csv('../data/rep/exp_embed_use_large.csv', index_col='no')

## add events columns

In [6]:
vectorizer = CountVectorizer(binary=False, min_df=1, token_pattern='[^ ]+')
counts = vectorizer.fit_transform(df['scenes_final'].fillna('-1')) # assign "-1" for recall data
df_features = pd.DataFrame(counts.toarray(), columns=vectorizer.get_feature_names(), index=df.index)

cols_addi = [str(num) for num in df_annot['num'] if str(num) not in df_features.columns]  # for some elems no one hit 
for col in cols_addi:
    df_features[col] = 0
    
cols_unsorted = df_features.columns.to_list()
cols_int = [int(float(e)) for e in cols_unsorted]
cols_int.sort()
cols_sorted = [str(e) for e in cols_int]
df_features = df_features[cols_sorted]

# correct multiple matches
print(df_features[df_features>1].count().sum())
df_features[df_features.loc[:,'51':]>1] = 1
print(df_features[df_features>1].count().sum())

df = pd.concat([df, df_features], axis=1)

427
427




In [7]:
# correct half points (1 to 0.5)

def get_half_elems(story, mode='add'):
    # mode: add or exclude
    
    half_elems_num = [num for i,num in enumerate(df_annot['num']) 
        if num//10%10 in [3,4,8,9] and num>100 and df_annot.loc[i,'story']==story]
    half_elems_str = [str(num) for num in half_elems_num]

    rule_elems = [l.split(' ') for l in df_annot.loc[
        df_annot['num'].isin(half_elems_num) & (df_annot['story']==story), 
                        f'half_point_{mode}'].values]
    
    return half_elems_str, rule_elems

def correct_half_point_e(row, mode='add'):
    # mode: add or exclude
    '''
    mode:
        add: set addi elements as 0.5 point, the total points are fixed
        exclude: total points are not fixed (deprecated)
    '''
    for half_e, rule_e in zip(*get_half_elems(row['story'], mode)):
        
        if mode == 'exclude':
            if row[half_e] == 1:
                if rule_e == ['']:        
                    row[half_e] = 0.5
                elif row[rule_e].sum() == 0:
                    row[half_e] = 0.5
                else:
                    row[half_e] = 0
                    
        elif mode == 'add':
            if rule_e == ['']:        
                pass
            elif row[rule_e].sum() > 0:
                row[half_e] = 1
            if row[half_e] == 1:
                row[half_e] = 0.5
    return row

In [8]:
df = df.apply(correct_half_point_e, axis=1)

In [9]:
first_event = '51'
first_event_loc = cols_sorted.index(first_event)

def get_elem_list(seg_num, mode=None, cond_direction=None, full_list=cols_sorted[first_event_loc:]):
    if mode == None:
        l = df_annot.query("seg_num==@seg_num")['num'].to_list()
        return [str(num) for num in l]
        # min_num = seg_num*100
        # max_num = min_num+49
        # return [e for e in full_list if min_num<int(e)<=max_num]
    
    else:
        if mode == 'gap':
            if cond_direction == 'f':
                return get_elem_list(seg_num-0.5)
            elif cond_direction == 'b':
                return get_elem_list(seg_num+0.5)
            
        elif mode == 'ahead':
            if cond_direction == 'f':
                min_num = seg_num*100+50
                return [e for e in full_list if min_num<int(e)]
            elif cond_direction == 'b':
                max_num = seg_num*100 - 1
                return [e for e in full_list if int(e)<=max_num]

        elif mode == 'ahead_main':
            full_list = [e for e in full_list if 100<int(e)<1150]  # exclude 51.. and 1201
            if cond_direction == 'f':
                min_num = seg_num*100+50
                return [e for e in full_list if min_num<int(e) and int(e)//10%10 in [0,1,2,3,4]]
            elif cond_direction == 'b':
                max_num = seg_num*100 - 1
                return [e for e in full_list if int(e)<=max_num and int(e)//10%10 in [0,1,2,3,4]]
            
        elif mode == 'ahead_gap':
            full_list = [e for e in full_list if 100<int(e)<1150]
            if cond_direction == 'f':
                min_num = seg_num*100+50
                return [e for e in full_list if min_num<int(e) and int(e)//10%10 in [5,6,7,8,9]]
            elif cond_direction == 'b':
                max_num = seg_num*100 - 1
                return [e for e in full_list if int(e)<=max_num and int(e)//10%10 in [5,6,7,8,9]]
                
def get_seg_elems(story, seg_num, version):
    if version=='full':
        return [num for num in df_annot.query(
            f"story=={story} and seg_num=={seg_num}")['num']]
    if version=='main':
        return [num for num in df_annot.query(
            f"story=={story} and seg_num=={seg_num}")['num'] if num//10%10 in [0,1,2,5,6,7]]
    if version=='add':
        return [num for num in df_annot.query(
            f"story=={story} and seg_num=={seg_num}")['num'] if num//10%10 in [3,4,8,9]]
    
def get_seg_points(story, seg_num):
    return len(get_seg_elems(story, seg_num, 'main'))+0.5*len(get_seg_elems(story, seg_num, 'add'))

In [10]:
def get_step_storyline(base_seg_num, ref_sc_num):

    # print(base_seg_num, ref_sc_num)
    # discard added scenes
    if df_annot.query("num==@ref_sc_num")['on_off'].to_list()[0] in ['off_add','on_add']:
        return np.nan

    storyline = df_annot.query("num==@ref_sc_num")['storyline'].to_list()[0]
    # segments_sl = df_annot.query("storyline==@storyline")['seg_num'].unique()
    on_seg_sl = list(df_annot.query("storyline==@storyline and on_off=='on'")['seg_num'].unique())
    on_seg_sl = [int(s) for s in on_seg_sl]
    ref_seg_num = num2seg(ref_sc_num)

    if base_seg_num not in on_seg_sl:  # across
        return np.nan

    if ref_seg_num % 1 == 0:  # main
        ref_step_base_sl = on_seg_sl.index(ref_seg_num) - on_seg_sl.index(base_seg_num)
    
    else:  # gap
        try:
            next_main_seg_num = np.array(on_seg_sl)[np.array(on_seg_sl) > ref_seg_num][0]
            ref_step_base_sl = on_seg_sl.index(next_main_seg_num) - on_seg_sl.index(base_seg_num) - 0.5
        except:
            pre_main_seg_num = np.array(on_seg_sl)[np.array(on_seg_sl) < ref_seg_num][0]
            ref_step_base_sl = on_seg_sl.index(next_main_seg_num) - on_seg_sl.index(base_seg_num) + 0.5

    return ref_step_base_sl

In [11]:
def get_elem_list_storyline(base_seg_num, cond_direction):

    current_storylines = df_annot.query("seg_num==@base_seg_num")['storyline'].unique()

    within_target_elem_list = []
    within_gap_elem_list = []
    within_ahead_elem_list = []
    across_elem_list = []

    for storyline in current_storylines:
        if cond_direction == 'f':
            try:
                within_target_seg_num = df_annot.query(
                    "storyline==@storyline and on_off=='on' and seg_num>@base_seg_num")['seg_num'].to_list()[0]
                within_target_elem_list_storyline = df_annot.query("storyline==@storyline and seg_num==@within_target_seg_num")['num'].to_list()

                within_gap_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num>@base_seg_num and seg_num<@within_target_seg_num")['num'].to_list()
                within_ahead_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num>@within_target_seg_num")['num'].to_list()
            except:
                within_target_elem_list_storyline = []
                within_gap_elem_list_storyline = []
                within_ahead_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num>@base_seg_num")['num'].to_list()
            

        elif cond_direction == 'b':
            try:
                within_target_seg_num = df_annot.query(
                    "storyline==@storyline and on_off=='on' and seg_num<@base_seg_num")['seg_num'].to_list()[-1]
                within_target_elem_list_storyline = df_annot.query("storyline==@storyline and seg_num==@within_target_seg_num")['num'].to_list()

                within_gap_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num<@base_seg_num and seg_num>@within_target_seg_num")['num'].to_list()
                within_ahead_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num<@within_target_seg_num")['num'].to_list()
            except:
                within_target_elem_list_storyline = []
                within_gap_elem_list_storyline = []
                within_ahead_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num<@base_seg_num")['num'].to_list()

        within_target_elem_list.extend(within_target_elem_list_storyline)
        within_gap_elem_list.extend(within_gap_elem_list_storyline)
        within_ahead_elem_list.extend(within_ahead_elem_list_storyline)
        
        
    across_storylines = [storyline for storyline in [1,2,3,4] if storyline not in current_storylines]
    for storyline in across_storylines:
        if cond_direction == 'f':
            across_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num>@base_seg_num")['num'].to_list()
        elif cond_direction == 'b':
            across_elem_list_storyline = df_annot.query(
                    "storyline==@storyline and seg_num<@base_seg_num")['num'].to_list()
            
        across_elem_list.extend(across_elem_list_storyline)

    # convert to string
    within_target_elem_list = [str(e) for e in within_target_elem_list]
    within_gap_elem_list = [str(e) for e in within_gap_elem_list]
    within_ahead_elem_list = [str(e) for e in within_ahead_elem_list]
    across_elem_list = [str(e) for e in across_elem_list]
    
    return within_target_elem_list, within_gap_elem_list, within_ahead_elem_list, across_elem_list

In [12]:
# calculate e_total e_match e_unmatch e_target e_gap e_ahead e_ahead_main e_ahead_gap
# e_0.5 e_1 e_1.5 ... e_11 e_11.5 e_12

for i, col in enumerate(['e_total','e_match','e_unmatch','e_target','e_gap','e_ahead','e_ahead_main','e_ahead_gap','e_target_sl','e_gap_sl','e_ahead_sl','e_across_sl'] 
                         + ['e_'+str(0.5*n) if 0.5*n%1!=0 else 'e_'+str(int(0.5*n)) for n in range(0,28)]):
    df[col] = np.nan

def add_e_count_columns(row):
    # print(row.name)
    seg_num = row['segment_num']
    base_seg_num = row['base_seg_num']
    cond_direction = row['cond_direction']
    
    # uncorrected
    row['e_total'] = row[ cols_sorted[1:]].sum()
    row['e_match'] = row[ cols_sorted[first_event_loc:]].sum()
    row['e_unmatch'] = row[ cols_sorted[1:first_event_loc]].sum()
    row['e_target'] = row[ get_elem_list(seg_num) ].sum()
    row['e_gap'] = row[ get_elem_list(seg_num, 'gap', cond_direction) ].sum()

    for col in ['e_ahead', 'e_ahead_main', 'e_ahead_gap']:
#         print(col)
        match_cols = get_elem_list(seg_num, col[2:], cond_direction)
        if match_cols == []:
            row[col] = 0
        else:   
            row[col] = row[ match_cols ].sum()

    # lag_corrected - within storyline
    within_target_elem_list, within_gap_elem_list, within_ahead_elem_list, across_elem_list = get_elem_list_storyline(base_seg_num, cond_direction)
    # print(within_target_elem_list)
    row['e_target_sl'] = row[ within_target_elem_list ].sum()
    row['e_gap_sl'] = row[ within_gap_elem_list ].sum()
    row['e_ahead_sl'] = row[ within_ahead_elem_list ].sum()
    row['e_across_sl'] = row[ across_elem_list ].sum()

    
    for col in ['e_'+str(0.5*n) if 0.5*n%1!=0 else 'e_'+str(int(0.5*n)) for n in range(0,28)]:  # 0 to 13.5
        seg_num = float(col.split('_')[1])
        row[col] = row[ get_elem_list(seg_num) ].sum()
        
    return row

df = df.apply(add_e_count_columns, axis=1)

  df[col] = np.nan


In [13]:
# df_seg_info

df_seg_info = pd.DataFrame(columns=['story','segment_num','target_points'])

for story in [3]:
    for seg_num in np.arange(0.5,14,0.5):
        target_points =  get_seg_points(story, seg_num)
        df_seg_info = df_seg_info.append({'story':story, 'segment_num':seg_num,
                            'target_points':target_points}, ignore_index=True)

df_seg_info['story'] = df_seg_info['story'].astype('int')
df_seg_info['other_refer_seg_num'] = df_seg_info['segment_num']
df_seg_info['other_refer_seg_points'] = df_seg_info['target_points']

# add total_points
df = pd.merge(df, df_seg_info, on=['story','segment_num'], how='left')
df['target_hit_rate'] = df['e_target'] / df['target_points']

# df_seg_info.to_csv("../data/rep/df_seg_info.csv", index=False)

  df['target_hit_rate'] = df['e_target'] / df['target_points']


In [14]:
df.to_csv("../data/rep/exp_with_scenes_use_large.csv", index=False) 

## generating dfs

In [17]:
df = pd.read_csv("../data/rep/exp_with_scenes_use_large.csv")
# df = pd.read_csv("../data/processed/exp_with_scenes_simcse.csv")
# df = pd.read_csv("../data/processed/exp_with_scenes_t5_3b.csv")
# df = pd.read_csv("../data/processed/exp_with_scenes_sbert.csv")

df_seg_info = pd.read_csv("../data/rep/df_seg_info.csv")

# now subset
print(df.shape)

df['base_segment'] = df['base_segment'].astype(str) # prevent excluded from groupyby
df['segment_pair'] = df['segment_pair'].astype(str)
df['segment_count_demean'] = df['segment_count']-6.5
df['segment_num_demean'] = df['segment_num']-7

# seg_mean
seg_mean = df.groupby(['condition','cond_direction','cond_amount','story','segment_num','target_points',
                       'segment','segment_count','base_segment','base_seg_num','segment_pair'])[[
                       'res_1_simi_info_z','res_1_MD_z'] + ['target_hit_rate'] +
                        df.loc[:,'0':'e_13.5'].columns.tolist()].mean().reset_index()

# cond_mean
cond_mean = seg_mean.groupby(['condition']).mean().reset_index()

(878, 236)


## for R

In [18]:
# df_pr_R and df_pr_recall_R <- df (for figure 1 stats)
cols = ['sub','story','segment','segment_count','segment_num','base_segment',
        'segment_pair','condition','cond_direction','cond_amount','target_points',
        'e_total','e_target','res_1_simi_info_z','res_1_MD_sub_z','segment_count_demean','segment_num_demean']

df_pr_R = df[cols]

# replace inf values
df_pr_R['res_1_MD_sub_z'].replace([np.inf, -np.inf], np.arctanh(0.8), inplace=True)

df_pr_R.to_csv("../data/rep/R/df_pr_for_R.csv", index=False)
print(df_pr_R.shape)

(878, 17)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return self._update_inplace(result)


In [29]:
# df_pr_econd_R (for figure 2 stats)
cols = ['sub','story','segment','segment_count','segment_num','base_segment','segment_count_demean','segment_num_demean',
        'segment_pair','condition','cond_direction','cond_amount']

df_pr_econd_R = pd.melt(df, id_vars=cols, value_vars=e_cond, 
        var_name='e_cond', value_name='n_hits')

df_pr_econd_R.to_csv("../data/rep/R/df_pr_econd_for_R.csv", index=False)
df_pr_econd_R.shape

(3512, 14)

# segment count plot: figure S3

In [31]:
# map cond names
seg_mean['cond_short'] = seg_mean['condition'].map(cond_short_map)
seg_mean['cond_long'] = seg_mean['condition'].map(cond_long_map)

In [32]:
def reg_plot(index, scale_max, title, label):
    scatter = alt.Chart().mark_circle(opacity=1).encode(
        # alt.X('segment_count:O', axis=alt.Axis(title='Segment count', labelAngle=0)),
        alt.X('segment_num:O', axis=alt.Axis(title='Segment num', labelAngle=0)),
        alt.Y(index, axis=alt.Axis(title=label), scale=alt.Scale(domain=(0, scale_max))
        # scale=alt.Scale(domain=(0, scale_max))
        ),
    #     alt.Column('condition', sort=cond_pr),
        tooltip='segment'
        ).properties(width=200, height=200)

    # reg_line = scatter.transform_regression('segment_count', index).mark_line()
    reg_line = scatter.transform_regression('segment_num', index).mark_line()

    faceted = alt.layer(scatter, reg_line, data=seg_mean.query(f"condition=={cond_pr}")).facet(title=title, 
            column=alt.Column('cond_short:N', sort=cond_pr_short, title='Condition', 
            header=alt.Header(labelOrient='top', titleFontSize=14, labelFontSize=14, titleFontWeight='normal', titlePadding=0)))
    return faceted

reg_hit_rate = reg_plot('target_hit_rate',0.4,'A','Target events hit rate')
reg_pre = reg_plot('res_1_simi_info_z',1,'B','Precision')
reg_conv = reg_plot('res_1_MD_z',1,'C','Convergence')

(reg_hit_rate & reg_pre & reg_conv).configure_axis(
    titleFontWeight='normal',
    titleFontSize=14,
    labelFontSize=14,
).configure_concat(
    spacing=40
).configure_legend( 
    symbolSize=100,  
    labelFontSize=14, 
    titleFontWeight='normal',
    titleFontSize=16,
).configure_title(
    fontSize=20,
)

# figure 4D

In [19]:
# cond_mean_e_long <- cond_mean

cond_mean_e_long = pd.melt(cond_mean, id_vars=['condition'], value_vars=e_cond, 
        var_name='e_cond', value_name='scenes')

cond_mean_e_long = cond_mean_e_long.query(f"condition=={cond_pr}")
cond_mean_e_long['sort'] = cond_mean_e_long['e_cond'].\
    map({'e_target':2, 'e_gap':1, 'e_ahead':3, 'e_unmatch':4})

cond_mean_e_long['cond_short'] = cond_mean_e_long['condition'].map(cond_short_map)
cond_mean_e_long['e_cond_long'] = cond_mean_e_long['e_cond'].map(e_cond_long_map)
cond_mean_e_long.head(1)

# cond_mean_e_sl_long <- cond_mean

cond_mean_e_sl_long = pd.melt(cond_mean, id_vars=['condition'], value_vars=e_cond_sl, 
        var_name='e_cond', value_name='scenes')

cond_mean_e_sl_long = cond_mean_e_sl_long.query(f"condition=={cond_pr}")
cond_mean_e_sl_long['sort'] = cond_mean_e_sl_long['e_cond'].\
    map({'e_target_sl':2, 'e_gap_sl':1, 'e_ahead_sl':3, 'e_across_sl':4, 'e_unmatch':5})

cond_mean_e_sl_long['cond_short'] = cond_mean_e_sl_long['condition'].map(cond_short_map)
cond_mean_e_sl_long['e_cond_long'] = cond_mean_e_sl_long['e_cond'].map(e_cond_sl_long_map)
cond_mean_e_sl_long.head(1)

Unnamed: 0,condition,e_cond,scenes,sort,cond_short,e_cond_long
0,p,e_target_sl,0.296024,2,c-P,Within-storyline...


In [22]:
def make_stack_bars(df, cond_map, cmap, n_conds, legendX, width, title):

    stack_bars = alt.Chart(df).mark_bar(size=42, strokeWidth=2, stroke='white').encode(
        y=alt.Y('scenes:Q', axis=alt.Axis(title='Number of events hit', titleFontWeight='normal')),
        x=alt.X('cond_short:N', sort=cond_pr_short, axis=alt.Axis(title='Condition', 
                            titleFontWeight='normal',
                            labelFontWeight='normal', 
                            labelAngle=0)),
        color=alt.Color('e_cond_long', 
                        scale=alt.Scale(domain=list(cond_map.values()), 
                        range=list(cmap.values())),
                        legend=alt.Legend(title='Event type', orient='none', 
                                        legendY=200, legendX=legendX)   #  legendX=360 for 6 conds
                    ),
        order=alt.Order("sort")
    ).properties(width=width, height=350, title=title)  # width=350 for 6 conds
    text = alt.Chart(df).mark_text(fontSize=8, dx=0, dy=5, color='black').encode(
        y=alt.Y('scenes:Q', stack='zero'),
        x=alt.X('cond_short:N', sort=cond_pr_short),
        detail='e_cond_long:N',
        order=alt.Order("sort"),
        text=alt.Text('scenes:Q', format='.2f')
    )

    if n_conds==2:
        stack_bars = stack_bars.transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=['r0','p0']))
        text = text.transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=['r0','p0']))

    return stack_bars+text

stack_bars2 = make_stack_bars(cond_mean_e_long, e_cond_long_map, e_cond_cmap, n_conds=2, legendX=140, width=120, title='D')
stack_bars4 = make_stack_bars(cond_mean_e_long, e_cond_long_map, e_cond_cmap, n_conds=6, legendX=250, width=240, title='')

cond_mean_e_sl_long

p = (stack_bars2).configure_legend(
    titleFontWeight='normal',
    symbolSize=400,  
    titleFontSize=20,
    labelFontSize=16, 
    labelLimit=1000,
    titleLimit=1000,
    titlePadding=10,
    rowPadding=5
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    grid=False
).configure_view(
    strokeWidth=0
).configure_title(
    fontSize=28,
    anchor='start',
    offset=10
)
# p.save('../figs/core/events_stacked_bar.svg')
p

# figure S5

In [21]:
stack_bars4

# figure 4E

### for R

In [23]:
# df_refer_long <- df (constrained to seg 1-11) (now seg 0.5-13)

df_refer_long = pd.melt(df, id_vars=df.loc[:,:'segment_count'], 
        value_vars=df.loc[:,'e_0.5':'e_13'], var_name='refer_seg_num', value_name='refer_seg_e')

df_refer_long['refer_seg_num'] = df_refer_long['refer_seg_num'].str[2:].astype(float)
# # df_refer_long['refer_step'] = df_refer_long['refer_seg_num'] - df_refer_long['segment_num']
df_refer_long['refer_step_base'] = df_refer_long['refer_seg_num'] - df_refer_long['base_seg_num']

In [24]:
# add fields to df_refer_long

df_seg_info['refer_seg_num'] = df_seg_info['other_refer_seg_num']
df_seg_info['refer_seg_points'] = df_seg_info['other_refer_seg_points']

# add total_points
df_refer_long = pd.merge(df_refer_long, df_seg_info[['story','refer_seg_num','refer_seg_points']], on=['story','refer_seg_num'], how='left')

# add main or gap
df_refer_long['mod1'] = df_refer_long['refer_seg_num'] % 1
df_refer_long['main_or_gap'] = df_refer_long['mod1'].map({0:'main', 0.5:'gap'})

# add refer_seg
df_refer_long['refer_seg'] = df_refer_long['story'].astype(str) + '_' + df_refer_long['refer_seg_num'].astype(str)

In [25]:
# df_refer_long_p0r0 <- df_refer_long

# select segments in right direction
df_refer_long_p0r0 = df_refer_long.query(
    "condition=='p0' & refer_seg_num>=base_seg_num | condition=='r0' & refer_seg_num<=base_seg_num").copy()

df_refer_long_p0r0_for_R = df_refer_long.query(
    "condition=='p0' & refer_seg_num>base_seg_num | condition=='r0' & refer_seg_num<base_seg_num").copy()

df_refer_long_p0r0_for_R['segment_count_demean'] = df_refer_long_p0r0_for_R['segment_count']-6.5
df_refer_long_p0r0_for_R['segment_num_demean'] = df_refer_long_p0r0_for_R['segment_num']-7
df_refer_long_p0r0_for_R['step_abs'] = abs(df_refer_long_p0r0_for_R['refer_step_base'])
df_refer_long_p0r0_for_R['step_abs_demean_old'] = df_refer_long_p0r0_for_R['step_abs']-4.625
df_refer_long_p0r0_for_R['step_abs_demean'] = df_refer_long_p0r0_for_R['step_abs']-5.625
df_refer_long_p0r0_for_R['base_refer_seg_pair'] = df_refer_long_p0r0_for_R['base_segment'] + '_' + df_refer_long_p0r0_for_R['refer_seg']

refer_step_bound = abs(df_refer_long_p0r0.query("refer_seg_e>=1")['refer_step_base'].min()) #12.5

In [26]:
cols = ['condition','story','sub','segment_count','segment_count_demean','main_or_gap','base_segment','step_abs','step_abs_demean','step_abs_demean_old','refer_seg','base_refer_seg_pair','refer_seg_e','refer_seg_points']

df_refer_long_p0r0_for_R = df_refer_long_p0r0_for_R[cols]

df_refer_long_p0r0_for_R.to_csv('../data/rep/R/df_refer_long_p0r0_for_R.csv', index=False)

In [27]:
# df_refer <- seg_mean (equivelant to seg_seg_mean)

df_refer = pd.melt(seg_mean, id_vars=seg_mean.loc[:,:'base_seg_num'], 
        value_vars=seg_mean.loc[:,'e_0.5':'e_13'], 
                   var_name='other_refer_seg_num', value_name='refer_seg_num_e')

df_refer['other_refer_seg_num'] = df_refer['other_refer_seg_num'].str[2:].astype(float)
df_refer['other_refer_step'] = df_refer['other_refer_seg_num'] - df_refer['segment_num']
df_refer['other_refer_step_base'] = df_refer['other_refer_seg_num'] - df_refer['base_seg_num']

# select only res in correct direction (to avoid altair plotting problem)
# df_refer = df_refer.query(
#     "cond_direction=='f' & other_refer_seg_num>=base_seg_num | cond_direction=='b' & other_refer_seg_num<=base_seg_num")

df_refer = df_refer.drop(columns=['target_points'])
df_refer = pd.merge(df_refer, df_seg_info[['story','other_refer_seg_num','other_refer_seg_points']],
                                           on=['story','other_refer_seg_num'], how='left')
df_refer['other_refer_seg_hit_rate'] = df_refer['refer_seg_num_e'] / df_refer['other_refer_seg_points']

# add main or gap
df_refer['mod1'] = df_refer['other_refer_seg_num'] % 1
df_refer['main_or_gap'] = df_refer['mod1'].map({0:'main', 0.5:'gap'})

set_color(df_refer, 'other_refer_step_base')

In [29]:
# fig p0r0_step_hit_rate

base = alt.Chart(df_refer.query(
    f"condition==['p0','r0'] & (cond_direction=='f' & other_refer_seg_num>=base_seg_num | cond_direction=='b' & other_refer_seg_num<=base_seg_num)")).encode(
    alt.X('other_refer_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13))))
).properties(width=550, height=350, title={
      "text": ["E"], 
    #   "subtitle": ['Uncued retrodiction (u-R) and uncued prediction (u-P)']
    })

bar = base.mark_bar(size=10).encode(
    alt.Y('mean(refer_seg_num_e):Q', axis=alt.Axis(title=['Number of events hit']), 
        #   scale=alt.Scale(domain=(0, 0.7))
         ), 
    alt.Color('color', scale=None),
)

# error_bar = base.mark_errorbar(extent='ci', color='black', opacity=0.4).encode(
#     alt.Y('refer_seg_num_hit:Q', axis=alt.Axis(title='Hit rate'))
# )

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(refer_seg_num_e)',
    y2='ci1(refer_seg_num_e)',
)

# y_pos = 0.75
# text = alt.Chart(pd.DataFrame({'x': [-4, 4], 'y':[y_pos, y_pos], 'text':['R0', 'P0']})).mark_text(
#     fontSize=20
# ).encode(
#     x='x:O', y='y:Q', text='text'
# )

step_scenes_hits = (bar + error_bar
).configure_axis(
    titleFontSize=16,
    labelFontSize=16,
    titleFontWeight='normal'
).configure_title(
    fontSize=28,
    subtitleFontSize=18,
    anchor='start',
    offset=10
)

# step_scenes_hits.save('../figs/core/step_events_hits.svg')
step_scenes_hits

In [33]:
((stack_bars2) | (bar + error_bar)
).configure_legend(
    titleFontWeight='normal',
    symbolSize=300,  
    titleFontSize=20,
    labelFontSize=16, 
    labelLimit=1000,
    titleLimit=1000,
    titlePadding=10,
    rowPadding=2
).configure_axis(
    titleFontWeight='normal',
    titleFontSize=16,
    labelFontSize=16,
).configure_view(
    strokeWidth=0
).configure_concat(
    spacing=350  # 350
).configure_title(
    fontSize=28,
    subtitleFontSize=20,
    anchor='start',
    offset=10
)

# figure 3E, F, G

In [34]:
# cond_mean_ci (for plotting)
cond_mean_ci = seg_mean.groupby(['condition']).agg(
    # hit=('hit', 'mean'),
    # hit_ci_lower=('hit', bootstrap_ci_lower),
    # hit_ci_upper=('hit', bootstrap_ci_upper), 
    e_target=('e_target', 'mean'),
    e_target_ci_lower=('e_target', bootstrap_ci_lower),
    e_target_ci_upper=('e_target', bootstrap_ci_upper), 
    target_hit_rate=('target_hit_rate', 'mean'),
    target_hit_rate_ci_lower=('target_hit_rate', bootstrap_ci_lower),
    target_hit_rate_ci_upper=('target_hit_rate', bootstrap_ci_upper),
    
    res_1_simi_info_z=('res_1_simi_info_z', 'mean'),
    res_1_simi_info_z_ci_lower=('res_1_simi_info_z', bootstrap_ci_lower),
    res_1_simi_info_z_ci_upper=('res_1_simi_info_z', bootstrap_ci_upper),
    res_1_MD_z=('res_1_MD_z', 'mean'),
    res_1_MD_z_ci_lower=('res_1_MD_z', bootstrap_ci_lower),
    res_1_MD_z_ci_upper=('res_1_MD_z', bootstrap_ci_upper)
).reset_index()

# replace e_target in recall to np.nan for now
cond_mean_ci['e_target'] = cond_mean_ci['e_target'].replace(0, np.nan)

# transform back to r
for col in cond_mean_ci.columns[7:]:
    cond_mean_ci[col+'_r'] = np.tanh(cond_mean_ci[col])

# add recall as upper limit
# cond_mean_ci['r_recall_res_1_simi_other_z_r'] = cond_mean_ci.query("condition=='r_recall'")['res_1_simi_other_z_r'].values[0]
# cond_mean_ci['p_recall_res_1_simi_other_z_r'] = cond_mean_ci.query("condition=='p_recall'")['res_1_simi_other_z_r'].values[0]

# map cond names
cond_mean_ci['cond_short'] = cond_mean_ci['condition'].map(cond_short_map)
cond_mean_ci['cond_long'] = cond_mean_ci['condition'].map(cond_long_map)
# cond_mean_ci['res_1_simi_self_z_r_range'] = cond_mean_ci['res_1_simi_self_z_ci_upper_r'] - cond_mean_ci['res_1_simi_self_z_ci_lower_r'] 

In [35]:
data = cond_mean_ci.query(f"condition=={cond_pr}")

base = alt.Chart(data).encode(
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short, 
          axis=alt.Axis(title='Condition', labelAngle=0)),
)

width = 200; height= 300; point_size = 120

# precision
prec_point = base.mark_circle(opacity=1, size=point_size, filled=True, strokeWidth=3.5).encode(
    alt.Y('res_1_simi_info_z_r:Q', axis=alt.Axis(title='Precision'), scale=alt.Scale(domain=(0, 1))), 
#     strokeWidth = alt.StrokeWidth('strokeWidth', legend=None),
    stroke = alt.Color('cond_long', scale=alt.Scale(domain=list(cond_long_3_map.values()), 
                                                    range=list(color_stroke_color_3_map.values()))),
    tooltip='res_1_simi_info_z_r',
    fill=alt.Color('cond_long', scale=alt.Scale(domain=list(cond_long_3_map.values()), 
                                                range=list(cond_color_3_map.values())),
                   legend=alt.Legend(direction='vertical', title="Condition"))
)

prec_error_bar = base.mark_rule(size=2.6, color='black', opacity=0.4).encode(
    y='ci0(res_1_simi_info_z_ci_lower_r)',
    y2='ci1(res_1_simi_info_z_ci_upper_r)',  
)


prec_plot = (prec_error_bar + prec_point).properties(width=width, height=height, title='F').\
    transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=cond_pr)
)

# target hit rate
target_hit_rate = prec_point.encode(
    alt.Y('target_hit_rate:Q', axis=alt.Axis(title='Target events hit rate'), 
          scale=alt.Scale(domain=(0, 1))
         ), 
    tooltip='target_hit_rate',
)

target_hit_rate_error = prec_error_bar.encode(
    y='ci0(target_hit_rate_ci_lower)',
    y2='ci1(target_hit_rate_ci_upper)',  
)
target_plot = (target_hit_rate_error + target_hit_rate).properties(width=width, height=height, title='E') # .\
    # transform_filter(alt.FieldOneOfPredicate(field='condition', oneOf=cond_pr)
#)

# convergence
conv_point = prec_point.encode(
    alt.Y('res_1_MD_z_r:Q', axis=alt.Axis(title='Convergence'), 
          scale=alt.Scale(domain=(0, 1))), 
    tooltip='res_1_MD_z_r',
)

conv_error = prec_error_bar.encode(
    y='ci0(res_1_MD_z_ci_lower_r)',
    y2='ci1(res_1_MD_z_ci_upper_r)',  
)

conv_plot = (conv_error + conv_point).properties(width=width, height=height, title='G')

# sig
h1 = 0.65
prec_sig = pd.DataFrame({'cond_short': ['c-R'], 'x2': ['c-P'], 'y':[h1], 'text':['*']})
prec_line = pd.DataFrame({'cond_short': ['c-R','c-P'], 'y':[h1]*2})

prec_sig_line = alt.Chart(prec_line).mark_line(strokeWidth=2, color='black').encode(
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short),
    y='y')

prec_sig_text = alt.Chart(prec_sig).mark_text(dx=25, fontSize=20).encode(       #dx=21
    alt.X(field='cond_short', type='nominal', sort=cond_pr_recall_short),
    y='y', text='text')



# # for denoting sig ***
# overlay = pd.DataFrame({'cond_short': ['R','P'], 'sig':[0.2,0.4], 'text':['    p','dfsdfsdfsdfsdfr']})
# line_test = alt.Chart(overlay).mark_line(color='red', strokeWidth=3).encode(x='cond_short:O', y='sig')
# text_test = alt.Chart(overlay).mark_text(color='blue', dx=4).encode(x='cond_short:O', y='sig', text='text')

p = (target_plot | prec_plot+prec_sig_line+prec_sig_text | conv_plot

).configure_legend(
    columns=1,
    columnPadding=40,
    symbolSize=100, 
    orient='none',
    legendX=800,  # 0
    legendY=0,   # 400
    symbolStrokeWidth=3,
    rowPadding=8,
    labelFontSize=18, 
    titleFontWeight='normal',
    titleFontSize=20,
    titlePadding=12,
    labelLimit=1000
).configure_axis(
    titleFontSize=20,
    labelFontSize=20,
    titleFontWeight='normal',
    labelFontWeight='normal',
).configure_concat(
    spacing=30
).configure_title(
    fontSize=28,
    anchor='start',
#     offset=20
)

# p.save('../figs/core/summary_results.svg') # use manual save instead
p

# references

### set refer type using seg_sc_pr

In [36]:
# seg_sc_pr <- seg_mean_pr <- seg_mean

# seg_mean_pr = seg_mean.query(f"condition in {cond_pr}")
seg_mean_pr = seg_mean.query(f"condition in {cond_pr}")


seg_sc_pr = pd.melt(seg_mean_pr, 
        id_vars=['story','segment_num','segment','cond_amount','cond_direction','condition','base_seg_num'], 
        value_vars=df.loc[:,'51':'1354'], var_name='ref_sc_num', 
                        value_name='ref_sc_hit')

# correct mismatched story-sc pairs
seg_sc_pr['ref_sc_num'] = seg_sc_pr['ref_sc_num'].astype('int')
story3_sc_nums = df_annot.query("story==3")['num'].to_list()
seg_sc_pr = seg_sc_pr.query(f"story==3 and ref_sc_num=={story3_sc_nums}")

seg_sc_pr['ref_seg_num'] = seg_sc_pr['ref_sc_num'].apply(num2seg)
seg_sc_pr['ref_step_base'] = seg_sc_pr['ref_seg_num'] - seg_sc_pr['base_seg_num']


# discard wrong direction ref in guesses, preserve all in recall (also preserve base_seg for plotting)
seg_sc_pr = seg_sc_pr.query(
    "(cond_amount!='p_recall/r_recall' & (cond_direction=='f' & ref_seg_num>=base_seg_num | cond_direction=='b' & ref_seg_num<=base_seg_num)) | cond_amount=='p_recall/r_recall'").\
    reset_index(drop=True)


In [37]:
seg_sc_pr['ref_step_base_sl'] = seg_sc_pr.apply(
    lambda row: get_step_storyline(row['base_seg_num'], row['ref_sc_num']), axis=1)

In [40]:
# add ‘refer_seg_type’ in seg_sc_pr

def add_ref_seg_type(row):
    step = row['ref_step_base']
    if abs(step) == 0.5:
        row['ref_seg_type'] = 'gap'
    elif abs(step) == 1:
        row['ref_seg_type'] = 'target'
    elif abs(step) > 1 and step%1 == 0.5:
        row['ref_seg_type'] = 'ahead_gap'
    elif abs(step) > 1 and step%1 == 0:
        row['ref_seg_type'] = 'ahead_main'
    elif step == 0:
        row['ref_seg_type'] = 'self'
    return row

seg_sc_pr['ref_seg_type'] = np.nan
seg_sc_pr = seg_sc_pr.apply(add_ref_seg_type, axis=1)

### read reference files

In [7]:
# references count
df_refer_full = pd.read_excel('../data/rep/TheChair.xlsx', sheet_name='references_full')

In [12]:
df_refer_full.groupby(['episode','direction'])['time'].count().reset_index().to_csv("../data/the_chair/the_chair_manual_reference_counts.csv", index=False)

In [53]:
# df_refer_raw_bounded <- df_refer_raw

df_refer_raw = pd.read_excel('../data/rep/TheChair.xlsx', sheet_name='references')

# add storyline
df_refer_raw = pd.merge(df_refer_raw, df_annot[['num','storyline']], left_on='source_num', right_on='num', how='left')
df_refer_raw.rename(columns = {'storyline':'source_storyline'}, inplace = True)

df_refer_raw = pd.merge(df_refer_raw, df_annot[['num','storyline']], left_on='refer_num', right_on='num', how='left')
df_refer_raw.rename(columns = {'storyline':'refer_storyline'}, inplace = True)

# add type c
df_refer_raw = pd.merge(df_refer_raw, df_annot[['num','social/non-social']], left_on='refer_num', right_on='num', how='left')

# df_refer_raw.to_csv("refer_check.csv")

df_refer_raw_bounded = df_refer_raw.query("is_included==1")
df_refer_raw_bounded_ori = df_refer_raw_bounded.query("amended != 1")
df_refer_raw_bounded_ori.groupby(['direction'])['seg_1'].count()

direction
b    46
f     7
Name: seg_1, dtype: int64

## feature analysis

### social

In [43]:
pd.crosstab(df_refer_raw_bounded_ori['direction'], df_refer_raw_bounded_ori['social/non-social'])

social/non-social,n,s
direction,Unnamed: 1_level_1,Unnamed: 2_level_1
b,17,29
f,0,7


In [44]:
from scipy.stats import chi2_contingency
obs = np.array([[30,22], [15,15]])
chi2_contingency(obs)

(0.1970285670285673,
 0.6571300677827455,
 1,
 array([[28.53658537, 23.46341463],
        [16.46341463, 13.53658537]]))

In [45]:
obs = np.array([[17,29], [0,7]])
chi2_contingency(obs)

(2.3011876852190154,
 0.1292751153315131,
 1,
 array([[14.75471698, 31.24528302],
        [ 2.24528302,  4.75471698]]))

In [46]:
## refer step plot

df_refer_dist = pd.DataFrame(df_refer_raw_bounded.query("amended != 1")['step'].value_counts().rename_axis('step').reset_index(name='counts'))
df_refer_dist['condition'] = np.where(df_refer_dist['step']>0, 'p','r')
set_color(df_refer_dist, 'step')

## figure S7B

In [47]:
w=635; h=180

p = alt.Chart(df_refer_dist).mark_bar(size=10).encode(
    alt.X('step:Q', axis=alt.Axis(title='Lag', labelAngle=0, grid=False, tickMinStep=1), scale=alt.Scale(domain=(-12, 12))),
    alt.Y('counts:Q', axis=alt.Axis(title='Count'), 
          scale=alt.Scale(domain=(0, 20))
        ), 
    alt.Color('color', scale=None),
).properties(width=w, height=h, 
# title={
#       "text": [""], 
      # "subtitle": ["Reference rate"]}
).configure_axis(
    titleFontSize=14,
    labelFontSize=14,
    titleFontWeight='normal'
)
p

In [54]:
def get_seg_sl_index(segnum, on_seg_sl):
    # get relative index of any segnum within a storyline
    
    if segnum in on_seg_sl:
        return on_seg_sl.index(segnum)
    try: 
        left_seg_num = np.array(on_seg_sl)[np.array(on_seg_sl) < segnum][-1]
        return on_seg_sl.index(left_seg_num) + 0.5
    except:
        return -0.5


In [55]:
# add 'refer_type', 'refer_boost_type', 'refer_boost_type_tri', 'refer_boost_type_sl to seg_sc_pr

def add_refer_type(row, df=df_refer_raw_bounded):
    
    base_seg_num = row['base_seg_num']
    segment_num = row['segment_num']
    ref_seg_num = row['ref_seg_num']
    ref_sc_num = row['ref_sc_num']

    # storyline
    storyline = df_annot.query("num==@ref_sc_num")['storyline'].to_list()[0]
    on_seg_sl = list(df_annot.query("storyline==@storyline and on_off=='on'")['seg_num'].unique())
    on_seg_sl = [int(s) for s in on_seg_sl]

#     print(row.name, row['condition'], base_seg_num, ref_seg_num)
    
    # main or gap
    if ref_seg_num % 1 == 0:
        row['main_or_gap'] = 'main'
    elif ref_seg_num % 1 == 0.5:
        row['main_or_gap'] = 'gap'
    
    # get refer segments
    if row['condition'] in ['r0','r']:
        refer_segments = df.\
            query(f"story=={row['story']} & seg_1>={base_seg_num} & seg_2<{base_seg_num}")
        refer_segments_sl = df.\
            query(f"refer_storyline=={storyline} & seg_1>={base_seg_num} & seg_2<{base_seg_num}")
        refer_segments_base_only = df.\
            query(f"story=={row['story']} & seg_1=={base_seg_num} & seg_2<{base_seg_num}")  # only from base segment
        referring_segments = df.\
            query(f"story=={row['story']} & main_or_gap=='main' & seg_2>={base_seg_num} & seg_1<{base_seg_num}")  # swap seg_1 and seg_2, referenced main events only
        
    elif row['condition'] in ['p0','p']:
        refer_segments = df.\
            query(f"story=={row['story']} & seg_1<={base_seg_num} & seg_2>{base_seg_num}")
        refer_segments_sl = df.\
            query(f"refer_storyline=={storyline} & seg_1<={base_seg_num} & seg_2>{base_seg_num}")
        refer_segments_base_only = df.\
            query(f"story=={row['story']} & seg_1=={base_seg_num} & seg_2>{base_seg_num}")
        referring_segments = df.\
            query(f"story=={row['story']} & main_or_gap=='main' & seg_2<={base_seg_num} & seg_1>{base_seg_num}")

    refer_list = refer_segments['refer_num'].to_numpy()
    refer_list_base_only = refer_segments_base_only['refer_num'].to_numpy()   
    referring_list = referring_segments['source_num'].to_numpy()

    # storyline
    refer_list_sl = refer_segments_sl['refer_num'].to_list()
    refer_seg_list_sl = [num2seg(sc) for sc in refer_list_sl]

    refer_sc_sl_indexes = [get_seg_sl_index(seg, on_seg_sl) for seg in refer_seg_list_sl]
    ref_sc_sl_index = get_seg_sl_index(ref_seg_num, on_seg_sl)
    if len(refer_sc_sl_indexes) == 0:
        refer_boost_step_sl = 99
    else:
        refer_boost_step_sl = abs(np.array(refer_sc_sl_indexes) - np.array(ref_sc_sl_index)).min()

    # refer_type
    if ref_sc_num in refer_list:
        row['refer_type'] = 'referred'
    else:
        row['refer_type'] = 'none'
    
    # refer_type_base_only
    if ref_sc_num in refer_list_base_only:
        row['refer_type_base_only'] = 'referred'
    else:
        row['refer_type_base_only'] = 'none'

    # referring_type
    if ref_sc_num in referring_list:
        row['referring_type'] = 'referring' 
    else:
        row['referring_type'] = 'none'

    # refer_referring_type (referred has priority)
    if ref_sc_num in refer_list:
        row['rr_type'] = 'referred' 
    elif ref_sc_num in referring_list:
        row['rr_type'] = 'referring'
    else:
        row['rr_type'] = 'none'

    if ref_sc_num in referring_list:
        row['rrr_type'] = 'referring' 
    elif ref_sc_num in refer_list:
        row['rrr_type'] = 'referred'
    else:
        row['rrr_type'] = 'none'
    
 
    row['refer_list'] = refer_list
    row['referring_list'] = referring_list
    
    refer_seg_list = refer_segments['seg_2'].to_numpy() # both d and i
    
    refer_seg_list_left = refer_seg_list[refer_seg_list<=ref_seg_num]
    refer_seg_list_right = refer_seg_list[refer_seg_list>=ref_seg_num]
    
    # refer_boost_step
    if len(refer_seg_list_left) == 0:
        row['refer_boost_step_left'] = 99
    else:
        row['refer_boost_step_left'] = ref_seg_num - refer_seg_list_left.max()
    
    if len(refer_seg_list_right) == 0:
        row['refer_boost_step_right'] = 99
    else:
        row['refer_boost_step_right'] = refer_seg_list_right.min() - ref_seg_num 
        
    row['refer_boost_step'] = np.nanmin([row['refer_boost_step_left'], row['refer_boost_step_right']])

    row['refer_boost_step_sl'] = refer_boost_step_sl
    
    
    # refer_boost_type 
    if row['main_or_gap'] == 'main':
        if ref_sc_num in refer_list:
            row['refer_boost_type'] = 'referred'
            row['refer_boost_type_tri'] = 'referred'
        elif row['refer_boost_step'] == 0:
            row['refer_boost_type'] = 'same_seg'
            row['refer_boost_type_tri'] = 'neighbor'
        elif row['refer_boost_step'] <= 1:
            row['refer_boost_type'] = 'neighbor'
            row['refer_boost_type_tri'] = 'neighbor'
        else:
            row['refer_boost_type'] = 'none'      
            row['refer_boost_type_tri'] = 'none'

        # storyline
        if ref_sc_num in refer_list_sl:
            row['refer_boost_type_sl'] = 'referred'
        elif row['refer_boost_step_sl'] <= 1:
            row['refer_boost_type_sl'] = 'neighbor'
        else:
            row['refer_boost_type_sl'] = 'none'      
    
    if row['main_or_gap'] == 'gap':
        if ref_sc_num in refer_list:
            row['refer_boost_type'] = 'referred'
            row['refer_boost_type_tri'] = 'referred'
        elif row['refer_boost_step'] == 0:
            row['refer_boost_type'] = 'same_seg'
            row['refer_boost_type_tri'] = 'neighbor'
        elif row['refer_boost_step'] <= 0.5:
            row['refer_boost_type'] = 'neighbor'
            row['refer_boost_type_tri'] = 'neighbor'
        else:
            row['refer_boost_type'] = 'none' 
            row['refer_boost_type_tri'] = 'none'

        # storyline
        if ref_sc_num in refer_list_sl:
            row['refer_boost_type_sl'] = 'referred'
        elif row['refer_boost_step_sl'] <= 0.5:
            row['refer_boost_type_sl'] = 'neighbor'
        else:
            row['refer_boost_type_sl'] = 'none'   
    
    # correct NA refer type
    if (row['condition'] in ['p0','p','p_r'] and ref_seg_num <= base_seg_num) or \
       (row['condition'] in ['r0','r','r_p'] and ref_seg_num >= base_seg_num) or \
       (row['condition'] in ['p_recall','r_recall'] and ref_seg_num == segment_num):

        row['refer_type'] = 'NA'
        row['refer_type_base_only'] = 'NA'
        row['referring_type'] = 'NA'
        row['rr_type'] = 'NA'
        row['refer_boost_type'] = 'NA'
        row['refer_boost_type_tri'] = 'NA'
        row['refer_boost_type_sl'] = 'NA'

    return row

In [56]:
seg_sc_pr = seg_sc_pr.apply(add_refer_type, axis=1)

In [57]:
# add more fields to seg_sc_pr
seg_sc_pr['refer_type_merge'] = seg_sc_pr['refer_type'].\
    map({'referred':'referred','none':'none','NA':'NA'})

seg_sc_pr['refer_type_point_merge'] = seg_sc_pr['refer_type'].\
    map({'referred':1, 'none':0})

seg_sc_pr['refer_type_base_only_point'] = seg_sc_pr['refer_type_base_only'].\
    map({'referred':1, 'none':0})

seg_sc_pr['referring_type_point'] = seg_sc_pr['referring_type'].\
    map({'referring':1, 'none':0})

# full point or half point scene
# seg_sc_pr['sc_point'] = np.where(seg_sc_pr['ref_sc_num']//10%10%5 == 0, 1, 0.5)  # incorrect
seg_sc_pr = pd.merge(seg_sc_pr, df_annot[['partial','num']], left_on='ref_sc_num', right_on=['num'], how='left')
seg_sc_pr['sc_point'] = np.where(seg_sc_pr['partial']=='ori', 1, 0.5)

# refer_point_c_merge: 0 0.5 1
seg_sc_pr['refer_point_c_merge'] = seg_sc_pr['refer_type_point_merge']*seg_sc_pr['sc_point']

# refer_point_base_only_c
seg_sc_pr['refer_point_base_only_c'] = seg_sc_pr['refer_type_base_only_point']*seg_sc_pr['sc_point']

# referring_point_c
seg_sc_pr['referring_point_c'] = seg_sc_pr['referring_type_point']*seg_sc_pr['sc_point']


seg_sc_pr['ref_step_base_abs'] = abs(seg_sc_pr['ref_step_base'])

# set color
set_color(seg_sc_pr, 'ref_step_base')

# seg_sc_pr_bounded
bounded_segments = df_annot.query("on_off==['on','off']")['num'].to_list()
seg_sc_pr_bounded = seg_sc_pr.query("ref_sc_num==@bounded_segments")
# seg_sc_pr_bounded.to_csv("../data/processed/seg_sc_pr_bounded.csv", index=False)
seg_sc_pr_bounded.shape

(3790, 44)

### melt df and merge cols from seg_sc_pr for single trial data

In [58]:
# df_pr <- df
# df_seg_sc_pr <- seg_sc_pr, df_pr (single sub data)

df_pr = df.query(f"condition in {cond_pr}")

df_seg_sc_pr = pd.melt(df_pr, id_vars=['sub','story','segment_num','segment','segment_count',
                        'cond_direction','cond_amount','condition','base_seg_num','base_segment','res_corrected'], 
        value_vars=df.loc[:,'51':'1354'], var_name='ref_sc_num', 
                        value_name='ref_sc_hit')

# correct mismatched story-sc pairs
df_seg_sc_pr['ref_sc_num'] = df_seg_sc_pr['ref_sc_num'].astype('int')
story3_sc_nums = df_annot.query("story==3")['num'].to_list()
df_seg_sc_pr = df_seg_sc_pr.query(f"story==3 and ref_sc_num=={story3_sc_nums}")

df_seg_sc_pr['ref_seg_num'] = df_seg_sc_pr['ref_sc_num'].apply(num2seg)
df_seg_sc_pr['ref_step'] = df_seg_sc_pr['ref_seg_num'] - df_seg_sc_pr['segment_num']
df_seg_sc_pr['ref_step_base'] = df_seg_sc_pr['ref_seg_num'] - df_seg_sc_pr['base_seg_num']
df_seg_sc_pr['ref_sc_num_full'] = df_seg_sc_pr['story'].astype(str)+'_'+df_seg_sc_pr['ref_sc_num'].astype(str)

# discard wrong direction ref in guesses, preserve all in recall
df_seg_sc_pr = df_seg_sc_pr.query(
    "(cond_amount!='p_recall/r_recall' & (cond_direction=='f' & ref_seg_num>base_seg_num | cond_direction=='b' & ref_seg_num<base_seg_num)) | cond_amount=='p_recall/r_recall'").\
    reset_index(drop=True)
# past
# df_seg_sc_pr = df_seg_sc_pr.query(
#     "cond_direction=='f' & ref_seg_num>base_seg_num | cond_direction=='b' & ref_seg_num<base_seg_num").\
#     reset_index(drop=True)

# merge cols from seg_sc_pr
cols_for_merge = ['condition','segment','ref_sc_num','ref_seg_type', 'main_or_gap', 
       'refer_type','refer_type_base_only','referring_type','rr_type','refer_list','referring_list','refer_boost_step_left', 'refer_boost_step_right',
       'refer_boost_step', 'refer_boost_type','refer_boost_type_tri','refer_boost_type_sl', 'refer_type_merge',
       'refer_type_point_merge',
       'sc_point', 'refer_point_c_merge', 'ref_step_base_abs','ref_step_base_sl',
       'color']

df_seg_sc_pr = pd.merge(df_seg_sc_pr, seg_sc_pr[cols_for_merge],
                on=['condition','segment','ref_sc_num'], how='left')

df_seg_sc_pr = pd.merge(df_seg_sc_pr, df_annot[['story','num']],  left_on=['story','ref_sc_num'],
                right_on=['story','num'], how='left')

df_seg_sc_pr['hit_rate'] =  df_seg_sc_pr['ref_sc_hit'] / df_seg_sc_pr['sc_point'] 

# df_seg_sc_pr_target = df_seg_sc_pr.query("ref_seg_type=='target'")
# df_seg_sc_pr_target.to_csv("../data/R/df_seg_sc_pr_target.csv", index=False)
# df_seg_sc_pr_target.shape

In [59]:
df_seg_sc_p0r0 = df_seg_sc_pr.loc[df_seg_sc_pr['condition'].isin(['p0','r0'])]
bounded_segments = df_annot.query("on_off==['on','off']")['num'].to_list()
df_seg_sc_p0r0_bounded = df_seg_sc_p0r0.query("ref_sc_num==@bounded_segments")
df_seg_sc_p0r0_bounded.shape

(32196, 40)

In [68]:
df_seg_sc_p0r0_bounded.to_csv("../data/rep/df_seg_sc_p0r0_bounded.csv", index_label='no')

### for R

#### feature type

In [60]:
df_seg_sc_p0r0_bounded = pd.merge(df_seg_sc_p0r0_bounded, df_annot[['num','social/non-social']], left_on='ref_sc_num', right_on='num', how='left')

In [61]:
# df_seg_seg_social_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_social_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_seg_num','base_segment','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','social/non-social']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_social_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_social_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_social_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_social_p0r0_for_R['segment_count_demean'] = df_seg_seg_social_p0r0_for_R['segment_count']-6.5
df_seg_seg_social_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_social_p0r0_for_R['ref_step_base_abs']-5.625

df_seg_seg_social_p0r0_main_for_R = df_seg_seg_social_p0r0_for_R.query("main_or_gap=='main'")

df_seg_seg_social_p0r0_main_for_R.to_csv('../data/rep/R/df_seg_seg_social_p0r0_main_for_R.csv', index=False)
df_seg_seg_social_p0r0_main_for_R.shape

(3768, 19)

#### refer type

In [62]:
# df_seg_seg_type_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_type_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_seg_num','base_segment','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','refer_type_merge']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_type_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_type_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_type_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_type_p0r0_for_R['segment_count_demean'] = df_seg_seg_type_p0r0_for_R['segment_count']-6.5
df_seg_seg_type_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_type_p0r0_for_R['ref_step_base_abs']-5.625

df_seg_seg_type_p0r0_for_R.to_csv('../data/rep/R/df_seg_seg_type_p0r0_for_R.csv', index=False)

In [63]:
# df_seg_seg_rr_type_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_rr_type_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_seg_num','base_segment','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','rr_type']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_rr_type_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_rr_type_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_rr_type_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_rr_type_p0r0_for_R['segment_count_demean'] = df_seg_seg_rr_type_p0r0_for_R['segment_count']-6.5
df_seg_seg_rr_type_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_rr_type_p0r0_for_R['ref_step_base_abs']-5.625

df_seg_seg_rr_type_p0r0_main_for_R = df_seg_seg_rr_type_p0r0_for_R.query("main_or_gap=='main'")

df_seg_seg_rr_type_p0r0_main_for_R.to_csv('../data/rep/R/df_seg_seg_rr_type_p0r0_main_for_R.csv', index=False)
df_seg_seg_rr_type_p0r0_main_for_R.shape

(4073, 19)

In [64]:
# df_seg_seg_boost_type_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_boost_type_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_segment','base_seg_num','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type','refer_boost_type_tri']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_boost_type_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_boost_type_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_boost_type_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_boost_type_p0r0_for_R['segment_count_demean'] = df_seg_seg_boost_type_p0r0_for_R['segment_count']-6.5
df_seg_seg_boost_type_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_boost_type_p0r0_for_R['ref_step_base_abs']-5.625

df_seg_seg_boost_type_p0r0_for_R.to_csv('../data/rep/R/df_seg_seg_boost_type_p0r0_for_R.csv', index=False)

In [65]:
# df_seg_seg_boost_type_sl_p0r0_for_R <- df_seg_sc_p0r0_bounded
df_seg_seg_boost_type_sl_p0r0_for_R = df_seg_sc_p0r0_bounded.groupby(['sub','condition','story','segment',
            'segment_count','base_segment','base_seg_num','ref_seg_num','ref_step_base',
            'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_sl']).agg(
                total_count=('condition', 'count'),
                total_points=('sc_point', 'sum'),
                total_hits=('ref_sc_hit', 'sum')
                ).reset_index()

df_seg_seg_boost_type_sl_p0r0_for_R['base_refer_seg_pair'] = df_seg_seg_boost_type_sl_p0r0_for_R['base_segment'] + \
    '_' + df_seg_seg_boost_type_sl_p0r0_for_R['ref_seg_num'].astype('str')

df_seg_seg_boost_type_sl_p0r0_for_R['segment_count_demean'] = df_seg_seg_boost_type_sl_p0r0_for_R['segment_count']-6.5
df_seg_seg_boost_type_sl_p0r0_for_R['ref_step_base_abs_demean'] = df_seg_seg_boost_type_sl_p0r0_for_R['ref_step_base_abs']-4.625

df_seg_seg_boost_type_sl_p0r0_for_R.to_csv('../data/rep/R/df_seg_seg_boost_type_sl_p0r0_for_R.csv', index=False)

### seg_seg 

In [66]:
# seg_seg_boost_type_sl_pr
seg_seg_boost_type_sl_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_sl','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
seg_seg_boost_type_sl_pr['sc_hit_rate_c'] = seg_seg_boost_type_sl_pr['total_hits'] / seg_seg_boost_type_sl_pr['total_points']

In [67]:
# seg_seg_pr <- seg_sc_pr_bounded (average scenes in same segments in same trials)
seg_seg_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
                            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    # refer_points=('refer_point_c', 'sum'),
    refer_points_merge=('refer_point_c_merge', 'sum'),
    refer_points_base_only=('refer_point_base_only_c', 'sum'),
    referring_points=('referring_point_c', 'sum'),
    ).reset_index()

# seg_seg_type_pr
seg_seg_type_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','refer_type_merge','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
seg_seg_type_pr['sc_hit_rate_c'] = seg_seg_type_pr['total_hits'] / seg_seg_type_pr['total_points']

# seg_seg_boost_type_pr
seg_seg_boost_type_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_tri','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
seg_seg_boost_type_pr['sc_hit_rate_c'] = seg_seg_boost_type_pr['total_hits'] / seg_seg_boost_type_pr['total_points']

# seg_seg_boost_type_sl_pr
seg_seg_boost_type_sl_pr = seg_sc_pr_bounded.groupby(['cond_direction','condition','story','segment','base_seg_num','ref_seg_num',
            'ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_sl','color']).agg(
    total_count=('condition', 'count'),
    total_points=('sc_point', 'sum'),
    total_hits=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
seg_seg_boost_type_sl_pr['sc_hit_rate_c'] = seg_seg_boost_type_sl_pr['total_hits'] / seg_seg_boost_type_sl_pr['total_points']


def add_addi_fields(row, df1=seg_seg_type_pr, df2=seg_seg_boost_type_pr, df3=seg_seg_boost_type_sl_pr , 
                                    boost_type_col1='refer_boost_type_tri', boost_type_col2='refer_boost_type_sl'):
    none_sc_hits = df1.loc[(df1['condition']==row['condition']) & 
                           (df1['segment']==row['segment']) &
                           (df1['ref_step_base']==row['ref_step_base']) &
                           (df1['refer_type_merge']=='none'), 'total_hits'].values
    
    if len(none_sc_hits)>0:
        row['none_hits'] = none_sc_hits[0]
    else:
        row['none_hits'] = np.nan
        
#     row['refer_bonus_hit_rate'] = row['sc_hit_rate_c'] - row['none_hit_rate_c']
#     row['refer_bonus_hit_rate_c'] = row['refer_bonus_hit_rate'] / (1-row['none_hit_rate'])
#     row['direct_refer_pure_hit_rate'] = (row['direct_hit_rate'] - row['none_hit_rate']) / (1-row['none_hit_rate'])

    remain_sc_hit_rate = df2.loc[(df2['condition']==row['condition']) & 
                                 (df2['segment']==row['segment']) &
                                 (df2['ref_step_base']==row['ref_step_base']) &
                                 (df2[boost_type_col1]=='none'), 'sc_hit_rate_c'].values

    remain_sc_hit_rate_sl = df3.loc[(df3['condition']==row['condition']) & 
                                 (df3['segment']==row['segment']) &
                                 (df3['ref_step_base']==row['ref_step_base']) &
                                 (df3[boost_type_col2]=='none'), 'sc_hit_rate_c'].values
    
    if len(remain_sc_hit_rate)>0:
        row['remain_hit_rate'] = remain_sc_hit_rate[0]
    else:
        row['remain_hit_rate'] = np.nan

    if len(remain_sc_hit_rate_sl)>0:
        row['remain_hit_rate_sl'] = remain_sc_hit_rate_sl[0]
    else:
        row['remain_hit_rate_sl'] = np.nan
        
    # for neighboring rate    
    neighbor_points = df2.loc[(df2['condition']==row['condition']) & 
                                 (df2['segment']==row['segment']) &
                                 (df2['ref_step_base']==row['ref_step_base']) &
                                 (df2[boost_type_col1]=='neighbor'), 'total_points'].values

    neighbor_points_sl = df3.loc[(df3['condition']==row['condition']) & 
                                 (df3['segment']==row['segment']) &
                                 (df3['ref_step_base']==row['ref_step_base']) &
                                 (df3[boost_type_col2]=='neighbor'), 'total_points'].values

    if len(neighbor_points)>0:
        row['neighbor_points'] = neighbor_points[0]
    else:
        row['neighbor_points'] = 0
    
    if len(neighbor_points_sl)>0:
        row['neighbor_points_sl'] = neighbor_points_sl[0]
    else:
        row['neighbor_points_sl'] = 0
        
    return row

# ori
seg_seg_pr = seg_seg_pr.apply(add_addi_fields, axis=1)

seg_seg_pr['refer_rate'] = seg_seg_pr['refer_points_merge'] / seg_seg_pr['total_points']
seg_seg_pr['referring_rate'] = seg_seg_pr['referring_points'] / seg_seg_pr['total_points']
seg_seg_pr['refer_base_hit_rate'] = seg_seg_pr['total_hits'] / seg_seg_pr['refer_points_base_only']
seg_seg_pr['none_points'] = seg_seg_pr['total_points'] - seg_seg_pr['refer_points_merge']
seg_seg_pr['neighbor_rate'] = seg_seg_pr['neighbor_points'] / seg_seg_pr['none_points']
seg_seg_pr['neighbor_rate_sl'] = seg_seg_pr['neighbor_points_sl'] / seg_seg_pr['none_points']
seg_seg_pr.loc[seg_seg_pr['ref_step_base']==0, 'neighbor_rate'] = 0
seg_seg_pr['hit_rate_c'] = seg_seg_pr['total_hits'] / seg_seg_pr['total_points']
seg_seg_pr['none_hit_rate'] = seg_seg_pr['none_hits'] / (seg_seg_pr['total_points'] - seg_seg_pr['refer_points_merge'])  # same from seg_seg_type_pr 

# seg_seg_step_pr (average over step)
seg_seg_step_pr = seg_seg_pr.groupby(['cond_direction','condition','ref_step_base','ref_step_base_abs','ref_seg_type','main_or_gap','color']).agg(
    remain_hit_rate=('remain_hit_rate', 'mean'),
    remain_hit_rate_sl=('remain_hit_rate_sl', 'mean'),
    none_hit_rate=('none_hit_rate', 'mean'),
    hit_rate_c=('hit_rate_c', 'mean'),
    none_points=('none_points', 'mean'),
    neighbor_points=('neighbor_points', 'mean'),
    neighbor_points_sl=('neighbor_points_sl', 'mean'),
    ).reset_index()

seg_seg_step_pr['neighbor_rate'] = seg_seg_step_pr['neighbor_points'] / seg_seg_step_pr['none_points']
seg_seg_step_pr['neighbor_rate_sl'] = seg_seg_step_pr['neighbor_points_sl'] / seg_seg_step_pr['none_points']
seg_seg_step_pr['refer_bonus_hit_rate'] = seg_seg_step_pr['hit_rate_c'] - seg_seg_step_pr['none_hit_rate']
seg_seg_step_pr['neighbor_bonus_hit_rate'] = seg_seg_step_pr['none_hit_rate'] - seg_seg_step_pr['remain_hit_rate']
seg_seg_step_pr['neighbor_bonus_hit_rate_sl'] = seg_seg_step_pr['none_hit_rate'] - seg_seg_step_pr['remain_hit_rate_sl']
seg_seg_step_pr['refer_combo_bonus_hit_rate'] = seg_seg_step_pr['hit_rate_c'] - seg_seg_step_pr['remain_hit_rate']
seg_seg_step_pr['refer_combo_bonus_hit_rate_sl'] = seg_seg_step_pr['hit_rate_c'] - seg_seg_step_pr['remain_hit_rate_sl']
seg_seg_step_pr.loc[seg_seg_step_pr['ref_step_base']==0, 'neighbor_rate'] = 0
seg_seg_step_pr.loc[seg_seg_step_pr['ref_step_base']==0, 'neighbor_rate_sl'] = 0
seg_seg_pr['points'] = seg_seg_pr['refer_points_merge']  # for plotting points in step 0
seg_seg_pr.loc[seg_seg_pr['ref_step_base']==0, 'points'] = seg_seg_pr.loc[seg_seg_pr['ref_step_base']==0, 'total_points']

## figure 11A

In [68]:
bar_size = 7

In [71]:
# referring rate for main segments (seg_seg_pr: mean(refer_rate))

w=400 ; h=180

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0'] and main_or_gap=='main'")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0))
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Referring rate"]
    })  # cannnot be interative

bar = base.mark_bar(size=12).encode(
    alt.Y('mean(referring_rate):Q', axis=alt.Axis(title='Proportion'), 
          scale=alt.Scale(domain=(0, 0.6))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(refer_rate)',
    y2='ci1(refer_rate)',
)

ref_rate_plot = (bar
).configure_axis(
    titleFontSize=12,
    labelFontSize=12,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
)
ref_rate_plot

## figure S8A

In [72]:
# 1-1 refer rate (seg_seg_pr: mean(refer_rate))

w=400 ; h=180
# w=600 ; h=180

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13))))
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Reference rate"]
    })  # cannnot be interative

bar = base.mark_bar(size=bar_size).encode(
    alt.Y('mean(refer_rate):Q', axis=alt.Axis(title='Proportion'), 
          scale=alt.Scale(domain=(0, 0.7))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(refer_rate)',
    y2='ci1(refer_rate)',
)

ref_rate_plot = (bar)
ref_rate_plot

## figure S8B

In [73]:
seg_seg_step_pr = seg_seg_step_pr.fillna(0)

In [74]:
# 1-2 refer bonus hit rate (seg_seg_step_pr: refer_bonus_hit_rate)

refer_bonus_hit_rate_plot = alt.Chart(seg_seg_step_pr.query("condition==['r0','p0']")).mark_bar(size=bar_size).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
    alt.Y('refer_bonus_hit_rate:Q', axis=alt.Axis(title=['Hit rate']),
          scale=alt.Scale(domain=(-0.01, 0.06))
         ),  
    alt.Color('color', scale=None),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Hit rate difference:", "all events - unreferenced events"]
    }).interactive()

refer_bonus_hit_rate_plot

## figure S8C

In [75]:
seg_seg_pr = seg_seg_pr.fillna(0)

In [76]:
# 1-3 none hit rate (seg_seg_pr: mean(none_hit_rate))

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Unreferenced events hit rate"]
    })

bar = base.mark_bar(size=bar_size).encode(
    alt.Y('mean(none_hit_rate):Q', axis=alt.Axis(title='Hit rate'), 
        #   scale=alt.Scale(domain=(0, 0.15))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(none_hit_rate)',
    y2='ci1(none_hit_rate)',
)

none_hit_rate_plot = (bar + error_bar)
none_hit_rate_plot

## figure S9A

In [77]:
# 2-1 neighbor rate (seg_seg_pr: mean(neighbor_rate))

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Adjacent-reference rate"]
    })

bar = base.mark_bar(size=bar_size).encode(
    alt.Y('mean(neighbor_rate):Q', axis=alt.Axis(title='Proportion'), 
          scale=alt.Scale(domain=(0, 1))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
#     y='ci0(neighbor_rate)',
    y2='ci1(neighbor_rate)',
)

neighbor_rate_plot = (bar)
neighbor_rate_plot

## figure S10B

In [78]:
# 2-1* neighbor rate sl (seg_seg_pr: mean(neighbor_rate_sl))

base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Adjacent-reference (within storyline) rate"]
    })

bar = base.mark_bar(size=bar_size).encode(
    alt.Y('mean(neighbor_rate_sl):Q', axis=alt.Axis(title='Proportion'), 
          scale=alt.Scale(domain=(0, 1))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
#     y='ci0(neighbor_rate)',
    y2='ci1(neighbor_rate_sl)',
)

neighbor_rate_sl_plot = (bar)
neighbor_rate_sl_plot

## figure S9B

In [79]:
# 2-2 neighbor bonus hit rate (seg_seg_step_pr: neighbor_bonus_hit_rate)

neighbor_bonus_hit_rate_plot = alt.Chart(seg_seg_step_pr.query("condition==['r0','p0']")).mark_bar(size=bar_size).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
    alt.Y('neighbor_bonus_hit_rate:Q', axis=alt.Axis(title='Hit rate'),
          scale=alt.Scale(domain=(-0.10, 0.10))
         ),  
    alt.Color('color', scale=None),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Hit rate difference:", "Unreferenced events - remaining events"]
    })

neighbor_bonus_hit_rate_plot

## figure S10C

In [80]:
# 2-2* neighbor bonus hit rate sl (seg_seg_step_pr: neighbor_bonus_hit_rate sl)

neighbor_bonus_hit_rate_sl_plot = alt.Chart(seg_seg_step_pr.query("condition==['r0','p0']")).mark_bar(size=bar_size).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
    alt.Y('neighbor_bonus_hit_rate_sl:Q', axis=alt.Axis(title='Hit rate'),
          scale=alt.Scale(domain=(-0.13, 0.13))
         ),  
    alt.Color('color', scale=None),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Hit rate difference:", "Unreferenced events - remaining events"]
    })

neighbor_bonus_hit_rate_sl_plot

## figure S9C

In [81]:
# 1-3 none hit rate (seg_seg_pr: mean(remain_hit_rate))
 
base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Remaining events hit rate"]
    })

bar = base.mark_bar(size=bar_size).encode(
    alt.Y('mean(remain_hit_rate):Q', axis=alt.Axis(title='Hit rate'), 
          scale=alt.Scale(domain=(0, 0.1))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(remain_hit_rate)',
    y2='ci1(remain_hit_rate)',
)

remain_hit_rate_plot = (bar + error_bar)
remain_hit_rate_plot

## figure S10D

In [82]:
# 1-3* none hit rate sl (seg_seg_pr: mean(remain_hit_rate_sl))
 
base = alt.Chart(seg_seg_pr.query(f"condition==['p0','r0']")).encode(
    alt.X('ref_step_base:O', axis=alt.Axis(title='Lag', labelAngle=0, values=list(range(-12, 13)))),
).properties(width=w, height=h, title={
      "text": [""], 
      "subtitle": ["Remaining events hit rate"]
    })

bar = base.mark_bar(size=bar_size).encode(
    alt.Y('mean(remain_hit_rate_sl):Q', axis=alt.Axis(title='Hit rate'), 
          scale=alt.Scale(domain=(0, 0.1))
         ), 
    alt.Color('color', scale=None),
)

error_bar = base.mark_rule(size=1.6, color='black', opacity=0.4).encode(
    y='ci0(remain_hit_rate_sl)',
    y2='ci1(remain_hit_rate_sl)',
)

remain_hit_rate_sl_plot = (bar + error_bar)
remain_hit_rate_sl_plot

In [83]:
hit_rate_concat_plot = (ref_rate_plot & refer_bonus_hit_rate_plot & none_hit_rate_plot
).configure_axis(
    titleFontSize=16,
    labelFontSize=12,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
).configure_concat(
    spacing=23
)

# hit_rate_concat_plot.save('../figs/core/hit_rate_concat_plot.svg')
hit_rate_concat_plot

In [84]:
hit_rate_concat_plot = (neighbor_rate_plot & neighbor_bonus_hit_rate_plot & remain_hit_rate_plot
).configure_axis(
    titleFontSize=16,
    labelFontSize=12,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
).configure_concat(
    spacing=23
)

# hit_rate_concat_plot.save('../figs/core/hit_rate_concat_plot.svg')
hit_rate_concat_plot

In [85]:
hit_rate_concat_sl_plot = (neighbor_rate_sl_plot & neighbor_bonus_hit_rate_sl_plot & remain_hit_rate_sl_plot
).configure_axis(
    titleFontSize=16,
    labelFontSize=12,
    titleFontWeight='normal'
).configure_title(
    subtitleFontSize=18,
    anchor='middle',
    offset=0
).configure_concat(
    spacing=23
)

# hit_rate_concat_plot.save('../figs/core/hit_rate_concat_plot.svg')
hit_rate_concat_sl_plot

## step

In [86]:
# average across subjects
seg_sc_p0r0 = df_seg_sc_p0r0_bounded.groupby(['cond_direction','cond_amount','condition','story','segment','segment_count','segment_num','base_seg_num','refer_type_base_only','ref_sc_num','ref_step_base','ref_step_base_abs',
'ref_seg_type','sc_point'], dropna=False).agg(
                ref_sc_hit=('ref_sc_hit', 'mean'),
                ).reset_index()



In [87]:
# step <- seg_sc_p0r0 
step = seg_sc_p0r0.query("cond_amount=='p_recall/r_recall'").groupby(['condition','ref_step_base','ref_step_base_abs','refer_type_base_only']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()

step['sc_hit_rate_c'] = step['total_points'] / step['points_count']

step.head(10)

Unnamed: 0,condition,ref_step_base,ref_step_base_abs,refer_type_base_only,total_count,points_count,total_points,sc_hit_rate_c


In [88]:
# step_type_pr step_type3_pr step_boost_type_pr <- seg_sc_pr_bounded

# step_type_pr
step_type_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_type_merge']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()
# corrected hit rate
step_type_pr['sc_hit_rate_c'] = step_type_pr['total_points']/step_type_pr['points_count']
# add missing step
step_type_pr = step_type_pr.append({"cond_direction":"f", "condition":"p0", "ref_step_base":11.5, "ref_step_base_abs":11.5, "ref_seg_type":"ahead_gap", 
    "main_or_gap":"gap", "refer_type_merge":"none", "total_count":0, "points_count":0, "total_points":0, "sc_hit_rate_c":0}, ignore_index=True)

# step_rr_type_pr
step_rr_type_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','rr_type']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()
step_rr_type_pr['sc_hit_rate_c'] = step_rr_type_pr['total_points']/step_rr_type_pr['points_count']
step_rr_type_pr['rr_type'] = step_rr_type_pr['rr_type'].astype("category")
step_rr_type_pr['rr_type'] = step_rr_type_pr['rr_type'].cat.set_categories(['referred','referring','none'])

# step_boost_type_pr
step_boost_type_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()

step_boost_type_pr['sc_hit_rate_c'] = step_boost_type_pr['total_points']/step_boost_type_pr['points_count']
step_boost_type_pr['refer_boost_type'] = step_boost_type_pr['refer_boost_type'].astype('category')
step_boost_type_pr['refer_boost_type'].cat.reorder_categories(['referred','same_seg','neighbor','none'], inplace=True)
step_boost_type_pr = step_boost_type_pr.sort_values(by=['condition', 'ref_step_base_abs', 'refer_boost_type'])

# step_boost_type_tri_pr
step_boost_type_tri_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_tri']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()

step_boost_type_tri_pr['sc_hit_rate_c'] = step_boost_type_tri_pr['total_points']/step_boost_type_tri_pr['points_count']
step_boost_type_tri_pr = step_boost_type_tri_pr.append({"cond_direction":"f", "condition":"p0", "ref_step_base":11.5, "ref_step_base_abs":11.5,       
  "ref_seg_type":"ahead_gap", "main_or_gap":"gap", "refer_boost_type_tri":"none", "total_count":0, "points_count":0, "total_points":0, "sc_hit_rate_c":0}, ignore_index=True)
step_boost_type_tri_pr['refer_boost_type_tri'] = step_boost_type_tri_pr['refer_boost_type_tri'].astype('category')
step_boost_type_tri_pr['refer_boost_type_tri'].cat.reorder_categories(['referred','neighbor','none'], inplace=True)
step_boost_type_tri_pr = step_boost_type_tri_pr.sort_values(by=['condition', 'ref_step_base_abs', 'refer_boost_type_tri'])

# step_boost_type_tri_sl_pr
step_boost_type_sl_pr = seg_sc_pr_bounded.query("ref_step_base !=0 ").groupby(['cond_direction','condition','ref_step_base',
       'ref_step_base_abs','ref_seg_type','main_or_gap','refer_boost_type_sl']).agg(
    total_count=('condition', 'count'),
    points_count=('sc_point', 'sum'),
    total_points=('ref_sc_hit', 'sum'),
    ).reset_index()

step_boost_type_sl_pr['sc_hit_rate_c'] = step_boost_type_sl_pr['total_points']/step_boost_type_sl_pr['points_count']
step_boost_type_sl_pr = step_boost_type_sl_pr.append({"cond_direction":"f", "condition":"p0", "ref_step_base":11.5, "ref_step_base_abs":11.5,       
  "ref_seg_type":"ahead_gap", "main_or_gap":"gap", "refer_boost_type_sl":"none", "total_count":0, "points_count":0, "total_points":0, "sc_hit_rate_c":0}, ignore_index=True)
step_boost_type_sl_pr['refer_boost_type_sl'] = step_boost_type_sl_pr['refer_boost_type_sl'].astype('category')
step_boost_type_sl_pr['refer_boost_type_sl'].cat.reorder_categories(['referred','neighbor','none'], inplace=True)
step_boost_type_sl_pr = step_boost_type_sl_pr.sort_values(by=['condition', 'ref_step_base_abs', 'refer_boost_type_sl'])

  step_boost_type_pr['refer_boost_type'].cat.reorder_categories(['referred','same_seg','neighbor','none'], inplace=True)
  step_boost_type_tri_pr['refer_boost_type_tri'].cat.reorder_categories(['referred','neighbor','none'], inplace=True)
  step_boost_type_sl_pr['refer_boost_type_sl'].cat.reorder_categories(['referred','neighbor','none'], inplace=True)


In [89]:
# add more fields for plotting

def add_cols_for_plotting(df, type_col, value_col, count_col, doubled=True):
    # discard step=0
    df = df.query("ref_step_base != 0")

    # df['x0'] = np.nan
    # df['x1'] = np.nan
    
    if type_col == 'refer_type_merge' or type_col == 'refer_type_base_only':  
        ascending=[True, True, False]  # referred and none
    elif type_col == 'refer_type':  
        ascending=[True, True, True]  # direct, indirect and none
    else:
        ascending=[True, True, True]

    if type_col == 'refer_type_base_only':
        step_col = 'ref_step_base'
    else:
        step_col = 'ref_step_base_abs'

    df = df.sort_values(by=['condition', step_col, type_col], ascending=ascending)
   
    df['x1'] = df[count_col].cumsum()
    df['x0'] = df['x1'] - df[count_col]

    # aligh x cood, start from 0 for each step
    df['x0_c'] = df['x0']
    df['x1_c'] = df['x1']
    
    def aligh_x_cood(row):        
        step_base = df.loc[(df['ref_step_base']==row['ref_step_base']) 
                           & (df['condition']==row['condition']), 'x0'].values[0]
        row['x0_c'] = row['x0'] - step_base
        row['x1_c'] = row['x1'] - step_base

        return row

    df = df.apply(aligh_x_cood, axis=1)
    df['y0'] = 0
    df['y1'] = df[value_col]
    
    if doubled:
        df_copy = df.copy()
        df[f'{type_col}_plus'] = df[type_col] + '_hit'
        df_copy[f'{type_col}_plus'] = df_copy[type_col] + '_miss'

        df_copy['y0'] = df_copy[value_col]
        df_copy['y1'] = 1

        df = pd.concat([df, df_copy], ignore_index=True)
    return df   

In [90]:
step_type_pr_doubled = add_cols_for_plotting(step_type_pr, 'refer_type_merge',
                                             'sc_hit_rate_c', 'points_count')

step_type_pr = add_cols_for_plotting(step_type_pr, 'refer_type_merge',
                                             'sc_hit_rate_c', 'points_count', False)

step_rr_type_pr_doubled = add_cols_for_plotting(step_rr_type_pr, 'rr_type',
                                             'sc_hit_rate_c', 'points_count')

step_rr_type_pr = add_cols_for_plotting(step_rr_type_pr, 'rr_type',
                                             'sc_hit_rate_c', 'points_count', False)

step_boost_type_pr_doubled = add_cols_for_plotting(step_boost_type_pr, 'refer_boost_type',
                                             'sc_hit_rate_c', 'points_count')

step_boost_type_pr = add_cols_for_plotting(step_boost_type_pr, 'refer_boost_type',
                                             'sc_hit_rate_c', 'points_count', False)

step_boost_type_tri_pr_doubled = add_cols_for_plotting(step_boost_type_tri_pr, 'refer_boost_type_tri',
                                             'sc_hit_rate_c', 'points_count')

step_boost_type_sl_pr_doubled = add_cols_for_plotting(step_boost_type_sl_pr, 'refer_boost_type_sl',
                                             'sc_hit_rate_c', 'points_count')

# step_refer_type_base_only_re_doubled = add_cols_for_plotting(step_re, 'refer_type_base_only',
#                                              'sc_hit_rate_c', 'points_count')

In [91]:
# refer type
type_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                 'none_miss':'Unreferenced: Miss', 'none_hit':'Unreferenced: Hit'}

type_color_map_double = { 
                #  'referred_miss':Category20b[20][15], 'referred_hit':'#CC0033', 
                #  'referred_miss':Category20b[20][15], 'referred_hit':Category20b[20][12],
               'referred_miss':Category20c[20][7], 'referred_hit':Category20c[20][4],
                  'none_miss':Category20c[20][18], 'none_hit':Category20c[20][16]}

step_type_pr_doubled['refer_type_merge_plus_long'] = step_type_pr_doubled['refer_type_merge_plus'].map(type_long_map)
# step_refer_type_base_only_re_doubled['refer_type_base_only_plus_long'] = \
#   step_refer_type_base_only_re_doubled['refer_type_base_only_plus'].map(type_long_map)

# rr type
rr_type_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                 'referring_miss':'Referring: Miss', 'referring_hit':'Referring: Hit',
                 'none_miss':'Other: Miss', 'none_hit':'Other: Hit'}

rr_type_color_map_double = { 'referred_miss':Category20b[20][15], 'referred_hit':'#CC0033', 
                #   'referring_miss':Category20c[20][3], 'referring_hit':Category20c[20][0],
                  # 'referring_miss':Category20[20][19], 'referring_hit':Category20[20][18],
                  'referring_miss':Category20c[20][11], 'referring_hit':Category20c[20][8],
                  'none_miss':Category20c[20][18], 'none_hit':Category20c[20][16]}

step_rr_type_pr_doubled['rr_type_plus_long'] = step_rr_type_pr_doubled['rr_type_plus'].map(rr_type_long_map)

# refer boost type
boost_type_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                       'same_seg_miss':'Same-segment event referenced: Miss', 'same_seg_hit':'Same-segment event referenced: Hit',
                       'neighbor_miss':'Neighbor-segment event referenced: Miss', 'neighbor_hit':'Neighbor-segment event referenced: Hit',
                       'none_miss':'Other: Miss', 'none_hit':'Other: Hit'}

boost_type_tri_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                       'neighbor_miss':'Reference-adjacent: Miss', 'neighbor_hit':'Reference-adjacent: Hit',
                       'none_miss':'Remaining: Miss', 'none_hit':'Remaining: Hit'}

boost_type_sl_long_map = {'referred_miss':'Referenced: Miss', 'referred_hit':'Referenced: Hit',
                       'neighbor_miss':'Reference-adjacent (within storyline): Miss', 'neighbor_hit':'Reference-adjacent (within storyline): Hit',
                       'none_miss':'Remaining: Miss', 'none_hit':'Remaining: Hit'}

boost_type_color_map_double = {'referred_miss':Category20b[20][15], 'referred_hit':Category20b[20][12],
                        'same_seg_miss':Category20c[20][7], 'same_seg_hit':Category20c[20][4],
                        'neighbor_miss':Category20c[20][3], 'neighbor_hit':Category20c[20][0],
                        'none_miss':Category20c[20][19], 'none_hit':Category20c[20][16]}

boost_type_tri_color_map_double = {'referred_miss':Category20b[20][15], 'referred_hit':'#CC0033', # 'referred_hit':Category20b[20][12],
#                             'referred_miss':Pastel2[8][1], 'referred_hit':Category20c[20][4],
#                             'neighbor_miss':Pastel2[8][6], 'neighbor_hit':Category20[20][2], 
                            'neighbor_miss':Purples[9][5], 'neighbor_hit':Category20[20][8],  # Category20[20][9] Category20[20][8]
                            'none_miss':Category20c[20][19], 'none_hit':Category20c[20][17]}

boost_type_tri_color_map = {'referred':Category20c[20][4],
                            'neighbor':Category20c[20][0], 
                            'none':Category20c[20][16]}

boost_type_tri_color_map_stroke = {'referred':Pastel2[8][1],
                                   'neighbor':Category20c[20][3], 
                                   'none':Category20c[20][19]}

step_boost_type_pr_doubled['refer_boost_type_plus_long'] = step_boost_type_pr_doubled['refer_boost_type_plus'].map(boost_type_long_map)
step_boost_type_tri_pr_doubled['refer_boost_type_tri_plus_long'] = step_boost_type_tri_pr_doubled['refer_boost_type_tri_plus'].map(boost_type_tri_long_map)
step_boost_type_sl_pr_doubled['refer_boost_type_sl_plus_long'] = step_boost_type_sl_pr_doubled['refer_boost_type_sl_plus'].map(boost_type_sl_long_map)

## figure S8D

In [92]:
# function
def area_plot(df, color, scale, value, width=40.1):
    legend_cols = len(scale)/2
    width = width
    height = 230
    
    p0 = alt.Chart(df.query("condition=='p0' and ref_step_base <= 12")).mark_rect().encode(
        x=alt.X('x0_c:Q', 
    #             scale=alt.Scale(domain=(0, 140)), 
                axis=alt.Axis(title='', titleFontSize=7)),
        x2='x1_c',
        y=alt.Y('y0:Q', axis=alt.Axis(title='Hit rate')),
        y2='y1',
        color=alt.Color(color, legend=alt.Legend(direction='vertical', orient='none', title="Event type: Response type"), 
            scale=alt.Scale(domain=list(scale.values()), range=list(value.values()))),
        column=alt.Column('ref_step_base:N', title='Lag', header=alt.Header(labelOrient='top', titleFontSize=16, 
        titleFontWeight='normal', titlePadding=0, labelFontSize=16)),
    ).properties(
        width=width, height=height, title={
      "text": [""], 
      "subtitle": ['Uncued prediction (u-P)']
    })

    r0 = alt.Chart(df.query("condition=='r0' and ref_step_base >= -12")).mark_rect().encode(
        x=alt.X('x0_c:Q', 
#                 scale=alt.Scale(domain=(0, 160)), 
                axis=alt.Axis(title='')),
        x2='x1_c',
        y=alt.Y('y0:Q', axis=alt.Axis(title='Hit rate')),
        y2='y1',
        color=alt.Color(color, scale=alt.Scale(domain=list(scale.values()), range=list(value.values()))),
        column=alt.Column('ref_step_base:N', title='Lag', sort='descending', header=alt.Header(labelOrient='top', 
        titleFontSize=16, titleFontWeight='normal',  titlePadding=0, labelFontSize=16)),
    ).properties(
        width=width, height=height, title={
      "text": [""], 
      "subtitle": ['Uncued retrodiction (u-R)']
    })

    type_hit_plot = alt.vconcat(r0, p0, title=''
    ).configure_axis(
        grid=False, 
        titleFontWeight='normal',
        titleFontSize=16,
        labelFontSize=16,
    ).configure_view(
        strokeWidth=0
    ).configure_concat(
        spacing=50
    ).configure_legend( 
        legendX=0,
        legendY=695,
        rowPadding=4,
        columnPadding=20,
        titlePadding=10,
        symbolSize=200,  
        labelFontSize=16, 
        titleFontWeight='normal',
        titleFontSize=20,
        labelLimit=1000,
        titleLimit=1000,
        columns=legend_cols
    ).configure_title(
        subtitleFontSize=20,
        anchor='start',
        offset=0
    )

#     boost_type_hit_plot.save('../figs/boost_type_hit_plot.svg')
    return type_hit_plot

# boost 4
# area_plot(step_boost_type_pr_doubled, 'refer_boost_type_plus_long', boost_type_long_map, boost_type_color_map_double)

# boost 3
# area_plot(step_boost_type_tri_pr_doubled, 'refer_boost_type_tri_plus_long', boost_type_tri_long_map, boost_type_tri_color_map_double)

# boost 2
a = area_plot(step_type_pr_doubled, 'refer_type_merge_plus_long', type_long_map, type_color_map_double)
a
# a.save('../figs/core/a1.svg')


## figure S9D

In [105]:
area_plot(step_boost_type_tri_pr_doubled, 'refer_boost_type_tri_plus_long', boost_type_tri_long_map, boost_type_tri_color_map_double)

## figure S10E

In [106]:
area_plot(step_boost_type_sl_pr_doubled, 'refer_boost_type_sl_plus_long', boost_type_sl_long_map, boost_type_tri_color_map_double)

## figure S11B

In [107]:
area_plot(step_rr_type_pr_doubled.query("main_or_gap=='main'"), 'rr_type_plus_long', rr_type_long_map, rr_type_color_map_double, width=100)