# <font color='steelblue'> Functions for optogenetics position and performance analysis - data preparation

## Functions used in latency calculation

In [2]:
def collect_pruned_position_and_cp_rois_all_rats_in_path(path):

    '''
    Collect the pruned timestamped position and CP rois from all rats in the given path.
    Returns two dataframes: (1) Timestamped position data of all rats; (2) CP ROI measures of all rats
    '''

    data_list = []
    rois_list = []
    
    for a,b,c in os.walk(path):   
        data_list = get_data_given_string_to_match(a,b,c, 'pruned', data_list)
        rois_list = get_data_given_string_to_match(a,b,c, 'cp_rois_converted', rois_list)
        
    data = pd.concat(data_list)
    rois = pd.concat(rois_list)
    
    return data, rois

def get_data_given_string_to_match(a, b, c, string_to_match, data_list):
    
    filename = [i for i in c if string_to_match in i] 

    if filename: 

        path = os.path.join(a, filename[0])
        rat_data = pd.read_csv(path, header=0, index_col=0)   
        rat_code = re.search(r"_(\w+\d+)",a).group(1)        
        group_code = re.search(r"([A-Z]+)",rat_code).group(1)  
                       
        rat_data['rat'] = rat_code
        rat_data['group'] = group_code  
        data_list.append(rat_data)
    
    return data_list

def drop_after_cp_entry(group, roi_session):
        
    xlim =  roi_session['x'].iloc[0]
    ylim1 = roi_session['ylim1'].iloc[0]
    ylim2 = roi_session['ylim2'].iloc[0]    
    mask = (group['x']< xlim) & group['y'].between(ylim1, ylim2)
    group = group[mask]
       
    return group

def drop_outside_cp(group, roi_session):
    
    xlim1 =  roi_session['xlim1'].iloc[0]
    xlim2 = roi_session['xlim2'].iloc[0]
    ylim1 = roi_session['ylim1'].iloc[0]
    ylim2 = roi_session['ylim2'].iloc[0] 
    
    mask = (group['x'].between(xlim1, xlim2) & group['y'].between(ylim1, ylim2))
    group = group[mask] 
            
    return group
    
def discard_data_given_roi(group, roi, to_keep):
    
    session = group['session'].iloc[0]
    rat = group['rat'].iloc[0]    
    mask = (roi['session'].str.contains(session[0:15], regex=False)) & (roi['rat'] == rat)    
    roi_session = roi[mask]

    if to_keep=='before cp entry':
        group = group.groupby(['run_nr']).apply(drop_after_cp_entry, roi_session)
    elif to_keep=='inside cp':
        group = group.groupby(['run_nr']).apply(drop_outside_cp, roi_session)

    return group

def calculate_quantiles(df, cols_to_group, col_to_calc):
    
    quantiles = df.groupby(cols_to_group)[col_to_calc].describe(
    percentiles = [.1, .2, .3, .4, .5, .6, .7, .8, .9])    
    # Keep only the quantiles
    quantiles = quantiles.drop(['count', 'mean', 'std', 'min', 'max'], axis=1)
    # Rearrange quantiles
    level_name = 'level_'+str(len(cols_to_group))
    quantiles = quantiles.stack().reset_index().rename(columns={level_name:'quantile', 0:'time'})
    
    return quantiles


## Get crossing timestamps

In [13]:
def get_roi_crossing_timestamps_for_runs_in_df(df, cp_rois):
    
    roi_before_cp = get_roi_before_cp(cp_rois)
    cp_square = get_cp_square_limits_from_rois(cp_rois)    
    
    # Keep only data until entering CP 
    runs_before_cp_entry = df.groupby(['session', 'rat']).apply(
        discard_data_given_roi, 
        roi_before_cp, 
        'before cp entry'
    )
    runs_inside_cp = df.groupby(['session', 'rat']).apply(
        discard_data_given_roi, 
        cp_square, 
        'inside cp'
    ) 
  
    # Rearrange the dataframes
    runs_before_cp_entry = runs_before_cp_entry.reset_index(drop=True)    
    runs_inside_cp = runs_inside_cp.reset_index(drop=True)  
    
    # Collect crossing points in maze for each run
    start = runs_before_cp_entry.groupby(['session', 'rat', 'run_nr']).nth(0).reset_index()
    cp_entry = runs_before_cp_entry.groupby(['session', 'rat', 'run_nr']).last().reset_index()
    cp_exit = runs_inside_cp.groupby(['session','rat','run_nr']).last().reset_index()
    
    #Create a new df with the timestamps of each crossing
    test_runs = start.rename(columns={'timestamp':'start_timestamp'}).drop(['x', 'y', 'x_diff'], axis=1)
    test_runs['cp_entry_timestamp'] = cp_entry['timestamp']
    cp_exit = cp_exit.rename(columns={'timestamp':'cp_exit_timestamp'})
    cp_exit = cp_exit.drop(['x', 'y', 'x_diff', 'stim_condition', 'run_type','rat', 'outcome', 'group'], axis=1)

    test_runs = test_runs.merge(cp_exit, how='inner', on=['session', 'run_nr'])
    
    #test_runs = pd.concat([test_runs, cp_entry, cp_exit]).reset_index()
    #test_runs['cp_entry_timestamp'] = cp_entry['timestamp']
    #test_runs['cp_exit_timestamp'] = cp_exit['timestamp']
       
    #Plot data until CP (green) over all data (in orange)
    #plt.Figure(figsize=(8,4))
    #sns.set(style='white', context='talk')
    
    #import random
    #sample_session=random.choice(df['session'].unique())

    #sample_raw = df[df['session']==sample_session]
    #sample_to_cp = runs_before_cp_entry[runs_before_cp_entry['session']==sample_session]
    #sample_inside_cp = runs_inside_cp[runs_inside_cp['session']==sample_session]
    #sample_cp_entry = cp_entry[cp_entry['session']==sample_session]
    #sample_cp_exit = cp_exit[cp_exit['session']==sample_session]

    #sns.scatterplot(data=sample_raw, x='x', y='y', color='gray')
    #sns.scatterplot(data=sample_to_cp, x='x', y='y', color='green')
    #sns.scatterplot(data=sample_inside_cp, x='x', y='y', color='orange')
     
    #sns.scatterplot(data=sample_cp_entry, x='x', y='y', color='red')
    #sns.scatterplot(data=sample_cp_exit, x='x', y='y', color='blue' )
    #sns.despine()      
    
    return test_runs

## Functions used in performance calculations

In [16]:
'''def remove_single_samples_and_tests(df):
    
    
    Remove trials with only a single sample or single test. They cannot be used to computed performances
    
    
    removed_samples = 0
    removed_tests = 0
    df = df.reset_index(drop=True)
    sample_indices = df.index[df['run_type']=='S'].tolist()
    test_indices = df.index[df['run_type'] == 'T'].tolist()
    
    for i in sample_indices:
        prev_i = i-1
        if prev_i >=0 and df.loc[prev_i, 'run_type'] =='S':
            df.loc[prev_i, 'run_type'] = np.NaN
            removed_samples +=1
        else:
            pass
    
    for t in test_indices:
        next_t = t+1
        if next_t <len(df) and df.loc[next_t, 'run_type'] == 'T':
            df.loc[next_t, 'run_type'] = np.NaN
            removed_tests +=1
            
    print('removed unpaired samples:'+str(removed_samples), 'removed unpaired tests: '+str(removed_tests))
    return df    '''  

def calc_performance_in_group(group):
    
    '''
    Calculate the performance for a given group (GroupBy object)
    '''

    group = group.dropna()
    
    n_correct_trials = len(group[group['outcome'] == 1])
    n_total_trials = len(group)    
    group_performance = (n_correct_trials / n_total_trials)*100
    
    return group_performance

def add_condition_trial_nr(group):
    
    '''
    Add a condition trial number column to each group
    '''
    
    group['cond_trial_nr'] = range(1, len(group)+1)
    
    return group

def calculate_performance_given_ntrials(group, group_label, N):
       
    '''
    
    Calculate the performance across the stim protocol. Group the trials by rat and stimulation condition. 
    Then, for each group, groupby again by the number of rows (df.index / N). For each, count the total 
    number of trials and sum the correct trials (df['outcome'] ==1). Store each type of information into 
    two dataframes ('sums' and 'total') and divide, thus calculating the partitioned performance. 
    These will be labeled with the corresponding rat name, experimental group and stimulation condition. 
    All these dataframes will be concatenated into one.
    
    '''
    
    df = group.get_group(group_label)
    df.index = range(len(df))
    
    # Divide the number of rows in each group (rat x stim) by N. 
    # Then either sum the correct trials (outcome==1) or count the total number of trials
    sums = df['outcome'].groupby(df.index//N).sum()
    total = df['outcome'].groupby(df.index//N).count()
    
    # Calculate partitioned performance by dividing the "sums" and "total" dataframes
    # Attirbute the rat  group and condition label to each performance in the new par_perf dataframe
    
    par_perf = ((sums/total)*100).to_frame()
    par_perf['group'] = group_label[0]
    par_perf['rat'] = group_label[1] 
    par_perf['stim']= group_label[2]
  
    return par_perf

## Plot quantile curves 

In [3]:
def plot_quantile_curves_within_condition(df):
    
    '''Plot the quantile curves for each stimulation condition, comparing each group'''

    sns.set(style='white', context='poster')
    g = sns.relplot(
        data=df, kind='line',
        x='quantile', y='time', hue='group',
        markers=['o', 'o'], style='group', 
        palette=['navy', 'green'],
        col='stim_condition', height=5, aspect=1.1, 
        legend=False   
    )
    g.set_xticklabels(labels=quantiles['quantile'].unique(), rotation=45)
    
    axes = g.axes.flatten()
    axes[0].set_title("No Illumination")
    axes[1].set_title("Illumination")   
    plt.legend(['NpHR-', 'NpHR+'], frameon=False, bbox_to_anchor=(1.5,1.2))
    sns.despine()
    
def plot_quantile_curves_within_group(df):

    '''Plot the quantile curves for each stimulation condition, comparing each group'''
    
    sns.set(style='white', context='poster')
    g = sns.relplot(
        data=df, kind='line', x='quantile', y='time',hue='stim_condition',
        col='group', col_wrap=2,height=5, aspect=1.1,
        markers=['o', 'o'], style='stim_condition', palette=['black', 'orangered'], 
        legend=False 
    )
    g.set_xticklabels(labels=quantiles['quantile'].unique(), rotation=45)

    axes = g.axes.flatten()
    axes[0].set_title("NpHR- group")
    axes[1].set_title("NpHR+ group")
    plt.legend(['No illumination', 'Illumination'], 
               frameon=False, bbox_to_anchor=(1.1,1))
    sns.despine()

## Time distributions plots

In [16]:
def plot_latency_distributions (test_runs, xlim):
    
    '''
    Similar to the subplots but all combined into one plot
    '''
    fig, axes = plt.subplots(2,1, figsize=(5,8), sharex=True, dpi=300)
    sns.set(style='white', context='talk')
    plt.suptitle('Latency to choice point entry distributions')

    #Stripplots
    g = sns.stripplot(ax=axes[0], 
                      data=test_runs, x='latency_to_cp_entry', y='group', hue='stim_condition',
                      hue_order=['0','3'], palette=['black', 'orangered'], 
                      s=1.5, alpha=.5, jitter=.25, dodge=True)
    g.set(yticklabels=['NpHR+', 'NpHR-'], ylabel='Group', xlabel='')
    g.legend_.remove()
     
    # Kernel density plots

    b = sns.kdeplot(ax=axes[1], 
                data=test_runs[test_runs['group'] == 'CTRL'],
                x='latency_to_cp_entry', 
                hue='stim_condition', hue_order=['0', '3'],
                palette=['black', 'orangered'], linestyle='dotted')
    
    sns.kdeplot(ax=axes[1], 
                data=test_runs[test_runs['group'] == 'NPHR'], 
                x='latency_to_cp_entry', 
                hue='stim_condition', hue_order=['0', '3'],
                palette=['black', 'orangered'], linestyle='solid')
    
    b.set(xlim=xlim, ylim=[0, .1], xlabel='seconds')
    sns.despine()


In [11]:
    
def subplot_latency_distributions_within_conditions(test_runs, xlim, ylim):
    
    '''
    Creates a 4x4 plot grid comparing the latency distributions within illumination condition
    (no illumination vs. illumination).
    On the top row, it plots the distribution of the latencies between 0 and 60 seconds Below, it plots the kernel density 
    function (Gaussian) for the latencies on top. Each column contains the data of each illumination condition'''
    
    dfs = []
    
    for condition in test_runs['stim_condition'].unique():
        condition_df =  test_runs[test_runs['stim_condition']==condition]
        dfs.append(condition_df)

    fig, axes = plt.subplots(2,2, figsize=(10,8), sharex=True, dpi=300)
    sns.set(style='white', context='talk')
    plt.suptitle('Latency to choice point entry distributions')

    colors = ['green', 'navy']
    groups = ['NPHR', 'CTRL']
    titles = ['No illumination', 'Illumination']

    for j, df, title in zip(range(2), dfs, titles):
        
        #Stripplots
        sns.stripplot(ax=axes[0, j], data=df, x='latency_to_cp_entry', y='group', 
                      palette=colors, s=1.4,
                     jitter=.25) 
        #Kde plots
        for color, group in zip(colors, groups):

             sns.kdeplot(ax=axes[1,j], data=df[df['group'] == group], x='latency_to_cp_entry', c=color)

        axes[1,j].set(xlabel='Latency (s)', xlim=[0,30], ylim=ylim)
        axes[0,j].set(title = title, xlim=xlim)

        axes[0,1].set(ylabel='', yticklabels='', xlabel='', xlim=xlim)
        axes[0,0].set(ylabel='Group', yticklabels=['NpHR+', 'NpHR-'], xlabel=' ', xlim=xlim)

        sns.despine()
      
    
def subplot_latency_distributions_within_groups(test_runs, conditions, xlim, ylim):
    
    # Conditions is a str list of the conditions to plot (no vs. test or no vs. sample)
    ctrl = test_runs[test_runs['group']=='CTRL']
    nphr = test_runs[test_runs['group']=='NPHR']
    dfs = [ctrl, nphr]

    fig, axes = plt.subplots(2,2, figsize=(10,8), sharex=True, dpi=300)
    sns.set(style='white', context='talk')
    plt.suptitle('Latency to choice point entry distributions')

    colors = ['black','orangered']
    titles = ['NpHR-', 'NpHR+']

    for j, df, title in zip(range(2), dfs, titles):
        
        #Stripplot
        sns.stripplot(ax=axes[0, j], data=df, x='latency_to_cp_entry', y='stim_condition', 
                      palette=colors, s=1.4, jitter=.25) 
        #Kde plots
        for color, condition in zip(colors, conditions):
            sns.kdeplot(ax=axes[1,j],
                        data=df[df['stim_condition'] == condition], 
                        x='latency_to_cp_entry', c=color)   

        axes[0,j].set(title=title, xlabel='', ylabel='', yticklabels='', xlim=xlim)
        axes[0,0].set(ylabel='Condition', yticklabels=['No illumination', 'Illumination'], xlim=xlim)
        axes[1,j].set(xlabel='Latency (s)', xlim=xlim, ylim=ylim)
      
    sns.despine()
    

## Get maze limits

In [20]:
def get_cp_square_limits_from_rois(df):
    '''
    Create the CP square limits using the real video CP limits. It will add 10 cm to compensate for light
    detection ouside the real limits, since the position tracking light is attached to the patch and not the rat.
    10 cm were chosen upon validation of the position data and limits
    '''
    df2 = df.copy()
    df2.rename(columns={'y':'ylim1', 'x':'xlim1'}, inplace=True)
    
    df2['ylim2'] = df2['ylim1']+df2['height']+10
    df2['ylim1'] -=10

    df2['xlim2'] = df2['xlim1']+df2['width']+10
    df2['xlim1'] -= 10
    
    return df2

In [11]:
def get_roi_before_cp(df):
     
    '''
    Get the limits of the maze before reaching the CP.
    '''  
    df2 = df.copy()
    df2['ylim1'] = df2['y']-10
    df2['ylim2'] = df2['y']+df2['height']+10
    df2['x']-=10
    
    return df2