In [None]:
def calc_performance_and_median_latencies_in_group(group):
    
    '''
    Calculate the performance for a given group (GroupBy object)
    '''

    group = group.dropna()
    n_correct_trials = len(group[group['outcome'] == 1])
    n_total_trials = len(group)    
    group['performance'] = (n_correct_trials / n_total_trials)*100
    group['median_latency_to_cp']=np.median(group['latency_to_cp_entry'])
    group['median_time_in_cp']=np.median(group['time_in_cp'])
    
    # Return one median latency and performance value per group (removes repeated data)
    return group.head(1)

In [2]:
def plot_iqr_and_median_across_sessions(df):
    """
    Plots the average IQR limits and median across sessions, subdivided by group (row) and 
    stim_condition (column). 
    arg1, df, Pandas DataFrame - Contains the data to plot.
    """
    plt.figure(dpi=300)
    sns.set(context='talk', style='white')
    g=sns.relplot(
        kind='line',data=df, 
        col='group', row='stim_condition', 
        hue='stim_condition',
        y='50%', x='session_nr', ci=None, height=4, aspect=1.5,
        palette=['black', 'royalblue', 'orangered'], legend=False)

    x=range(1,16)
    g.set(xlim=[1,15], ylim=(1.5,5.5), xlabel='session', ylabel='I.Q.R')
    colors = {0:'black', 1:'royalblue', 3:'orangered'}

    for group, ax in g.axes_dict.items():
        group_iqr=iqr[(iqr['group']==group[1])&(iqr['stim_condition']==group[0])]
        color = colors[group[0]]
        ax.fill_between(x, group_iqr['25%'], group_iqr['75%'], alpha=.4, color=color)

    ax = g.axes.flatten()  
    ax[0].set(title='NpHR-')
    ax[1].set(title='NpHR+')
    for i in range(2,6):
        ax[i].set(title='')
    return g

In [3]:
def plot_individual_metrics_across_sessions(df, var):
    '''
    Plots a given metric per rat across sessions
    arg1, df, Pandas DataFrame - Contains the data to plot
    arg2, df, str - Column name of the variable to plot
    '''
    plt.figure(dpi=300)
    sns.set(style='white', context='talk')
    g = sns.relplot(
        kind='line', data=df, row='group', col='stim_condition',
        x='session_nr', y=var, hue='rat', ci='sd', estimator=np.mean,
        legend=False, height=4, aspect=1.3, linewidth=1, palette=['lightgray']*17)
    
    colors={0:'black', 1:'royalblue', 3:'orangered'}
    for group, ax in g.axes_dict.items():
        
        data = df[(df['stim_condition']==group[1]) & (df['group']==group[0])]
        color=colors[group[1]]
        sns.lineplot(ax=ax, data=data, x='session_nr', y=var, estimator=np.median,
                     ci=None, color=color, legend=False)

    #g.legend_.remove()
    sns.despine()
    return g 

In [None]:
def plot_latency_distributions_within_group(df, plot_type, var, colors):
    '''
    Plot the distributions of latencies within each experimental group.
    arg1, df, Pandas DataFrame - contains the data;
    arg2, plot_type, str - Type of distribution plot;
    arg3, var, str - The column name of the variable to be plotted;
    arg4, colors, list - List of colors to use in hue
    '''
    sns.set(style='white', context='talk')
    plt.figure(dpi=300)

    #Plot histograms (8 bins per seconds)
    if plot_type == 'hist':
        g = sns.displot(kind=plot_type, data=df, col='group', 
                        x=var, stat='probability', 
                        hue=df['stim_condition'].map({0:'None', 1:'Sample', 3:'Test'}), 
                        col_order=['CTRL', 'NPHR'], element='step', 
                        palette=colors, 
                        bins=np.arange(0,15,.1), alpha=0, height=5, linewidth=2, aspect=1.3)
    elif plot_type == 'kde':
        g = sns.displot(kind=plot_type, data=df, 
                        col='group', 
                        x=var,  
                        hue=df['stim_condition'].map({0:'None', 1:'Sample', 3:'Test'}), 
                        col_order=['CTRL', 'NPHR'], 
                        palette=colors, alpha=.5,
                        height=5, linewidth=2, aspect=1.3)

    axes = g.axes.flatten()
    axes[0].set(title="NpHR-")
    axes[1].set(title="NpHR+")
    g._legend.set(bbox_to_anchor=(.85,.7), title='Condition')
    g.tight_layout()
    sns.despine()    
    return g

In [None]:
def plot_latency_distributions_within_condition(df, plot_type, var):
    
    sns.set(style='white', context='poster')
    plt.figure(dpi=300)

    # Plot histogram ( 8 bins per second)
    if plot_type=='hist':
        g = sns.displot(
            kind=plot_type, data=df, 
            col='stim_condition', col_order=[0, 1, 3],
            x=var, stat='probability', 
            hue=df['group'].map({'CTRL':'NPHR- ', 'NPHR':'NpHR+'}),  
            palette=['orange','lightseagreen'], 
            alpha=.1, linewidth=2,
            element='step', bins=np.arange(0,15,.125), 
            height=6, aspect=1.2
        )
    
    elif plot_type=='kde':
        g = sns.displot(
            kind=plot_type, data=df, 
            col='stim_condition', col_order=[0, 1, 3],
            x=var, 
            hue=df['group'].map({'CTRL':'NPHR- ', 'NPHR':'NpHR+'}),  
            palette=['orange','lightseagreen'], 
            alpha=.1, linewidth=2,
            height=6, aspect=1.2
        )
        
    # subplot specifications
    axes = g.axes.flatten()
    axes[0].set(title="None")
    axes[1].set(title="Sample")
    axes[2].set(title="Test")

    #Legend 
    g._legend.set(bbox_to_anchor=(.95,.5), title='Group')
    g.tight_layout()
    sns.despine()
    
    return g

In [None]:
def plot_cumulative_distributions_within_group(df, var):
    sns.set(style='white', context='poster')
    plt.figure(dpi=300)

    g = sns.displot(
        kind='ecdf', data=df, col='group', 
        x=var, stat='proportion',
        hue=df['stim_condition'].map({0:'None', 1:'Sample', 3:'Test'}),
        col_order=['CTRL','NPHR'], hue_order=['Test','Sample','None'],
        palette=['orangered','royalblue','black'], 
        linewidth=2, alpha=.8, height=6, aspect=1.1
    )
    #Legend 
    g._legend.set(bbox_to_anchor=(.98,.7), title='Condition')
    g.tight_layout()
    sns.despine
    
    return g    

In [None]:

def plot_cumulative_distributions_within_condition(df, var):
        
    sns.set(style='white', context='poster')
    plt.figure(dpi=300)

    g = sns.displot(
        kind='ecdf', data=df, 
        col='stim_condition', 
        col_order=[0,1,3],
        x=var, stat='proportion',
        hue=df['group'].map({'CTRL':'NpHR- ', 'NPHR':'NpHR+'}),
        palette=['orange','lightseagreen'], 
        linewidth=1.5, legend=False
    )

    # subplot specs
    axes = g.axes.flatten()
    axes[0].set(title="None")
    axes[0].axhline(0.5, linestyle='dashed', c='black', linewidth=1.5)
    axes[1].set(title="Sample")
    axes[1].axhline(0.5, linestyle='dashed', c='black', linewidth=1.5)
    axes[2].set(title="Test")
    axes[2].axhline(0.5, linestyle='dashed', c='black', linewidth=1.5)
    g.tight_layout()
    sns.despine()
    
    return g

In [1]:
def plot_distribution_of_session_medians(df):
    '''
    '''
    sns.set(style='white', context='talk')
    plt.figure(dpi=300)
    g=sns.pointplot(
                 data=df, 
                 x='group',
                 y='50%', 
                 hue=latency_medians['stim_condition'].map({0:'None', 1:'Sample', 3:'Test'}), 
                 dodge=.5, jitter=True, palette=['black', 'royalblue', 'orangered'], 
                 hue_order=['None','Sample','Test'], estimator=np.median, ci=None,
                 linewidth=1, join=False, legend=False)

    sns.stripplot(data=df, 
                 x='group',
                 y='50%', hue='stim_condition', hue_order=[0,1,3],
                 dodge=True, palette=['black', 'royalblue', 'orangered'],
                 s=4, alpha=.4)
    g.set(ylim=[1,5], ylabel='Median Time (s)', xlabel='Group', xticks=[0,1], xticklabels=['NpHR-', 'NpHR+'])
    plt.legend(labels=['None', 'Sample','Test'],frameon=False, title='Condition', bbox_to_anchor=[.97,.9])
    sns.despine()
    
    return g