In [2]:
%matplotlib qt5
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import colors, cm
from utils.mean_trace_utils import make_dates_pretty

In [32]:
def get_initiation_times(behavioural_stats_for_day, num_trials):
    trial_start_times = []
    for trial_num in range(num_trials):
        if 'Port2In' in behavioural_stats_for_day.iloc[trial_num]['TrialEvents']:
            initiation_time = behavioural_stats_for_day.iloc[trial_num]['TrialEvents']['Port2In']
            if type(initiation_time) == np.ndarray:
                initiation_time = initiation_time[0]
            trial_start_times.append(initiation_time)
    return(trial_start_times)

In [57]:
def get_session_stats_saturation(behavioural_stats, date):
    date = make_dates_pretty(date)[0]
    trials = behavioural_stats[behavioural_stats['SessionTime'].str.contains(date)]
    duration = trials.iloc[-1]['TrialStartTimestamp'] - trials.iloc[0]['TrialStartTimestamp']
    number_of_trials = trials.shape[0]
    trial_rate = number_of_trials/ (duration/60/60)
    if trials.iloc[0]['Contingency'] == 2:
        trial_side_right = 1
        trial_side_left = 2
    else: 
        trial_side_right = 2
        trial_side_left = 1
    rightward_choices = trials.loc[(trials['FirstPoke']==2)]
    rightward_persistance = rightward_choices.loc[(rightward_choices['TrialSide']==trial_side_right)].shape[0]/ trials.loc[(trials['TrialSide']==trial_side_right)].shape[0]
    leftward_choices = trials.loc[(trials['FirstPoke']==1)]
    leftward_persistance = leftward_choices.loc[(leftward_choices['TrialSide']==trial_side_left)].shape[0]/ trials.loc[(trials['TrialSide']==trial_side_left)].shape[0]
    habit_persistance = np.mean([leftward_persistance, rightward_persistance])
    response_times = trials['ResponseTime'].values
    initiation_times = get_initiation_times(trials, number_of_trials)
    return(trial_rate, habit_persistance, response_times, initiation_times)

In [124]:
def all_metrics_one_mouse_and_plot(mouse, saturation_day, day_before):
    saturation_date = [saturation_day]
    day_before_date = [day_before]
    BpodProtocol = '/Two_Alternative_Choice/'
    GeneralDirectory = 'W:/photometry_2AC/bpod_data/'
    DFfile = GeneralDirectory + mouse + BpodProtocol + 'Data_Analysis/' + mouse + '_dataframe.pkl'
    behavioural_stats = pd.read_pickle(DFfile)
    
    (saturation_trial_rate, saturation_rightward_choices, saturation_response_times, saturation_initiation_times)=get_session_stats_saturation(behavioural_stats, saturation_date) 
    (day_before_trial_rate, day_before_rightward_choices, day_before_response_times, day_before_initiation_times) = get_session_stats_saturation(behavioural_stats, day_before_date)
    
    fig, axs = plt.subplots(2, ncols=2, figsize=(10, 8))
    fig.subplots_adjust(hspace=0.5, wspace=0.2)
    lower_limit = min(min(saturation_initiation_times), min(day_before_initiation_times))
    upper_limit = max(max(saturation_initiation_times), max(day_before_initiation_times))
    axs[0, 0].set_title('Time to initiate trial')
    saturation_initiation_histogram = axs[0, 0].hist(saturation_initiation_times, 200, range=(lower_limit, upper_limit), alpha=0.5,density=True)
    day_before_initiation_histogram = axs[0, 0].hist(day_before_initiation_times, 200,range=(lower_limit, upper_limit), alpha=0.5,density=True)
    
    lower_limit = min(min(saturation_response_times), min(day_before_response_times))
    upper_limit = max(max(saturation_response_times), max(day_before_response_times))
    axs[0, 1].set_title('Response time')
    saturation_response_histogram = axs[0, 1].hist(saturation_response_times, 100, range=(lower_limit, upper_limit), alpha=0.4,density=True)
    day_before_response_histogram = axs[0, 1].hist(day_before_response_times, 100, range=(lower_limit, upper_limit), alpha=0.4, density=True)
    
    axs[1, 0].set_title('Trials per hour')
    axs[1, 0].bar([0,1], [day_before_trial_rate, saturation_trial_rate])
    
    axs[1, 1].set_title('Proportion of choices to original contingency')
    axs[1, 1].bar([0,1], [day_before_rightward_choices, saturation_rightward_choices])
    
    saturation_data = {'Trial rate': saturation_trial_rate, 'Median initiation time': np.median(saturation_initiation_times),
                      'Median reaction time': np.median(saturation_initiation_times), 'Rightward choices': saturation_rightward_choices}
    day_before_data = {'Trial rate': day_before_trial_rate, 'Median initiation time': np.median(day_before_initiation_times),
                      'Median reaction time': np.median(day_before_initiation_times), 'Rightward choices': day_before_rightward_choices}

    return(saturation_data, day_before_data)

In [70]:
def get_number_of_trials_before_session(behavioural_stats, date):
    date = make_dates_pretty(date)[0]
    trials = behavioural_stats[behavioural_stats['SessionTime'].str.contains(date)]
    trials_before_session = trials.index[0] -1
    return(trials_before_session)

In [69]:
def all_metrics_one_mouse(mouse, saturation_day, day_before):
    saturation_date = [saturation_day]
    day_before_date = [day_before]
    BpodProtocol = '/Two_Alternative_Choice/'
    GeneralDirectory = 'W:/photometry_2AC/bpod_data/'
    DFfile = GeneralDirectory + mouse + BpodProtocol + 'Data_Analysis/' + mouse + '_dataframe.pkl'
    behavioural_stats = pd.read_pickle(DFfile)
    
    (saturation_trial_rate, saturation_rightward_choices, saturation_response_times, saturation_initiation_times)=get_session_stats_saturation(behavioural_stats, saturation_date) 
    (day_before_trial_rate, day_before_rightward_choices, day_before_response_times, day_before_initiation_times) = get_session_stats_saturation(behavioural_stats, day_before_date)
    number_of_trials_pre_saturation = get_number_of_trials_before_session(behavioural_stats, saturation_date)
    
    saturation_data = {'Trial rate': saturation_trial_rate, 'Median initiation time': np.median(saturation_initiation_times),
                      'Median reaction time': np.median(saturation_response_times), 'Rightward choices': saturation_rightward_choices, 'Trials done before session': number_of_trials_pre_saturation}
    day_before_data = {'Trial rate': day_before_trial_rate, 'Median initiation time': np.median(day_before_initiation_times),
                      'Median reaction time': np.median(day_before_response_times), 'Rightward choices': day_before_rightward_choices}

    return(saturation_data, day_before_data)

In [52]:
class behaviouralStats(object):
    def __init__(self, mouse_id, saturation_day, day_before_saturation, contingency_switch_day, day_before_contingency_switch):
        self.mouse = mouse_id
        self.saturation_day = saturation_day
        self.day_before_saturation = day_before_saturation
        self.contingency_switch_day = contingency_switch_day
        self.day_before_contingency_switch = day_before_contingency_switch
    def addSaturationStats(self):
        self.saturation_stats, self.day_before_saturation_stats = all_metrics_one_mouse(self.mouse, self.saturation_day, self.day_before_saturation)
    def addContingencyStats(self):
        self.contingency_stats, self.day_before_contingency_stats = all_metrics_one_mouse(self.mouse, self.contingency_switch_day, self.day_before_contingency_switch)
        

In [58]:
mouse_days = {'SNL_photo16':{'saturation_day': '20200318', 'day_before_saturation': '20200317', 
                             'contingency_switch_day': '20200320',  'day_before_contingency_switch': '20200317'},
              'SNL_photo17':{'saturation_day': '20200311', 'day_before_saturation': '20200310', 
                             'contingency_switch_day': '20200313',  'day_before_contingency_switch': '20200310'},
              'SNL_photo18':{'saturation_day': '20200318', 'day_before_saturation': '20200317', 
                             'contingency_switch_day': '20200320',  'day_before_contingency_switch': '20200317'},
              'SNL_photo15':{'saturation_day': '20200312', 'day_before_saturation': '20200311', 
                             'contingency_switch_day': '20200316',  'day_before_contingency_switch': '20200315'},
              'SNL_photo12':{'saturation_day': '20200207', 'day_before_saturation': '20200205', 
                             'contingency_switch_day': '20200307',  'day_before_contingency_switch': '20200311'},
              'SNL_photo19':{'saturation_day': '20200318', 'day_before_saturation': '20200317', 
                             'contingency_switch_day': '20200320',  'day_before_contingency_switch': '20200317'},
              'SNL_photo20':{'saturation_day': '20200318', 'day_before_saturation': '20200317', 
                             'contingency_switch_day': '20200320',  'day_before_contingency_switch': '20200317'}}

In [38]:
mouse_days = {'SNL_photo12':{'saturation_day': '20200207', 'day_before_saturation': '20200205', 
                             'contingency_switch_day': '20200307',  'day_before_contingency_switch': '20200311'}}

In [74]:
all_mouse_stats = []
for mouse, days in mouse_days.items():
    mouse_stats = behaviouralStats(mouse, days['saturation_day'], days['day_before_saturation'],
                                   days['contingency_switch_day'], days['day_before_contingency_switch'])
    mouse_stats.addSaturationStats()
    mouse_stats.addContingencyStats()
    all_mouse_stats.append(mouse_stats)

Plotting


In [116]:
num_types = len(all_mouse_stats)
colours = cm.Set2(np.linspace(0, 0.8, num_types))
fig, axs = plt.subplots(2, ncols=4, figsize=(12, 10))
fig.subplots_adjust(hspace=0.6, wspace=0.5)
keys = ['Median reaction time', 'Median initiation time', 'Trial rate', 'Rightward choices']
legends = ['Median reaction time (s)', 'Median initiation time (s)', 'Trials per hour', 'Proportion of turns to original contingency']
fig.suptitle('Change in behaviour caused by saturation (top), and contingency switch (bottom)')

for mouse_num, mouse_stats in enumerate(all_mouse_stats):
    for key_num, ax in enumerate(axs[0, :]): 
        change_in_metric = mouse_stats.day_before_saturation_stats[keys[key_num]] - mouse_stats.saturation_stats[keys[key_num]]
        ax.scatter(mouse_stats.saturation_stats['Trials done before session'], change_in_metric, color=colours[mouse_num], alpha=0.8, label=mouse_stats.mouse)
        ax.axhline([0], color='gray')
        ax.set_xlabel('Trials done before session')
        ax.set_ylabel('Absolute change')
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=20)
        ax.set_title(legends[key_num])
        
    for key_num, ax in enumerate(axs[1, :]): 
        change_in_metric = mouse_stats.day_before_contingency_stats[keys[key_num]] - mouse_stats.contingency_stats[keys[key_num]]
        ax.scatter(mouse_stats.contingency_stats['Trials done before session'], change_in_metric, color=colours[mouse_num], alpha=0.8, label=mouse_stats.mouse)    
        ax.axhline([0], color='gray')
        ax.set_xlabel('Trials done before session')
        ax.set_ylabel('Absolute change')
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=20)
        ax.set_title(legends[key_num])
        
handles, labels = axs[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(labels))
plt.show() 

In [118]:
num_types = len(all_mouse_stats)
colours = cm.Set2(np.linspace(0, 0.8, num_types))
fig, axs = plt.subplots(2, ncols=4, figsize=(12, 10))
fig.subplots_adjust(hspace=0.6, wspace=0.5)
keys = ['Median reaction time', 'Median initiation time', 'Trial rate', 'Rightward choices']
legends = ['Median reaction time (s)', 'Median initiation time (s)', 'Trials per hour', 'Proportion of turns to original contingency']
fig.suptitle('Percentage change in behaviour caused by saturation (top), and contingency switch (bottom)')

for mouse_num, mouse_stats in enumerate(all_mouse_stats):
    for key_num, ax in enumerate(axs[0, :]): 
        change_in_metric = (mouse_stats.day_before_saturation_stats[keys[key_num]] - mouse_stats.saturation_stats[keys[key_num]])/mouse_stats.day_before_saturation_stats[keys[key_num]] *-100
        ax.scatter(mouse_stats.saturation_stats['Trials done before session'], change_in_metric, color=colours[mouse_num], alpha=0.8, label=mouse_stats.mouse)
        ax.axhline([0], color='gray')
        ax.set_xlabel('Trials done before session')
        ax.set_ylabel('Percentage change')
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=20)
        ax.set_title(legends[key_num])

        
    for key_num, ax in enumerate(axs[1, :]): 
        change_in_metric = (mouse_stats.day_before_contingency_stats[keys[key_num]] - mouse_stats.contingency_stats[keys[key_num]])/mouse_stats.day_before_contingency_stats[keys[key_num]] * -100
        ax.scatter(mouse_stats.contingency_stats['Trials done before session'], change_in_metric, color=colours[mouse_num], alpha=0.8, label=mouse_stats.mouse)    
        ax.axhline([0], color='gray')
        ax.set_xlabel('Trials done before session')
        ax.set_ylabel('Percentage change')
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=20)
        ax.set_title(legends[key_num])

handles, labels = axs[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(labels))
plt.show() 

In [119]:
num_types = len(all_mouse_stats)
colours = cm.Set2(np.linspace(0, 0.8, num_types))
fig, axs = plt.subplots(2, ncols=4, figsize=(12, 10))
fig.subplots_adjust(hspace=0.6, wspace=0.5)
fig.suptitle('Behaviour during and before saturation (top), and contingency switch (bottom) sessions')
for mouse_num, mouse_stats in enumerate(all_mouse_stats):
    
    axs[0, 0].set_title('Response time')
    axs[0, 0].scatter(1, mouse_stats.saturation_stats['Median reaction time'], alpha=0.8, color=colours[mouse_num], label = mouse_stats.mouse)
    axs[0, 0].scatter(0, mouse_stats.day_before_saturation_stats['Median reaction time'], alpha=0.8, color=colours[mouse_num])
    axs[0, 0].plot([0, 1], [mouse_stats.day_before_saturation_stats['Median reaction time'], mouse_stats.saturation_stats['Median reaction time']], color=colours[mouse_num])
    axs[0, 0].set_xticks([0, 1])
    axs[0, 0].set_xticklabels(['normal', 'saturation'], rotation=20)
    axs[0, 0].set_xlim([-0.3, 1.3])
    axs[0, 0].set_ylabel('Median response time (s)')
    
    axs[0, 1].set_title('Initiation time')
    axs[0, 1].scatter(1, mouse_stats.saturation_stats['Median initiation time'], alpha=0.8, color=colours[mouse_num])
    axs[0, 1].scatter(0, mouse_stats.day_before_saturation_stats['Median initiation time'], alpha=0.8, color=colours[mouse_num])
    axs[0, 1].plot([0, 1], [mouse_stats.day_before_saturation_stats['Median initiation time'], mouse_stats.saturation_stats['Median initiation time']], color=colours[mouse_num])
    axs[0, 1].set_xticks([0, 1])
    axs[0, 1].set_xticklabels(['normal', 'saturation'], rotation=20)
    axs[0, 1].set_xlim([-0.3, 1.3])
    axs[0, 1].set_ylabel('Median time to initiate trial (s)')
    
    axs[0, 2].set_title('Trials per hour')
    axs[0, 2].scatter(1, mouse_stats.saturation_stats['Trial rate'], alpha=0.8, color=colours[mouse_num])
    axs[0, 2].scatter(0, mouse_stats.day_before_saturation_stats['Trial rate'], alpha=0.8, color=colours[mouse_num])
    axs[0, 2].plot([0, 1], [mouse_stats.day_before_saturation_stats['Trial rate'], mouse_stats.saturation_stats['Trial rate']], color=colours[mouse_num])
    axs[0, 2].set_xticks([0, 1])
    axs[0, 2].set_xticklabels(['normal', 'saturation'], rotation=20)
    axs[0, 2].set_xlim([-0.3, 1.3])
    axs[0, 2].set_ylabel('Trials per hour')
    
    axs[0, 3].set_title('Habit persistence')
    axs[0, 3].scatter(1, mouse_stats.saturation_stats['Rightward choices'], alpha=0.8, color=colours[mouse_num])
    axs[0, 3].scatter(0, mouse_stats.day_before_saturation_stats['Rightward choices'], alpha=0.8, color=colours[mouse_num])
    axs[0, 3].plot([0, 1], [mouse_stats.day_before_saturation_stats['Rightward choices'], mouse_stats.saturation_stats['Rightward choices']], color=colours[mouse_num])
    axs[0, 3].set_xticks([0, 1])
    axs[0, 3].set_xticklabels(['normal', 'saturation'], rotation=20)
    axs[0, 3].set_xlim([-0.3, 1.3])
    axs[0, 3].set_ylabel('Proportion of turns to original contingency')
    
    axs[1, 0].scatter(1, mouse_stats.contingency_stats['Median reaction time'], alpha=0.8, color=colours[mouse_num])
    axs[1, 0].scatter(0, mouse_stats.day_before_contingency_stats['Median reaction time'], alpha=0.8, color=colours[mouse_num])
    axs[1, 0].plot([0, 1], [mouse_stats.day_before_contingency_stats['Median reaction time'], mouse_stats.contingency_stats['Median reaction time']], color=colours[mouse_num])
    axs[1, 0].set_xticks([0, 1])
    axs[1, 0].set_xticklabels(['normal', 'contingency switch'], rotation=20)
    axs[1, 0].set_xlim([-0.3, 1.3])
    axs[1, 0].set_ylabel('Median response time (s)')
    
    axs[1, 1].scatter(1, mouse_stats.contingency_stats['Median initiation time'], alpha=0.8, color=colours[mouse_num])
    axs[1, 1].scatter(0, mouse_stats.day_before_contingency_stats['Median initiation time'], alpha=0.8, color=colours[mouse_num])
    axs[1, 1].plot([0, 1], [mouse_stats.day_before_contingency_stats['Median initiation time'], mouse_stats.contingency_stats['Median initiation time']], color=colours[mouse_num])
    axs[1, 1].set_xticks([0, 1])
    axs[1, 1].set_xticklabels(['normal', 'contingency switch'], rotation=20)
    axs[1, 1].set_xlim([-0.3, 1.3])
    axs[1, 1].set_ylabel('Median time to initiate trial (s)')
    
    axs[1, 2].scatter(1, mouse_stats.contingency_stats['Trial rate'], alpha=0.8, color=colours[mouse_num])
    axs[1, 2].scatter(0, mouse_stats.day_before_contingency_stats['Trial rate'], alpha=0.8, color=colours[mouse_num])
    axs[1, 2].plot([0, 1], [mouse_stats.day_before_contingency_stats['Trial rate'], mouse_stats.contingency_stats['Trial rate']], color=colours[mouse_num])
    axs[1, 2].set_xticks([0, 1])
    axs[1, 2].set_xticklabels(['normal', 'contingency switch'], rotation=20)
    axs[1, 2].set_xlim([-0.3, 1.3])
    axs[1, 2].set_ylabel('Trials per hour')
    
    axs[1, 3].scatter(1, mouse_stats.contingency_stats['Rightward choices'], alpha=0.8, color=colours[mouse_num])
    axs[1, 3].scatter(0, mouse_stats.day_before_contingency_stats['Rightward choices'], alpha=0.8, color=colours[mouse_num])
    axs[1, 3].plot([0, 1], [mouse_stats.day_before_contingency_stats['Rightward choices'], mouse_stats.contingency_stats['Rightward choices']], color=colours[mouse_num])
    axs[1, 3].set_xticks([0, 1])
    axs[1, 3].set_xticklabels(['normal', 'contingency switch'], rotation=20)
    axs[1, 3].set_xlim([-0.3, 1.3])
    axs[1, 3].set_ylabel('Proportion of turns to original contingency')
    
handles, labels = axs[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(labels))
plt.show()

In [125]:
all_metrics_one_mouse_and_plot('SNL_photo12', '20200207', '20200205')

({'Trial rate': 28.650801666591263,
  'Median initiation time': 57.78425,
  'Median reaction time': 57.78425,
  'Rightward choices': 0.6235294117647059},
 {'Trial rate': 835.2243832004373,
  'Median initiation time': 1.866,
  'Median reaction time': 1.866,
  'Rightward choices': 0.941226073024707})

In [14]:
mouse = 'SNL_photo12'
BpodProtocol = '/Two_Alternative_Choice/'
GeneralDirectory = 'W:/photometry_2AC/bpod_data/'
DFfile = GeneralDirectory + mouse + BpodProtocol + 'Data_Analysis/' + mouse + '_dataframe.pkl'
behavioural_stats = pd.read_pickle(DFfile)

In [40]:
date = make_dates_pretty(['20200207'])[0]
trials = behavioural_stats[behavioural_stats['SessionTime'].str.contains(date)]

In [43]:
trials['ResponseTime'].values

array([ 0.4752,  7.9246,  0.5839,  1.1635,  0.9405,  0.6166, 10.    ,
        1.0532,  1.6364, 10.    ,  0.7407,  0.8036,  0.6444,  0.6819,
        0.5042,  4.1901,  0.4982,  0.4571,  0.3623, 10.    ,  0.494 ,
        0.447 ,  0.8267,  0.6866,  0.6393,  0.7034, 10.    ,  0.5431,
       10.    ,  2.1687, 10.    ,  0.7792])

In [68]:
trials.index

Int64Index([8413, 8414, 8415, 8416, 8417, 8418, 8419, 8420, 8421, 8422, 8423,
            8424, 8425, 8426, 8427, 8428, 8429, 8430, 8431, 8432, 8433, 8434,
            8435, 8436, 8437, 8438, 8439, 8440, 8441, 8442, 8443, 8444],
           dtype='int64')

In [261]:
behavioural_stats.loc[(behavioural_stats['FirstPoke']==1)].shape[0]/behavioural_stats.loc[(behavioural_stats['TrialSide']==1)].shape[0]

1.131448617416426

In [None]:
trial_data.loc[(trial_data['State type'] == params.state)]