In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import plotly.graph_objects as go
from os.path import join as pjoin
from natsort import natsorted

sys.path.append("../../")
import circletrack_behavior as ctb
import plotting_functions as pf

In [None]:
## Settings
parent_dir = 'MultiCon_Imaging'
experiment_dir = 'MultiCon_Imaging7'
lin_path = f'../../../{parent_dir}/{experiment_dir}/output/lin_behav/'
circle_path = f'../../../{parent_dir}/{experiment_dir}/output/behav/'
fig_path = f'../../../{parent_dir}/{experiment_dir}/intermediate_figures'
male_mice = ['mc64', 'mc65']
control_mice = ['mc61', 'mc64']
chance_color = 'darkgrey'
avg_color = 'midnightblue'
subject_color = 'darkgrey'
experimental_sessions = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_sessions = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
mf_colors = ['darkorchid', 'midnightblue']
mf_colors_dict = {'Male': mf_colors[1], 'Female': mf_colors[0]}
ce_colors = ['midnightblue', '#287347']
ce_colors_dict = {'Two-context': ce_colors[0], 'Multi-context': ce_colors[1]}
# error_color = ['rgba(153,50,204,0.4)', 'rgba(25,25,112,0.4)', 'rgba(169,169,169,0.4)'] #'rgba(169,169,169,0.4)' is gray
error_color = {'Multi-context': 'rgba(40, 115, 71, 0.4)', 'Two-context': 'rgba(25,25,112,0.4)'}

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

### Linear track - rewards across days.

In [None]:
lin_dict = {'mouse': [], 'day': [], 'sex': [], 'rewards': []}
for mouse in os.listdir(lin_path):
    mouse_path = pjoin(lin_path, mouse)
    sex = 'Male' if mouse in male_mice else 'Female'
    for idx, session in enumerate(natsorted(os.listdir(mouse_path))):
        if mouse == 'mc64':
            idx += 1
        behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        lin_dict['mouse'].append(mouse)
        lin_dict['day'].append(idx+1)
        lin_dict['sex'].append(sex)
        lin_dict['rewards'].append(np.sum(behav['water']))
lin_df = pd.DataFrame(lin_dict)

In [None]:
## Plot rewards across days on linear track
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day'], plot_transitions=None,
                                   marker_color=subject_color, avg_color=avg_color, expert_line=False, chance=False,
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.show()
fig.write_image(pjoin(fig_path, 'linear_track_rewards.png'))

In [None]:
## Plot rewards across days on linear track
fig = pf.plot_behavior_across_days(lin_df, x_var='day', y_var='rewards', groupby_var=['day', 'sex'], plot_transitions=None,
                                   marker_color=mf_colors, avg_color=avg_color, expert_line=False, chance=False, symbols=['circle', 'circle'],
                                   x_title='Day', y_title='Rewards', titles=['Linear Track'], height=500, width=500)
fig.show()

### Circle track - lick accuracy and rewards.

In [None]:
circletrack_results = {'mouse': [], 'sex': [], 'group': [], 'day': [], 'session': [], 'lick_accuracy': [], 'rewards': []}
for mouse in os.listdir(circle_path):
    mouse_path = pjoin(circle_path, mouse)
    sex = 'Male' if mouse in male_mice else 'Female'
    group = 'Two-context' if mouse in control_mice else 'Multi-context'
    for idx, session in enumerate(os.listdir(mouse_path)):
        behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        behav = behav[~behav['probe']]
        reward_one, reward_two = np.unique(behav['reward_one'])[0], np.unique(behav['reward_two'])[0]
        pc = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=False)
        circletrack_results['mouse'].append(mouse)
        circletrack_results['sex'].append(sex)
        circletrack_results['group'].append(group)
        circletrack_results['day'].append(idx+1)
        circletrack_results['session'].append(np.unique(behav['session'])[0])
        circletrack_results['lick_accuracy'].append(pc)
        circletrack_results['rewards'].append(np.sum(behav['water']))
ct_df = pd.DataFrame(circletrack_results)

In [None]:
## Plot lick accuracy across days for control vs experimental
fig = pf.plot_behavior_across_days(ct_df, x_var='day', y_var='lick_accuracy', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5], 
                                   transition_color=['darkgrey', 'darkgrey', 'darkgrey'], symbols=['circle', 'circle'],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=True, plot_datapoints=True, datapoint_type='lines',
                                   x_title='Day', y_title='Lick Accuracy (%)', titles=[''], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()

In [None]:
## Plot rewards across days for control vs experimental
fig = pf.plot_behavior_across_days(ct_df, x_var='day', y_var='rewards', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5], 
                                   transition_color=['darkgrey', 'darkgrey', 'darkgrey'], symbols=['circle', 'circle'],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=False, plot_datapoints=True, datapoint_type='lines',
                                   x_title='Day', y_title='Rewards', titles=[''], height=500, width=500)
fig.show()

### Lick accuracy across trials.

In [None]:
bin_size = 4
trial_res = {'mouse': [], 'sex': [], 'group': [], 'day': [], 'session': [], 'session_two': [], 'trial': [], 'lick_acc': []}
for mouse in os.listdir(circle_path):
    mouse_path = pjoin(circle_path, mouse)
    sex = 'Male' if mouse in male_mice else 'Female'
    group = 'Two-context' if mouse in control_mice else 'Multi-context'
    for idx, session in enumerate(os.listdir(mouse_path)):
        behav = pd.read_feather(pjoin(mouse_path, f'{session}'))
        reward_one, reward_two = behav['reward_one'].unique()[0], behav['reward_two'].unique()[0]
        trial_acc = ctb.lick_accuracy(behav, port_list=[reward_one, reward_two], lick_threshold=5, by_trials=True)
        
        if bin_size > 0:
            binned_acc = ctb.bin_data(trial_acc, bin_size)
        else:
            binned_acc = trial_acc

        for trial, val in enumerate(binned_acc):
            trial_res['mouse'].append(mouse)
            trial_res['sex'].append(sex)
            trial_res['group'].append(group)
            trial_res['day'].append(idx+1)
            trial_res['session'].append(behav['session'].unique()[0])
            trial_res['session_two'].append(behav['session_two'].unique()[0])
            trial_res['trial'].append(trial+1)
            trial_res['lick_acc'].append(val)
trial_df = pd.DataFrame(trial_res)
avg_acc = trial_df.groupby(['group', 'day', 'session', 'trial'], as_index=False).agg({'lick_acc': ['mean', 'sem']})
avg_acc_mf = trial_df.groupby(['group', 'day', 'session', 'sex', 'trial'], as_index=False).agg({'lick_acc': ['mean', 'sem']})

In [None]:
## Plot trial accuracy for each day for experimental mice
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=experimental_sessions, height=1000, width=1000,
                               shared_y=True, shared_x=True)
for group in avg_acc['group'].unique():
    if group == 'Two-context':
        pass 
    else:
        gdata = avg_acc[avg_acc['group'] == group]
        for idx, day in enumerate(gdata['day'].unique()):
            day_data = gdata[gdata['day'] == day]

            if idx < 5:
                row, col = 1, idx + 1
            elif (idx >= 5) & (idx < 10):
                row, col = 2, idx - 4
            elif (idx >= 10) & (idx < 15):
                row, col = 3, idx - 9
            else:
                row, col = 4, idx - 14

            x_data = np.arange(1, np.array(day_data['trial'])[-1]*bin_size, bin_size)
            upper = day_data['lick_acc']['mean'] + day_data['lick_acc']['sem']
            lower = day_data['lick_acc']['mean'] - day_data['lick_acc']['sem']

            fig.add_trace(go.Scatter(x=x_data, y=day_data['lick_acc']['mean'], mode='lines', line_color=ce_colors_dict[group], showlegend=False), row=row, col=col)
            fig.add_trace(go.Scatter(x=x_data, y=upper, mode='lines', marker=dict(color=error_color[group]),
                                    name='Upper Bound', line=dict(width=0), showlegend=False), row=row, col=col)
            fig.add_trace(go.Scatter(x=x_data, y=lower, mode='lines', marker=dict(color=error_color[group]),
                                    name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'), row=row, col=col)
                             
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig.update_xaxes(title='Trial', row=4)
fig.update_yaxes(range=[0, 100])
fig.show()

In [None]:
## Plot trial accuracy for each day for experimental mice
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=control_sessions, height=1000, width=1000,
                               shared_y=True, shared_x=True)
for group in avg_acc['group'].unique():
    if group == 'Multi-context':
        pass 
    else:
        gdata = avg_acc[avg_acc['group'] == group]
        for idx, day in enumerate(gdata['day'].unique()):
            day_data = gdata[gdata['day'] == day]

            if idx < 5:
                row, col = 1, idx + 1
            elif (idx >= 5) & (idx < 10):
                row, col = 2, idx - 4
            elif (idx >= 10) & (idx < 15):
                row, col = 3, idx - 9
            else:
                row, col = 4, idx - 14

            x_data = np.arange(1, np.array(day_data['trial'])[-1]*bin_size, bin_size)
            upper = day_data['lick_acc']['mean'] + day_data['lick_acc']['sem']
            lower = day_data['lick_acc']['mean'] - day_data['lick_acc']['sem']

            fig.add_trace(go.Scatter(x=x_data, y=day_data['lick_acc']['mean'], mode='lines', line_color=ce_colors_dict[group], showlegend=False), row=row, col=col)
            fig.add_trace(go.Scatter(x=x_data, y=upper, mode='lines', marker=dict(color=error_color[group]),
                                    name='Upper Bound', line=dict(width=0), showlegend=False), row=row, col=col)
            fig.add_trace(go.Scatter(x=x_data, y=lower, mode='lines', marker=dict(color=error_color[group]),
                                    name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=error_color[group], fill='tonexty'), row=row, col=col)
                             
fig.add_hline(y=25, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Lick Accuracy (%)', col=1)
fig.update_xaxes(title='Trial', row=4)
fig.update_yaxes(range=[0, 100])
fig.show()