In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
# import pingouin as pg
import plotly.graph_objects as go
from os.path import join as pjoin
from natsort import natsorted

sys.path.append("../../")
import circletrack_behavior as ctb
import plotting_functions as pf

In [None]:
## Settings
parent_dir = 'CircleTrack_Recall'
experiment_dir = 'Recall2'
lin_path = f'../../../{parent_dir}/{experiment_dir}/output/lin_behav/'
circle_path = f'../../../{parent_dir}/{experiment_dir}/output/behav/'
fig_path = f'../../../{parent_dir}/{experiment_dir}/intermediate_figures'
chance_color = 'darkgrey'
avg_color = 'midnightblue'
subject_color = 'darkgrey'
two_group_colors = ['midnightblue', 'darkgrey']
four_group_colors = ['darkorchid', 'midnightblue', 'violet', 'blue']
young_mice = [f'mcr{x}' for x in np.arange(17, 25)]
male_mice = [f'mcr{x}' for x in np.arange(21, 25)] + [f'mcr{x}' for x in np.arange(31, 36)]
excluded_mice = ['mcr24', 'mcr29']
group_colors_dict = {'Middle Aged Female': 'darkorchid', 'Middle Aged Male': 'midnightblue', 'Young Male': 'blue', 'Young Female': 'violet',
                     'Young Adult': 'darkgrey', 'Middle Aged': 'midnightblue'}

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

### Linear track rewards

In [None]:
lin_dict = {'mouse': [], 'day': [], 'rewards': [], 'group': [], 'group_two': [], 'group_three': []}
for mouse in os.listdir(lin_path):
    mouse_path = pjoin(lin_path, mouse)
    group = 'Young Adult' if mouse in young_mice else 'Middle Aged'
    group_two = 'Young Male' if mouse in young_mice and mouse in male_mice else 'Young Female' if mouse in young_mice and mouse not in male_mice else 'Middle Aged Male' if mouse not in young_mice and mouse in male_mice else 'Middle Aged Female'
    for idx, session in enumerate(os.listdir(mouse_path)):
        lin_behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        lin_dict['mouse'].append(mouse)
        lin_dict['day'].append(idx+1)
        lin_dict['rewards'].append(np.sum(lin_behav['water']))
        lin_dict['group'].append(group)
        lin_dict['group_two'].append(group_two)
        lin_dict['group_three'].append(lin_behav['maze'].unique()[0])
lin_df = pd.DataFrame(lin_dict)

In [None]:
## Plot rewards across days on linear track
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=False,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.update_yaxes(range=[0, 150])
fig.show()
fig.write_image(pjoin(fig_path, 'linear_track_rewards.png'))

In [None]:
## Plot rewards across days on linear track for young/middle male/female
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day', 'group_two'], plot_transitions=None,
                                   marker_color=four_group_colors, avg_color=avg_color, expert_line=False, chance=False, plot_datapoints=False,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=700)
fig.update_yaxes(range=[0, 150])
fig.show()
fig.write_image(pjoin(fig_path, 'linear_track_rewards_malefemale_age.png'))

In [None]:
## Plot rewards across days on linear track for the two mazes
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day', 'group_three'], plot_transitions=None,
                                   marker_color=two_group_colors, avg_color=avg_color, expert_line=False, chance=False, plot_datapoints=True,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=600)
fig.update_yaxes(range=[0, 150])
fig.show()
fig.write_image(pjoin(fig_path, 'linear_track_rewards_mazes.png'))

### Circle track accuracy and rewards

In [None]:
circletrack_results = {'mouse': [], 'day': [], 'group': [], 'group_two': [], 'session': [], 'lick_accuracy_thresh5': [], 'rewards': []}
for mouse in os.listdir(circle_path):
    mouse_path = pjoin(circle_path, mouse)
    group = 'Young Adult' if mouse in young_mice else 'Middle Aged'
    group_two = 'Young Male' if mouse in young_mice and mouse in male_mice else 'Young Female' if mouse in young_mice and mouse not in male_mice else 'Middle Aged Male' if mouse not in young_mice and mouse in male_mice else 'Middle Aged Female'
    for idx, session in enumerate(os.listdir(mouse_path)):
        behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        behav = behav[~behav['probe']]
        reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
        pc_thresh5 = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
        circletrack_results['mouse'].append(mouse)
        circletrack_results['day'].append(idx+1)
        circletrack_results['group'].append(group)
        circletrack_results['group_two'].append(group_two)
        circletrack_results['session'].append(np.unique(behav['session'])[0])
        circletrack_results['lick_accuracy_thresh5'].append(pc_thresh5)
        circletrack_results['rewards'].append(np.sum(behav['water']))
ct_df = pd.DataFrame(circletrack_results)

In [None]:
## Plot 5th lick accuracy across days
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 6], x_var='day', y_var='lick_accuracy_thresh5', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=True,
                                   x_title='Day', y_title='5th Lick Accuracy', titles=['Circle Track'], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_across_days_ct.png'))

In [None]:
## Plot 5th lick accuracy across days
fig = pf.plot_behavior_across_days(ct_df, x_var='day', y_var='lick_accuracy_thresh5', groupby_var=['day', 'group_two'], plot_transitions=[5.5],
                                   marker_color=four_group_colors, avg_color=avg_color, expert_line=False, chance=True, transition_color=['darkgrey'],
                                   x_title='Day', y_title='5th Lick Accuracy', titles=['Circle Track'], height=500, width=600)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_across_days_groups_ct.png'))

In [None]:
## Plot 5th lick accuracy across days excluding two non-learners
fig = pf.plot_behavior_across_days(ct_df[(ct_df['mouse'] != 'mcr24') & (ct_df['mouse'] != 'mcr29')], x_var='day', y_var='lick_accuracy_thresh5', 
                                   groupby_var=['day', 'group_two'], plot_transitions=[5.5], transition_color=['darkgrey'],
                                   marker_color=four_group_colors, avg_color=avg_color, expert_line=False, chance=True,
                                   x_title='Day', y_title='5th Lick Accuracy', titles=['Circle Track'], height=500, width=600)
fig.update_yaxes(range=[0, 100])
fig.show()

In [None]:
## Plot 5th lick accuracy across days excluding two non-learners for young vs middle aged
fig = pf.plot_behavior_across_days(ct_df[(ct_df['mouse'] != 'mcr24') & (ct_df['mouse'] != 'mcr29')], x_var='day', y_var='lick_accuracy_thresh5', 
                                   groupby_var=['day', 'group'], plot_transitions=[5.5], transition_color=['darkgrey'],
                                   marker_color=two_group_colors, avg_color=avg_color, expert_line=False, chance=True,
                                   x_title='Day', y_title='Lick Accuracy (%)', titles=['Circle Track'], height=500, width=600)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_across_days_y_ma.png'))

### Circle track rewards

In [None]:
## Plot rewards
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 6], x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=True,
                                   x_title='Day', y_title='Rewards', titles=['Circle Track'], height=500, width=500)
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_days_ct.png'))

In [None]:
## Plot rewards
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 6], x_var='day', y_var='rewards', groupby_var=['day', 'group_two'], plot_transitions=None,
                                   marker_color=four_group_colors, avg_color=avg_color, expert_line=False, chance=True,
                                   x_title='Day', y_title='Rewards', titles=['Circle Track'], height=500, width=600)
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_days_ct_mfages.png'))

### Probe accuracy

In [None]:
lick_dict_probe = {'mouse': [], 'experiment': [], 'group': [], 'group_two': [], 'session': [], 'day': [], 'num_licks': [], 'probe_acc': [], 'session_acc': []}
for mouse in os.listdir(circle_path):
    mpath = pjoin(circle_path, mouse)
    group = 'Young Adult' if mouse in young_mice else 'Middle Aged'
    group_two = 'Young Male' if mouse in young_mice and mouse in male_mice else 'Young Female' if mouse in young_mice and mouse not in male_mice else 'Middle Aged Male' if mouse not in young_mice and mouse in male_mice else 'Middle Aged Female'
    for idx, session in enumerate(os.listdir(mpath)):
        behav = pd.read_feather(pjoin(mpath, session))
        if any(behav['probe']):
            behav_probe = behav[behav['probe']]
            behav_no_probe = behav[~behav['probe']]
            reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
            percent_correct = ctb.lick_accuracy(behav_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
            session_pc = ctb.lick_accuracy(behav_no_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
            lick_dict_probe['mouse'].append(mouse)
            lick_dict_probe['experiment'].append(behav['cohort'].unique()[0])
            lick_dict_probe['group'].append(group)
            lick_dict_probe['group_two'].append(group_two)
            lick_dict_probe['day'].append(idx+1)
            lick_dict_probe['session'].append(np.unique(behav['session'])[0])
            lick_dict_probe['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
            lick_dict_probe['probe_acc'].append(percent_correct)
            lick_dict_probe['session_acc'].append(session_pc)
        else:
            pass
probe_df = pd.DataFrame(lick_dict_probe)

first_last = pd.DataFrame()
context_list = ['A']
for mouse in probe_df['mouse'].unique():
    mouse_data = probe_df[probe_df['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])
avg_combined = first_last.groupby(['day_type', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
avg_group_probe = first_last.groupby(['day', 'group', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
avg_grouptwo_probe = first_last.groupby(['day', 'group_two', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})

In [None]:
## Plot probe performance
fig = pf.custom_graph_template(x_title='', y_title='', titles=['A'], shared_x=True, shared_y=True)
for idx, session in enumerate(context_list):
    plot_data = avg_combined[avg_combined['session'] == session]
    fig.add_trace(go.Scatter(x=plot_data['day_type'], y=plot_data['probe_acc']['mean'],
                                error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                line_color=avg_color, showlegend=False), row=1, col=idx+1)
for mouse in first_last['mouse'].unique():
    mdata = first_last[first_last['mouse'] == mouse]
    for idx, context in enumerate(mdata['session'].unique()):
        pdata = mdata[mdata['session'] == context]
        fig.add_trace(go.Scatter(x=pdata['day_type'], y=pdata['probe_acc'], mode='lines', line_color=chance_color,
                                 line_width=1, opacity=0.7, showlegend=False, name=mouse), row=1, col=idx+1)

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', col=1)
fig.show()

In [None]:
## Plot probe performance separated by group
plot_mice = False
fig = pf.custom_graph_template(x_title='', y_title='', titles=[''], 
                               shared_x=True, shared_y=True, width=600)
avg_group_probe = probe_df.groupby(['day', 'group', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
for group in avg_group_probe['group'].unique():
    gdata = avg_group_probe[avg_group_probe['group'] == group]
    for session in gdata['session'].unique():
        plot_data = gdata[gdata['session'] == session]
        fig.add_trace(go.Scatter(x=plot_data['day'], y=plot_data['probe_acc']['mean'], mode='markers',
                                    error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                    line_color=group_colors_dict[group], name=group, showlegend=False, legendgroup=group))
if plot_mice:
    for mouse in probe_df['mouse'].unique():
        mdata = probe_df[probe_df['mouse'] == mouse]
        fig.add_trace(go.Scatter(x=mdata['day'], y=mdata['probe_acc'], mode='lines', line_color=group_colors_dict[mdata['group'].unique()[0]],
                                    line_width=0.6, opacity=0.6, showlegend=False, name=mouse, legendgroup=mdata['group'].to_numpy()[0]))

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.add_vline(x=5.5, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', col=1, range=[0, 100])
fig.update_xaxes(
    ticktext=['A1', 'A5', 'R'],
    tickvals=[1, 5, 6],
)
fig['data'][0]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_recall_young_middle.png'))

In [None]:
## Plot probe performance separated by group two
plot_mice = False
fig = pf.custom_graph_template(x_title='', y_title='', titles=[''], 
                               shared_x=True, shared_y=True, width=600)
avg_grouptwo_probe = probe_df.groupby(['day', 'group_two', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
for group in avg_grouptwo_probe['group_two'].unique():
    gdata = avg_grouptwo_probe[avg_grouptwo_probe['group_two'] == group]
    for session in gdata['session'].unique():
        plot_data = gdata[gdata['session'] == session]
        fig.add_trace(go.Scatter(x=plot_data['day'], y=plot_data['probe_acc']['mean'], mode='markers',
                                    error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                    line_color=group_colors_dict[group], name=group, showlegend=False, legendgroup=group))
if plot_mice:
    for mouse in probe_df['mouse'].unique():
        mdata = probe_df[probe_df['mouse'] == mouse]
        fig.add_trace(go.Scatter(x=mdata['day'], y=mdata['probe_acc'], mode='lines', line_color=group_colors_dict[mdata['group_two'].unique()[0]],
                                 line_width=0.6, opacity=0.6, showlegend=False, name=mouse, legendgroup=mdata['group_two'].to_numpy()[0]))

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.add_vline(x=5.5, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', col=1, range=[0, 100])
fig.update_xaxes(
    ticktext=['A1', 'A5', 'R'],
    tickvals=[1, 5, 6],
)
fig['data'][0]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig['data'][4]['showlegend'] = True
fig['data'][6]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_recall_youngmiddle_malefemale.png'))