In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from os.path import join as pjoin
from natsort import natsorted
import statsmodels.api as sm
from statsmodels.formula.api import ols
from statsmodels.stats.power import TTestIndPower

sys.path.append("../")
import circletrack_behavior as ctb
import plotting_functions as pf

def set_group(mouse, control_mice):
    return 'Two-context' if mouse in control_mice else 'Multi-context'

In [None]:
## Path settings
project_folder = 'MultiCon_Imaging'
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6', 'MultiCon_Imaging7']
dpath = f'../../{project_folder}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots'

## Plot settings
chance_color = '#7d7d7d'
avg_color = '#378616'
subject_color = '#7d7d7d'
ce_colors = ['#7A22BC', '#378616']
ce_colors_dict = {'Two-context': '#378616', 'Multi-context': '#7A22BC'}
symbol_dict = {'Two-context': 'x', 'Multi-context': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073',
                  'A1-5': '#00802d', 'A5-10': '#006c79', 'A10-15': '#004da4'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
mf_colors = ['darkorchid', 'midnightblue']
error_color = {'Two-context': 'rgba(55, 134, 22, 0.6)', 'Multi-context': 'rgba(122, 34, 188, 0.6)'}
male_mice = ['mc44', 'mc46', 'mc54', 'mc55', 'mc57', 'mc64', 'mc65']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc57', 'mc59', 'mc60', 'mc61', 'mc64']
excluded_mice = ['mc48', 'mc57'] ## mc48 due to reward port not delivering in B and changing their behavior, 
## mc61 because they have fewer than 90 cells
experimental_sessions = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_sessions = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
day_names = [f'Day {x}' for x in np.arange(1, 21)]
lick_thresh = 5

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

### Plot weight percent change across days of the experiment.

In [None]:
weight_data = pd.read_csv(pjoin(fig_path, 'intermediate_data/weight_percent_change.csv'))
column_list = [f'{x}' for x in np.arange(1, 28)]
weight_melt = weight_data.melt(id_vars=['Mouse'], value_vars=column_list, var_name='day', value_name='percent_change')
weight_melt['day'] = weight_melt['day'].astype(int)
avg_weight = weight_melt.groupby(['day'], as_index=False).agg({'percent_change': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='Weight Change from Baseline (%)')
fig.add_trace(go.Scatter(x=natsorted(avg_weight['day']), y=avg_weight['percent_change']['mean'], mode='lines+markers', line_color=avg_color,
                         error_y=dict(type='data', array=avg_weight['percent_change']['sem']), showlegend=False))

fig.update_yaxes(range=[75, 101])
fig.add_vline(x=7.5, line_width=1, opacity=1, line_dash='dash', line_color=chance_color)
fig.add_hline(y=80, line_width=1, opacity=1, line_dash='dash', line_color=chance_color)
fig.show()
fig.write_image(pjoin(fig_path, 'weight_from_baseline.png'), height=500, width=500)

### Plot fluorescence intensity across days of the experiment.

In [None]:
fluo_data = pd.read_csv(pjoin(fig_path, 'intermediate_data/fluorescence_data.csv'))
avg_fluo = fluo_data.groupby(['day'], as_index=False).agg({'mean_fluorescence': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='Mean Fluorescence (a.u.)')
fig.add_trace(go.Scatter(x=natsorted(avg_fluo['day']), y=avg_fluo['mean_fluorescence']['mean'], mode='lines+markers', line_color=avg_color,
                         error_y=dict(type='data', array=avg_fluo['mean_fluorescence']['sem']), showlegend=False))
for mouse in fluo_data['mouse'].unique():
    mdata = fluo_data[fluo_data['mouse'] == mouse]
    fig.add_trace(go.Scatter(x=mdata['day'], y=mdata['mean_fluorescence'], mode='lines', line_color=subject_color, 
                             name=mouse, opacity=0.7, line=dict(width=1), showlegend=False))
fig.show()

### Linear track rewards across days.

In [None]:
## Linear track behavior
data_of_interest = 'lin_behav' ## one of behav, aligned_minian, lin_behav
lin_dict = {'mouse': [], 'experiment': [], 'group': [], 'group_two': [], 'day': [], 'rewards': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            
            for idx, session in enumerate(os.listdir(mpath)):
                behav = pd.read_feather(pjoin(mpath, session))
                lin_dict['mouse'].append(mouse)
                lin_dict['experiment'].append(behav['cohort'].unique()[0])
                lin_dict['group'].append(sex)
                lin_dict['group_two'].append(group)
                lin_dict['day'].append(idx+1)
                lin_dict['rewards'].append(np.sum(behav['water']))
lin_df = pd.DataFrame(lin_dict)

In [None]:
## Plot rewards across days on linear track
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=False,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.show()
fig.write_image(pjoin(fig_path, 'linear_track_rewards.png'))

In [None]:
## Plot rewards across days on linear track between male and female
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day', 'group'], plot_transitions=None,
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=False,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.show()
fig.write_image(pjoin(fig_path, 'linear_track_rewards_mf.png'))
lin_df.mixed_anova(dv='rewards', within='day', between='group', subject='mouse')

### Circle track lick accuracy and rewards across days.

In [None]:
## Circle track behavior
lick_thresh = 5
data_of_interest = 'behav' ## one of behav, aligned_minian, lin_behav
circ_dict = {'mouse': [], 'experiment': [], 'sex': [], 'group': [], 'day': [], 'session': [], 'rewards': [], 'percent_correct': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if mouse in excluded_mice:
                pass 
            else:
                mpath = pjoin(exp_path, mouse)
                sex = 'Male' if mouse in male_mice else 'Female'
                group = set_group(mouse, control_mice)
                for idx, session in enumerate(natsorted(os.listdir(mpath))):
                    behav = pd.read_feather(pjoin(mpath, session))
                    behav = behav[~behav['probe']] ## exclude probe
                    reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]    
                    pc = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=lick_thresh, by_trials=False)
                    circ_dict['mouse'].append(mouse)
                    circ_dict['experiment'].append(behav['cohort'].unique()[0])
                    circ_dict['sex'].append(sex)
                    circ_dict['group'].append(group)
                    circ_dict['day'].append(idx+1)
                    circ_dict['session'].append(behav['session'].unique()[0])
                    circ_dict['rewards'].append(np.sum(behav['water']))
                    circ_dict['percent_correct'].append(pc)
ct_df = pd.DataFrame(circ_dict)

In [None]:
## Plot lick accuracy for just experimental mice
fig = pf.plot_behavior_across_days(ct_df[ct_df['group'] == 'Multi-context'], x_var='day', y_var='percent_correct', groupby_var=['day'], plot_transitions=[5.5, 10.5, 15.5],
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=True, transition_color=['darkgrey', 'darkgrey', 'darkgrey'],
                                   plot_datapoints=False, x_title='Day', y_title='Lick Accuracy (%)', titles=['Circle Track'], height=500, width=500)
fig.update_yaxes(range=[0, 101])
fig.show()
# fig.write_image(pjoin(fig_path, 'lick_accuracy_experimental_only.png'))

In [None]:
## Plot 5th lick accuracy across days for control vs experimental
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 21], x_var='day', y_var='percent_correct', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=True, transition_color=['darkgrey', 'darkgrey', 'darkgrey'],
                                   symbols=symbols_list, plot_datapoints=False, x_title='Day', y_title='Lick Accuracy (%)', titles=[''], height=500, width=600)
fig.update_yaxes(range=[0, 101])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_accuracy_two_four_context.png'), width=600, height=500)

## Perform a linear regression to test for differences between groups. 
data = ct_df[ct_df['day'] == 16].reset_index(drop=True)
model = ols('percent_correct ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Power analysis
n1, n2 = 8, 8
s1, s2 = np.std(data['percent_correct'][data['group'] == 'Multi-context'], ddof=1)**2, np.std(data['percent_correct'][data['group'] == 'Two-context'], ddof=1)**2
s = np.sqrt(((n1 -1) * s1 + (n2 - 1) * s2) / (n1 + n2 - 2))
u1, u2 = np.mean(data['percent_correct'][data['group'] == 'Multi-context']), np.mean(data['percent_correct'][data['group'] == 'Two-context'])
d = (u1 - u2) / s
print(f'Effect size: {d}')
alpha = 0.05
power = 0.8

obj = TTestIndPower()
n = obj.solve_power(effect_size=d, alpha=alpha, power=power, ratio=1, alternative='two-sided')
print(f'Sample size: {n}')

In [None]:
## Plot rewards across days for control and experimental
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 21], x_var='day', y_var='rewards', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=False, transition_color=['darkgrey', 'darkgrey', 'darkgrey'],
                                   symbols=symbols_list, plot_datapoints=False, x_title='Day', y_title='Rewards', titles=[''], height=500, width=600)
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_two_four_context.png'), height=500, width=600)

## Perform a linear regression to test for differences between groups. 
data = ct_df[ct_df['day'] == 16].reset_index(drop=True)
model = ols('rewards ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot 5th lick accuracy across days for control vs experimental including reversal
fig = pf.plot_behavior_across_days(ct_df, x_var='day', y_var='percent_correct', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5, 20.5],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=True, transition_color=['darkgrey', 'darkgrey', 'darkgrey', 'darkgrey'],
                                   symbols=symbols_list, plot_datapoints=False, x_title='Day', y_title='Lick Accuracy (%)', titles=[''], height=500, width=600)
fig.update_yaxes(range=[0, 101])
fig.show()

## Perform a linear regression to test for differences between groups. 
data = ct_df[ct_df['day'] == 21].reset_index(drop=True)
model = ols('percent_correct ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot rewards across days for control and experimental including reversal
fig = pf.plot_behavior_across_days(ct_df, x_var='day', y_var='rewards', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5, 20.5],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=False, transition_color=['darkgrey', 'darkgrey', 'darkgrey', 'darkgrey'],
                                   symbols=symbols_list, plot_datapoints=False, x_title='Day', y_title='Rewards', titles=[''], height=500, width=600)
fig.show()

## Perform a linear regression to test for differences between groups. 
data = ct_df[ct_df['day'] == 21].reset_index(drop=True)
model = ols('rewards ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot lick accuracy between male and female mice in two-context group
fig = pf.plot_behavior_across_days(ct_df[(ct_df['group'] == 'Two-context') & (ct_df['day'] < 21)], x_var='day', y_var='percent_correct', groupby_var=['day', 'sex'],
                                   marker_color=mf_colors, avg_color=avg_color, expert_line=False, chance=True, plot_transitions=[5.5, 10.5, 15.5],
                                   transition_color=['darkgrey', 'darkgrey', 'darkgrey'], symbols=symbols_list,
                                   plot_datapoints=False, x_title='Day', y_title='Lick Accuracy (%)', titles=[''], height=500, width=500)
fig.update_yaxes(range=[0, 101])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_accuracy_two-context_mf.png'), width=600, height=500)

In [None]:
## Plot lick accuracy between male and female mice in multi-context group
fig = pf.plot_behavior_across_days(ct_df[(ct_df['group'] == 'Multi-context') & (ct_df['day'] < 21)], x_var='day', y_var='percent_correct', groupby_var=['day', 'sex'],
                                   marker_color=mf_colors, avg_color=avg_color, expert_line=False, chance=True, plot_transitions=[5.5, 10.5, 15.5],
                                   transition_color=['darkgrey', 'darkgrey', 'darkgrey'], symbols=symbols_list,
                                   plot_datapoints=False, x_title='Day', y_title='Lick Accuracy (%)', titles=[''], height=500, width=500)
fig.update_yaxes(range=[0, 101])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_accuracy_multi-context_mf.png'), width=600, height=500)

## Perform a linear regression to test for differences between groups. 
data = ct_df[(ct_df['day'] == 16) & (ct_df['group'] == 'Multi-context')].reset_index(drop=True)
model = ols('rewards ~ C(sex)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot lick accuracy as separate lines for A, B, C, and D for experimental mice
exp_df = ct_df[(ct_df['group'] == 'Multi-context') & (ct_df['session'] != 'AR')]
fig = pf.custom_graph_template(x_title='Day in Context', y_title='Lick Accuracy (%)')
for session in ['A', 'D']:
    sdata = exp_df[exp_df['session'] == session].reset_index(drop=True)
    sdata.loc[:, 'context_day'] = [1, 2, 3, 4, 5] * int(sdata.shape[0]/5)
    avg_sdata = sdata.groupby(['context_day'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
    fig.add_trace(go.Scatter(x=avg_sdata['context_day'], y=avg_sdata['percent_correct']['mean'], name=session,
                             error_y=dict(type='data', array=avg_sdata['percent_correct']['sem']), line_color=context_colors[session]))
fig.add_hline(y=25, line_dash='dash', line_color=chance_color, line_width=1, opacity=1)
fig.update_yaxes(range=[0, 101])
fig.show()
fig.write_image(pjoin(fig_path, 'multi_context_AD.png'), width=500, height=500)

In [None]:
## See if there is a significant difference in lick accuracy between certain days
days = [6, 11]
data = exp_df[(exp_df['day'] == days[0]) | (exp_df['day'] == days[1])].reset_index(drop=True)
model = ols('percent_correct ~ C(session)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot lick accuracy as separate lines for A, B, C, and D for control mice
tc_df = ct_df[(ct_df['group'] == 'Two-context') & (ct_df['session'] != 'AR')]
tc_df = tc_df[(tc_df['day'] < 6) & (tc_df['session'] == 'A') | (tc_df['day'] > 15) & (tc_df['session'] == 'B')]
tc_df = tc_df.replace({'B': 'D'})

fig = pf.custom_graph_template(x_title='Day in Context', y_title='Lick Accuracy (%)')
for session in ['A', 'B']:
    gdata = tc_df[tc_df['session'] == session].reset_index(drop=True)
    gdata.loc[:, 'context_day'] = [1, 2, 3, 4, 5] * int(gdata.shape[0]/5)
    avg_gdata = gdata.groupby(['context_day'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
    fig.add_trace(go.Scatter(x=avg_gdata['context_day'], y=avg_gdata['percent_correct']['mean'], name=session,
                             error_y=dict(type='data', array=avg_gdata['percent_correct']['sem']), line_color=context_colors[session]))
fig.add_hline(y=25, line_dash='dash', line_color=chance_color, line_width=1, opacity=1)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'two_context_AD.png'), width=500, height=500)

In [None]:
## Plot lick accuracy as separate lines for A, B, C, and D for control mice
two_cont_list = ['A1-5', 'A1-5', 'A1-5', 'A1-5', 'A1-5', 'A5-10', 'A5-10', 'A5-10', 'A5-10', 'A5-10', 
                 'A10-15', 'A10-15', 'A10-15', 'A10-15', 'A10-15', 'D', 'D', 'D', 'D', 'D']
exp_df = ct_df[(ct_df['group'] == 'Two-Context') & (ct_df['session'] != 'AR')]
exp_df = exp_df.replace({'B': 'D'})
exp_df['session_two'] = two_cont_list * np.unique(exp_df['mouse']).shape[0]

fig = pf.custom_graph_template(x_title='Day in Context', y_title='Lick Accuracy (%)')
for session in ['A1-5', 'A5-10', 'A10-15', 'D']:
    gdata = exp_df[exp_df['session_two'] == session].reset_index(drop=True)
    gdata.loc[:, 'context_day'] = [1, 2, 3, 4, 5] * int(gdata.shape[0]/5)
    avg_gdata = gdata.groupby(['context_day'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
    fig.add_trace(go.Scatter(x=avg_gdata['context_day'], y=avg_gdata['percent_correct']['mean'], name=session,
                             error_y=dict(type='data', array=avg_gdata['percent_correct']['sem']), line_color=context_colors[session]))
fig.add_hline(y=25, line_dash='dash', line_color=chance_color, line_width=1, opacity=1)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'two_context_A5A10A15D.png'))

In [None]:
## Plot lick accuracy in D for both groups
exp_df = ct_df[ct_df['session'] != 'AR']
exp_df = exp_df[exp_df['day'] > 15]
exp_df = exp_df.replace({'B': 'D'})

fig = pf.custom_graph_template(x_title='Day in Context', y_title='Lick Accuracy (%)', titles=['Context D'])
for group in ['Two-Context', 'Multi-Context']:
    gdata = exp_df[exp_df['group'] == group].reset_index(drop=True)
    gdata.loc[:, 'context_day'] = [1, 2, 3, 4, 5] * int(gdata.shape[0]/5)
    avg_gdata = gdata.groupby(['context_day'], as_index=False).agg({'percent_correct': ['mean', 'sem']})
    fig.add_trace(go.Scatter(x=avg_gdata['context_day'], y=avg_gdata['percent_correct']['mean'], name=group, mode='lines+markers',
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=avg_gdata['percent_correct']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.add_hline(y=25, line_width=1, opacity=1, line_dash='dash', line_color=chance_color)
fig.show()
fig.write_image(pjoin(fig_path, 'lick_accuracy_D_multi_two.png'))

### Example mouse linear position, lick accuracy across days.

In [None]:
## Plot linearized position
mouse = 'mc44'
position_color = 'darkgrey'
data_of_interest = 'behav' 
fig = pf.custom_graph_template(x_title='', y_title='', rows=5, columns=1, shared_y=True, shared_x=True, 
                            titles=['Context A'], width=800, height=1000)

exp_path = pjoin(dpath, f'{experiment_folders[0]}/output/{data_of_interest}/')
mpath = pjoin(exp_path, mouse)
sex = 'Male' if mouse in male_mice else 'Female'
group = set_group(mouse, control_mice)
for idx, session in enumerate(os.listdir(mpath)[0:5]):
    data_out = pd.read_feather(pjoin(mpath, session))
    lin_pos = data_out['a_pos']
    fig.add_trace(go.Scattergl(x=data_out['t'], y=lin_pos, mode='lines', marker_color=position_color, marker_size=2, showlegend=False), row=idx + 1, col=1)
    fig.add_trace(go.Scattergl(x=data_out['t'][data_out['water']], y=lin_pos[data_out['water']], 
                            mode='markers', marker_color='red', marker_size=6, name='Rewards'), row=idx + 1, col=1)
    fig.add_trace(go.Scattergl(x=data_out['t'][data_out['lick_port'] != -1], y=lin_pos[data_out['lick_port'] != -1],
                            mode='markers', marker_color='black', marker_size=4, opacity=0.3, name='Licks'), row=idx + 1, col=1)
fig.update_yaxes(title='Position (deg)', col=1)
fig.update_xaxes(title='Time (s)', row=5)
fig.show()
# fig.write_image(pjoin(fig_path, 'lin_pos_with_reward.png'))

In [None]:
## Plot linearized position
mouse = 'mc44'
position_color = 'darkgrey'
data_of_interest = 'behav' 
fig = pf.custom_graph_template(x_title='Time (s)', y_title='', rows=1, columns=2, shared_y=True, shared_x=True, 
                               titles=['Day 1', 'Day 5'], width=1200, height=500)

exp_path = pjoin(dpath, f'{experiment_folders[0]}/output/{data_of_interest}/')
mpath = pjoin(exp_path, mouse)
sex = 'Male' if mouse in male_mice else 'Female'
group = set_group(mouse, control_mice)
count = 0
for idx, session in enumerate(os.listdir(mpath)[0:5]):
    if (idx == 0) | (idx == 4):
        count += 1
        data_out = pd.read_feather(pjoin(mpath, session))
        lin_pos = data_out['a_pos']
        fig.add_trace(go.Scattergl(x=data_out['t'], y=lin_pos, mode='lines', marker_color=position_color, marker_size=2, showlegend=False), row=1, col=count)
        fig.add_trace(go.Scattergl(x=data_out['t'][data_out['water']], y=lin_pos[data_out['water']], 
                                mode='markers', marker_color='red', marker_size=10, name='Rewards', showlegend=False), row=1, col=count)
        fig.add_trace(go.Scattergl(x=data_out['t'][data_out['lick_port'] != -1], y=lin_pos[data_out['lick_port'] != -1],
                                mode='markers', marker_color='black', marker_size=6, opacity=0.6, name='Licks', showlegend=False), row=1, col=count)
fig.update_yaxes(title='Position (deg)', col=1)
fig['data'][1]['showlegend'] = True 
fig['data'][2]['showlegend'] = True 
fig.show()
fig.write_image(pjoin(fig_path, 'lin_pos_with_reward_day1_day5.png'))

In [None]:
## Plot lick accuracy across days
mdata = ct_df[(ct_df['mouse'] == mouse) & (ct_df['day'] < 6)]
fig = pf.custom_graph_template(x_title='Day', y_title='Lick Accuracy (%)')
fig.add_trace(go.Scatter(x=mdata['day'], y=mdata['percent_correct'], mode='lines', line_color=subject_color))
fig.add_hline(y=25, line_dash='dash', line_width=1, line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 101])
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_example_lick_acc.png'))

In [None]:
## Plot lick accuracy across trials across days for an example mouse. Must run code below to get dataframe
mdata = trial_df[(trial_df['mouse'] == mouse) & (trial_df['day'] < 6)]
fig = pf.custom_graph_template(x_title='', y_title='', rows=1, columns=5, shared_y=True, shared_x=True,
                                titles=['A1', 'A2', 'A3', 'A4', 'A5'], width=1000, height=300)

for day in mdata['day'].unique():
    tdata = mdata[mdata['day'] == day]
    x_data = np.arange(1, np.array(tdata['trial'])[-1]*bin_size, bin_size)
    fig.add_trace(go.Scatter(x=x_data, y=tdata['lick_acc'], mode='lines', line_color=subject_color, showlegend=False), row=1, col=day)
fig.add_hline(y=25, line_dash='dash', line_width=1, line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 101])
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig.update_xaxes(title='Trial')
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_acc_across_trials.png'))

### Plot probe accuracies separately for control and experimental mice.

In [None]:
data_of_interest = 'behav' ## one of behav, aligned_minian, lin_behav
lick_dict_probe = {'mouse': [], 'experiment': [], 'session': [], 'sex': [], 'group': [],
                   'day': [], 'num_licks': [], 'probe_acc': [], 'session_acc': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if (mouse == 'mc48') | (mouse == 'mc57'):
                pass 
            else:
                mpath = pjoin(exp_path, mouse)
                sex = 'Male' if mouse in male_mice else 'Female'
                group = set_group(mouse, control_mice)
                for idx, session in enumerate(os.listdir(mpath)):
                    behav = pd.read_feather(pjoin(mpath, session))
                    if any(behav['probe']):
                        behav_probe = behav[behav['probe']]
                        behav_no_probe = behav[~behav['probe']]
                        reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                        percent_correct = ctb.lick_accuracy(behav_probe, port_list=[reward_one, reward_two], lick_threshold=lick_thresh, by_trials=False)
                        session_pc = ctb.lick_accuracy(behav_no_probe, port_list=[reward_one, reward_two], lick_threshold=lick_thresh, by_trials=False)
                        lick_dict_probe['mouse'].append(mouse)
                        lick_dict_probe['experiment'].append(behav['cohort'].unique()[0])
                        lick_dict_probe['sex'].append(sex)
                        lick_dict_probe['group'].append(group)
                        lick_dict_probe['day'].append(idx+1)
                        lick_dict_probe['session'].append(np.unique(behav['session'])[0])
                        lick_dict_probe['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
                        lick_dict_probe['probe_acc'].append(percent_correct)
                        lick_dict_probe['session_acc'].append(session_pc)
                    else:
                        pass
probe_df = pd.DataFrame(lick_dict_probe)

In [None]:
## Plot probe performance for experimental mice
first_last = pd.DataFrame()
context_list = ['A', 'B', 'C', 'D']
df = probe_df[probe_df['group'] == 'Experimental']
for mouse in df['mouse'].unique():
    mouse_data = df[df['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])
avg_combined = first_last.groupby(['day_type', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
avg_combined = avg_combined.replace({'First': '1', 'Last': '5'})

fig = pf.custom_graph_template(x_title='Day', y_title='', width=1200, rows=1, columns=len(context_list), 
                               titles=['A', 'B', 'C', 'D'], shared_x=True, shared_y=True)
for idx, session in enumerate(context_list):
    plot_data = avg_combined[avg_combined['session'] == session]
    fig.add_trace(go.Scatter(x=plot_data['day_type'], y=plot_data['probe_acc']['mean'], mode='markers',
                                error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                line_color=avg_color, showlegend=False), row=1, col=idx+1)
# for mouse in first_last['mouse'].unique():
#     mdata = first_last[first_last['mouse'] == mouse]
#     for idx, context in enumerate(mdata['session'].unique()):
#         pdata = mdata[mdata['session'] == context]
#         fig.add_trace(go.Scatter(x=pdata['day_type'], y=pdata['probe_acc'], mode='lines', line_color=chance_color,
#                                  line_width=1, opacity=0.7, showlegend=False, name=mouse), row=1, col=idx+1)

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', range=[0, 101], col=1)
fig.show()
fig.write_image(pjoin(fig_path, 'probe_acc_experimental.png'))

In [None]:
## Plot probe performance for just A
## Plot probe performance for experimental mice
first_last = pd.DataFrame()
context_list = ['A']
# df = probe_df[probe_df['group'] == 'Experimental']
df = probe_df.copy()
for mouse in df['mouse'].unique():
    mouse_data = df[df['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])
avg_combined = first_last.groupby(['day_type', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
avg_combined = avg_combined.replace({'First': '1', 'Last': '5'})

fig = pf.custom_graph_template(x_title='Day', y_title='', width=500, rows=1, columns=len(context_list), 
                               titles=['Probe Accuracy'], shared_x=True, shared_y=True)
for idx, session in enumerate(context_list):
    plot_data = avg_combined[avg_combined['session'] == session]
    fig.add_trace(go.Scatter(x=plot_data['day_type'], y=plot_data['probe_acc']['mean'], mode='markers',
                                error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                line_color=avg_color, showlegend=False), row=1, col=idx+1)

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', range=[0, 101], col=1)
fig.show()
fig.write_image(pjoin(fig_path, 'probe_acc_just_A.png'))

In [None]:
## Plot probe performance for control mice
df = probe_df[probe_df['group_two'] == 'Control']
avg = df.groupby(['day', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', width=800, rows=1, columns=2, 
                               titles=['A', 'B'], shared_x=True, shared_y=True)
for idx, session in enumerate(avg['session'].unique()):
    plot_data = avg[avg['session'] == session]
    fig.add_trace(go.Scatter(x=plot_data['day'], y=plot_data['probe_acc']['mean'],
                                error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                line_color=ce_colors[0], showlegend=False), row=1, col=idx+1)
for mouse in df['mouse'].unique():
    mdata = df[df['mouse'] == mouse]
    for idx, context in enumerate(mdata['session'].unique()):
        pdata = mdata[mdata['session'] == context]
        fig.add_trace(go.Scatter(x=pdata['day'], y=pdata['probe_acc'], mode='lines', line_color=chance_color,
                                 line_width=1, opacity=0.7, showlegend=False, name=mouse), row=1, col=idx+1)
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', range=[0, 100], col=1)
fig.show()
fig.write_image(pjoin(fig_path, 'probe_acc_control.png'))

### Plot probe accuracy for both groups on the same plot

In [None]:
## Plot probe performance for control and experimental mice
probe_df = pd.DataFrame(lick_dict_probe)
idx_probe = (probe_df['day'] == 15) & (probe_df['group'] != 'Experimental')
probe_df = probe_df[~idx_probe] ## remove day 15 for control
first_last = pd.DataFrame()
context_list = ['A', 'B', 'C', 'D']
df = probe_df[probe_df['group'] == 'Experimental']
for mouse in df['mouse'].unique():
    mouse_data = df[df['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])

df_c = probe_df[probe_df['group'] == 'Control']
context_list = ['A', 'B']
for mouse in df_c['mouse'].unique():
    mouse_data = df_c[df_c['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])

avg = first_last.groupby(['day_type', 'session', 'group', 'day'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
avg = avg.replace({'First': '1', 'Last': '5'})

In [None]:
plot_mice = False
fig = pf.custom_graph_template(x_title='Day in Context', y_title='', width=1000, rows=1, columns=4, 
                               titles=['A', 'B', 'C', 'D'], shared_x=True, shared_y=True)
for g_idx, group in enumerate(avg['group'].unique()):
    gdata = avg[avg['group'] == group]
    for idx, session in enumerate(gdata['session'].unique()):
        if (group == 'Control') & (session == 'B'):
            idx = idx + 2
        plot_data = gdata[gdata['session'] == session]
        fig.add_trace(go.Scatter(x=plot_data['day_type'], y=plot_data['probe_acc']['mean'], name=plot_data['group'].unique()[0],
                                    error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), marker_symbol=symbol_dict[group],
                                    line_color=ce_colors[g_idx], showlegend=False, legendgroup=plot_data['group'].unique()[0]), row=1, col=idx+1)
if plot_mice:
    for mouse in first_last['mouse'].unique():
        mdata = first_last[first_last['mouse'] == mouse]
        for idx, context in enumerate(mdata['session'].unique()):
            pdata = mdata[mdata['session'] == context]
            fig.add_trace(go.Scatter(x=pdata['day_type'], y=pdata['probe_acc'], mode='markers', line_color=ce_colors_dict[pdata['group'].unique()[0]],
                                    line_width=1, opacity=0.7, showlegend=False, name=mouse, legendgroup=pdata['group'].unique()[0]), row=1, col=idx+1)

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', range=[0, 101], col=1)
fig['data'][0]['showlegend'] = True
fig['data'][4]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'probe_acc_control_and_experimental.png'))

### Plot lick accuracy across trials for the probe

In [None]:
lick_dict = {'mouse': [], 'experiment': [], 'session': [], 'day': [], 'num_licks': [], 'probe_acc': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            for idx, session in enumerate(os.listdir(mpath)):
                behav = pd.read_feather(pjoin(mpath, session))
                if any(behav['probe']):
                    behav_probe = behav[behav['probe']]
                    behav_no_probe = behav[~behav['probe']]
                    reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                    percent_correct = ctb.lick_accuracy(behav_probe, port_one=reward_one, port_two=reward_two, lick_threshold=lick_thresh, by_trials=True)
                    lick_dict['mouse'].append(mouse)
                    lick_dict['experiment'].append(behav['cohort'].unique()[0])
                    lick_dict['day'].append(idx+1)
                    lick_dict['session'].append(np.unique(behav['session'])[0])
                    lick_dict['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
                    lick_dict['probe_acc'].append(percent_correct)
        else:
            pass
probe_trial_df = pd.DataFrame(lick_dict)

# first_last_trial = pd.DataFrame()
# context_list = ['A', 'B', 'C', 'D']
# for mouse in probe_trial_df['mouse'].unique():
#     mouse_data = probe_trial_df[probe_trial_df['mouse'] == mouse]
#     index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
#     index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
#     sub_data = mouse_data.loc[index_list, :]
#     sub_data.insert(0, 'day_type', 'First')
#     sub_data_two = mouse_data.loc[index_list_two, :]
#     sub_data_two.insert(0, 'day_type', 'Last')
#     first_last_trial = pd.concat([first_last_trial, sub_data, sub_data_two])

### Plot lick accuracy across trials for each session.

In [None]:
bin_size = 5
data_of_interest = 'behav'
trial_res = {'mouse': [], 'sex': [], 'group': [], 'day': [], 'session': [], 'session_two': [], 'trial': [], 'lick_acc': [], 'rewards': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if mouse in excluded_mice:
                pass 
            else:
                mpath = pjoin(exp_path, mouse)
                sex = 'Male' if mouse in male_mice else 'Female'
                group = set_group(mouse, control_mice)
                for idx, session in enumerate(natsorted(os.listdir(mpath))):
                    behav = pd.read_feather(pjoin(mpath, f'{session}'))
                    reward_one, reward_two = behav['reward_one'].unique()[0], behav['reward_two'].unique()[0]
                    trial_acc = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=lick_thresh, by_trials=True)
                    num_rewards = ctb.rewards_across_trials(behav)
                    
                    if bin_size > 0:
                        binned_acc = ctb.bin_data(trial_acc, bin_size)
                        binned_rewards = ctb.bin_data(num_rewards, bin_size)
                    else:
                        binned_acc = trial_acc
                        binned_rewards = num_rewards

                    for trial, val in enumerate(binned_acc):
                        trial_res['mouse'].append(mouse)
                        trial_res['sex'].append(sex)
                        trial_res['group'].append(group)
                        trial_res['day'].append(idx+1)
                        trial_res['session'].append(behav['session'].unique()[0])
                        trial_res['session_two'].append(behav['session_two'].unique()[0])
                        trial_res['trial'].append(trial+1)
                        trial_res['lick_acc'].append(val)
                        trial_res['rewards'].append(binned_rewards[trial])
trial_df = pd.DataFrame(trial_res)
avg_acc = trial_df.groupby(['group', 'day', 'session_two', 'trial'], as_index=False).agg({'lick_acc': ['mean', 'sem']})
avg_acc_mf = trial_df.groupby(['group', 'day', 'session_two', 'sex', 'trial'], as_index=False).agg({'lick_acc': ['mean', 'sem']})

In [None]:
## Plot trial accuracy for each day
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=experimental_sessions, height=1000, width=1000,
                               shared_y=True, shared_x=True)
for group in avg_acc['group'].unique():
    gdata = avg_acc[avg_acc['group'] == group]
    for idx, day in enumerate(gdata['day'].unique()):
        if idx == 20:
            pass
        else:
            day_data = gdata[gdata['day'] == day]

            if idx < 5:
                row, col = 1, idx + 1
            elif (idx >= 5) & (idx < 10):
                row, col = 2, idx - 4
            elif (idx >= 10) & (idx < 15):
                row, col = 3, idx - 9
            else:
                row, col = 4, idx - 14

            x_data = np.arange(1, np.array(day_data['trial'])[-1]*bin_size, bin_size)
            upper = day_data['lick_acc']['mean'] + day_data['lick_acc']['sem']
            lower = day_data['lick_acc']['mean'] - day_data['lick_acc']['sem']

            fig.add_trace(go.Scatter(x=x_data, y=day_data['lick_acc']['mean'], mode='lines', line_color=ce_colors_dict[group], showlegend=False), row=row, col=col)
            fig.add_trace(go.Scatter(x=x_data, y=upper, mode='lines', marker=dict(color=error_color[group]),
                                    name='Upper Bound', line=dict(width=0), showlegend=False), row=row, col=col)
            fig.add_trace(go.Scatter(x=x_data, y=lower, mode='lines', marker=dict(color=error_color[group]),
                                    name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'), row=row, col=col)
                             
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig.update_xaxes(title='Trial', row=4)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'acc_trials_experimental_all_days.png'))

In [None]:
## Plot trial accuracy for each day for control mice
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=control_sessions, height=1000, width=1000,
                               shared_y=True, shared_x=True)
for group in avg_acc['group'].unique():
    if group == 'Experimental':
        pass 
    else:
        gdata = avg_acc[avg_acc['group'] == group]
        for idx, day in enumerate(gdata['day'].unique()):
            if idx == 20:
                pass 
            else:
                day_data = gdata[gdata['day'] == day]

                if idx < 5:
                    row, col = 1, idx + 1
                elif (idx >= 5) & (idx < 10):
                    row, col = 2, idx - 4
                elif (idx >= 10) & (idx < 15):
                    row, col = 3, idx - 9
                else:
                    row, col = 4, idx - 14

                x_data = np.arange(1, np.array(day_data['trial'])[-1]*bin_size, bin_size)
                upper = day_data['lick_acc']['mean'] + day_data['lick_acc']['sem']
                lower = day_data['lick_acc']['mean'] - day_data['lick_acc']['sem']

                fig.add_trace(go.Scatter(x=x_data, y=day_data['lick_acc']['mean'], mode='lines', line_color=ce_colors_dict[group], showlegend=False), row=row, col=col)
                fig.add_trace(go.Scatter(x=x_data, y=upper, mode='lines', marker=dict(color=error_color[2]),
                                        name='Upper Bound', line=dict(width=0), showlegend=False), row=row, col=col)
                fig.add_trace(go.Scatter(x=x_data, y=lower, mode='lines', marker=dict(color=error_color[2]),
                                        name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[2], fill='tonexty'), row=row, col=col)
                             
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig.update_xaxes(title='Trial', row=4)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'acc_trials_control_all_days.png'))

### Signal detection metrics.

In [None]:
## Correct rejection rate
data_of_interest = 'behav' ## one of behav, aligned_minian, lin_behav
sig_df = pd.DataFrame()
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            for idx, session in enumerate(os.listdir(mpath)):
                behav = pd.read_feather(pjoin(mpath, session))
                behav = behav[~behav['probe']] ## exclude probe
                reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]    
                signal = pd.DataFrame(ctb.dprime_metrics(behav, mouse, day=idx+1, reward_ports=[reward_one, reward_two], forward_reverse='forward'))
                signal['experiment'] = behav['cohort'].unique()[0]
                sig_df = pd.concat([sig_df, signal], ignore_index=True)

In [None]:
## Plot correct rejection rate
corr_rej = sig_df.groupby(['day', 'mouse'], as_index=False).agg({'CR': 'mean'})
fig = pf.plot_behavior_across_days(corr_rej, x_var='day', y_var='CR', groupby_var=['day'], plot_transitions=[5.5, 10.5, 15.5],
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=False, transition_color=['darkgrey', 'darkgrey', 'darkgrey'],
                                   x_title='Day', y_title='Correct Rejection Rate', titles=['Circle Track'], height=500, width=500)
fig.update_yaxes(range=[0, 1])
fig.show()

### Determine where mice are licking across each session.

In [None]:
data_of_interest = 'behav'
threshold_list = [1, 2, 3, 4, 5]
lick_dict = {'mouse': [], 'day': [], 'group': [], 'group_two': [], 'session_two': [], 'lick_thresh': [], 'reward_ports': [],
             'front_ports': [], 'back_ports': [], 'final_ports': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            for idx, session in enumerate(os.listdir(mpath)):
                day = idx + 1
                behav = pd.read_feather(pjoin(mpath, f'{session}'))
                behav = behav[~behav['probe']]
                reward_one, reward_two = behav['reward_one'].unique()[0], behav['reward_two'].unique()[0]
                front_ports, back_ports = ctb.front_back_ports(reward_list=[reward_one, reward_two])

                licks = behav[behav['lick_port'] != -1]

                for lick_threshold in threshold_list:
                    count = 0
                    lick_port = np.nan
                    if licks.empty:
                        pass
                    else:
                        for idx, _ in licks.iterrows():
                            if lick_port != licks.loc[idx, 'lick_port']:
                                count = 1
                            else:
                                count += 1
                            
                            if count < lick_threshold - 1:
                                licks.loc[idx, 'threshold_reached'] = False
                            elif count == lick_threshold:
                                licks.loc[idx, 'threshold_reached'] = True
                            else:
                                licks.loc[idx, 'threshold_reached'] = False

                            lick_port =  licks.loc[idx, 'lick_port']

                        port_licks = licks[['lick_port', 'threshold_reached']].groupby(['lick_port'], as_index=False).agg({'threshold_reached': 'sum'})
                        total_licks = port_licks['threshold_reached'].sum()
                        reward_port_licks = 0
                        front_port_licks = 0
                        back_port_licks = 0
                        final_port_licks = 0
                        for _, row in port_licks.iterrows():
                            p_num = row['lick_port']
                            if p_num in [reward_one, reward_two]:
                                reward_port_licks = reward_port_licks + row['threshold_reached']
                            elif p_num in front_ports:
                                front_port_licks = front_port_licks + row['threshold_reached']
                            elif p_num in back_ports:
                                back_port_licks = back_port_licks + row['threshold_reached']
                            else:
                                final_port_licks = final_port_licks + row['threshold_reached']

                        lick_dict['mouse'].append(mouse)
                        lick_dict['day'].append(day)
                        lick_dict['group'].append(sex)
                        lick_dict['group_two'].append(group)
                        lick_dict['session_two'].append(behav['session_two'].unique()[0])
                        lick_dict['lick_thresh'].append(lick_threshold)
                        lick_dict['reward_ports'].append((reward_port_licks / total_licks) * 100)
                        lick_dict['front_ports'].append((front_port_licks / total_licks) * 100)
                        lick_dict['back_ports'].append((back_port_licks / total_licks) * 100)
                        lick_dict['final_ports'].append((final_port_licks / total_licks) * 100)
lick_df = pd.DataFrame(lick_dict)

In [None]:
## For experimental mice
lick_thresh = 5
x_axis = ['RP', 'FP', 'BP', 'LP']
fig = pf.custom_graph_template(x_title='', y_title='', height=1000, width=1000, 
                               shared_y=True, rows=4, columns=5, titles=session_list)
exp_licks = lick_df[lick_df['group_two'] == 'Experimental']
avg_licks = exp_licks.groupby(['session_two', 'lick_thresh'], as_index=False).agg({'reward_ports': ['mean', 'sem'], 'front_ports': ['mean', 'sem'],
                                                                                 'back_ports': ['mean', 'sem'], 'final_ports': ['mean', 'sem']})
for idx, session in enumerate(avg_licks['session_two'].unique()):
    pdata = avg_licks[(avg_licks['session_two'] == session) & (avg_licks['lick_thresh'] == lick_thresh)]

    if idx < 5:
        row, col = 1, idx + 1
    elif (idx >= 5) & (idx < 10):
        row, col = 2, idx - 4
    elif (idx >= 10) & (idx < 15):
        row, col = 3, idx - 9
    elif idx >= 15:
        row, col = 4, idx - 14

    y_data = [pdata['reward_ports']['mean'].values[0],
              pdata['front_ports']['mean'].values[0],
              pdata['back_ports']['mean'].values[0],
              pdata['final_ports']['mean'].values[0]]
    y_sem = [pdata['reward_ports']['sem'].values[0],
              pdata['front_ports']['sem'].values[0],
              pdata['back_ports']['sem'].values[0],
              pdata['final_ports']['sem'].values[0]]
    fig.add_trace(go.Bar(x=x_axis, y=y_data, showlegend=False, marker_color=avg_color,
                         error_y=dict(type='data', array=y_sem), marker_line_color='black', marker_line_width=2), row=row, col=col)
    
    for mouse in exp_licks['mouse'].unique():
        data = exp_licks[(exp_licks['mouse'] == mouse) & (exp_licks['session_two'] == session) & (exp_licks['lick_thresh'] == lick_thresh)]
        fig.add_trace(go.Scatter(x=x_axis, y=data.loc[:, ['reward_ports', 'front_ports', 'back_ports', 'final_ports']].to_numpy()[0],
                                 mode='markers', marker_color='darkgrey', marker=dict(line=dict(width=1)),
                                 name=mouse, showlegend=False, opacity=0.6), row=row, col=col)
fig.update_yaxes(title='5th Licks (%)', col=1, range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'front_back_leftover_ports_experimental.png'))

In [None]:
## For control mice
lick_thresh = 5
x_axis = ['RP', 'FP', 'BP', 'LP']
natsort_key = natsort_keygen()
fig = pf.custom_graph_template(x_title='', y_title='', height=1000, width=1000, 
                               shared_y=True, rows=4, columns=5, titles=control_sessions)
exp_licks = lick_df[lick_df['group_two'] == 'Control']
avg_licks = exp_licks.groupby(['session_two', 'lick_thresh'], as_index=False).agg({'reward_ports': ['mean', 'sem'], 'front_ports': ['mean', 'sem'],
                                                                                 'back_ports': ['mean', 'sem'], 'final_ports': ['mean', 'sem']})
for idx, session in enumerate(sorted(avg_licks['session_two'].unique(), key=natsort_key)):
    pdata = avg_licks[(avg_licks['session_two'] == session) & (avg_licks['lick_thresh'] == lick_thresh)]

    if idx < 5:
        row, col = 1, idx + 1
    elif (idx >= 5) & (idx < 10):
        row, col = 2, idx - 4
    elif (idx >= 10) & (idx < 15):
        row, col = 3, idx - 9
    elif idx >= 15:
        row, col = 4, idx - 14

    y_data = [pdata['reward_ports']['mean'].values[0],
              pdata['front_ports']['mean'].values[0],
              pdata['back_ports']['mean'].values[0],
              pdata['final_ports']['mean'].values[0]]
    y_sem = [pdata['reward_ports']['sem'].values[0],
              pdata['front_ports']['sem'].values[0],
              pdata['back_ports']['sem'].values[0],
              pdata['final_ports']['sem'].values[0]]
    fig.add_trace(go.Bar(x=x_axis, y=y_data, showlegend=False, marker_color=ce_colors[0],
                         error_y=dict(type='data', array=y_sem), marker_line_color='black', marker_line_width=2), row=row, col=col)
    
    for mouse in exp_licks['mouse'].unique():
        data = exp_licks[(exp_licks['mouse'] == mouse) & (exp_licks['session_two'] == session) & (exp_licks['lick_thresh'] == lick_thresh)]
        fig.add_trace(go.Scatter(x=x_axis, y=data.loc[:, ['reward_ports', 'front_ports', 'back_ports', 'final_ports']].to_numpy()[0],
                                 mode='markers', marker_color='darkgrey', marker=dict(line=dict(width=1)),
                                 name=mouse, showlegend=False, opacity=0.6), row=row, col=col)
fig.update_yaxes(title='5th Licks (%)', col=1, range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'front_back_leftover_ports_control.png'))

### During probe on first day in new context, calculate lick accuracy based on rotationally equivalent ports.

In [None]:
data_of_interest = 'behav' ## one of behav, aligned_minian, lin_behav
lick_dict_probe = {'mouse': [], 'experiment': [], 'session': [], 'sex': [], 'group': [], 'day': [], 'num_licks': [], 'shift': [], 'probe_acc': []}
reward_list = [x for x in np.arange(1, 9)]

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if (mouse == 'mc48') | (mouse == 'mc57'):
                pass 
            else:
                mpath = pjoin(exp_path, mouse)
                sex = 'Male' if mouse in male_mice else 'Female'
                group = set_group(mouse, control_mice)

                for idx, session in enumerate(os.listdir(mpath)):
                    behav = pd.read_feather(pjoin(mpath, session))
                    if session in [f'{mouse}_5.feat', f'{mouse}_10.feat', f'{mouse}_15.feat']:
                        reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]

                    if session in [f'{mouse}_6.feat', f'{mouse}_11.feat', f'{mouse}_16.feat']:
                        behav_probe = behav[behav['probe']]
                        for shift in np.arange(1, 9):
                            shifted_rewards = np.roll(reward_list, shift)
                            percent_correct = ctb.lick_accuracy(behav_probe, port_list=[shifted_rewards[reward_one-1], shifted_rewards[reward_two-1]], 
                                                                lick_threshold=lick_thresh, by_trials=False, to_percent=True)
                            lick_dict_probe['mouse'].append(mouse)
                            lick_dict_probe['experiment'].append(behav['cohort'].unique()[0])
                            lick_dict_probe['sex'].append(sex)
                            lick_dict_probe['group'].append(group)
                            lick_dict_probe['day'].append(idx+1)
                            lick_dict_probe['session'].append(np.unique(behav['session'])[0])
                            lick_dict_probe['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
                            lick_dict_probe['shift'].append(shift)
                            lick_dict_probe['probe_acc'].append(percent_correct)
shifted_df = pd.DataFrame(lick_dict_probe)
shifted_df = shifted_df[shifted_df['group'] == 'Experimental'] ## select only the experimental group for this analysis
shifted_avg = shifted_df.groupby(['day', 'shift'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
mouse_avg = shifted_df.groupby(['mouse', 'shift'], as_index=False).agg({'probe_acc': ['mean', 'sem']})

In [None]:
## Plot avg+sem of rotations of reward ports for each day
color_dict = {6: 'darkgrey', 11: 'midnightblue', 16: 'darkorchid'}
fig = pf.custom_graph_template(x_title='Rotation Number', y_title='Lick Accuracy (%)')
for day in shifted_avg['day'].unique():
    data = shifted_avg[shifted_avg['day'] == day]
    fig.add_trace(go.Scatter(x=data['shift'], y=data['probe_acc']['mean'], name=f'Day {day}', mode='markers',
                             error_y=dict(type='data', array=data['probe_acc']['sem']), line_color=color_dict[day]))

fig.add_hline(y=25, line_width=1, opacity=1, line_color='black', line_dash='dash')
fig.update_yaxes(range=[0, 101])
fig.show()
fig.write_image(pjoin(fig_path, 'rotated_probe_accuracy_all_rotations.png'))

In [None]:
## Violin plot of all rotations
fig = pf.custom_graph_template(x_title='Rotation Number', y_title='Lick Accuracy (%)', width=1300)
for day in shifted_df['day'].unique():
    data = shifted_df[shifted_df['day'] == day]
    fig.add_trace(go.Violin(x=data['shift'], y=data['probe_acc'], legendgroup=str(day), scalegroup=str(day), name=f'Day {day}', 
                            line_color=color_dict[day], box_visible=True, meanline_visible=True))
fig.update_traces(points='all', jitter=0.4)
fig.update_layout(violinmode='group')
fig.add_hline(y=25, line_width=1, opacity=1, line_color='black', line_dash='dash')
fig.show()

In [None]:
## Accuracy for each rotation where the three days are collapsed into one value for each mouse
fig = pf.custom_graph_template(x_title='Rotation Number', y_title='Lick Accuracy (%)', width=1000)
fig.add_trace(go.Violin(x=mouse_avg['shift'], y=mouse_avg['probe_acc']['mean'], box_visible=True, meanline_visible=True, line_color=ce_colors_dict['Experimental']))
fig.update_traces(points='all', jitter=0.4)
fig.add_hline(y=25, line_width=1, opacity=1, line_color='black', line_dash='dash')
fig.update_yaxes(range=[-20, 101])
fig.show()

### Rotated ports, but only for the rotation based on old maze to new maze.

In [None]:
data_of_interest = 'behav' ## one of behav, aligned_minian, lin_behav
lick_dict = {'mouse': [], 'experiment': [], 'session': [], 'sex': [], 'group': [], 'day': [], 'num_licks': [], 'probe_acc': []}

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if (mouse == 'mc48') | (mouse == 'mc57'):
                pass 
            else:
                mpath = pjoin(exp_path, mouse)
                sex = 'Male' if mouse in male_mice else 'Female'
                group = set_group(mouse, control_mice)

                if group == 'Two-context':
                    pass
                else:

                    for idx, session in enumerate(os.listdir(mpath)):
                        behav = pd.read_feather(pjoin(mpath, session))
                        if session in [f'{mouse}_5.feat', f'{mouse}_10.feat', f'{mouse}_15.feat']:
                            reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                            old_maze = behav['maze'].unique()[0]

                        if session in [f'{mouse}_6.feat', f'{mouse}_11.feat', f'{mouse}_16.feat']:
                            behav_probe = behav[behav['probe']]

                            ## Rotate ports
                            rot_one, rot_two = ctb.rotate_ports(input_maze=old_maze, 
                                                                output_maze=np.unique(behav_probe['maze'])[0], 
                                                                reward_one=reward_one, 
                                                                reward_two=reward_two)
                            percent_correct = ctb.lick_accuracy(behav_probe, port_list=[rot_one, rot_two], 
                                                                lick_threshold=5, by_trials=False, to_percent=True)
                            lick_dict['mouse'].append(mouse)
                            lick_dict['experiment'].append(behav['cohort'].unique()[0])
                            lick_dict['sex'].append(sex)
                            lick_dict['group'].append(group)
                            lick_dict['day'].append(idx+1)
                            lick_dict['session'].append(np.unique(behav['session'])[0])
                            lick_dict['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
                            lick_dict['probe_acc'].append(percent_correct)
rotation_df = pd.DataFrame(lick_dict)

In [None]:
## Check if mice are using navigating based on a rotation of the previous environment
rotation_df['day'] = rotation_df['day'].replace({6: 'A to B', 11: 'B to C', 16: 'C to D'})
fig =  pf.custom_graph_template(x_title='', y_title='Lick Accuracy (%)')
fig.add_trace(go.Violin(x=rotation_df['day'], y=rotation_df['probe_acc'], box_visible=True, meanline_visible=True, line_color=ce_colors_dict['Experimental']))
fig.update_traces(points='all', jitter=0.4)
fig.add_hline(y=25, line_width=1, opacity=1, line_color='black', line_dash='dash')
fig.show()
fig.write_image(pjoin(fig_path, 'rotated_probe_oldmaze_newmaze.png'))

### Check and see where mice lick first

In [None]:
## Circle track behavior
data_of_interest = 'behav' ## one of behav, aligned_minian, lin_behav
first_dict = {'mouse': [], 'experiment': [], 'sex': [], 'group': [], 'day': [], 'session': [], 'lick_bool': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if (mouse == 'mc48') | (mouse == 'mc57'):
                pass 
            else:
                mpath = pjoin(exp_path, mouse)
                sex = 'Male' if mouse in male_mice else 'Female'
                group = set_group(mouse, control_mice)
                for idx, session in enumerate(natsorted(os.listdir(mpath))):
                    behav = pd.read_feather(pjoin(mpath, session))
                    reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                    first_lick = behav['lick_port'][behav['lick_port'] != -1].reset_index(drop=True)[0]
                    if (first_lick == reward_one) | (first_lick == reward_two):
                        lick_bool = 1
                    else:
                        lick_bool = 0
                    first_dict['mouse'].append(mouse)
                    first_dict['experiment'].append(behav['cohort'].unique()[0])
                    first_dict['sex'].append(sex)
                    first_dict['group'].append(group)
                    first_dict['day'].append(idx+1)
                    first_dict['session'].append(behav['session'].unique()[0])
                    first_dict['lick_bool'].append(lick_bool)
first_lick_df = pd.DataFrame(first_dict)

In [None]:
first_lick_df

### Create a hysteresis plot of rolling lick accuracy and rolling rewards using each trial as a bin for D1.

In [None]:
## Create a rolling average of lick accuracy and rewards for day 16
wnd = 5 ## for moving average
bin_size = 5 ## for number of trials within a bin
lick_thresh = 5 
data_of_interest = 'behav' 

trial_list = []
hyst_dict = {'mouse': [], 'sex': [], 'group': [], 'experiment': [], 'session': [], 'day': [], 'trial': [], 'lick_acc': [], 
             'rewards': [], 'rolling_lick_acc': [], 'rolling_rewards': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                if session == f'{mouse}_16.feat':
                    behav = pd.read_feather(pjoin(mpath, session))
                    reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                    percent_correct = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=lick_thresh, by_trials=True)
                    rolling_lick_acc = np.convolve(percent_correct, v=np.ones(wnd)/wnd, mode='same')
                    num_rewards = ctb.rewards_across_trials(behav)
                    rolling_rewards = np.convolve(num_rewards, v=np.ones(wnd)/wnd, mode='same')
                    trial_list.append(len(percent_correct))
                    for trial in behav['trials'].unique():
                        hyst_dict['mouse'].append(mouse)
                        hyst_dict['sex'].append(sex)
                        hyst_dict['group'].append(group)
                        hyst_dict['experiment'].append(behav['cohort'].unique()[0])
                        hyst_dict['day'].append(idx+1)
                        hyst_dict['session'].append(np.unique(behav['session'])[0])
                        hyst_dict['trial'].append(trial+1) ## to move from zero indexing
                        hyst_dict['lick_acc'].append(percent_correct[int(trial)])
                        hyst_dict['rewards'].append(num_rewards[int(trial)])
                        hyst_dict['rolling_lick_acc'].append(rolling_lick_acc[int(trial)])
                        hyst_dict['rolling_rewards'].append(rolling_rewards[int(trial)])
                else:
                    pass
hyst_df = pd.DataFrame(hyst_dict)

## Get the median number of trials to plot lick accuracy and rewards only up to that value
median_trial = np.median(trial_list)
med_df = hyst_df[hyst_df['trial'] <= median_trial]
bin_dict = {'mouse': [], 'sex': [], 'group': [], 'trial': [], 'lick_acc': [], 'rewards': []}
for mouse in med_df['mouse'].unique():
    mdata = med_df[med_df['mouse'] == mouse].reset_index(drop=True)
    accuracy = ctb.bin_data(mdata['lick_acc'].to_numpy(), bin_size)
    rewards = ctb.bin_data(mdata['rewards'].to_numpy(), bin_size)
    for idx, val in enumerate(accuracy):
        bin_dict['mouse'].append(mouse)
        bin_dict['sex'].append(mdata['sex'].unique()[0])
        bin_dict['group'].append(mdata['group'].unique()[0])
        bin_dict['trial'].append(idx * bin_size)
        bin_dict['lick_acc'].append(val)
        bin_dict['rewards'].append(rewards[idx])
bin_df = pd.DataFrame(bin_dict)
avg_df = bin_df.groupby(['group', 'trial'], as_index=False).agg({'lick_acc': ['mean', 'sem'], 'rewards': ['mean', 'sem']})

In [None]:
## Hysteresis plot of rewards vs lick accuracy
xval = 'rolling_rewards' 
yval = 'rolling_lick_acc'
avg = hyst_df.groupby(['group', 'trial'], as_index=False).agg({'lick_acc': ['mean', 'sem'], 
                                                               'rewards': ['mean', 'sem'], 
                                                               'rolling_lick_acc': ['mean', 'sem'], 
                                                               'rolling_rewards': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Rewards', y_title='Lick Accuracy', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    fig.add_trace(go.Scattergl(x=gdata[xval]['mean'], y=gdata[yval]['mean'], mode='lines', name=group,
                               line_color=ce_colors_dict[group]))
fig.show()

In [None]:
## Lick accuracy across trials for D1
yvar = 'rolling_lick_acc'
fig = pf.custom_graph_template(x_title='Trial', y_title='Lick Accuracy (%)', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    x_data = np.arange(1, gdata['trial'].max()+1)
    upper = gdata[yvar]['mean'] + gdata[yvar]['sem']
    lower = gdata[yvar]['mean'] - gdata[yvar]['sem']
    fig.add_trace(go.Scatter(x=x_data, y=gdata[yvar]['mean'], mode='lines', 
                               name=group, line_color=ce_colors_dict[group]))
    fig.add_trace(go.Scatter(x=x_data, y=upper, mode='lines', marker=dict(color=error_color[group]),
                        name='Upper Bound', line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=x_data, y=lower, mode='lines', marker=dict(color=error_color[group]),
                            name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'))
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_trials_D1.png'), width=600, height=500)

In [None]:
## Rewards across all trials for D1
fig = pf.custom_graph_template(x_title='Trial', y_title='Rewards', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    x_data = np.arange(1, gdata['trial'].max()+1)
    upper = gdata['rolling_rewards']['mean'] + gdata['rolling_rewards']['sem']
    lower = gdata['rolling_rewards']['mean'] - gdata['rolling_rewards']['sem']
    fig.add_trace(go.Scatter(x=x_data, y=gdata['rolling_rewards']['mean'], mode='lines', 
                               name=group, line_color=ce_colors_dict[group]))
    fig.add_trace(go.Scatter(x=x_data, y=upper, mode='lines', marker=dict(color=error_color[group]),
                        name='Upper Bound', line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=x_data, y=lower, mode='lines', marker=dict(color=error_color[group]),
                            name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'))
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_trials_D1.png'), width=600, height=500)

In [None]:
## Plot lick accuracy for the median number of trials
fig = pf.plot_groups_shaded_error(df=avg_df, yvar='lick_acc', x_data=avg_df['trial'], group_list=['Multi-context', 'Two-context'],
                                  colors_dict=ce_colors_dict, error_color=error_color, x_title='Trial', y_title='Lick Accuracy (%)', width=600)
fig.add_hline(y=25, line_color=chance_color, opacity=1, line_width=1, line_dash='dash')
fig.show()
fig.write_image(pjoin(fig_path, 'median_trials_lick_acc_D1.png'), width=600, height=500)

In [None]:
## Plot rewards for the median number of trials
fig = pf.plot_groups_shaded_error(df=avg_df, yvar='rewards', x_data=avg_df['trial'], group_list=['Multi-context', 'Two-context'],
                                  colors_dict=ce_colors_dict, error_color=error_color, x_title='Trial', y_title='Rewards', width=600)
fig.show()
fig.write_image(pjoin(fig_path, 'median_trials_rewards_D1.png'))

### Create a hysteresis plot of lick accuracy and rewards using 30 or 60s bins for D1.

In [None]:
## Create a rolling average of lick accuracy and rewards for day 16
wnd = 2
lick_thresh = 5
bin_in_seconds = 30
data_of_interest = 'behav'
time_bins = np.arange(0, 900 + bin_in_seconds, bin_in_seconds)

hyst_dict = {'mouse': [], 'sex': [], 'group': [], 'experiment': [], 'session': [], 'day': [], 'bin': [], 'lick_acc': [], 'rewards': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                if session == f'{mouse}_16.feat':
                    behav = pd.read_feather(pjoin(mpath, session))
                    reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                    for tbin in time_bins[1:]:
                        sub = behav[(behav['t'] < tbin) & (behav['t'] > tbin - bin_in_seconds)].reset_index(drop=True)
                        percent_correct = ctb.lick_accuracy(sub, port_list=[reward_one, reward_two], lick_threshold=lick_thresh, by_trials=False)
                        num_rewards = np.sum(sub['water'])

                        hyst_dict['mouse'].append(mouse)
                        hyst_dict['sex'].append(sex)
                        hyst_dict['group'].append(group)
                        hyst_dict['experiment'].append(behav['cohort'].unique()[0])
                        hyst_dict['day'].append(idx+1)
                        hyst_dict['session'].append(np.unique(behav['session'])[0])
                        hyst_dict['bin'].append(tbin) 
                        hyst_dict['lick_acc'].append(percent_correct)
                        hyst_dict['rewards'].append(num_rewards)
                else:
                    pass
df = pd.DataFrame(hyst_dict)

hyst_time_df = pd.DataFrame()
for mouse in df['mouse'].unique():
    mdata = df[df['mouse'] == mouse].reset_index(drop=True)
    mdata['rolling_lick_acc'] = np.convolve(mdata['lick_acc'], v=np.ones(wnd)/wnd, mode='same')
    mdata['rolling_rewards'] = np.convolve(mdata['rewards'], v=np.ones(wnd)/wnd, mode='same')
    mdata['cumulative_lick_acc'] = np.cumsum(mdata['lick_acc'])
    mdata['cumulative_rewards'] = np.cumsum(mdata['rewards'])
    hyst_time_df = pd.concat([hyst_time_df, mdata], ignore_index=True)

avg = hyst_time_df.groupby(['group', 'bin'], as_index=False).agg({'lick_acc': ['mean', 'sem'], 'rewards': ['mean', 'sem'],
                                                                 'rolling_lick_acc': ['mean', 'sem'], 'rolling_rewards': ['mean', 'sem'],
                                                                 'cumulative_lick_acc': ['mean', 'sem'], 'cumulative_rewards': ['mean', 'sem']})

In [None]:
## Plot lick accuracy across the session in bins of time
yvar = 'rolling_lick_acc'
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Lick Accuracy (%)', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    upper = gdata[yvar]['mean'] + gdata[yvar]['sem']
    lower = gdata[yvar]['mean'] - gdata[yvar]['sem']
    fig.add_trace(go.Scatter(x=time_bins[1:], y=gdata[yvar]['mean'], mode='lines', 
                               name=group, line_color=ce_colors_dict[group]))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=upper, mode='lines', marker=dict(color=error_color[group]),
                        name='Upper Bound', line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=lower, mode='lines', marker=dict(color=error_color[group]),
                            name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'))
fig.add_hline(y=25, line_dash='dash', line_color=chance_color, line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_across_time_D1.png'), width=600, height=500)

In [None]:
## Plot rewards across the session in bins of time
yvar = 'rolling_rewards'
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Rewards', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    upper = gdata[yvar]['mean'] + gdata[yvar]['sem']
    lower = gdata[yvar]['mean'] - gdata[yvar]['sem']
    fig.add_trace(go.Scatter(x=time_bins[1:], y=gdata[yvar]['mean'], mode='lines', 
                               name=group, line_color=ce_colors_dict[group]))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=upper, mode='lines', marker=dict(color=error_color[group]),
                        name='Upper Bound', line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=lower, mode='lines', marker=dict(color=error_color[group]),
                            name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'))
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_time_D1.png'), width=600, height=500)

In [None]:
## Hysteresis plot
xval = 'rolling_rewards' 
yval = 'rolling_lick_acc'
fig = pf.custom_graph_template(x_title='Rewards', y_title='Lick Accuracy', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    fig.add_trace(go.Scattergl(x=gdata[xval]['mean'], y=gdata[yval]['mean'], mode='lines', name=group,
                               line_color=ce_colors_dict[group]))
fig.add_hline(y=25, line_dash='dash', line_color=chance_color, line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, 'D1_hysteresis_rewards_lick_acc.png'), width=600, height=500)

In [None]:
yvar = 'cumulative_lick_acc'
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Cumulative Lick Accuracy (a.u.)', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    upper = gdata[yvar]['mean'] + gdata[yvar]['sem']
    lower = gdata[yvar]['mean'] - gdata[yvar]['sem']
    fig.add_trace(go.Scatter(x=time_bins[1:], y=gdata[yvar]['mean'], mode='lines', 
                               name=group, line_color=ce_colors_dict[group]))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=upper, mode='lines', marker=dict(color=error_color[group]),
                        name='Upper Bound', line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=lower, mode='lines', marker=dict(color=error_color[group]),
                            name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'))
fig.show()
fig.write_image(pjoin(fig_path, 'cumulative_lick_acc_D1.png'), width=600, height=500)

In [None]:
yvar = 'cumulative_rewards'
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Cumulative Rewards', width=600)
for group in ['Multi-context', 'Two-context']:
    gdata = avg[avg['group'] == group]
    upper = gdata[yvar]['mean'] + gdata[yvar]['sem']
    lower = gdata[yvar]['mean'] - gdata[yvar]['sem']
    fig.add_trace(go.Scatter(x=time_bins[1:], y=gdata[yvar]['mean'], mode='lines', 
                               name=group, line_color=ce_colors_dict[group]))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=upper, mode='lines', marker=dict(color=error_color[group]),
                        name='Upper Bound', line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=time_bins[1:], y=lower, mode='lines', marker=dict(color=error_color[group]),
                            name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'))
fig.show()
fig.write_image(pjoin(fig_path, 'cumulative_rewards_D1.png'), width=600, height=500)

### Number of trials across sessions as a histogram each day.

In [None]:
trial_dict = {'mouse': [], 'sex': [], 'group': [], 'experiment': [], 'day': [], 'session': [], 'num_trials': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                behav = pd.read_feather(pjoin(mpath, session))

                trial_dict['mouse'].append(mouse)
                trial_dict['sex'].append(sex)
                trial_dict['group'].append(group)
                trial_dict['experiment'].append(behav['cohort'].unique()[0])
                trial_dict['day'].append(idx+1)
                trial_dict['session'].append(behav['session'].unique()[0])
                trial_dict['num_trials'].append(behav['trials'].unique()[-1])
num_trials_df = pd.DataFrame(trial_dict)

In [None]:
## Plot histogram of number of trials for each group each day
fig = pf.custom_graph_template(x_title='', y_title='', height=1000, width=1000, shared_x=True, shared_y=True,
                               rows=4, columns=5, titles=day_names)

for idx, day in enumerate(np.arange(1, 21)):
    
    if idx < 5:
        row, col = 1, idx + 1
    elif (idx >= 5) & (idx < 10):
        row, col = 2, idx - 4
    elif (idx >= 10) & (idx < 15):
        row, col = 3, idx - 9
    elif idx >= 15:
        row, col = 4, idx - 14
    
    for group in ['Multi-context', 'Two-context']:
        gdata = num_trials_df[(num_trials_df['group'] == group) & (num_trials_df['day'] == day)].reset_index(drop=True)
        fig.add_trace(go.Histogram(x=gdata['num_trials'], marker_color=ce_colors_dict[group], showlegend=False,
                                   marker_line_width=1), row=row, col=col)
        fig.add_vline(x=np.mean(gdata['num_trials']), line_dash='dash', line_color=ce_colors_dict[group], line_width=1, opacity=1, row=row, col=col)
fig.update_yaxes(title='Count', col=1)
fig.update_xaxes(title='Trials', row=4)
fig.show()
fig.write_image(pjoin(fig_path, 'num_trials_both_groups_histograms.png'), width=1000, height=1000)

### Get percentage of time spent in the correct direction for the session.

In [None]:
correct_dict = {'mouse': [], 'sex': [], 'group': [], 'experiment': [], 'day': [], 'session': [], 'correct_dir': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                behav = pd.read_feather(pjoin(mpath, session))

                correct_dict['mouse'].append(mouse)
                correct_dict['sex'].append(sex)
                correct_dict['group'].append(group)
                correct_dict['experiment'].append(behav['cohort'].unique()[0])
                correct_dict['day'].append(idx+1)
                correct_dict['session'].append(behav['session'].unique()[0])
                correct_dict['correct_dir'].append((np.sum(behav['correct_dir']) / behav.shape[0]) * 100)
correct_dir_df = pd.DataFrame(correct_dict)

In [None]:
fig = pf.plot_behavior_across_days(correct_dir_df, x_var='day', y_var='correct_dir', groupby_var=['day', 'group'],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=False, plot_transitions=[5.5, 10.5, 15.5],
                                   transition_color=['darkgrey', 'darkgrey', 'darkgrey'], symbols=symbols_list,
                                   plot_datapoints=False, x_title='Day', y_title='Correct Direction (%)', titles=[''], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'correct_dir_all_days.png'), width=500, height=500)
## Perform a linear regression to test for differences between groups. 
data = correct_dir_df[correct_dir_df['day'] == 16].reset_index(drop=True)
model = ols('correct_dir ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Bin by time and see when mice are running in the correct direction
bin_in_seconds = 30
data_of_interest = 'behav'
time_bins = np.arange(0, 900 + bin_in_seconds, bin_in_seconds)

correct_dict = {'mouse': [], 'sex': [], 'group': [], 'experiment': [], 'day': [], 'session': [], 'tbin': [], 'correct_dir': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, mouse)
            sex = 'Male' if mouse in male_mice else 'Female'
            group = set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                behav = pd.read_feather(pjoin(mpath, session))
                for tbin in time_bins[1:]:
                    sub = behav[(behav['t'] < tbin) & (behav['t'] > tbin - bin_in_seconds)].reset_index(drop=True)
                    correct_dict['mouse'].append(mouse)
                    correct_dict['sex'].append(sex)
                    correct_dict['group'].append(group)
                    correct_dict['experiment'].append(behav['cohort'].unique()[0])
                    correct_dict['day'].append(idx+1)
                    correct_dict['session'].append(behav['session'].unique()[0])
                    correct_dict['tbin'].append(tbin)
                    correct_dict['correct_dir'].append((np.sum(sub['correct_dir']) / sub.shape[0]) * 100)
correct_dir_df = pd.DataFrame(correct_dict)
avg_correct = correct_dir_df.groupby(['day', 'tbin', 'group'], as_index=False).agg({'correct_dir': ['mean', 'sem']})

In [None]:
## Plot percentage correct direction across time bins for D1 (day 16)
sub_df = avg_correct[avg_correct['day'] == 16].reset_index(drop=True)
fig = pf.plot_groups_shaded_error(df=sub_df, yvar='correct_dir', x_data=sub_df['tbin'].unique(), group_list=['Multi-context', 'Two-context'],
                                  colors_dict=ce_colors_dict, error_color=error_color, x_title='Time (s)', y_title='Correct Direction (%)', width=600)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'correct_dir_by_time_D1.png'), width=600, height=500)