In [2]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import pingouin as pg
import plotly.graph_objects as go
from os.path import join as pjoin
import plotly.express as px

sys.path.append("../..")
import circletrack_behavior as ctb
import plotting_functions as pf

  return warn(


In [3]:
## Excluded mice ms18, ms22, ms28 due to failing to learn
## Set mouse lists and directory
fixed = [f'ms0{num}' for num in np.arange(1, 9)]
fixed = fixed + [f'ms{num}' for num in np.arange(35, 41)]
criteria = ['ms09', 'ms11', 'ms12', 'ms17', 'ms19', 'ms21', 
            'ms26', 'ms27', 'ms30', 'ms31', 'ms33', 'ms34']
cr_5 = ['ms10', 'ms13', 'ms14', 'ms15', 'ms16', 'ms20', 
        'ms23', 'ms24', 'ms25', 'ms29', 'ms32']
no_probe = ['ms27', 'ms31', 'ms33', 'ms34', 'ms14', 'ms23', 'ms24', 'ms09', 'ms38', 'ms29', 'ms05', 'ms06', 'ms35', 'ms01', 'ms16', 'ms11', 'ms40', 'ms03', 'ms19'] ## check ms23 later
orthogonal = ['ms02', 'ms05', 'ms06', 'ms15', 'ms17', 'ms23', 'ms26', 'ms30', 'ms34', 'ms37']
behav_path = '../../../MultiCon_Behavior/BehaviorCohort1/output/behav/'
fig_path = '../../../MultiCon_Behavior/BehaviorCohort1/intermediate_figures/'
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

## Group colors to change plots
group_colors = {'criteria': 'darkorchid', 'cr_5': 'darkgrey', 'fixed': 'lightblue', 'orthog': 'turquoise', 'nonorthog': 'darkgrey'}

In [None]:
## Check whether trial trimes are different between groups
trial_times_dict = {'mouse': [], 'forward_trials': [], 'reverse_trials': [], 'group': [], 'session': []}
for mouse in fixed + criteria + cr_5:
    mouse_path = pjoin(behav_path, mouse)
    group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'
    for session in os.listdir(mouse_path):
        if (session == 'ms35_3.feat') | (session == 'ms23_50.feat'): ## Trials are NaN
            next
        else:
            behav_data = pd.read_feather(pjoin(mouse_path, f'{session}'))
            time_diff_forward, time_diff_reverse = ctb.calculate_trial_length(behav_data, forward_reverse=True, recalc_trials=False)
            trial_times_dict['mouse'].append(mouse)
            trial_times_dict['forward_trials'].append(time_diff_forward)
            trial_times_dict['reverse_trials'].append(time_diff_reverse)
            trial_times_dict['group'].append(group)
            trial_times_dict['session'].append(np.unique(behav_data['session'])[0])
## Convert to dataframe
trial_times = pd.DataFrame(trial_times_dict)

fixed_trials = []
for trials in trial_times['forward_trials'][trial_times['group'] == 'fixed']:
    fixed_trials = fixed_trials + trials

criteria_trials = []
for trials in trial_times['forward_trials'][trial_times['group'] == 'criteria']:
    criteria_trials = criteria_trials + trials

cr5_trials = []
for trials in trial_times['forward_trials'][trial_times['group'] == 'cr_5']:
    cr5_trials = cr5_trials + trials

In [None]:
## Bar chart of trial times per mouse per session
trial_dict_session = {'mouse': [], 'day': [], 'session': [], 'avg_time': [], 'std': [], 'sem': [], 'group': []}
for mouse in np.unique(trial_times['mouse']):
    mouse_trials = trial_times[trial_times['mouse'] == mouse].reset_index()
    for idx, trials in enumerate(mouse_trials['forward_trials']):
        trial_dict_session['mouse'].append(mouse)
        trial_dict_session['day'].append(idx+1)
        trial_dict_session['session'].append(mouse_trials.loc[idx, 'session'])
        trial_dict_session['avg_time'].append(np.mean(trials))
        trial_dict_session['std'].append(np.std(trials))
        trial_dict_session['sem'].append(np.std(trials)/len(trials))
        trial_dict_session['group'].append(np.unique(mouse_trials['group'])[0])
mouse_trial_session_df = pd.DataFrame(trial_dict_session)
avg_session = mouse_trial_session_df.groupby(['group', 'session'], as_index=False).agg({'avg_time': ['mean', 'sem']})
avg_session_df = avg_session['avg_time']
avg_session_df.insert(0, 'group', avg_session['group'])
avg_session_df.insert(1, 'session', avg_session['session'])

In [None]:
## Bar chart of average trial time per context
fixed_df = avg_session_df[(avg_session_df['group'] == 'fixed') & (avg_session_df['session'] != 'ND')]
criteria_df = avg_session_df[(avg_session_df['group'] == 'criteria') & (avg_session_df['session'] != 'ND')]
cr5_df = avg_session_df[avg_session_df['group'] == 'cr_5'] 
cr5_df = cr5_df[(avg_session_df['session'] == 'A') | (avg_session_df['session'] == 'B') | (avg_session_df['session'] == 'C') | 
                (avg_session_df['session'] == 'D') | (avg_session_df['session'] == 'AP')]
subset_trials = mouse_trial_session_df[(mouse_trial_session_df['session'] == 'A') | (mouse_trial_session_df['session'] == 'B') |
                                       (mouse_trial_session_df['session'] == 'C') | (mouse_trial_session_df['session'] == 'D') | (mouse_trial_session_df['session'] == 'AP')]
mtt = subset_trials.groupby(['mouse', 'session', 'group'], as_index=False).agg({'avg_time': 'mean'})
fig = px.strip(mtt, x='session', y='avg_time', color='group', hover_name='mouse',
               color_discrete_sequence=[group_colors['fixed'], group_colors['criteria'], group_colors['cr_5']]).update_traces(showlegend=False, opacity=0.8,
                                                                                                                              marker_line_width=1)
fig.add_trace(go.Bar(x=fixed_df['session'], y=fixed_df['mean'],
                     error_y=dict(type='data', array=fixed_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['fixed'], marker_line_color='black', 
                     marker_line_width=2, name='Fixed', opacity=0.8))
fig.add_trace(go.Bar(x=criteria_df['session'], y=criteria_df['mean'],
                     error_y=dict(type='data', array=criteria_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['criteria'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria', opacity=0.8))
fig.add_trace(go.Bar(x=cr5_df['session'], y=cr5_df['mean'],
                     error_y=dict(type='data', array=cr5_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['cr_5'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria+5', opacity=0.8))
fig.update_layout(height=500, width=600, template='simple_white', legend_title_text='', yaxis_title='Average Trial Time (s)',
                  xaxis_title='', font=dict(size=16))
fig.update_xaxes(categoryorder='array', categoryarray=['A', 'B', 'C', 'D', 'AP'])
fig.show()
fig.write_image(pjoin(fig_path, 'average_trial_time.png'))

In [None]:
subset_trials.mixed_anova(dv='avg_time', within='session', subject='mouse', between='group')

In [None]:
## Number of trials
num_trials_dict = {'mouse': [], 'num_trials': [], 'group': [], 'session': []}
for mouse in np.unique(trial_times['mouse']):
    sub_data = trial_times[trial_times['mouse'] == mouse]
    for row, _ in sub_data.iterrows():
        num_trials_dict['mouse'].append(mouse)
        num_trials_dict['num_trials'].append(len(sub_data.loc[row, 'forward_trials']))
        num_trials_dict['group'].append(sub_data.loc[row, 'group'])
        num_trials_dict['session'].append(sub_data.loc[row, 'session'])
num_trials_df = pd.DataFrame(num_trials_dict)
result = num_trials_df.groupby(['group', 'session'], as_index=False).agg({'num_trials': ['mean', 'sem']})
avg_trial_times = result['num_trials']
avg_trial_times.insert(0, 'group', result['group'])
avg_trial_times.insert(1, 'session', result['session'])

In [None]:
## Plot avg number of trials per mouse per context
fixed_df = avg_trial_times[(avg_trial_times['group'] == 'fixed') & (avg_trial_times['session'] != 'ND')]
criteria_df = avg_trial_times[(avg_trial_times['group'] == 'criteria') & (avg_trial_times['session'] != 'ND')]
cr5_df = avg_trial_times[avg_trial_times['group'] == 'cr_5'] 
cr5_df = cr5_df[(avg_trial_times['session'] == 'A') | (avg_trial_times['session'] == 'B') | (avg_trial_times['session'] == 'C') | 
                (avg_trial_times['session'] == 'D') | (avg_trial_times['session'] == 'AP')]
subset_trials = num_trials_df[(mouse_trial_session_df['session'] == 'A') | (mouse_trial_session_df['session'] == 'B') |
                                       (mouse_trial_session_df['session'] == 'C') | (mouse_trial_session_df['session'] == 'D') | (mouse_trial_session_df['session'] == 'AP')]
ntrials_df = subset_trials.groupby(['mouse', 'session', 'group'], as_index=False).agg({'num_trials': 'mean'})
fig = px.strip(ntrials_df, x='session', y='num_trials', color='group', hover_name='mouse',
               color_discrete_sequence=[group_colors['fixed'], group_colors['criteria'], group_colors['cr_5']]).update_traces(showlegend=False, opacity=0.8,
                                                                                                                              marker_line_width=1)
fig.add_trace(go.Bar(x=fixed_df['session'], y=fixed_df['mean'],
                     error_y=dict(type='data', array=fixed_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['fixed'], marker_line_color='black', 
                     marker_line_width=2, name='Fixed', opacity=0.8))
fig.add_trace(go.Bar(x=criteria_df['session'], y=criteria_df['mean'],
                     error_y=dict(type='data', array=criteria_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['criteria'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria', opacity=0.8))
fig.add_trace(go.Bar(x=cr5_df['session'], y=cr5_df['mean'],
                     error_y=dict(type='data', array=cr5_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['cr_5'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria+5', opacity=0.8))
fig.update_layout(height=500, width=600, template='simple_white', legend_title_text='', yaxis_title='Average Number of Trials',
                  xaxis_title='', font=dict(size=16))
fig.update_xaxes(categoryorder='array', categoryarray=['A', 'B', 'C', 'D', 'AP'])
fig.show()
fig.write_image(pjoin(fig_path, 'average_number_of_trials.png'))

In [None]:
ntrials_df.mixed_anova(dv='num_trials', within='session', subject='mouse', between='group')

In [10]:
## Days to Criteria
days_to = {'mouse': [], 'group': [], 'session': []}
for mouse in criteria + cr_5:
    mouse_path = pjoin(behav_path, mouse)
    group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'
    for session in os.listdir(mouse_path):
        behav_data = pd.read_feather(pjoin(mouse_path, f'{session}'))
        days_to['mouse'].append(mouse)
        days_to['group'].append(group)
        days_to['session'].append(np.unique(behav_data['session'])[0])
## Convert to dataframe
days_to_df = pd.DataFrame(days_to)
days_to_df = days_to_df[(days_to_df['session'] == 'A') | (days_to_df['session'] == 'B') |
                        (days_to_df['session'] == 'C') | (days_to_df['session'] == 'D')]
days_to_df.insert(3, 'session_count', days_to_df['session'])
days_to_criteria = days_to_df.groupby(['mouse', 'group', 'session'], as_index=False).agg({'session_count': 'count'})
result = days_to_criteria.groupby(['group', 'session'], as_index=False).agg({'session_count': ['mean', 'sem']})

In [11]:
## Plot days to criteria
fig = pf.custom_graph_template(x_title='', y_title='Days to Criteria')
for group in np.unique(result['group']):
    plot_data = result[result['group'] == group]
    if group == 'cr_5':
        gname = 'Criteria+5'
    else:
        gname = 'Criteria'
    fig.add_trace(go.Scatter(x=plot_data['session'], y=plot_data['session_count']['mean'], mode='lines+markers',
                             error_y=dict(type='data', array=plot_data['session_count']['sem'], thickness=1.5, width=8),
                             name=gname, legendgroup=group, line_color=group_colors[group]))
for mouse in criteria + cr_5:
    mouse_data = days_to_criteria[days_to_criteria['mouse'] == mouse]
    group = 'criteria' if mouse in criteria else 'cr_5'
    fig.add_trace(go.Scatter(x=mouse_data['session'], y=mouse_data['session_count'], mode='markers',
                             marker_color=group_colors[group], legendgroup=group, opacity=0.8, 
                             showlegend=False, marker_line_width=1, name=mouse))
fig.show()
fig.write_image(pjoin(fig_path, 'days_to_criteria.png'))
days_to_criteria.mixed_anova(dv='session_count', within='session', subject='mouse', between='group')

Unnamed: 0,Source,SS,DF1,DF2,MS,F,p-unc,np2,eps
0,group,0.222661,1,21,0.222661,0.017901,0.8948397,0.000852,
1,session,1106.206522,3,63,368.735507,27.550804,1.672721e-11,0.567463,0.941527
2,Interaction,0.86166,3,63,0.28722,0.02146,0.9956918,0.001021,


In [76]:
mouse = 'ms16'
pd.read_feather(pjoin(behav_path, f'{mouse}/{mouse}_1.feat'))

Unnamed: 0,unix,frame,t,x,y,a_pos,lick_port,water,animal,session,cohort,trials,lin_position,reward_one,reward_two,probe,maze
0,1.653586e+09,1,-0.384278,334.0,449.0,272.0,5,False,ms16,A,bc1,0.0,4.747296,2,5,True,maze3
1,1.653586e+09,4,0.321171,333.0,450.0,272.0,-1,False,ms16,A,bc1,0.0,4.747296,2,5,False,maze3
2,1.653586e+09,5,0.366048,334.0,450.0,272.0,-1,False,ms16,A,bc1,0.0,4.747296,2,5,False,maze3
3,1.653586e+09,6,0.411928,334.0,450.0,272.0,-1,False,ms16,A,bc1,0.0,4.747296,2,5,False,maze3
4,1.653586e+09,7,0.458798,333.0,450.0,272.0,-1,False,ms16,A,bc1,0.0,4.747296,2,5,False,maze3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23299,1.653588e+09,23518,1199.788649,465.0,391.0,314.0,-1,False,ms16,A,bc1,11.0,5.480334,2,5,False,maze3
23300,1.653588e+09,23520,1199.831569,466.0,391.0,315.0,4,False,ms16,A,bc1,11.0,5.497787,2,5,False,maze3
23301,1.653588e+09,23521,1199.889559,466.0,391.0,315.0,-1,False,ms16,A,bc1,11.0,5.497787,2,5,False,maze3
23302,1.653588e+09,23523,1199.925528,466.0,391.0,315.0,-1,False,ms16,A,bc1,11.0,5.497787,2,5,False,maze3


In [88]:
## Probe Accuracy
lick_dict_probe = {'percent_correct': [], 'mouse': [], 'day': [], 'group': [], 'session': [], 'probe_length': [], 'num_licks': [], 'session_pc': []}
for mouse in fixed + criteria + cr_5:
    mouse_path = pjoin(behav_path, mouse)
    group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'
    for idx, session in enumerate(os.listdir(mouse_path)):
        behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        if any(behav['probe']):
            behav_probe = behav[behav['probe']]
            behav_no_probe = behav[~behav['probe']]
            reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
            if (mouse == 'ms02') and pd.isna(reward_one):
                reward_one, reward_two = 1, 5
            elif (mouse == 'ms03') and pd.isna(reward_two):
                reward_one, reward_two = 1, 6
            ## Percent correct licking
            pc = ctb.lick_accuracy(behav_probe, port_list=[reward_one, reward_two], lick_threshold=2, by_trials=False)
            session_pc = ctb.lick_accuracy(behav_no_probe, port_list=[reward_one, reward_two], lick_threshold=2, by_trials=False)
            lick_dict_probe['percent_correct'].append(pc)
            lick_dict_probe['mouse'].append(mouse)
            lick_dict_probe['day'].append(idx+1)
            lick_dict_probe['group'].append(group)
            lick_dict_probe['session'].append(np.unique(behav['session'])[0])
            lick_dict_probe['probe_length'].append(round(behav_probe['t'].to_numpy()[-1] - behav_probe['t'].to_numpy()[0]))
            lick_dict_probe['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
            lick_dict_probe['session_pc'].append(session_pc)
        else:
            pass
## Convert to dataframe
probe_df = pd.DataFrame(lick_dict_probe)
# First and last day in each context
last = pd.DataFrame()
for mouse in fixed + criteria + cr_5:
    if mouse in no_probe:
        pass
    else:
        if mouse in cr_5:
            mouse_data = probe_df[probe_df['mouse'] == mouse].reset_index(drop=True)
            mouse_data = mouse_data[(mouse_data['session'] == 'A') | (mouse_data['session'] == 'B') |
                                    (mouse_data['session'] == 'C') | (mouse_data['session'] == 'D') | (mouse_data['session'] == 'AP')].reset_index(drop=True)
            index_list = ctb.pick_context_day(mouse_data, day_index=-1, contexts=['A', 'B', 'C', 'D', 'AP'])
            sub_data = mouse_data.loc[index_list, :]
            last = pd.concat([last, sub_data])
        else:
            mouse_data = probe_df[probe_df['mouse'] == mouse].reset_index(drop=True)
            index_list = ctb.pick_context_day(mouse_data, day_index=-1, contexts=['A', 'B', 'C', 'D', 'AP'])
            sub_data = mouse_data.loc[index_list, :]
            last = pd.concat([last, sub_data])
            
last['licks_sec'] = last['num_licks'] / last['probe_length']
avg_last = last.groupby(['session', 'group'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
avg_last_df = avg_last['percent_correct']
avg_last_df.insert(0, 'group', avg_last['group'])
avg_last_df.insert(1, 'session', avg_last['session'])

In [107]:
## Plot probe accuracy for first and last day in A and A-probe (AP)
first_last = pd.DataFrame()
for mouse in fixed + criteria + cr_5:
    if mouse in no_probe:
        pass
    else:
        if mouse in cr_5:
            mouse_data = probe_df[probe_df['mouse'] == mouse].reset_index(drop=True)
            mouse_data = mouse_data[(mouse_data['session'] == 'A') | (mouse_data['session'] == 'AP')].reset_index(drop=True)
            ## First day
            index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=['A', 'AP'])
            sub_data = mouse_data.loc[index_list, :]
            sub_data['day_type'] = [1, 6]
            sub_data['session'] = ['A1', 'AR']
            first_last = pd.concat([first_last, sub_data])
            ## Last day
            index_list = ctb.pick_context_day(mouse_data, day_index=-1, contexts=['A'])
            sub_data = mouse_data.loc[index_list, :]
            sub_data['day_type'] = 5
            sub_data['session'] = 'A5'
            first_last = pd.concat([first_last, sub_data])
            
        else:
            mouse_data = probe_df[probe_df['mouse'] == mouse].reset_index(drop=True)
            index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=['A', 'AP'])
            sub_data = mouse_data.loc[index_list, :]
            sub_data['day_type'] = [1, 6]
            sub_data['session'] = ['A1', 'AR']
            first_last = pd.concat([first_last, sub_data])
            index_list = ctb.pick_context_day(mouse_data, day_index=-1, contexts=['A'])
            sub_data = mouse_data.loc[index_list, :]
            sub_data['day_type'] = 5
            sub_data['session'] = 'A5'
            first_last = pd.concat([first_last, sub_data])

In [112]:
avg_probe = first_last.groupby(['session', 'day_type'], as_index=False).agg({'percent_correct': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='', y_title='', titles=[''], 
                               shared_x=True, shared_y=True, width=600)
fig.add_trace(go.Scatter(x=avg_probe['day_type'], y=avg_probe['percent_correct']['mean'], mode='markers',
                         error_y=dict(type='data', array=avg_probe['percent_correct']['sem']), marker_color='midnightblue'))
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.add_vline(x=5.5, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', col=1, range=[0, 100])
fig.update_xaxes(
    ticktext=['A1', 'A5', 'R'],
    tickvals=[1, 5, 6],
)
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_with_recall.png'))

In [None]:
## Bar chart of probe accuracy per mouse on the last day within a context
fixed_df = avg_last_df[(avg_last_df['group'] == 'fixed') & (avg_last_df['session'] != 'ND')]
criteria_df = avg_last_df[(avg_last_df['group'] == 'criteria') & (avg_last_df['session'] != 'ND')]
cr5_df = avg_last_df[avg_last_df['group'] == 'cr_5'] 
cr5_df = cr5_df[(avg_last_df['session'] == 'A') | (avg_last_df['session'] == 'B') | (avg_last_df['session'] == 'C') | 
                (avg_last_df['session'] == 'D') | (avg_last_df['session'] == 'AP')]
subset_trials = last[(last['session'] == 'A') | (last['session'] == 'B') |
                     (last['session'] == 'C') | (last['session'] == 'D') | (last['session'] == 'AP')]
fig = px.strip(subset_trials, x='session', y='percent_correct', color='group', hover_name='mouse',
               color_discrete_sequence=[group_colors['fixed'], group_colors['criteria'], group_colors['cr_5']]).update_traces(showlegend=False, opacity=0.8,
                                                                                                                              marker_line_width=1)
fig.add_trace(go.Bar(x=fixed_df['session'], y=fixed_df['mean'],
                     error_y=dict(type='data', array=fixed_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['fixed'], marker_line_color='black', 
                     marker_line_width=2, name='Fixed', opacity=0.8))
fig.add_trace(go.Bar(x=criteria_df['session'], y=criteria_df['mean'],
                     error_y=dict(type='data', array=criteria_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['criteria'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria', opacity=0.8))
fig.add_trace(go.Bar(x=cr5_df['session'], y=cr5_df['mean'],
                     error_y=dict(type='data', array=cr5_df['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['cr_5'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria+5', opacity=0.8))
fig.update_layout(height=500, width=600, template='simple_white', legend_title_text='', yaxis_title='Probe Accuracy (%)',
                  xaxis_title='', font=dict(size=16))
fig.update_xaxes(categoryorder='array', categoryarray=['A', 'B', 'C', 'D', 'AP'])
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy.png'))

In [None]:
subset_trials.mixed_anova(dv='percent_correct', within='session', subject='mouse', between='group')

In [None]:
## Number of licks during the probe for each context
avg_num_licks = last.groupby(['group', 'session'], as_index=False).agg({'num_licks': ['mean', 'sem']})
fixed_df = avg_num_licks[avg_num_licks['group'] == 'fixed']
criteria_df = avg_num_licks[avg_num_licks['group'] == 'criteria']
cr5_df = avg_num_licks[avg_num_licks['group'] == 'cr_5']
fig = px.strip(last, x='session', y='num_licks', color='group', hover_name='mouse',
               color_discrete_sequence=[group_colors['fixed'], group_colors['criteria'], group_colors['cr_5']]).update_traces(showlegend=False, opacity=0.8,
                                                                                                                              marker_line_width=1)
fig.add_trace(go.Bar(x=fixed_df['session'], y=fixed_df['num_licks']['mean'],
                     error_y=dict(type='data', array=fixed_df['num_licks']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['fixed'], marker_line_color='black', 
                     marker_line_width=2, name='Fixed', opacity=0.8))
fig.add_trace(go.Bar(x=criteria_df['session'], y=criteria_df['num_licks']['mean'],
                     error_y=dict(type='data', array=criteria_df['num_licks']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['criteria'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria', opacity=0.8))
fig.add_trace(go.Bar(x=cr5_df['session'], y=cr5_df['num_licks']['mean'],
                     error_y=dict(type='data', array=cr5_df['num_licks']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['cr_5'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria+5', opacity=0.8))
fig.update_layout(height=500, width=600, template='simple_white', legend_title_text='', yaxis_title='Number of Licks During Probe',
                  xaxis_title='', font=dict(size=16))
fig.update_xaxes(categoryorder='array', categoryarray=['A', 'B', 'C', 'D', 'AP'])
fig.show()
fig.write_image(pjoin(fig_path, 'num_licks_probe.png'))

In [None]:
last.mixed_anova(dv='num_licks', within='session', subject='mouse', between='group')

In [None]:
## Licks/s during the probe for each context on the last day
licks_s = last.groupby(['group', 'session'], as_index=False).agg({'licks_sec': ['mean', 'sem']})
fixed_df = licks_s[licks_s['group'] == 'fixed']
criteria_df = licks_s[licks_s['group'] == 'criteria']
cr5_df = licks_s[licks_s['group'] == 'cr_5']
fig = px.strip(last, x='session', y='licks_sec', color='group', hover_name='mouse',
               color_discrete_sequence=[group_colors['fixed'], group_colors['criteria'], group_colors['cr_5']]).update_traces(showlegend=False, opacity=0.8,
                                                                                                                              marker_line_width=1)
fig.add_trace(go.Bar(x=fixed_df['session'], y=fixed_df['licks_sec']['mean'],
                     error_y=dict(type='data', array=fixed_df['licks_sec']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['fixed'], marker_line_color='black', 
                     marker_line_width=2, name='Fixed', opacity=0.8))
fig.add_trace(go.Bar(x=criteria_df['session'], y=criteria_df['licks_sec']['mean'],
                     error_y=dict(type='data', array=criteria_df['licks_sec']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['criteria'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria', opacity=0.8))
fig.add_trace(go.Bar(x=cr5_df['session'], y=cr5_df['licks_sec']['mean'],
                     error_y=dict(type='data', array=cr5_df['licks_sec']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['cr_5'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria+5', opacity=0.8))
fig.update_layout(height=500, width=600, template='simple_white', legend_title_text='', yaxis_title='Licks/sec During Probe',
                  xaxis_title='', font=dict(size=16))
fig.update_xaxes(categoryorder='array', categoryarray=['A', 'B', 'C', 'D', 'AP'])
fig.show()
fig.write_image(pjoin(fig_path, 'licks_sec_probe.png'))

In [None]:
last.mixed_anova(dv='licks_sec', within='session', subject='mouse', between='group')

In [None]:
## Scatterplot of all days within a context probe accuracy vs probe length
probe_df = probe_df[probe_df['probe_length'] != 0]
session_list = ['A', 'B', 'C', 'D', 'AP']
fig = pf.custom_graph_template(x_title='Probe Length (s)', y_title='', rows=1, columns=5, 
                               width=1500, height=500, titles=['A', 'B', 'C', 'D', 'AP'], shared_x=True, shared_y=True, font_size=17)
for idx, session in enumerate(session_list):
    sub_f = probe_df[(probe_df['session'] == session) & (probe_df['group'] == 'fixed')]
    sub_c = probe_df[(probe_df['session'] == session) & (probe_df['group'] == 'criteria')]
    sub_c5 = probe_df[(probe_df['session'] == session) & (probe_df['group'] == 'cr_5')]

    ## Linear regression
    lm_f = pg.linear_regression(X=sub_f['probe_length'], y=sub_f['percent_correct'], as_dataframe=False, remove_na=True)
    lm_c = pg.linear_regression(X=sub_c['probe_length'], y=sub_c['percent_correct'], as_dataframe=False, remove_na=True)
    lm_c5 = pg.linear_regression(X=sub_c5['probe_length'], y=sub_c5['percent_correct'], as_dataframe=False, remove_na=True)
    lm_f_x = []
    lm_c_x = []
    lm_c5_x = []
    for n in np.arange(len(lm_f['X'])):
        lm_f_x.append(lm_f['X'][n][1])
    for n in np.arange(len(lm_c['X'])):
        lm_c_x.append(lm_c['X'][n][1])
    for n in np.arange(len(lm_c5['X'])):
        lm_c5_x.append(lm_c5['X'][n][1])

    fig.add_trace(go.Scatter(x=sub_f['probe_length'], y=sub_f['percent_correct'], mode='markers',
                             showlegend=False, name='Fixed', marker_line_width=1, legendgroup='Fixed',
                             marker_color=group_colors['fixed']), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=sub_c['probe_length'], y=sub_c['percent_correct'], mode='markers',
                             showlegend=False, name='Criteria', marker_line_width=1, legendgroup='Criteria',
                             marker_color=group_colors['criteria']), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=sub_c5['probe_length'], y=sub_c5['percent_correct'], mode='markers',
                             showlegend=False, name='Criteria+5', marker_line_width=1, legendgroup='Criteria+5',
                             marker_color=group_colors['cr_5']), row=1, col=idx+1)
    
    fig.add_trace(go.Scatter(x=lm_f_x, y=lm_f['pred'], mode='lines', line_color=group_colors['fixed'], 
                             legendgroup='Fixed', showlegend=False), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=lm_c_x, y=lm_c['pred'], mode='lines', line_color=group_colors['criteria'], 
                             legendgroup='Criteria', showlegend=False), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=lm_c5_x, y=lm_c5['pred'], mode='lines', line_color=group_colors['cr_5'], 
                             legendgroup='Criteria+5', showlegend=False), row=1, col=idx+1)
    # fig.add_annotation(x=73, y=92, text=f"R<sup>2</sup> = {np.round(lm_f['r2'], decimals=3)}", row=1, col=idx+1, showarrow=False)
    # fig.add_annotation(x=73, y=85, text=f"R<sup>2</sup> = {np.round(lm_c['r2'], decimals=3)}", row=1, col=idx+1, showarrow=False)
fig.update_yaxes(title='Probe Accuracy (%)', row=1, col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_vs_length.png'))

In [None]:
## Probe Accuracy vs Probe Length for last day in each context
session_list = ['A', 'B', 'C', 'D', 'AP']
fig = pf.custom_graph_template(x_title='Probe Length (s)', y_title='', rows=1, columns=5, 
                               width=1500, height=500, titles=['A', 'B', 'C', 'D', 'AP'], shared_x=True, shared_y=True, font_size=17)
for idx, session in enumerate(session_list):
    sub_f = last[(last['session'] == session) & (last['group'] == 'fixed')]
    sub_c = last[(last['session'] == session) & (last['group'] == 'criteria')]
    sub_c5 = last[(last['session'] == session) & (last['group'] == 'cr_5')]

    ## Linear regression
    lm_f = pg.linear_regression(X=sub_f['probe_length'], y=sub_f['percent_correct'], as_dataframe=False, remove_na=True)
    lm_c = pg.linear_regression(X=sub_c['probe_length'], y=sub_c['percent_correct'], as_dataframe=False, remove_na=True)
    lm_c5 = pg.linear_regression(X=sub_c5['probe_length'], y=sub_c5['percent_correct'], as_dataframe=False, remove_na=True)
    lm_f_x = []
    lm_c_x = []
    lm_c5_x = []
    for n in np.arange(len(lm_f['X'])):
        lm_f_x.append(lm_f['X'][n][1])
    for n in np.arange(len(lm_c['X'])):
        lm_c_x.append(lm_c['X'][n][1])
    for n in np.arange(len(lm_c5['X'])):
        lm_c5_x.append(lm_c5['X'][n][1])

    fig.add_trace(go.Scatter(x=sub_f['probe_length'], y=sub_f['percent_correct'], mode='markers',
                             showlegend=False, name='Fixed', marker_line_width=1, legendgroup='Fixed',
                             marker_color=group_colors['fixed']), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=sub_c['probe_length'], y=sub_c['percent_correct'], mode='markers',
                             showlegend=False, name='Criteria', marker_line_width=1, legendgroup='Criteria',
                             marker_color=group_colors['criteria']), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=sub_c5['probe_length'], y=sub_c5['percent_correct'], mode='markers',
                             showlegend=False, name='Criteria+5', marker_line_width=1, legendgroup='Criteria+5',
                             marker_color=group_colors['cr_5']), row=1, col=idx+1)
    
    fig.add_trace(go.Scatter(x=lm_f_x, y=lm_f['pred'], mode='lines', line_color=group_colors['fixed'], 
                             legendgroup='Fixed', showlegend=False), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=lm_c_x, y=lm_c['pred'], mode='lines', line_color=group_colors['criteria'], 
                             legendgroup='Criteria', showlegend=False), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=lm_c5_x, y=lm_c5['pred'], mode='lines', line_color=group_colors['cr_5'], 
                             legendgroup='Criteria+5', showlegend=False), row=1, col=idx+1)
    # fig.add_annotation(x=73, y=92, text=f"R<sup>2</sup> = {np.round(lm_f['r2'], decimals=3)}", row=1, col=idx+1, showarrow=False)
    # fig.add_annotation(x=73, y=85, text=f"R<sup>2</sup> = {np.round(lm_c['r2'], decimals=3)}", row=1, col=idx+1, showarrow=False)
fig.update_yaxes(title='Probe Accuracy (%)', row=1, col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_vs_length_last_day.png'))

In [None]:
## Probe Accuracy vs Session Accuracy for the last day in each context
session_list = ['A', 'B', 'C', 'D', 'AP']
fig = pf.custom_graph_template(x_title='Probe Accuracy (%)', y_title='', rows=1, columns=5, 
                               width=1500, height=500, titles=['A', 'B', 'C', 'D', 'AP'], shared_x=True, shared_y=True, font_size=17)
for idx, session in enumerate(session_list):
    sub_f = last[(last['session'] == session) & (last['group'] == 'fixed')]
    sub_c = last[(last['session'] == session) & (last['group'] == 'criteria')]
    sub_c5 = last[(last['session'] == session) & (last['group'] == 'cr_5')]

    ## Linear regression
    lm_f = pg.linear_regression(X=sub_f['percent_correct'], y=sub_f['session_pc'], as_dataframe=False, remove_na=True)
    lm_c = pg.linear_regression(X=sub_c['percent_correct'], y=sub_c['session_pc'], as_dataframe=False, remove_na=True)
    lm_c5 = pg.linear_regression(X=sub_c5['percent_correct'], y=sub_c5['session_pc'], as_dataframe=False, remove_na=True)
    lm_f_x = []
    lm_c_x = []
    lm_c5_x = []
    for n in np.arange(len(lm_f['X'])):
        lm_f_x.append(lm_f['X'][n][1])
    for n in np.arange(len(lm_c['X'])):
        lm_c_x.append(lm_c['X'][n][1])
    for n in np.arange(len(lm_c5['X'])):
        lm_c5_x.append(lm_c5['X'][n][1])

    fig.add_trace(go.Scatter(x=sub_f['percent_correct'], y=sub_f['session_pc'], mode='markers',
                             showlegend=False, name='Fixed', marker_line_width=1, legendgroup='Fixed',
                             marker_color=group_colors['fixed']), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=sub_c['percent_correct'], y=sub_c['session_pc'], mode='markers',
                             showlegend=False, name='Criteria', marker_line_width=1, legendgroup='Criteria',
                             marker_color=group_colors['criteria']), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=sub_c5['percent_correct'], y=sub_c5['session_pc'], mode='markers',
                             showlegend=False, name='Criteria+5', marker_line_width=1, legendgroup='Criteria+5',
                             marker_color=group_colors['cr_5']), row=1, col=idx+1)
    
    fig.add_trace(go.Scatter(x=lm_f_x, y=lm_f['pred'], mode='lines', line_color=group_colors['fixed'], 
                             legendgroup='Fixed', showlegend=False), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=lm_c_x, y=lm_c['pred'], mode='lines', line_color=group_colors['criteria'], 
                             legendgroup='Criteria', showlegend=False), row=1, col=idx+1)
    fig.add_trace(go.Scatter(x=lm_c5_x, y=lm_c5['pred'], mode='lines', line_color=group_colors['cr_5'], 
                             legendgroup='Criteria+5', showlegend=False), row=1, col=idx+1)
    # fig.add_annotation(x=73, y=92, text=f"R<sup>2</sup> = {np.round(lm_f['r2'], decimals=3)}", row=1, col=idx+1, showarrow=False)
    # fig.add_annotation(x=73, y=85, text=f"R<sup>2</sup> = {np.round(lm_c['r2'], decimals=3)}", row=1, col=idx+1, showarrow=False)
fig.update_yaxes(title='Session Accuracy (%)', row=1, col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'session_accuracy_vs_probe_accuracy.png'))

In [3]:
## Number of rewards and percent correct
result_dict = {'mouse': [], 'rewards': [], 'percent_correct': [], 'group': [], 'session': []}
for mouse in fixed + criteria + cr_5:
    mouse_path = pjoin(behav_path, mouse)
    group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'
    for session in os.listdir(mouse_path):
        behav_data = pd.read_feather(pjoin(mouse_path, f'{session}'))
        behav_data = behav_data[~behav_data['probe']]
        reward_one, reward_two = np.unique(behav_data['reward_one'])[0], np.unique(behav_data['reward_two'])[0]
        pc = ctb.lick_accuracy(behav_data, port_one=reward_one, port_two=reward_two, by_trials=False)
        result_dict['mouse'].append(mouse)
        result_dict['rewards'].append(behav_data[behav_data['water']].shape[0])
        result_dict['percent_correct'].append(pc)
        result_dict['group'].append(group)
        result_dict['session'].append(np.unique(behav_data['session'])[0])
rewards = pd.DataFrame(result_dict)
last_rewards = pd.DataFrame()
for mouse in fixed + criteria + cr_5:
    mouse_data = rewards[rewards['mouse'] == mouse].reset_index(drop=True)
    mouse_data = mouse_data[(mouse_data['session'] == 'A') | (mouse_data['session'] == 'B') | (mouse_data['session'] == 'C') | 
                            (mouse_data['session'] == 'D') | (mouse_data['session'] == 'AP')].reset_index(drop=True)
    index_list = ctb.pick_context_day(mouse_data, day_index=-1, contexts=['A', 'B', 'C', 'D', 'AP'])
    sub_data = mouse_data.loc[index_list, :]
    last_rewards = pd.concat([last_rewards, sub_data])

  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_port'] == port_two)].dropna().sum()) /
  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_port'] == port_two)].dropna().sum()) /
  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_port'] == port_two)].dropna().sum()) /
  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_port'] == port_two)].dropna().sum()) /
  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_port'] == port_two)].dropna().sum()) /
  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_port'] == port_two)].dropna().sum()) /
  percent_correct = ((count_licks['first_licks'][(count_licks['lick_port'] == port_one) | (count_licks['lick_por

In [None]:
## Plot rewards by session on last day
reward_df = last_rewards.groupby(['group', 'session'], as_index=False).agg({'rewards': ['mean', 'sem']})
pc_df = last_rewards.groupby(['group', 'session'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
rewards_fixed = reward_df[reward_df['group'] == 'fixed']
rewards_criteria = reward_df[reward_df['group'] == 'criteria']
rewards_cr5 = reward_df[reward_df['group'] == 'cr_5']
rewards_cr5 = rewards_cr5[(mouse_data['session'] == 'A') | (mouse_data['session'] == 'B') | (mouse_data['session'] == 'C') | 
                          (mouse_data['session'] == 'D') | (mouse_data['session'] == 'AP')]
fig = px.strip(last_rewards, x='session', y='rewards', color='group', hover_name='mouse',
               color_discrete_sequence=[group_colors['fixed'], group_colors['criteria'], group_colors['cr_5']]).update_traces(showlegend=False, opacity=0.8,
                                                                                                                              marker_line_width=1)
fig.add_trace(go.Bar(x=rewards_fixed['session'], y=rewards_fixed['rewards']['mean'],
                     error_y=dict(type='data', array=rewards_fixed['rewards']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['fixed'], marker_line_color='black', 
                     marker_line_width=2, name='Fixed', opacity=0.8))
fig.add_trace(go.Bar(x=rewards_criteria['session'], y=rewards_criteria['rewards']['mean'],
                     error_y=dict(type='data', array=rewards_criteria['rewards']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['criteria'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria', opacity=0.8))
fig.add_trace(go.Bar(x=rewards_cr5['session'], y=rewards_cr5['rewards']['mean'],
                     error_y=dict(type='data', array=rewards_cr5['rewards']['sem'], thickness=2.5, width=10),
                     marker_color=group_colors['cr_5'], marker_line_color='black', 
                     marker_line_width=2, name='Criteria+5', opacity=0.8))
fig.update_layout(height=500, width=600, template='simple_white', legend_title_text='', yaxis_title='Total Rewards',
                  xaxis_title='', font=dict(size=16))
fig.show()
fig.write_image(pjoin(fig_path, 'total_rewards.png'))

In [None]:
## Mixed ANOVA
last_rewards.mixed_anova(dv='rewards', within='session', subject='mouse', between='group')

In [4]:
## Looking at accuracy 4 days before criteria across mice, across sessions
days_from_df = pd.DataFrame()
for mouse in fixed + criteria + cr_5:
    for session in ['A', 'B', 'C', 'D']:
        subset_data = rewards[(rewards['mouse'] == mouse) & (rewards['session'] == session)]
        days = list(reversed(np.arange(0, subset_data.shape[0])))
        subset_data.insert(0, 'days_from', days)
        days_from_df = pd.concat([days_from_df, subset_data])
avg_days_from = days_from_df.groupby(['group', 'session', 'days_from'], as_index=False).agg({'percent_correct': ['mean', 'sem']})

In [None]:
days_of_interest = 8
fig = pf.custom_graph_template(x_title='Days from Last', y_title='', rows=1, columns=4, titles=['A', 'B', 'C', 'D'], width=1000,
                               shared_x=True, shared_y=True)
for idx, session in enumerate(['A', 'B', 'C', 'D']):
    plot_data = avg_days_from[(avg_days_from['session'] == session) & (avg_days_from['days_from'] <= days_of_interest)]
    
    for group in np.unique(plot_data['group']):
        group_data = plot_data[plot_data['group'] == group]
        if group == 'fixed':
            name = 'Fixed' 
        elif group == 'criteria':
            name = 'Criteria' 
        else:
            name = 'Criteria+5'

        fig.add_trace(go.Scatter(x=group_data['days_from'], y=group_data['percent_correct']['mean'], mode='lines+markers',
                                 line_color=group_colors[group], legendgroup=group, showlegend=False, name=name,
                                 error_y=dict(type='data', array=group_data['percent_correct']['sem'])), row=1, col=idx+1)
fig.update_layout(xaxis=dict(autorange='reversed'))
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'days_from_last.png'))

In [9]:
## Look at days from last without fixed mice
days_of_interest = 4
fig = pf.custom_graph_template(x_title='Days from Last', y_title='', rows=1, columns=4, titles=['A', 'B', 'C', 'D'], width=1000,
                               shared_x=True, shared_y=True)
for idx, session in enumerate(['A', 'B', 'C', 'D']):
    plot_data = avg_days_from[(avg_days_from['session'] == session) & (avg_days_from['days_from'] <= days_of_interest)]
    
    for group in np.unique(plot_data['group']):
        if group == 'fixed':
            pass 
        else:
            group_data = plot_data[plot_data['group'] == group]
            if group == 'criteria':
                name = 'Criteria' 
            else:
                name = 'Criteria+5'

            fig.add_trace(go.Scatter(x=group_data['days_from'], y=group_data['percent_correct']['mean'], mode='lines+markers',
                                    line_color=group_colors[group], legendgroup=group, showlegend=False, name=name,
                                    error_y=dict(type='data', array=group_data['percent_correct']['sem'])), row=1, col=idx+1)
fig.update_layout(xaxis=dict(autorange='reversed'))
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'days_from_last_no_fixed.png'))

In [None]:
## Look at probe accuracy on the first and second day
second_day = pd.DataFrame()
context_list = ['A', 'B', 'C', 'D']
for mouse in fixed + criteria + cr_5:
    if mouse in no_probe:
        next
    else:
        if mouse == 'ms17':
            context_list = ['A', 'B', 'C']
        else:
            context_list = ['A', 'B', 'C', 'D']
            
        mouse_data = probe_df[probe_df['mouse'] == mouse]
        index_list = ctb.pick_context_day(mouse_data, day_index=1, contexts=context_list)
        sub_data = mouse_data.loc[index_list, :]
        second_day = pd.concat([second_day, sub_data])

first_day = pd.DataFrame()
context_list = ['A', 'B', 'C', 'D']
for mouse in fixed + criteria + cr_5:
    if mouse in no_probe:
        next
    else:
        mouse_data = probe_df[probe_df['mouse'] == mouse]
        index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
        sub_data = mouse_data.loc[index_list, :]
        first_day = pd.concat([first_day, sub_data])

first_day.insert(0, 'day_type', 'first')
second_day.insert(0, 'day_type', 'second')
combined_df = pd.concat([first_day, second_day])

In [None]:
## Plot first and second day percent correct across different contexts
avg_combined = combined_df.groupby(['day_type', 'group', 'session'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
context_list = ['A', 'B', 'C', 'D']
fig = pf.custom_graph_template(x_title='', y_title='', width=1200, rows=1, columns=4, 
                               titles=['A', 'B', 'C', 'D'], shared_x=True, shared_y=True)
for idx, session in enumerate(context_list):
    for group in np.unique(avg_combined['group']):
        if group == 'fixed':
            gname = 'Fixed'
        elif group == 'criteria':
            gname = 'Criteria'
        else:
            gname = 'Criteria+5'
        plot_data = avg_combined[(avg_combined['session'] == session) & (avg_combined['group'] == group)]
        fig.add_trace(go.Scatter(x=plot_data['day_type'], y=plot_data['percent_correct']['mean'],
                                 error_y=dict(type='data', array=plot_data['percent_correct']['sem'], thickness=1.5, width=8), 
                                 line_color=group_colors[group], name=gname, legendgroup=group, showlegend=False), row=1, col=idx+1)
        fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_firstsecond.png'))

In [None]:
## Rotated probe analysis
result_dict = {'mouse': [], 'comparison': [], 'accuracy': [], 'group': []}
comparison_list = ['A to B', 'B to C', 'C to D']
for mouse in fixed + criteria + cr_5: ## criteria mice
    if mouse in no_probe:
        next 
    else:
        mouse_path = pjoin(behav_path, mouse)
        group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'

        mouse_data = probe_df[probe_df['mouse'] == mouse]
        if mouse in fixed:
            context_list = ['A', 'B', 'C']
        elif mouse in cr_5:
            context_list = ['A5', 'B5', 'C5']
        else:
            context_list = ['A', 'B', 'C']
        index_list = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
        day_list = mouse_data.loc[index_list, 'day']

        for idx, day in enumerate(day_list):
            sess_one = pd.read_feather(pjoin(mouse_path, f'{mouse}_{day}.feat'))
            sess_two = pd.read_feather(pjoin(mouse_path, f'{mouse}_{day+1}.feat'))
            ## Rotate ports
            rot_one, rot_two = ctb.rotate_ports(input_maze=np.unique(sess_one['maze'])[0], 
                                                output_maze=np.unique(sess_two['maze'])[0], 
                                                reward_one=np.unique(sess_one['reward_one'])[0], 
                                                reward_two=np.unique(sess_one['reward_two'])[0])
            ## Calculate lick accuracy during probe with rotated ports
            lick_acc = ctb.lick_accuracy(sess_two[sess_two['probe']], port_one=rot_one, port_two=rot_two)
            ## Add results
            result_dict['mouse'].append(mouse)
            result_dict['comparison'].append(comparison_list[idx])
            result_dict['accuracy'].append(lick_acc)
            result_dict['group'].append(group)
probe_metric_df = pd.DataFrame(result_dict)
avg_probe = probe_metric_df.groupby(['group', 'comparison'], as_index=False).agg({'accuracy': ['mean', 'sem']})

In [None]:
## Plot rotated probe results
fig = pf.custom_graph_template(x_title='', y_title='', titles=['Criteria+5', 'Criteria', 'Fixed'], 
                               font_size=17, rows=1, columns=3, width=900, shared_y=True)
for idx, group in enumerate(np.unique(probe_metric_df['group'])):
    group_data = probe_metric_df[probe_metric_df['group'] == group]
    avg_group = avg_probe[avg_probe['group'] == group]
    fig.add_trace(go.Bar(x=avg_group['comparison'], y=avg_group['accuracy']['mean'],
                         error_y=dict(type='data', array=avg_group['accuracy']['sem'], thickness=2.5, width=10),
                         marker_color=group_colors[group], marker_line_color='black', 
                         marker_line_width=2, opacity=0.8, showlegend=False), row=1, col=idx+1)
    for mouse in np.unique(group_data['mouse']):
        sub_data = group_data[group_data['mouse'] == mouse]
        fig.add_trace(go.Scatter(x=sub_data['comparison'], y=sub_data['accuracy'],
                                mode='markers', marker=dict(color=group_colors[group], line=dict(width=1)),
                                name=mouse, showlegend=False), row=1, col=idx+1)
    fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1, row=1, col=idx+1)
fig.update_yaxes(title='Lick Accuracy (%)', row=1, col=1)
fig.show()
fig.write_image(pjoin(fig_path, 'rotated_probe_accuracy.png'))

In [None]:
probe_metric_df.mixed_anova(dv='accuracy', within='comparison', subject='mouse', between='group')

In [None]:
## Look at accuracy across days for orthogonal vs non-orthogonal mice
df = days_from_df.copy() 
df.insert(0, 'port_group', np.nan)
orthog_data = pd.DataFrame()
for mouse in fixed + criteria + cr_5: 
    sub = df[df['mouse'] == mouse].reset_index(drop=True)
    if mouse in orthogonal:
        sub.loc[:, 'port_group'] = 'orthog'
    else:
        sub.loc[:, 'port_group'] = 'nonorthog'
    orthog_data = pd.concat([orthog_data, sub])
avg_orthog = orthog_data.groupby(['port_group', 'days_from', 'session'], as_index=False).agg({'percent_correct': ['mean', 'sem']})

In [None]:
days_of_interest = 8
fig = pf.custom_graph_template(x_title='Days from Last', y_title='', rows=1, columns=4, titles=['A', 'B', 'C', 'D'], width=1000,
                               shared_x=True, shared_y=True)
for idx, session in enumerate(['A', 'B', 'C', 'D']):
    plot_data = avg_orthog[(avg_orthog['session'] == session) & (avg_orthog['days_from'] <= days_of_interest)]
    
    for group in np.unique(plot_data['port_group']):
        group_data = plot_data[plot_data['port_group'] == group]
        if group == 'orthog':
            name = 'Orthogonal'
            color = 'turquoise'
        else:
            name = 'Non-Orthogonal' 
            color = 'darkgrey'

        fig.add_trace(go.Scatter(x=group_data['days_from'], y=group_data['percent_correct']['mean'], mode='lines+markers',
                                 line_color=group_colors[group], legendgroup=group, showlegend=False, name=name,
                                 error_y=dict(type='data', array=group_data['percent_correct']['sem'])), row=1, col=idx+1)
fig.update_layout(xaxis=dict(autorange='reversed'))
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'days_from_last_orthogvsnon.png'))

In [None]:
orthog_data.mixed_anova(dv='percent_correct', within='session', subject='mouse', between='port_group')

In [None]:
## Looking at learning curves based on port distance from cue
angle_dict = {'mouse': [], 'group': [], 'port_group': [], 'session': [], 'percent_correct': [], 'min_angle': []}
for mouse in criteria + cr_5:
    mouse_path = pjoin(behav_path, mouse)
    group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'
    port_group = 'orthog' if mouse in orthogonal else 'nonorthog'
    for session in os.listdir(mouse_path):
        behav_data = pd.read_feather(pjoin(mouse_path, f'{session}'))
        behav_data = behav_data[~behav_data['probe']]
        reward_one, reward_two = np.unique(behav_data['reward_one'])[0], np.unique(behav_data['reward_two'])[0]
        maze = np.unique(behav_data['maze'])[0]
        pc = ctb.lick_accuracy(behav_data, port_one=reward_one, port_two=reward_two, by_trials=False)
        if (pd.isna(reward_one)) & (pd.isna(reward_two)):
            next 
        elif pd.isna(reward_two):
            reward_two = reward_one
        else:
            min_angle = ctb.relative_port_distance(reward_one, reward_two, maze)
            angle_dict['mouse'].append(mouse)
            angle_dict['group'].append(group)
            angle_dict['port_group'].append(port_group)
            angle_dict['session'].append(np.unique(behav_data['session'])[0])
            angle_dict['percent_correct'].append(pc)
            angle_dict['min_angle'].append(min_angle)
angle_df = pd.DataFrame(angle_dict)

days_from_new = pd.DataFrame()
for mouse in criteria + cr_5:
    for session in ['A', 'B', 'C', 'D']:
        subset_data = angle_df[(angle_df['mouse'] == mouse) & (angle_df['session'] == session)]
        days = list(reversed(np.arange(0, subset_data.shape[0])))
        subset_data.insert(0, 'days_from', days)
        days_from_new = pd.concat([days_from_new, subset_data])
avg_days_group = days_from_new.groupby(['group', 'session', 'days_from', 'min_angle'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
avg_days_orthog = days_from_new.groupby(['port_group', 'session', 'days_from', 'min_angle'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
avg_days = days_from_new.groupby(['session', 'days_from', 'min_angle'], as_index=False).agg({'percent_correct': ['mean', 'sem']})

In [None]:
days_of_interest = 8
color_list = ['turquoise', 'darkgrey', 'darkorchid', 'green']
fig = pf.custom_graph_template(x_title='Days from Last', y_title='', rows=1, columns=4, 
                               width=1000, titles=['A', 'B', 'C', 'D'], shared_x=True, shared_y=True)
for idx, session in enumerate(['A', 'B', 'C', 'D']):
    for index, group in enumerate(np.unique(avg_days['min_angle'])):
        plot_data = avg_days[(avg_days['min_angle'] == group) & (avg_days['session'] == session) & (avg_days['days_from'] <= days_of_interest)]
        fig.add_trace(go.Scatter(x=plot_data['days_from'], y=plot_data['percent_correct']['mean'], mode='lines+markers',
                                 error_y=dict(type='data', array=avg_days_orthog['percent_correct']['sem']),
                                 name=f'{str(group)} Degrees', showlegend=False, legendgroup=str(group), 
                                 line_color=color_list[index]), row=1, col=idx+1)
fig.update_layout(xaxis=dict(autorange='reversed'))
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig['data'][3]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'days_from_last_port_angles_all.png'))                

In [None]:
days_of_interest = 8
color_list = ['turquoise', 'darkgrey', 'darkorchid', 'green']
fig = pf.custom_graph_template(x_title='Days from Last', y_title='', rows=1, columns=4, 
                               width=1000, titles=['A', 'B', 'C', 'D'], shared_x=True, shared_y=True)
sub_df = avg_days_orthog[avg_days_orthog['port_group'] == 'orthog']
for idx, session in enumerate(['A', 'B', 'C', 'D']):
    for index, group in enumerate(np.unique(avg_days['min_angle'])):
        plot_data = sub_df[(sub_df['min_angle'] == group) & (sub_df['session'] == session) & (sub_df['days_from'] <= days_of_interest)]
        fig.add_trace(go.Scatter(x=plot_data['days_from'], y=plot_data['percent_correct']['mean'], mode='lines+markers',
                                 error_y=dict(type='data', array=avg_days_orthog['percent_correct']['sem']),
                                 name=f'{str(group)} Degrees', showlegend=False, legendgroup=str(group), 
                                 line_color=color_list[index]), row=1, col=idx+1)
fig.update_layout(xaxis=dict(autorange='reversed'))
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig['data'][3]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'days_from_last_port_angles_orthog.png'))

In [None]:
## Days to Criteria for orthog vs nonorthog
days_from_new.insert(0, 'session_count', days_to_df['session'])
days_to_criteria_orthog = days_from_new.groupby(['mouse', 'port_group', 'session'], as_index=False).agg({'session_count': 'count'})
days_to_criteria_angles = days_from_new.groupby(['mouse', 'min_angle', 'session'], as_index=False).agg({'session_count': 'count'})
result_portgroup = days_to_criteria_orthog.groupby(['port_group', 'session'], as_index=False).agg({'session_count': ['mean', 'sem']})
result_angles = days_to_criteria_angles.groupby(['min_angle', 'session'], as_index=False).agg({'session_count': ['mean', 'sem']})

In [None]:
## Plot days to criteria for orthogonal vs nonorthogonal
fig = pf.custom_graph_template(x_title='', y_title='Days to Criteria')
for group in np.unique(result_portgroup['port_group']):
    plot_data = result_portgroup[result_portgroup['port_group'] == group]
    if group == 'nonorthog':
        gname = 'Non-Orthogonal'
    else:
        gname = 'Orthogonal'
    fig.add_trace(go.Scatter(x=plot_data['session'], y=plot_data['session_count']['mean'], mode='lines+markers',
                             error_y=dict(type='data', array=plot_data['session_count']['sem'], thickness=1.5, width=8),
                             name=gname, legendgroup=group, line_color=group_colors[group]))
for mouse in fixed + criteria + cr_5:
    mouse_data = days_to_criteria_orthog[days_to_criteria_orthog['mouse'] == mouse]
    group = 'orthog' if mouse in orthogonal else 'nonorthog'
    fig.add_trace(go.Scatter(x=mouse_data['session'], y=mouse_data['session_count'], mode='markers',
                             marker_color=group_colors[group], legendgroup=group, opacity=0.8, 
                             showlegend=False, marker_line_width=1, name=mouse))
fig.show()
fig.write_image(pjoin(fig_path, 'days_to_criteria_orthog.png'))

In [None]:
days_to_criteria_orthog.mixed_anova(dv='session_count', within='session', subject='mouse', between='port_group')

In [None]:
## Remove A and test main effect of session
anova_data = days_to_criteria[days_to_criteria['session'] != 'A']
anova_data.rm_anova(dv='session_count', within='session', subject='mouse')

In [None]:
## Plot days to criteria for different port angles from cue
fig = pf.custom_graph_template(x_title='', y_title='Days to Criteria')
for idx, group in enumerate(np.unique(result_angles['min_angle'])):
    plot_data = result_angles[result_angles['min_angle'] == group]
    fig.add_trace(go.Scatter(x=plot_data['session'], y=plot_data['session_count']['mean'], mode='lines+markers',
                            error_y=dict(type='data', array=plot_data['session_count']['sem'], thickness=1.5, width=8),
                            name=f'{str(group)} Degrees', legendgroup=str(group), line_color=color_list[idx]))
fig.show()
fig.write_image(pjoin(fig_path, 'days_to_criteria_angles.png'))

In [6]:
## Lick accuracy across trials for fixed, to criteria, and criteria+5
bin_size = 4
chance_color = 'darkgrey'
fig = pf.custom_graph_template(x_title='Trial', y_title='', height=500, width=1000,
                               shared_y=True, shared_x=True, rows=1, columns=3,
                               titles=['Fixed', 'Criteria', 'Criteria+5'])
for mouse in os.listdir(behav_path):
    mouse_path = pjoin(behav_path, mouse)
    group = 'fixed' if mouse in fixed else 'criteria' if mouse in criteria else 'cr_5'
    for session in os.listdir(mouse_path):
        if '_8' in session:
            behav_data = pd.read_feather(pjoin(mouse_path, f'{session}'))
            reward_one, reward_two = np.unique(behav_data['reward_one'])[0], np.unique(behav_data['reward_two'])[0]
            pc = ctb.lick_accuracy(behav_data, reward_one, reward_two, by_trials=True)
            binned_pc = ctb.bin_data(pc, bin_size=bin_size)
            x_data = np.arange(1, len(binned_pc)+1) * bin_size
            if group == 'fixed':
                col = 1
            elif group == 'criteria':
                col = 2
            else:
                col = 3
            
            fig.add_trace(go.Scatter(x=x_data, y=binned_pc, mode='lines', opacity=0.8,
                                        line_color=group_colors[group], showlegend=False, name=mouse), row=1, col=col)

        else:
            pass
fig.add_hline(y=75, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig.show()


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid value encountered in scalar divide


invalid v