###  Performance functions used in performance analysis notebooks

In [16]:
%matplotlib inline
%run "C:\Users\anasofiaccruz\Desktop\cingulate_silencing_repo\general_purpose_functions.ipynb"

In [8]:
# Import modules 
import pandas as pd
import numpy as np
import seaborn as sns
import os
import re
import matplotlib.pyplot as plt

### For pooled performance calculation and plots (A1)

In [9]:
def calc_performance_in_group(group):
    
    '''
    Calculate the performance for group
    Arg1, group, GroupBy object
    Return: group
    '''

    group = group.dropna()
    n_correct_trials = len(group[group['outcome'] == 1])
    n_total_trials = len(group)    
    group_performance = (n_correct_trials / n_total_trials)*100
    
    return group_performance


def create_pooled_performances_plot(df):
    '''
    Plots the pooled performances given group and stim condition
    Arg1 - df, Pandas DataFrame 
    '''
    sns.set(style="white", context='talk')    
    plt.figure(figsize=(5,3), dpi=300)

    g = sns.stripplot(
        data=df, x='group', y='performance', 
        hue=df['stim_condition'].map({0:'None', 1:'Sample', 3:'Test'}),
        hue_order=['None','Sample', 'Test'], dodge=2, jitter=.15, alpha=.2, s=7, 
        palette=['black', 'royalblue', 'orangered'], order=['CTRL', 'NPHR']
    )
    # Show means
    sns.pointplot(
        x='group', y='performance', hue='stim_condition',
        data=df, dodge=.55, join=False, markers="o", scale=.8, 
        estimator=np.mean, ci="sd", errwidth=1.4, palette=['black','royalblue', 'orangered'],
        order=['CTRL', 'NPHR'], legend=False
    )

    g.set(
        ylabel='Performance(%) +/- sd.', xlabel='Experimental group',
        ylim=(40, 100), xlim=(-0.6, 1.5), xticks=[0,1], xticklabels= ['NpHR-', 'NpHR+'],
    )

    g.legend_.remove()
    sns.despine()
    
    
def plot_pooled_performances(df):
    
    #Calculates the performances for each rat, session and condition
    session_perfs = df.groupby(['group','rat','session', 'stim_condition']).apply(
    calc_performance_in_group)
    session_perfs = session_perfs.reset_index().rename(columns={0:'performance'})
    
    # Calculate the overall pooled performance (all sessions combined)
    pooled_perfs = session_perfs.groupby(
        ['group', 'rat', 'stim_condition']
    )['performance'].mean().reset_index()
    
    #Plot the performances
    create_pooled_performances_plot(pooled_perfs)
    
    return pooled_perfs

### To calculate and plot Performance variations within group (A1)

In [15]:
def calculate_performance_differences(df, conditions):
    '''
    Calculates the performance differences between given conditions
    Arg1, df, Pandas DataFrame
    Arg2, conditions, list
    Return diffs, Pandas DataFrame
    '''
    
    a = pooled_perfs[pooled_perfs['stim_condition']==conditions[0]].reset_index(drop=True)
    b = pooled_perfs[pooled_perfs['stim_condition']==conditions[1]].reset_index(drop=True)
    diff = b['performance']-a['performance']
    diffs = pd.concat([a[['group', 'rat']], diff], axis=1)
    
    return diffs
def create_performance_variation_plot(df, color):
    '''
    Plot the performance variation 
    Arg1, df, Pandas DataFrame
    Arg2, color, str : color of the plot
    '''
    
    plt.Figure(figsize=(7,7))
    sns.set(style='white', context='talk')

    # Add bars
    g = sns.barplot(data=df, x='group', y='performance',
               ci='sd', dodge=True, errwidth=3, palette=[color, color], alpha=.3)

    # Add individual dots
    g = sns.stripplot(
        x='group', y='performance', hue='group',
        data=df, order=['CTRL', 'NPHR'],
        alpha=.7, s=10,
        palette=[color, color],
    )

    # Add horizontal line (= zero deviation)
    plt.axhline(0, 0, 2, linestyle='dashed', color='black')

    # Add specs
    g.set(ylabel='Performance variation', xlabel='Experimental group',
         xticks=[0,1], xticklabels=['NpHR-', 'NpHR+'], xlim=[-.8,1.8], ylim=[-20, 10])
    g.legend_.remove()
    sns.despine()
    
def plot_performance_variation_within_each_group(df, conditions):
    
    '''
    Plots the performance variation between two conditions within each group
    Arg1 - df, pandas DataFrame : containing the pooled performances
    Arg2 - conditions, list : list of conditions. 
    The order of conditions influences the variation calculation. Use 0 as first.
    Return void
    '''

    diffs=calculate_performance_differences(df, conditions)
    colors={1:'royalblue', 3:'orangered'}
    
    create_performance_variation_plot(diffs, colors[conditions[1]])
    return diffs
    

### Performance across sessions (A2)

In [3]:
def create_performance_across_sessions_plot(df):
    '''
    Plot the performance across sessions subdivided by group and stim condition
    Arg1 - df, Pandas DataFrame
    '''
    sns.set(style="white", context='talk')
    plt.figure(dpi=300)
    g = sns.relplot(
        kind='line', data=df, col='stim_condition',
        row='group',x='session_nr', y='performance', 
        hue=df['stim_condition'].map({0:'None', 1:'Sample',3:'Test'}), 
        ci='sd', linewidth=1.5, palette=['black', 'orangered', 'royalblue'], 
        height=4, aspect=1.4, err_style='bars')

    color_dict={0:'black', 1:'royalblue', 3:'orangered'}
    for group, ax in g.axes_dict.items():

        data=df[(df['stim_condition']==group[1])&(df['group']==group[0])]
        colors = [color_dict[group[1]]]*len(data['rat'].unique())

        sns.lineplot(ax=ax, data=data, x='session_nr', y='performance',
                     hue='rat', palette=colors, legend=False, alpha=.3, linewidth=1)
        if group[0]=='CTRL':
            ax.set(title='NpHR-')
        else:
            ax.set(title='NpHR+')
            
    g.set(
        xlim=(0,16), ylim=(0,110),
        ylabel='Avg.Performance(%)+/-sd', xlabel='Session',
        xticks=range(1,16,2), xticklabels=range(1,16,2))
    g._legend.set(bbox_to_anchor=[1.05,.8], title='Condition')
    sns.despine()

def plot_performances_across_sessions(df):
    '''
    Caulculates the performances across sessions by adding a session number and plots them
    subdivided by group and condition
    Arg1, path, str: Path to data file
    Returns: Pandas DataFrame
    '''
      
    # Calculate performances for each rat, session and stimulation condition
    session_perfs = df.groupby(['group','rat','session_nr', 'stim_condition']).apply(
        calc_performance_in_group
    )
    session_perfs = session_perfs.reset_index().rename(columns={0:'performance'})
    session_perfs = session_perfs[session_perfs['performance']!=0]
    #Plot the performances
    create_performance_across_sessions_plot(session_perfs)
    
    return session_perfs


### Performance variation across sessions

In [19]:
def plot_performance_variation_across_sessions(df, y_col, color):

    plt.figure(dpi=300)
    sns.set(style="white", context='talk')

    g = sns.relplot(
        data=df, 
        kind='line', col='group', hue='rat',
        x='session_nr', y=y_col,  err_style='bars',
        linewidth=1, col_order=['CTRL','NPHR'], alpha=.2,
        height=4, aspect=1.3, palette=[color]*17, legend=False
    )

    for group, ax in g.axes_dict.items():
        sns.lineplot(data=performance_diff[performance_diff['group']==group], 
                     x='session_nr', y=y_col,
                     estimator=np.mean, ci='sd', err_style='bars', 
                     color=color, linewidth=2, ax=ax)

        ax.axhline(linewidth=1, color='black', linestyle='dashed')

    g.set(xlabel='Session', ylabel='Performance variation (%)+/-sd', 
          xticks=range(1,16,2), xticklabels=range(1,16,2),
          xlim=[0,16], ylim=[-70,50])
    ax=g.axes.flatten()
    ax[0].set(title='NpHR-')
    ax[1].set(title='NpHR+')
    sns.despine()