# Figure S1 - Data collection timeline

In [None]:
# Imports
import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from RT import fileHandler, tools
from spl_ppr.utils import plotting_defaults

plotting_defaults(font='Arial')
%matplotlib inline

In [None]:
# Figure directory
fig_dir = 'saved_figures'

# Parameters
result_label   = 'spelling_paper_signal_analyses'
sub_result_num = 60

# Loads the analysis result data
analysis_result_file_path = fileHandler.getSubResultFilePath(
    sub_dir_key='analysis', extension='.h5', result_label=result_label,
    sub_result_num=sub_result_num
)
with pd.HDFStore(analysis_result_file_path, 'r') as f:
    data_timeline = f['data']

In [None]:
colors = sns.color_palette('Set2', 8)

# Parameters
x_ticks       = np.linspace(20, 130, 12)
y_ticks       = np.linspace(0, 1250, 6)
bar_width     = 0.2
subset_params = {
    'alphabet1_1_mimed'        : dict(color=colors[7], linewidth=0., hatch=None),
    'alphabet1_2_mimed'        : dict(color=colors[1], linewidth=0., hatch=None),
    'alphabet1_2_overt'        : dict(color=colors[2], linewidth=0., hatch=None),
    'copy_typing_optimization' : dict(color=colors[3], linewidth=0., hatch=None),
    'copy_typing_testing'      : dict(color=colors[5], linewidth=0., hatch=None),
    'conversational_testing'   : dict(color=colors[0], linewidth=0., hatch=None),
}
legend_label_updates = {
    'alphabet1_1_mimed'        : 'Isolated-target; English letters',
    'alphabet1_2_mimed'        : 'Isolated-target; NATO code words',
    'alphabet1_2_overt'        : 'Isolated-target; NATO code words (overt)',
    'copy_typing_optimization' : 'Real-time spelling; Copy-typing; Optimization',
    'copy_typing_testing'      : 'Real-time spelling; Copy-typing; Evaluation',
    'conversational_testing'   : 'Real-time spelling; Conversational; Evaluation',
}
new_label_order   = [
    'alphabet1_2_mimed', 'alphabet1_2_overt', 'alphabet1_1_mimed',
    'copy_typing_optimization', 'copy_typing_testing', 'conversational_testing'
]
tick_label_params = dict(fontdict=dict(fontsize='xx-large'))
axis_label_params = dict(fontdict=dict(fontsize='xx-large'))
grid_params = {
    "color"     : [0.6, 0.6, 0.6],
    "linestyle" : "-",
    "axis"      : "y",
    "linewidth" : 0.4,
    "zorder"    : -3.,
    "clip_on"   : False
}
figure_dpi = 300.0

# Checks the data range
x_inds = sorted(set(data_timeline['weeks_post_implantation']))
assert x_ticks[0] <= np.min(x_inds)
assert np.max(x_inds) <= x_ticks[-1]

# Creates the dictionary containing the trial counts for each subset
all_counts = {cur_subset: np.zeros(shape=(len(x_inds),), dtype=np.uint16) for cur_subset in subset_params.keys()}

# Initializes the array that will track the total heights of the stacked bars currently on the plot
cur_stacked_counts = next(iter(all_counts.values())).copy()

# Computes the trial counts for each data subset across the duration of recording
for cur_row in data_timeline.itertuples():
    all_counts[cur_row.block_type][x_inds.index(cur_row.weeks_post_implantation)] += cur_row.num_trials
    
fig, ax = plt.subplots(figsize=(15, 8), dpi=figure_dpi)

# Iterates through the subsets and plots the counts
for cur_subset, cur_bar_params in subset_params.items():
    cur_counts = all_counts[cur_subset]
    ax.bar(x=x_inds, height=cur_counts, width=bar_width, bottom=cur_stacked_counts, label=cur_subset, **cur_bar_params)
    cur_stacked_counts += cur_counts
    
# Plot formatting
ax.axes.set(xlim=(x_ticks[0] - 1, x_ticks[-1] + 1), ylim=(y_ticks[0], y_ticks[-1]), 
            xticks=x_ticks, yticks=y_ticks)
ax.set_xticklabels([f'{t:0.0f}' for t in x_ticks], **tick_label_params)
ax.set_yticklabels([f'{t:0.0f}' for t in y_ticks], **tick_label_params)
ax.set_xlabel('Weeks post-implantation', labelpad=15.0, **axis_label_params)
ax.set_ylabel('Number of trials collected', labelpad=15.0, **axis_label_params)

ax.grid(**grid_params)

# Plot legend
new_handles = []
new_labels  = []

for cur_handle, cur_label in zip(*ax.get_legend_handles_labels()):
    try:
        cur_new_label = legend_label_updates[cur_label]
    except KeyError:
        pass
    else:
        new_handles.append(cur_handle)
        new_labels.append(cur_new_label)

label_ordering = [new_labels.index(legend_label_updates[i]) for i in new_label_order]
ordered_handles = [new_handles[i] for i in label_ordering]
ordered_labels  = [new_labels[i] for i in label_ordering]
legend = ax.legend(ordered_handles, ordered_labels, loc='upper left', ncol=2, fontsize='large')
legend.get_title().set_fontsize('x-large');

## Save figure

In [None]:
fig.savefig(os.path.join(fig_dir, 'suppfig_data_collection.pdf'), transparent=True, bbox_inches='tight', dpi=figure_dpi)
fig.savefig(os.path.join(fig_dir, 'suppfig_data_collection_white.pdf'), transparent=False, bbox_inches='tight', dpi=figure_dpi)

## Print collected data amounts

In [None]:
# Displays additional information about the collected data
task_summary_data = {
    k : {
        'num_blocks'     : 0,
        'num_trials'     : 0,
        'duration_hours' : 0.,
        'session_dates'  : set(),
        'num_dates'      : 0,
    }
    for k in set(data_timeline['block_type'])
}

for cur_block_data in data_timeline.itertuples():
    cur_block_type_data = task_summary_data[cur_block_data.block_type]
    cur_block_type_data['num_blocks']     += 1
    cur_block_type_data['num_trials']     += cur_block_data.num_trials
    cur_block_type_data['duration_hours'] += (cur_block_data.duration_sec / 3600.)
    cur_block_type_data['session_dates'].update([cur_block_data.date])
    cur_block_type_data['num_dates'] = len(cur_block_type_data['session_dates'])
    
task_summary_data['average_duration_per_session_minutes'] = (
    (np.sum([v['duration_hours'] for v in task_summary_data.values()]) / np.sum([v['num_dates'] for v in task_summary_data.values()])) * 60.
)

tools.displayDictionaryStructure(
    task_summary_data, display_repr=True, cur_name='Task summary information:', keys_to_skip=['session_dates']
)