## Graph creator for lab

In [None]:
import os.path
import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from scipy.stats import spearmanr, pearsonr, mannwhitneyu, boxcox

import seaborn as sns
np.random.seed(sum(map(ord, "aesthetics")))

import matplotlib.pyplot as plt
%matplotlib inline

from ipywidgets import interact, interactive, Output, HBox, VBox, Dropdown, SelectMultiple, Checkbox, Text
import ipywidgets as widgets

#===========================

def get_filename():
    while True:
        try:
            value = input('Enter filename: ')
            file = os.getcwd() + '/' + value
        except ValueError:
            print('Sorry, not a valid entry')
            continue
        if (os.path.exists(file))== False: #file not present
            print('Sorry, File not present in folder')
            continue
        else:
            break
    return value
    
#===========================

def plot_hist(data, plot2, fname, save):
    nrows, ncols = 1, len(plot2)
    dx, dy = 1, 1
    fs = plt.figaspect(float(dy * nrows) / float(dx * ncols))
    fig, axes = plt.subplots(1, ncols, figsize=fs)
    if (ncols > 1):
        for i in range(ncols):
            h = sns.distplot(data[plot2[i]],ax=axes[i]);
    else:
        h = sns.distplot(data[plot2[0]]);
    fig.tight_layout()
    if save:
        fig.savefig(fname + '+histogram.pdf')
    
#===========================
    
def plot_graph(data, x, kind, fname, hue_by, save):
    hue_len = len(data[hue_by].unique())
    #if False and (kind=='violin' and hue_len==2):
    if (kind=='violin' and hue_len==2):
        g = sns.factorplot(x=x, y='value', hue=hue_by, col='Measure', size=5, aspect=1,
                       kind=kind, data=data, sharey=False, split=True)
    else:
        g = sns.factorplot(x=x, y='value', hue=hue_by, col='Measure', size=5, aspect=1,
                       kind=kind, data=data, sharey=False)
    if save:
        g.savefig(fname + '.pdf')

#===========================

def create_graph(plot_type, group_by, group, hue_by, plot2, stat_fn=None, save=False):
    
    n = len(plot2)
    grp = len(group)
    
    data2 = data.dropna(subset=list(plot2))
    num_subj = data2.shape[0]
    num_subj_counts_per_grp = dict(data2[group_by].value_counts())
    print('Number of subjects without NaN: ',num_subj)
    
    # Make figures folder in current directory of the script to save figures
    current_dir = os.getcwd()
    fig_dir = os.path.join(current_dir,r'Juplotter_figures')
    if not os.path.exists(fig_dir): os.makedirs(fig_dir)
    
    #select data for the chosen Dx group
    id_var = data2.select_dtypes(include=['category','object']).columns
    if (grp > 1):
        data_chosen_group = data2.loc[data[group_by].isin(group)]
        print('Number of subjects after group selection: ', data_chosen_group.shape[0])
        num_subj_counts_per_grp = dict(data_chosen_group[group_by].value_counts())
        print('Number of subjects per group:\n',num_subj_counts_per_grp)
        hue_len = data_chosen_group[hue_by].unique()
        print('Hue by: ', hue_len)
        data_rearranged = pd.melt(data_chosen_group, id_vars=id_var, value_vars=plot2, var_name='Measure')
        filename = fig_dir + os.sep + plot_type + '+' + 'N' + str(num_subj)+ '+' + '+'.join(map(str,plot2)) + \
        '+hueBy-' + hue_by + '+' + '_'.join(map(str,group))
    else:
        data_chosen_group = data2
        num_subj_counts_per_grp = dict(data_chosen_group[group_by].value_counts())
        print('Number of subjects per group:\n',num_subj_counts_per_grp)
        data_rearranged = pd.melt(data2, id_vars=id_var, value_vars=plot2, var_name='Measure')
        hue_len = len(data_rearranged[hue_by].unique())
        filename = fig_dir + os.sep + plot_type + '+' + 'N' + str(num_subj)+ '+' + '+'.join(map(str,plot2)) + \
        '+hueBy-' + hue_by
    
    #------------------------------
        
    if plot_type == 'swarmplot':
        print('Grouping by: ', group_by, '| Plotting: ', plot2)
        plot_hist(data_chosen_group, plot2, filename, save)
        plot_graph(data_rearranged, group_by, 'swarm', filename, hue_by, save)
        
    #------------------------------
        
    elif plot_type == 'boxplot':
        print('Grouping by: ', group_by, '| Plotting: ', plot2)
        plot_hist(data_chosen_group, plot2, filename, save)
        plot_graph(data_rearranged, group_by, 'box', filename, hue_by, save)
        
    #------------------------------
        
    elif plot_type == 'violinplot': # combines boxplot with kernel density estimation
        print('Grouping by: ', group_by, '| Plotting: ', plot2)
        plot_hist(data_chosen_group, plot2, filename, save)
        plot_graph(data_rearranged, group_by, 'violin', filename, hue_by, save)
        
    #------------------------------
            
    elif plot_type == 'pairplot':
        print('Grouping by: ', group_by, '| Plotting: ', plot2)
        if (n == 1):
            print('\nPairplot makes sense with at least 2 variables')
        else:
            plot_hist(data_chosen_group, plot2, filename, save)
            if grp > 1:
                g = sns.pairplot(data_chosen_group, vars=plot2, hue=hue_by , diag_kind='kde', size=4, aspect=1); # pairplot
            else:
                g = sns.pairplot(data2, vars=plot2, hue=hue_by , diag_kind='kde', size=4, aspect=1); # pairplot
            if save:
                g.savefig(filename + '.pdf')
                
    #------------------------------
    
    elif plot_type == 'linearModel':
        if (n == 1):
            print('\nlinear model requires at least 2 continuous variables')
        else:
            print('Grouping by: ', group_by, '| Plotting: ', plot2[0], ',', plot2[1])
            plot_hist(data_chosen_group, plot2, filename, save)
            if (grp > 1):
                lm1 = sns.jointplot(x=plot2[0], y=plot2[1], data=data_chosen_group, kind='reg', annot_kws=dict(stat=stat_fn), size=6)    
                lm2 = sns.pairplot(data_chosen_group, x_vars=plot2[0], y_vars=plot2[1], hue=hue_by, kind='reg', size=6, aspect=1)
                lm3 = sns.lmplot(x=plot2[0], y=plot2[1],data=data_chosen_group,col=group_by, hue=hue_by, size=6)
            else:
                lm1 = sns.jointplot(x=plot2[0], y=plot2[1], data=data2, kind='reg', annot_kws=dict(stat=stat_fn), size=6)    
                lm2 = sns.pairplot(data2, x_vars=plot2[0], y_vars=plot2[1], hue=hue_by, kind='reg', size=6, aspect=1)
                lm3 = sns.lmplot(x=plot2[0], y=plot2[1],data=data2,col=group_by, hue=hue_by, size=6)

            if save:
                lm1.savefig(filename + '+stat-' + stat_fn +'+LM1.pdf')
                lm2.savefig(filename + '+LM2.pdf')
                lm3.savefig(filename + '+LM3.pdf')
    else:
        f"Plot type {plot_type} is not an option"


In [None]:
# main code starts from here

# input your file name and sheet name here
# Get input filename and sheet name from user
input_data_file = 'Example_input_data.xlsx'
sheet = 'Sheet1';
#input_data_file = get_filename()
#sheet = input('Enter excel sheet name: ')
data = pd.read_excel(input_data_file,sheet_name=sheet,header=0)
print('Input data shape: ', data.shape)
#data = data.dropna() # missing data rows will be dropped

if (is_numeric_dtype(data['Dx']) ):
    dx = data['Dx'].unique()
    names = ["" for x in range(len(dx))]
    print('Current Dx column values: ', set(dx))
    '''for i in dx: #check why the sequence changes
        s='Diagnosis name for {}: '.format(dx[i])
        print(s)
        names[i] = input(' ')'''
    names = input('Enter diagnosis names in ascending order of Dx without spaces\n Eg: For{0, 1, 2} type DX1,Dx2,Dx3 : ' )
    data["Dx_names"] = data["Dx"].astype("category")
    data["Dx_names"].cat.categories = names.split(',')
    data['Dx_names'] = data['Dx_names'].astype('object')

data['Dx'] = data['Dx'].astype('object')
print('Input data shape after dropping missing value rows and adding Dx_names column: ', data.shape)
#data.Dx_names
#data.dtypes

### Interactive Plot

In [None]:
# style to have long descriptions display properly
style = {'description_width': 'initial'}

# choose what to plot: x-axis
# show only categorical columns
gb = widgets.Dropdown(options=data.select_dtypes(include=['category','object']).columns,
                      value=data.columns[-1],
                      description='Group by:',
                      disabled=False)

# -------------------------------------------------------------------

# choose what to distinguish: Hue by
# show only categorical columns
hb = widgets.Dropdown(options=data.select_dtypes(include=['category','object']).columns,
                      value=data.columns[-1],
                      description='Hue by:',
                      disabled=False)

# -------------------------------------------------------------------

# choose plot type
pt = widgets.ToggleButtons(options=['boxplot', 'swarmplot', 'violinplot', 'pairplot', 'linearModel'],
                           value='boxplot',
                           description='Plot type:',
                           disabled=False)

# -------------------------------------------------------------------

# When the group_by variable is changes, update the groub_by values
def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        #print("changed to %s" % change['new'])
        gb_val.options = data[change['new']].unique()
gb.observe(on_change)

# choose group
gb_val = widgets.SelectMultiple(options=data[gb.value].unique(),
                                #value=tuple(data[gb.value].unique()),
                                description='Group:',
                                disabled=False)

# -------------------------------------------------------------------

# choose what to plot: y-axis columns
# show only float and int value columns
cols = data.select_dtypes(exclude=['category','object']).columns
p = widgets.SelectMultiple(options=cols,
                            value=[cols[-1], cols[-2]],
                            description='Plot:',
                            disabled=False)

# -------------------------------------------------------------------

# save figures only on button click
sf = widgets.ToggleButton(value=False, description="Save figure")

def on_change(change):
    sf.value = False

sf.observe(on_change)

# -------------------------------------------------------------------

stat_fn = widgets.RadioButtons(
                options=['pearsonr', 'spearmanr', 'mannwhitneyu', 'boxcox'],
                value='pearsonr',
                description='Stats function for linear model:',
                disabled=False,
                style=style)

# -------------------------------------------------------------------

interactive_plot = interactive(create_graph, plot_type=pt, group_by=gb, \
                               group=gb_val, hue_by=hb, plot2=p, save=sf, stat_fn=stat_fn)

# avoid flickering
output = interactive_plot.children[-1]
output.layout.height = '300dpi'
interactive_plot