# Purpose of this notebook:
The purpose of this notebook is to plot aggregated data of Fed and Unfed STNS during saline, unfed hemolymph, and fed hemolymph incubations following dpon stimulation

Notes: this is simply an initial demo to test that the interactive features and hosting still will work with upgraded env.


# Load Required Packages

In [1]:
%matplotlib inline
import seaborn as sns; sns.set()
import pandas as pd
import numpy as np
import matplotlib as mpl
from matplotlib.lines import Line2D
from matplotlib import pyplot as plt
from collections import OrderedDict
from IPython.display import display
from glob import glob
# Widget stuff
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, fixed, interactive

# Logan Stuff
#from stns import (max_burst_number, condition_plot, paired_plots)

In [2]:
%matplotlib inline
# Loading Data
data = pd.read_csv('STNS_Feeding_State_Hemo_all_bursts.csv')
data_apples = pd.read_csv('STNS_Feeding_State_Hemo_apples2apples_bursts.csv')
data = data.drop(columns='Unnamed: 0')#, 'Unnamed: 0.1'])
data_apples = data_apples.drop(columns='Unnamed: 0')#, 'Unnamed: 0.1'])
dataframe = data

print('Summary Statistics:Median')
display(dataframe.groupby(['Date', 'Neuron', 'Condition'], as_index=True).median())

# Create dictionarys for plotting usuage. 
dates=list(dataframe['Date'].unique())
Conditions = list(dataframe['Condition'].unique())
markers = Line2D.filled_markers[:5]
cmap=sns.color_palette('colorblind',5)
marker_d = dict(zip(sorted(dataframe['Date'].unique()),markers))
color_d = dict(zip(sorted(dataframe['Date'].unique()), cmap))

cond_options = dataframe.Condition.unique()
y_val_options = ['Burst Duration (sec)', '# of Spikes',
               'Spike Frequency (Hz)', 'Instantaneous Period (sec)',
               'Instantaneous Frequency (Hz)','Duty Cycle']
neuron_options = ['LG', 'DG', 'GM', 'MG', 'MCN1R', 'MCN1L']
kind = ['violin', 'line', 'scatter', 'paired', 'isi', 'strip']

Summary Statistics:Median


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Burst Order,Burst Duration (s),# of Spikes,Spike Frequency,Cycle Period,Duty Cycle,Start of Burst (s),Burst#,Burst Duration (sec),Spike Frequency (Hz),Instantaneous Period (sec),Instantaneous Frequency (Hz),Interburst Duration (s)
Date,Neuron,Condition,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
1-13-22,DG,Fed STNS\n Fed Hemo,4.0,3.83643,12.0,3.344318,8.309925,0.421805,746.35656,4.0,3.83643,3.344318,8.309925,0.120394,4.954635
1-13-22,DG,Fed STNS\n Saline 0,22.0,4.58622,53.0,12.001327,12.858885,0.363879,863.98881,22.0,4.58622,12.001327,12.858885,0.077794,8.543205
1-13-22,DG,Fed STNS\n Saline 1,20.5,6.57288,51.5,8.247112,16.78266,0.398594,1057.40358,20.5,6.57288,8.247112,16.78266,0.059585,10.07397
1-13-22,LG,Fed STNS\n Fed Hemo,2.5,3.37014,8.5,2.318663,7.76439,0.385423,736.184145,2.5,3.37014,2.318663,7.76439,0.128793,4.81734
1-13-22,LG,Fed STNS\n Saline 0,20.5,6.859485,64.5,9.665517,14.15934,0.503864,847.755165,20.5,6.859485,9.665517,14.15934,0.070625,6.68655
1-13-22,LG,Fed STNS\n Saline 1,18.5,6.864075,55.5,8.133319,17.88912,0.392082,1032.023955,18.5,6.864075,8.133319,17.88912,0.0559,10.85967
1-13-22,MCN1R,Fed STNS\n Fed Hemo,26.5,0.02808,1.0,35.615828,1.03437,0.027397,755.190645,26.5,0.02808,35.615828,1.03437,0.966772,1.01169
1-13-22,MCN1R,Fed STNS\n Saline 0,97.0,0.73278,9.0,13.141468,1.1988,0.650103,1048.31115,97.0,0.73278,13.141468,1.1988,0.834211,0.43281
1-13-22,MCN1R,Fed STNS\n Saline 1,210.5,0.564975,6.0,11.902636,1.12941,0.515258,1228.453665,210.5,0.564975,11.902636,1.12941,0.885418,0.53892
11-18-21,DG,Unfed STNS\n Fed Hemo,11.5,4.0581,19.0,3.80133,8.63811,0.463516,8638.806075,11.5,4.0581,3.80133,8.63811,0.115766,5.06439


# Line Plot 

In [3]:
@widgets.interact(
    y_val=y_val_options, 
    burst_num=(1, 20),
    neuron = neuron_options, 
    Condition = widgets.SelectMultiple(
        options=cond_options,
        value=tuple(cond_options),
        rows = len(cond_options),
        description='Condition',
        disabled=False),
    dates=widgets.SelectMultiple(
        options=dates,
        value=dates,
        rows=len(dates),
        description='dates',
        disabled=False
        ))


def plot(y_val='Burst Duration (sec)', burst_num=20, plot_lines=True, 
         title='Plot title', neuron='LG',dates=dates, Condition='Pre',
        ):
    global dataframe
    global marker_d
    global color_d
    _dataframe = dataframe[(dataframe['Burst#'] < burst_num)
                           &(dataframe['Neuron']==neuron)
                           #&(dataframe['Condition']==condition)
                           &(dataframe.Condition.isin(Condition))
                           &(dataframe.Date.isin(dates))]
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12,5))
    for i,date in enumerate(_dataframe['Date'].unique()):
        data = _dataframe[_dataframe['Date']==date]
        plot = ax.scatter(x=data['Burst#'], 
                          y=data[y_val], 
                          label='{} {}'.format(date, neuron),
                          alpha=.5,
                          marker=marker_d[date],
                          color=color_d[date]
                         )
        if plot_lines:
            if data[data['Burst#']==1].shape[0] == 1:
                plot_ = ax.plot(data['Burst#'], 
                                data[y_val], color=color_d[date], ls='--',zorder=3, label='')
            elif data[data['Burst#']==1].shape[0] > 1:
                # Fix the issue here
                # This hassle and weird code is to prevent line connections between seperate
                # Experiments
                #x_data = find_consecutive_data(data['Burst#'])
                x_data = np.split(data['Burst#'],
                                  np.where(np.diff(data['Burst#']) != 1)[0] + 1)
                y_data = np.split(data[y_val], 
                                  np.where(np.diff(data['Burst#']) != 1)[0] + 1)
                for x,y in zip(x_data, y_data):
                    plot_ = ax.plot(x, y, color=color_d[date], ls='--',zorder=3, label='')
                
                

    y_label = ax.set_ylabel(y_val,fontsize=16)
    x_label = ax.set_xlabel('Burst #',fontsize=16)
    ax.set_title(title,fontsize=20)
    plt.legend(bbox_to_anchor=(1,1))
    plt.tight_layout()
    

interactive(children=(Dropdown(description='y_val', options=('Burst Duration (sec)', '# of Spikes', 'Spike Fre…

# UTILITY functions to be implemented if necessary

In [4]:
def modified_z_score(array, axis=-1):
    """Applies a modified z on the inputted array

    :param array: np.array, array to calc mod zscore over
    :param axis: int, axis over which to apply
    :return: modified zscore
    """
    # Below is basically copying how it's done in scipy.stat.zscore
    # .6745 is constant for conversion 
    a = np.asanyarray(array)
    median = np.nanmedian(a, axis=axis)
    median_absolute_deviation = np.nanmedian(np.abs(a - median), axis=axis)

    if axis and median_absolute_deviation.ndim < a.ndim:
        return ((0.6745 * (a - np.expand_dims(median, axis=axis))) /
                np.expand_dims(median_absolute_deviation, axis=axis))
    else:
        return (0.6745 * (a - median)) / median_absolute_deviation


def outliers_modified_z_score(array, threshold=3.5, axis=-1):
    """Return an array to mask at all points that are not outliers

    :param array: np.array, array to calc mod zscore over
    :param threshold: float, value at which to threshold
    :param axis: int, axis over which to apply
    :return: masked_array: array to mask at all points that are not outliers
    """
    modified_z_scores = modified_z_score(array, axis=axis)
    return np.where(np.abs(modified_z_scores) > threshold)


def legend_without_duplicate_labels(ax, **kwargs):
    """Pass in an axis from matplotlib, will return a legend without duplicate entries in it

    :param ax: mpl.ax, axis with corresponding legend
    :param kwargs:
    :return:
    """
    handles, labels = ax.get_legend_handles_labels()
    unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
    ax.legend(*zip(*unique), **kwargs)
    

def reorder_df(df, list_of_order, by='Condition'):
    assert(len(list_of_order)==len(df.groupby(by)))
    return pd.concat([df[df[by]==cond]for cond in list_of_order])
        
order = ['Unfed STNS\n Saline 0','Fed STNS\n Saline 0',
         'Unfed STNS\n Fed Hemo', 'Fed STNS\n unfed hemo', 'Fed STNS\n Fed Hemo',
         'Unfed STNS\n Saline 1', 'Fed STNS\n Saline 1',]

sns.set_context('notebook')

In [5]:
@widgets.interact(
    y_val=y_val_options, 
    burst_num=(1, 50),
    #neuron = neuron_options,
    neuron = widgets.Select(
        options=neuron_options, 
        value='LG',#tuple(kind), 
        rows=1, 
        description='Neuron', 
        disabled=False),
    kind = widgets.Select(
        options=kind, 
        value='violin',#tuple(kind), 
        rows=1, 
        description='Plot Kind', 
        disabled=False),
    Condition = widgets.SelectMultiple(
        options=cond_options,
        value=tuple(cond_options),
        rows = len(cond_options),
        description='Condition',
        disabled=False),
    dates=widgets.SelectMultiple(
        options=dates,
        value=dates,
        rows=len(dates),
        description='dates',
        disabled=False),
    y_lim = widgets.FloatRangeSlider(value=(0, 100), min=0, max=100)
    #apples=widgets.Checkbox(),
)



def plot(y_val='Burst Duration (sec)', burst_num=20, kind='violin', 
         title='{} {}', neuron='LG',dates=dates, Condition='Saline 0',
         apples=False, palette='colorblind', y_lim=(0,50)
        ):
    global dataframe
    global data_apples
    global marker_d
    global color_d
    
    if apples:
        print('Running Apples to Apples')
        _dataframe = data_apples[((data_apples['Burst#'] < burst_num))]
                                 #&(data_apples.Condition.isin(Condition))
                                 #&(data_apples.Date.isin(dates))]
        df = data_apples[(data_apples['Burst#'] < burst_num)
                        &(data_apples['Condition'].isin(Condition))
                        &(data_apples['Date'].isin(dates))]
        
    if not apples:
        _dataframe = dataframe[(dataframe['Burst#'] < burst_num)
                               &(dataframe.Condition.isin(Condition))
                               &(dataframe.Date.isin(dates))]
    if type(neuron)==str:
        _dataframe = _dataframe[_dataframe['Neuron']==neuron]
    elif type(neuron)== list:
        _dataframe = _dataframe[_dataframe['Neuron'].isin(neuron)]
        
    if type(Condition)==str:
        Condition = _dataframe[_dataframe['Neuron']==Condition]
    elif type(neuron)== list:
        _dataframe = _dataframe[_dataframe['Condition'].isin(Condition)]
        
        
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10,10))
    
    pal = palette
    sns.set_palette(pal)
    
    if kind == 'violin':
        #if apples:
            
            
        g = sns.violinplot(x='Condition',
                           y=y_val,
                           data=_dataframe,
                           hue='Date',
                           scale_hue=False,
                           inner='box',
                           scale='width',
                           split=False,
                           ax=ax,
                           bw='scott',
                           cut=0,
                          )
        
    if kind == 'strip':
        g= sns.stripplot(x="Condition",
                         y=y_val,
                         hue='Date',
                         data=_dataframe,
                         #order='Date',
                         #hue_order=None,
                         dodge=True,
                         size=5,
                         edgecolor='gray',
                         linewidth=0,
                         ax=ax,)
        
    elif kind == 'line':
        
        g = sns.pointplot(x="Condition",
                          y=y_val,
                          hue="Date",
                          capsize=.2,
                          #palette=pal,
                          ci='sd',
                          kind="point",
                          data=_dataframe,
                          ax=ax,
                          )

    elif kind =='scatter':
        """for condition in _dataframe['Condition'].unique():
            data = _dataframe[_dataframe['Condition'] == condition]
            g = ax.scatter(x=data['Burst#'],
                           y=data[y_val],
                           label='{} {} Neuron'.format(condition, neuron),
                           alpha=.5,
         
                          )"""
        
        g= sns.scatterplot(x="Burst#",
                           y=y_val,
                           hue='Date',
                           data=data,
                           style='Condition',
                           size=2,
                           edgecolor='black',
                           linewidth=0, 
                           alpha=.5,
                           ax=ax)

    y_label = ax.set_ylabel(y_val, fontsize=16)
    x_label = ax.set_xlabel('Burst #', fontsize=16)
    legend_without_duplicate_labels(ax=ax)
    ax.legend(bbox_to_anchor=(1.1, 1.05))
    
    plt.title(title.format(y_val, neuron), fontsize=25)
    
    ax.set_ylim(y_lim)
    plt.tight_layout()

interactive(children=(Dropdown(description='y_val', options=('Burst Duration (sec)', '# of Spikes', 'Spike Fre…