In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
# import pingouin as pg
import plotly.graph_objects as go
from os.path import join as pjoin
from natsort import natsorted

sys.path.append("../../")
import circletrack_behavior as ctb
import plotting_functions as pf

In [None]:
## Settings 
parent_dir = 'CircleTrack_Recall'
experiment_dir = 'Recall1'
lin_path = f'../../../{parent_dir}/{experiment_dir}/output/lin_behav/'
circle_path = f'../../../{parent_dir}/{experiment_dir}/output/behav/'
fig_path = f'../../../{parent_dir}/{experiment_dir}/intermediate_figures'
chance_color = 'darkgrey'
avg_color = 'blue'
subject_color = '#7d7d7d'
three_color_plots = ['#7d7d7d', 'blue', 'darkturquoise']
group_colors_dict = {'5 day': '#7d7d7d', '10 day': 'blue', '15 day': 'darkturquoise'}
excluded_mice = ['mcr09', 'mcr16'] ## didn't increase rewards across days, and didn't improve during probe
five_day = ['mcr02', 'mcr03', 'mcr07', 'mcr11', 'mcr13']
ten_day = ['mcr01', 'mcr05', 'mcr06', 'mcr10', 'mcr14']

### Linear track: rewards across days.

In [None]:
lin_dict = {'mouse': [], 'day': [], 'rewards': [], 'group': []}
for mouse in os.listdir(lin_path):
    mouse_path = pjoin(lin_path, mouse)
    for idx, session in enumerate(os.listdir(mouse_path)):
        lin_behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        lin_dict['mouse'].append(mouse)
        lin_dict['day'].append(idx+1)
        lin_dict['rewards'].append(np.sum(lin_behav['water']))
        lin_dict['group'].append(lin_behav['maze'].unique()[0])
lin_df = pd.DataFrame(lin_dict)

In [None]:
mouse = 'mcr01' 
session = '4'
lin_behav = pd.read_feather(pjoin(lin_path, f'{mouse}/{mouse}_{session}.feat'))
fig = pf.custom_graph_template(x_title='Time (s)', y_title='X Position (pixel)')
fig.add_trace(go.Scatter(x=lin_behav['t'], y=lin_behav['x'], mode='lines', line_color='darkgrey', showlegend=False))
fig.add_trace(go.Scatter(x=lin_behav['t'][lin_behav['water']], y=lin_behav['x'][lin_behav['water']], 
                         mode='markers', marker_color='red', showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_lineartrack_position.png'))

In [None]:
## Plot rewards across days on linear track
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=False,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()
# fig.write_image(pjoin(fig_path, 'linear_track_rewards.png'))

In [None]:
maze_colors = ['midnightblue', 'darkgrey']
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day', 'group'], plot_transitions=None,
                                   marker_color=maze_colors, symbols=['circle', 'circle'], avg_color='darkgrey', expert_line=False, chance=False, plot_datapoints=True,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()

### Circle track lick accuracy and rewards across days.

In [None]:
circletrack_results = {'mouse': [], 'group': [], 'day': [], 'session': [], 'lick_acc': [], 'rewards': []}
for mouse in os.listdir(circle_path):
    group = 'five_day' if mouse in five_day else 'ten_day' if mouse in ten_day else 'fifteen_day' 
    mouse_path = pjoin(circle_path, mouse)
    for idx, session in enumerate(os.listdir(mouse_path)):
        behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        behav = behav[~behav['probe']]
        reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
        pc_thresh5 = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
        circletrack_results['mouse'].append(mouse)
        circletrack_results['group'].append(group)
        circletrack_results['day'].append(idx+1)
        circletrack_results['session'].append(np.unique(behav['session'])[0])
        circletrack_results['lick_acc'].append(pc_thresh5)
        circletrack_results['rewards'].append(np.sum(behav['water']))
ct_df = pd.DataFrame(circletrack_results)

In [None]:
downsample_factor = 20
fig = pf.custom_graph_template(x_title='X Position (pixels)', y_title='Y Position (pixels)')
fig.add_trace(go.Scatter(x=behav['x'][::downsample_factor], y=behav['y'][::downsample_factor], 
                         mode='markers', line_color='darkgrey', showlegend=False))
fig.add_trace(go.Scatter(x=behav['x'][behav['water']], y=behav['y'][behav['water']], 
                         mode='markers', marker_color='midnightblue', showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, 'mcr15_circleposition.png'))

In [None]:
## Plot 5th lick accuracy across days
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 6], x_var='day', y_var='lick_acc', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=True, plot_datapoints=False,
                                   x_title='Day', y_title='Lick Accuracy (%)', titles=['Circle Track'], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_across_days_ct.png'))

In [None]:
## Plot rewards across days
fig = pf.plot_behavior_across_days(ct_df[ct_df['day'] < 6], x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=False, plot_datapoints=False,
                                   x_title='Day', y_title='Rewards', titles=['Circle Track'], height=500, width=500)
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_days.png'))

### Probe performance.

In [None]:
lick_dict_probe = {'mouse': [], 'experiment': [], 'session': [], 'day': [], 'num_licks': [], 'probe_acc': [], 'session_acc': []}
for mouse in os.listdir(circle_path):
    mpath = pjoin(circle_path, mouse)
    for idx, session in enumerate(os.listdir(mpath)):
        behav = pd.read_feather(pjoin(mpath, session))
        if any(behav['probe']):
            behav_probe = behav[behav['probe']]
            behav_no_probe = behav[~behav['probe']]
            reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
            percent_correct = ctb.lick_accuracy(behav_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
            session_pc = ctb.lick_accuracy(behav_no_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
            lick_dict_probe['mouse'].append(mouse)
            lick_dict_probe['experiment'].append(behav['cohort'].unique()[0])
            lick_dict_probe['day'].append(idx+1)
            lick_dict_probe['session'].append(np.unique(behav['session'])[0])
            lick_dict_probe['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
            lick_dict_probe['probe_acc'].append(percent_correct)
            lick_dict_probe['session_acc'].append(session_pc)
        else:
            pass
probe_df = pd.DataFrame(lick_dict_probe)

first_last = pd.DataFrame()
context_list = ['A']
for mouse in probe_df['mouse'].unique():
    mouse_data = probe_df[probe_df['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])
avg_combined = first_last.groupby(['day_type', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})

In [None]:
## Plot probe performance
fig = pf.custom_graph_template(x_title='', y_title='', titles=['A'], shared_x=True, shared_y=True)
for idx, session in enumerate(context_list):
    plot_data = avg_combined[avg_combined['session'] == session]
    fig.add_trace(go.Scatter(x=plot_data['day_type'], y=plot_data['probe_acc']['mean'],
                                error_y=dict(type='data', array=plot_data['probe_acc']['sem'], thickness=1.5, width=8), 
                                line_color=avg_color, showlegend=False), row=1, col=idx+1)
for mouse in first_last['mouse'].unique():
    mdata = first_last[first_last['mouse'] == mouse]
    for idx, context in enumerate(mdata['session'].unique()):
        pdata = mdata[mdata['session'] == context]
        fig.add_trace(go.Scatter(x=pdata['day_type'], y=pdata['probe_acc'], mode='lines', line_color=chance_color,
                                 line_width=1, opacity=0.7, showlegend=False, name=mouse), row=1, col=idx+1)

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Probe Accuracy (%)', col=1)
fig.show()

### Analyses excluding non-learners.

In [None]:
## Exclude mice that did not learn
subset_df = pd.DataFrame()
for mouse in ct_df['mouse'].unique():
    if mouse in excluded_mice:
        pass
    else:
        loop_df = ct_df[ct_df['mouse'] == mouse]
        subset_df = pd.concat([subset_df, loop_df], ignore_index=True)

In [None]:
## Lick accuracy across days
fig = pf.plot_behavior_across_days(subset_df[subset_df['day'] < 6], x_var='day', y_var='lick_acc', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=True, plot_datapoints=False,
                                   x_title='Day', y_title='Lick Accuracy (%)', titles=['Circle Track'], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'lick_acc_across_days_ct.png'))

In [None]:
## Rewards across days 
fig = pf.plot_behavior_across_days(subset_df[subset_df['day'] < 6], x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=True, plot_datapoints=False,
                                   x_title='Day', y_title='Rewards', titles=['Circle Track'], height=500, width=500)
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_days.png'))

In [None]:
## Separate by the three groups
subset_df = subset_df.replace({'five_day': '5 day', 'ten_day': '10 day', 'fifteen_day': '15 day'})
avg = subset_df.groupby(['group', 'day'], as_index=False).agg({'lick_acc': ['mean', 'sem']})
fig = pf.custom_graph_template(x_title='Day', y_title='Lick Accuracy (%)', titles=['Training'], font_size=26)
for group in ['5 day', '10 day', '15 day']:
    gdata = avg[(avg['day'] < 6) & (avg['group'] == group)]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['lick_acc']['mean'], mode='lines+markers', line_color=group_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['lick_acc']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.add_hline(y=25, line_width=1, opacity=1, line_dash='dash', line_color='darkgrey')
fig.show()
fig.write_image(pjoin(fig_path, 'three_group_lick_accuracy.png'))

In [None]:
## Separate by the three groups
subset_df = subset_df.replace({'five_day': '5 day', 'ten_day': '10 day', 'fifteen_day': '15 day'})
avg = subset_df.groupby(['group', 'day'], as_index=False).agg({'rewards': ['mean', 'sem']})
fig = pf.custom_graph_template(x_title='Day', y_title='Rewards', titles=['Training'], font_size=26)
for group in ['5 day', '10 day', '15 day']:
    gdata = avg[(avg['day'] < 6) & (avg['group'] == group)]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['rewards']['mean'], mode='lines+markers', line_color=group_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['rewards']['sem'])))
fig.show()
fig.write_image(pjoin(fig_path, 'rewards_across_days_3group.png'))

### Probe accuracy for the three groups

In [None]:
## Excluding mice
lick_dict_probe = {'mouse': [], 'group': [], 'session': [], 'day': [], 'num_licks': [], 'probe_acc': [], 'session_acc': []}
for mouse in os.listdir(circle_path):
    mpath = pjoin(circle_path, mouse)
    group = '5 day' if mouse in five_day else '10 day' if mouse in ten_day else '15 day' 
    if mouse in excluded_mice:
        pass
    else:
        for idx, session in enumerate(os.listdir(mpath)):
            behav = pd.read_feather(pjoin(mpath, session))
            if any(behav['probe']):
                behav_probe = behav[behav['probe']]
                behav_no_probe = behav[~behav['probe']]
                reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                percent_correct = ctb.lick_accuracy(behav_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
                session_pc = ctb.lick_accuracy(behav_no_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
                lick_dict_probe['mouse'].append(mouse)
                lick_dict_probe['group'].append(group)
                lick_dict_probe['day'].append(idx+1)
                lick_dict_probe['session'].append(np.unique(behav['session'])[0])
                lick_dict_probe['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
                lick_dict_probe['probe_acc'].append(percent_correct)
                lick_dict_probe['session_acc'].append(session_pc)
            else:
                pass
probe_df = pd.DataFrame(lick_dict_probe)

first_last = pd.DataFrame()
context_list = ['A']
for mouse in probe_df['mouse'].unique():
    mouse_data = probe_df[probe_df['mouse'] == mouse]
    index_list = ctb.pick_context_day(mouse_data, day_index=0, contexts=context_list)
    index_list_two = ctb.pick_context_day(mouse_data, day_index=-1, contexts=context_list)
    sub_data = mouse_data.loc[index_list, :]
    sub_data.insert(0, 'day_type', 'First')
    sub_data_two = mouse_data.loc[index_list_two, :]
    sub_data_two.insert(0, 'day_type', 'Last')
    first_last = pd.concat([first_last, sub_data, sub_data_two])
avg_combined = first_last.groupby(['day_type', 'session'], as_index=False).agg({'probe_acc': ['mean', 'sem']})
avg_group_probe = probe_df.groupby(['group', 'day'], as_index=False).agg({'probe_acc': ['mean', 'sem']})

In [None]:
## Plot probe performance for the three created groups
fig = pf.custom_graph_template(x_title='Day', y_title='', titles=['Probe Accuracy'], 
                               shared_x=True, shared_y=True, height=500, width=600, font_size=26)
for group in ['5 day', '10 day', '15 day']:
    gdata = avg_group_probe[avg_group_probe['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['probe_acc']['mean'], mode='markers',
                             error_y=dict(type='data', array=gdata['probe_acc']['sem'], thickness=1.5, width=8), 
                             line_color=group_colors_dict[group], name=group, showlegend=True, legendgroup=group))

fig.add_hline(y=25, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.add_vline(x=5.5, line_width=1, line_dash='dash', line_color='darkgrey', opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', col=1, range=[0, 100])
fig.update_xaxes(
    ticktext=['1', '5', 'Recall'],
    tickvals=[1, 5, 6],
)
fig.show()
fig.write_image(pjoin(fig_path, 'probe_accuracy_recall.png'))

### Analyze the probe across the trials during the probe.

In [None]:
lick_dict = {'mouse': [], 'experiment': [], 'group': [], 'session': [], 'day': [], 'num_licks': [], 'probe_acc': []}
for mouse in os.listdir(circle_path):
    mpath = pjoin(circle_path, mouse)
    if mouse in excluded_mice:
        pass 
    else:
        for idx, session in enumerate(os.listdir(mpath)):
            behav = pd.read_feather(pjoin(mpath, session))
            group = 'five_day' if mouse in five_day else 'ten_day' if mouse in ten_day else 'fifteen_day' 
            if any(behav['probe']):
                behav_probe = behav[behav['probe']]
                behav_no_probe = behav[~behav['probe']]
                reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
                percent_correct = ctb.lick_accuracy(behav_probe, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=True)
                lick_dict['mouse'].append(mouse)
                lick_dict['experiment'].append(behav['cohort'].unique()[0])
                lick_dict['day'].append(idx+1)
                lick_dict['group'].append(group)
                lick_dict['session'].append(np.unique(behav['session'])[0])
                lick_dict['num_licks'].append(len(behav_probe[behav_probe['lick_port'] != -1]))
                lick_dict['probe_acc'].append(percent_correct)
            else:
                pass
probe_trial_df = pd.DataFrame(lick_dict)

In [None]:
fig = pf.custom_graph_template(x_title='', y_title='', height=1100, width=1000, rows=4, columns=4,
                               titles=probe_trial_df['mouse'].unique(), shared_y=True, shared_x=True)
for idx, mouse in enumerate(probe_trial_df['mouse'].unique()):
    if mouse in excluded_mice:
        pass 
    else:
        mdata = probe_trial_df[probe_trial_df['mouse'] == mouse]
        for day in mdata['day']:
            if day == 1:
                linetype = 'dash'
            elif day == 5:
                linetype = 'solid'
            elif day == 6:
                linetype = 'dashdot'

            if idx < 4:
                row, col = 1, idx + 1
            elif (idx >= 4) & (idx < 8):
                row, col = 2, idx - 3
            elif (idx >= 8) & (idx < 12):
                row, col = 3, idx - 7
            else:
                row, col = 4, idx - 11

            day_data = mdata[mdata['day'] == day]
            x_data = np.arange(1, len(day_data['probe_acc'].values[0])+1)
            fig.add_trace(go.Scatter(x=x_data, y=day_data['probe_acc'].values[0], mode='lines', line=dict(dash=linetype), showlegend=False,
                                    line_color=group_colors_dict[day_data['group'].values[0]], name=f'Day {day}' if day < 6 else 'Recall'), 
                                    row=row, col=col)
fig.update_yaxes(title='Probe Accuracy (%)', col=1)
fig.update_xaxes(title='Trial', row=4)
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig['data'][0]['showlegend'] = True
fig['data'][1]['showlegend'] = True
fig['data'][2]['showlegend'] = True
fig.show()

### Correlate number of cfos-positive cells with probe accuracy during recall.

In [None]:
cell_dict = {'mouse': [], 'roi': [], 'num_cells': []}
cfos_path = f'../../../{parent_dir}/{experiment_dir}/brain_slices'
for mouse in os.listdir(cfos_path):
    output_path = pjoin(cfos_path, f'{mouse}/output/Austin')
    for file in os.listdir(output_path):
        if 'dapi' in file:
            pass
        else:
            data = xr.open_dataarray(pjoin(output_path, file))
            cell_dict['mouse'].append(mouse)
            cell_dict['roi'].append(data.name)
            cell_dict['num_cells'].append(data.attrs['num_cells'])
cell_df = pd.DataFrame(cell_dict)
total_cells = cell_df.groupby(['mouse'], as_index=False).agg({'num_cells': 'sum'})

for idx, mouse in enumerate(total_cells['mouse']):
    total_cells.loc[idx, 'probe_acc'] = probe_df['probe_acc'][(probe_df['mouse'] == mouse) & (probe_df['session'] == 'AR')].values[0]

In [None]:
fig = pf.custom_graph_template(x_title='Number of cfos+ Cells', y_title='Probe Accuracy (%)')
fig.add_trace(go.Scatter(x=total_cells['num_cells'], y=total_cells['probe_acc'], 
                         mode='markers', marker_color=subject_color))
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 100])
fig.show()
pg.linear_regression(X=total_cells['num_cells'], y=total_cells['probe_acc'])