#### <font color='crimson'>  Inference - Normality testing with shapiro

In [2]:
from scipy.stats import shapiro
def normality_testing_shapiro(df, groupby_cols, var):
    """
    Groups df according to groupby_cols and for given var, tests the normality of the 
    distribution using the Shapiro-Wilkes test.
    
    arg1, df <Pandas DataFrame>: the dataframe where the data to test is stored
    arg2, groupby_cols <list>: list of columns to group df by
    arg3, var <str>: name of column in df to test
    """
    normality_testing = {}
    
    for group, data in df.groupby(groupby_cols)[var]:
        stat, p = shapiro(data)
        normality_testing[group] = (stat, p)

    return pd.Series(normality_testing)


#### <font color='crimson'> Inference -  Testing with mannwhitneyu

In [4]:
from scipy.stats import mannwhitneyu
def inference_testing_mannwhitneyu(a, b, var):
    
    """
    """
    results = pd.DataFrame()
    i=0
    for dfa,dfb in zip(a,b):    
        results.loc[i, 'a_group'] = dfa['group'].iloc[0]
        results.loc[i, 'b_group'] = dfb['group'].iloc[0]
        results.loc[i, 'a_condition'] = dfa['stim_condition'].iloc[0]  
        results.loc[i, 'b_condition'] = dfb['stim_condition'].iloc[0]
        results.loc[i, 'a_outcome'] = dfa['outcome'].iloc[0]
        results.loc[i, 'b_outcome'] = dfb['outcome'].iloc[0]
        u, p = mannwhitneyu(x=dfa[var], y=dfb[var], alternative='two-sided')
        results.loc[i,'u']=u
        results.loc[i,'p']=p
        i+=1    
    return results

#### <font color='crimson'> Inference testing within group with kruskal Wallis test

In [None]:
 from scipy.stats import kruskal
def inference_testing_kruskal_within_group(df, var):
    
    """
    Inference testing within groups. Tests ull hypothesis that the 
    population median of all of the groups are equal. In this case,
    the compares the effect of the illumination condition within experimental group on a given 
    variable.
    arg1, df, Pandas DataFrame - DataFrame containing the data
    arg2, var, str - Column name of the dependent variable
    """
    results = pd.DataFrame()
    
    ctrl = df[df['group']=='CTRL']
    nphr = df[df['group']=='NPHR']
    i=0
    for x in [ctrl,nphr]:    
        n = x.loc[x['stim_condition']==0, var]
        s = x.loc[x['stim_condition']==1, var]
        t = x.loc[x['stim_condition']==3, var]
        
        results.loc[i, 'group'] = x['group'].iloc[0]
        h, p = kruskal(n,s,t)
        results.loc[i,'h']=h
        results.loc[i,'p']=p
        i+=1    
    return results

#### <font color='skyblue'> Visualization - Plot IQR and median across sessions

In [3]:
def plot_iqr_and_median_across_sessions(df):
    """
    Plots the average IQR limits and median across sessions, subdivided by group (row) and 
    stim_condition (column). 
    arg1, df, Pandas DataFrame - Contains the data to plot.
    """
    sns.set(style='white', context='talk')
    g=sns.relplot(kind='line', data=df, col='stim_condition',x='session_nr', y='iqr_25', hue='stim_condition',
                  palette=['black', 'royalblue', 'orangered'], ci=None,
                  row='group', legend=False)

    for group, ax in g.axes_dict.items():
            colors={0:'black', 1:'royalblue', 3:'orangered'}
            color=colors[group[1]]

            group_data = df[(df['group']==group[0]) & (df['stim_condition']==group[1])]                 
            sns.lineplot(ax=ax, data=group_data, 
                         x='session_nr', y='median', linestyle=':',
                         color=color, legend=False, ci=None)

            sns.lineplot(ax=ax, data=group_data, x='session_nr', y='iqr_75',
                         color=color, legend=False, ci=None)
    plt.ylim([1,5.5])
    sns.despine()
    return g

#### <font color='skyblue'> Visualization - Plot individual medians and IQRs across sessions

In [None]:
def plot_individual_metrics_across_sessions(df, var):
    '''
    Plots a given metric per rat across sessions
    arg1, df, Pandas DataFrame - Contains the data to plot
    arg2, df, str - Column name of the variable to plot
    '''
    plt.figure(dpi=300)
    sns.set(style='white', context='talk')
    g = sns.relplot(kind='line', data=df, row='group', col='stim_condition',
                x='session_nr', y=var, hue='rat', ci=None, legend=False,
                height=4, aspect=1.3, linewidth=1, palette=['lightgray']*17)
    
    colors={0:'black', 1:'royalblue', 3:'orangered'}
    for group, ax in g.axes_dict.items():
        print(group)
        data = df[(df['stim_condition']==group[1]) & (df['group']==group[0])]
        color=colors[group[1]]
        sns.lineplot(ax=ax, data=data, x='session_nr', y=var, estimator=np.median,
                     ci=None, color=color, legend=False)

    #g.legend_.remove()
    sns.despine()
    return g 

#### <font color='skyblue'>  Visualization - Plot distributions

In [4]:
def plot_latency_distributions_within_group(df, plot_type, var, colors):
    '''
    Plot the distributions of latencies within each experimental group.
    arg1, df, Pandas DataFrame - contains the data;
    arg2, plot_type, str - Type of distribution plot;
    arg3, var, str - The column name of the variable to be plotted;
    arg4, colors, list - List of colors to use in hue
    '''
    sns.set(style='white', context='talk')
    plt.figure(dpi=300)

    #Plot histograms (8 bins per seconds)
    if plot_type == 'hist':
        g = sns.displot(kind=plot_type, data=df, col='group', 
                        x=var, stat='probability', 
                        hue='stim_condition', 
                        col_order=['CTRL', 'NPHR'], element='step', 
                        palette=colors, 
                        bins=np.arange(0,15,.1), alpha=0, height=5, linewidth=2, aspect=1.3)
    elif plot_type == 'kde':
        g = sns.displot(kind=plot_type, data=df, col='group', 
                        x=var,  
                        hue='stim_condition', 
                        col_order=['CTRL', 'NPHR'], 
                        palette=colors, alpha=.5,
                        height=5, linewidth=2, aspect=1.3)

    g._legend.set(bbox_to_anchor=(.9,.7), title='Condition')
    g.tight_layout()
    sns.despine()    
    return g

def plot_latency_distributions_within_condition(df, plot_type, var):
    
    sns.set(style='white', context='poster')
    plt.figure(dpi=300)

    # Plot histogram ( 8 bins per second)
    if plot_type=='hist':
        g = sns.displot(
            kind=plot_type, data=df, 
            col='stim_condition', col_order=[0, 1, 3],
            x=var, stat='probability', 
            hue=below15['group'].map({'CTRL':'NPHR- ', 'NPHR':'NpHR+'}),  
            palette=['orange','lightseagreen'], 
            alpha=.1, linewidth=2,
            element='step', bins=np.arange(0,15,.125), 
            height=6, aspect=1.2
        )
    
    elif plot_type=='kde':
        g = sns.displot(
            kind=plot_type, data=df, 
            col='stim_condition', col_order=[0, 1, 3],
            x=var, 
            hue=below15['group'].map({'CTRL':'NPHR- ', 'NPHR':'NpHR+'}),  
            palette=['orange','lightseagreen'], 
            alpha=.1, linewidth=2,
            height=6, aspect=1.2
        )
    
        
    # subplot specifications
    axes = g.axes.flatten()
    axes[0].set(title="None")
    axes[1].set(title="Sample")
    axes[2].set(title="Test")

    #Legend 
    g._legend.set(bbox_to_anchor=(.95,.5), title='Group')
    g.tight_layout()
    sns.despine()
    
    return g

#### <font color='skyblue'> Visualization - Plot time cumulative distributions

In [1]:
def plot_cumulative_distributions_within_group(df, var):
    sns.set(style='white', context='poster')
    plt.figure(dpi=300)

    g = sns.displot(
        kind='ecdf', data=df, col='group', 
        x=var, stat='proportion',
        hue=df['stim_condition'].map({0:'None ', 1:'Sample', 3:'Test'}),
        col_order=['CTRL','NPHR'], palette=['orangered','royalblue','black'], 
        linewidth=2, alpha=.8, height=6, aspect=1.1
    )
    #Legend 
    g._legend.set(bbox_to_anchor=(.98,.7), title='Condition')
    g.tight_layout()
    sns.despine
    
    return g    
    
def plot_cumulative_distributions_within_condition(df, var):
        
    sns.set(style='white', context='poster')
    plt.figure(dpi=300)

    g = sns.displot(
        kind='ecdf', data=df, 
        col='stim_condition', 
        col_order=[0,1,3],
        x=var, stat='proportion',
        hue=df['group'].map({'CTRL':'NpHR- ', 'NPHR':'NpHR+'}),
        palette=['orange','lightseagreen'], 
        linewidth=1.5, legend=False
    )

    # subplot specs
    axes = g.axes.flatten()
    axes[0].set(title="None")
    axes[0].axhline(0.5, linestyle='dashed', c='black', linewidth=1.5)
    axes[1].set(title="Sample")
    axes[1].axhline(0.5, linestyle='dashed', c='black', linewidth=1.5)
    axes[2].set(title="Test")
    axes[2].axhline(0.5, linestyle='dashed', c='black', linewidth=1.5)
    g.tight_layout()
    sns.despine()
    
    return g

## QQ plot

In [5]:
def qqplot_within_condition(q1, q2, color):
    sns.set(style='white', context='talk')
    plt.figure(dpi=300, figsize=(4,4))
    g = sns.scatterplot(x=q1, y=q2, s=50, alpha=.6, color=color)
    sns.despine()
    return g

In [7]:
def qqplot_within_group(q1, q2, q3, colors):
    sns.set(style='white', context='talk')
    plt.figure(dpi=300, figsize=(3,3))
    g = sns.scatterplot(x=q1, y=q2, s=50, alpha=.6, color=colors[0])
    sns.scatterplot(x=q1, y=q3, s=50, alpha=.6, color=colors[1])
    sns.despine()
    return g

## Calculate session performances differences ( Illumin. vs. No Illumin.)

In [None]:
def calc_session_performance_diffs(df):

    df_none = df[df['stim_condition']==0].sort_values(by='session_nr').reset_index(drop=True)
    df_sample = df[df['stim_condition']==1].sort_values(by='session_nr').reset_index(drop=True)
    df_test = df[df['stim_condition']==3].sort_values(by='session_nr').reset_index(drop=True) 
    
    df_test['test_none_diff']=df_test['performance']-df_none['performance']
    
    if df['rat'].iloc[0]=='CTRL2':
        df_test['sample_none_diff']=np.nan
    else:
        df_test['sample_none_diff']=df_sample['performance']-df_none['performance']
    
    final_df = df_test.drop(columns=['performance', 'stim_condition'])
    return final_df

####  <font color='violet'> Add session number

In [1]:
def add_session_nr(group):
    """
    Add a session number label to each session in group.
    arg1, group, Pandas DataFrame - contains the data to add the session number to.
    """
    session_list = np.sort(group['session'].unique())
    i=1  
    for session in session_list:
        group.loc[group['session']==session,'session_nr']=i
        i+=1
    return group

####  <font color='violet'>Linear regression slopes

In [None]:
from sklearn.linear_model import LinearRegression

def get_linear_regression_slopes(df, targetcol):
    """
    Collects the slopes after performing a linear regression. Session nr as the feature 
    and targetcol as the target.
    arg1, df, Pandas DataFrame - contains the data on which to perform the regression
    arg2, targetcol, str - Name of the target column
    """
    lm = LinearRegression().fit(df[['session_nr']], df[[targetcol]])
    return lm.coef_[0][0]

####  <font color='violet'> Calculate the IQR values

In [1]:
def calculate_iqr(df, var): 
    """
    Calculates the IQR limits, the median and the IQR.
    arg1, df, Pandas DataFrame - DataFrame that contains the data;
    arg2, var, str - Name of the column from which to calculate the metrics.
    """
    iqr_values = np.percentile(df[var], [25,50,75])
    df['iqr_25']=iqr_values[0]
    df['median']=iqr_values[1]
    df['iqr_75']=iqr_values[2]
    df['iqr']=iqr_values[2]-iqr_values[0]
    cols_to_keep = ['group', 'rat', 'outcome', 'stim_condition',
                    'session_nr','iqr_25', 'median', 'iqr_75', 'iqr']
    return df[cols_to_keep]
    