In [None]:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import seaborn as sns
import umap
import sys
from scipy.signal import savgol_filter
import scipy.stats
from scipy.stats import tukey_hsd
from scipy.cluster import hierarchy
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
!ls

In [None]:


np.seterr(divide='ignore')

# Set figure export properties
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['font.sans-serif'] = "Arial"
matplotlib.rcParams['font.family'] = "sans-serif"

# Set which plots to generate
generate_level_plots = True
generate_size_hist_plots = True
generate_GABA_content_plots = True
generate_scatter_plots = True
generate_umap = True

# Save plots as PDFs. Alternatively leave them open for display
save_plots = False

# Filenames to grab data
fdir = '' #r'Z:\Data\AT_Image_Data\KDMSYN201228_F002_IF'
fname = 'KDMSYN201228_F002_IF_exc_intens_dil_0.csv'

cmap = plt.get_cmap('cool') # Colormap for scatter plots
num_size_bins = 7   # Number of bins to split into for size dependence plots
num_clusts = 4 # Number of clusters to cut (selected from Silhouette analysis)
sum_or_avg = 'avg' # Use sum or average intensity values for synapses

# Set exclustion criteria for outliers based on size or intensities
excl_syn_size = 250 # Exclude synapses that are absurdly large assuming excessive merging
q1_trim = 0.05
q3_trim = 0.95
quartile_mult = 1.5

# Set parameters for scatter plots
num_hist_bins = 100 # Number of histogram bins for scatter plots
smooth_len = 10     # Smoothing of lines for scatter plots
alpha_val = 0.5 # Transparency value for Scatter Plots
size_scale = 5 # Size scaling of points for Scatter Plots
gaba_cutoff = 0.98  # Cutoff to identify synapses onto GABAergic targets

# Key for resetting cluster numbers (colors) to match conjugate dataset
clust_conj = [0,1,2,3]
clust_if   = [3,2,1,0]


use_stains = ['num_pixels', 
             'PSD95_' + sum_or_avg,
             'Synapsin12_' + sum_or_avg,
             'GluA1_' + sum_or_avg, 
             'GluA2_' + sum_or_avg, 
             'GluA3_' + sum_or_avg, 
             'GluA4_' + sum_or_avg, 
             'GluN1_' + sum_or_avg, 
             'GluN2B_' + sum_or_avg, 
             'GABA_dil2_' + sum_or_avg, 
             'Gephyrin_' + sum_or_avg]

# List of channels to plot against one another for scatter plots
plt_list = [['GluA2_' + sum_or_avg, 'GluA1_' + sum_or_avg],
            ['GluA2_' + sum_or_avg, 'GluA3_' + sum_or_avg],
            ['GluA2_' + sum_or_avg, 'GluA4_' + sum_or_avg],
            ['GluA1_' + sum_or_avg, 'GluA3_' + sum_or_avg],
            ['GluA2_' + sum_or_avg, 'GluN1_' + sum_or_avg],
            ['GluN1_' + sum_or_avg, 'GluN2B_' + sum_or_avg],
            ['GluA2_' + sum_or_avg, 'GluN2B_' + sum_or_avg]]

# Plot each of these channel intensities against: (1) synapse size distribution; (2) top 10% GABA
size_hist_plots = ['GluA1_' + sum_or_avg, 'GluA2_'+ sum_or_avg, 'GluA3_'+ sum_or_avg, 'GluA4_'+ sum_or_avg, 
                   'num_pixels', 'GluN1_'+ sum_or_avg, 'GluN2B_'+ sum_or_avg, 
                   'PSD95_'+ sum_or_avg, 'GABA_dil2_'+ sum_or_avg, 'anratio_' + sum_or_avg]

# List of channels to include in UMAP plot
umap_list = ['num_pixels',
             'PSD95_' + sum_or_avg, 
             'GluA1_' + sum_or_avg, 
             'GluA2_' + sum_or_avg, 
             'GluA3_' + sum_or_avg, 
             'GluA4_' + sum_or_avg, 
             'GluN1_' + sum_or_avg, 
             'GluN2B_' + sum_or_avg]

# Channels to group together for AMPA and NMDA combined
ampa_comb_list = ['GluA1','GluA2','GluA3','GluA4']
nmda_comb_list = ['GluN1','GluN2B']
scatt_size_label = 'num_pixels'
num_pix_label = 'num_pixels'
gaba_cut_label = 'GABA_dil2_'+ sum_or_avg

# Read in CSV values
csv_vals = pd.read_csv(fdir + fname)
csv_vals = csv_vals[csv_vals[num_pix_label] <= excl_syn_size]

# Grab the appropriate list of shuffled data channels
shuffle_list = []
for stain_ct, stain_val in enumerate(use_stains):
    if '_avg' in stain_val:
        shuffle_list.append(stain_val.replace('avg','shuffleavg'))
    if '_sum' in stain_val:
        shuffle_list.append(stain_val.replace('sum','shufflesum'))
use_stains = use_stains + shuffle_list    

# Function for trimming outliers
def trim_quartile(x_vals, q1_trim, q3_trim, quartile_mult):
    x_vals = pd.DataFrame(x_vals).copy()
    q1 = x_vals.quantile(q1_trim)
    q3 = x_vals.quantile(q3_trim)
    iqr = q3 - q1
    lb = q1 - (quartile_mult * iqr)
    ub = q3 + (quartile_mult * iqr)
    bool_vals = np.array((x_vals >= lb) & (x_vals <= ub))[:,0].copy()
    x_vals = x_vals[bool_vals].copy()
    return x_vals, bool_vals

###############################################################################
# Combine values for AMPA and NMDA receptors to calculate AMPA/NMDA ratio     #
###############################################################################

# Initialize arrays for combined AMPA and NMDA ratios
csv_vals['ampa_comb_avg'] = np.zeros(len(csv_vals))
csv_vals['ampa_comb_sum'] = np.zeros(len(csv_vals))
csv_vals['ampa_comb_shuffleavg'] = np.zeros(len(csv_vals))
csv_vals['ampa_comb_shufflesum'] = np.zeros(len(csv_vals))
csv_vals['nmda_comb_avg'] = np.zeros(len(csv_vals))
csv_vals['nmda_comb_sum'] = np.zeros(len(csv_vals))
csv_vals['nmda_comb_shuffleavg'] = np.zeros(len(csv_vals))
csv_vals['nmda_comb_shufflesum'] = np.zeros(len(csv_vals))
# Calculate combined AMPA and NMDA ratios
for this_ct,this_chan in enumerate(ampa_comb_list):
    csv_vals['ampa_comb_avg'] = csv_vals['ampa_comb_avg'] + (csv_vals[this_chan + '_avg'] / len(ampa_comb_list))
    csv_vals['ampa_comb_sum'] = csv_vals['ampa_comb_sum']  + (csv_vals[this_chan + '_sum'] / len(ampa_comb_list))
    csv_vals['ampa_comb_shuffleavg'] = csv_vals['ampa_comb_shuffleavg'] + (csv_vals[this_chan + '_shuffleavg'] / len(ampa_comb_list))
    csv_vals['ampa_comb_shufflesum'] = csv_vals['ampa_comb_shufflesum'] + (csv_vals[this_chan + '_shufflesum'] / len(ampa_comb_list))
for this_ct,this_chan in enumerate(nmda_comb_list):
    csv_vals['nmda_comb_avg'] = csv_vals['nmda_comb_avg'] + (csv_vals[this_chan + '_avg'] / len(nmda_comb_list))
    csv_vals['nmda_comb_sum'] = csv_vals['nmda_comb_sum']  + (csv_vals[this_chan + '_sum'] / len(nmda_comb_list))
    csv_vals['nmda_comb_shuffleavg'] = csv_vals['nmda_comb_shuffleavg'] + (csv_vals[this_chan + '_shuffleavg'] / len(nmda_comb_list))
    csv_vals['nmda_comb_shufflesum'] = csv_vals['nmda_comb_shufflesum'] + (csv_vals[this_chan + '_shufflesum'] / len(nmda_comb_list))
# Story combined AMPA and NMDA ratios in the larger dataset
csv_vals['anratio_avg'] = csv_vals['ampa_comb_avg'] / csv_vals['nmda_comb_avg']
csv_vals['anratio_sum'] = csv_vals['ampa_comb_sum'] / csv_vals['nmda_comb_sum']
csv_vals['anratio_shuffleavg'] = csv_vals['ampa_comb_shuffleavg'] / csv_vals['nmda_comb_shuffleavg']
csv_vals['anratio_shufflesum'] = csv_vals['ampa_comb_shufflesum'] / csv_vals['nmda_comb_shufflesum']
csv_vals[np.isnan(csv_vals)] = 0

# Drop all columns not in use
csv_vals_trunc = csv_vals.copy()[use_stains]

###############################################################################
#      Plot the average values for each channel in box (or violin) plot       #
###############################################################################

# Create a DataFrame for plotting average values
syn_avg_df = pd.DataFrame()
stain_list = [[] for _ in range(len(csv_vals_trunc))]
for stain_ct, stain_val in enumerate(csv_vals_trunc.columns):
    if ('_avg' in stain_val) or ('_sum' in stain_val):
        stain_val_shuffle = stain_val.replace('_avg','_shuffleavg').replace('_sum','_shufflesum')
        stain_val_norm = stain_val.replace('_avg','_normavg').replace('_sum','_normsum')
        for syn_ct, syn_val in enumerate(stain_list):
            stain_list[syn_ct] = stain_val
        syn_avg_df = pd.concat([syn_avg_df,pd.DataFrame({'Stain':stain_list,
                                                         'Fluor':csv_vals_trunc[stain_val],
                                                         'Shuff':csv_vals_trunc[stain_val_shuffle],
                                                         'Norm':csv_vals_trunc[stain_val] - np.mean(csv_vals_trunc[stain_val_shuffle])})],
                               ignore_index=True)

if generate_level_plots:
    # Box plot average intensities
    plt_fig = plt.figure(figsize=[12,6])
    bar_ax = plt_fig.add_subplot(121)
    box_ax = plt_fig.add_subplot(122)
    plt_fig_nsub = plt.figure(figsize=[12,6])
    bar_ax_nsub = plt_fig_nsub.add_subplot(121)
    box_ax_nsub = plt_fig_nsub.add_subplot(122)
    box_df      = pd.DataFrame()
    box_df_trim = pd.DataFrame()
    sub_vals      = np.zeros(len(np.unique(syn_avg_df['Stain'])))
    sub_vals_trim = np.zeros(len(np.unique(syn_avg_df['Stain'])))
    for stain_ct, stain_val in enumerate(np.unique(syn_avg_df['Stain'])):
        # Build list for y-axis values and labels
        y_vals = syn_avg_df['Fluor'][syn_avg_df['Stain']==stain_val]
        if len(y_vals) > 0:
            # Identify shuffled intensities
            y_vals_shuff = syn_avg_df['Shuff'][syn_avg_df['Stain']==stain_val]
            # Trim outliers
            y_vals_trim,       trim_bool       = trim_quartile(y_vals,       q1_trim, q3_trim, quartile_mult)
            y_vals_shuff_trim, shuff_trim_bool = trim_quartile(y_vals_shuff, q1_trim, q3_trim, quartile_mult)
            
            # Subtract background (shuffled) intensities
            sub_vals[stain_ct]      = np.mean(y_vals_shuff)
            sub_vals_trim[stain_ct] = np.mean(y_vals_shuff_trim)
            y_vals_sub      = y_vals      - sub_vals[stain_ct]
            y_vals_sub_trim = y_vals_trim - sub_vals_trim[stain_ct]
            
            # Build list for y-axis values and labels
            y_labels      = []
            y_labels_trim = []
            for y_ct, y_val in enumerate(y_vals):
                y_labels.append(stain_val)
            for y_ct, y_val in enumerate(y_vals[trim_bool]):
                y_labels_trim.append(stain_val)
            # Append values to DataFrame
            box_df = pd.concat([box_df, pd.DataFrame(data={'Labels':y_labels,
                                                           'Fluor':y_vals,
                                                           'Shuff':y_vals_shuff,
                                                           'Fluor_sub':y_vals_sub})],ignore_index=True)
            box_df_trim = pd.concat([box_df_trim, pd.DataFrame(data={'Labels':    y_labels_trim,
                                                                     'Fluor':     y_vals_trim.squeeze(),
                                                                     'Shuff':     y_vals[trim_bool].squeeze(),
                                                                     'Fluor_sub': y_vals_sub_trim.squeeze()})],
                                    ignore_index=True)
            
            # Plot with errorbars
            bar_ax.bar(stain_ct, np.mean(y_vals_sub_trim,axis=0),edgecolor='k',facecolor='none')
            bar_ax.errorbar(stain_ct, np.mean(y_vals_sub_trim,axis=0),scipy.stats.sem(y_vals_sub_trim),
                            marker='.',markersize=10,markerfacecolor='none',markeredgecolor='none',
                            linestyle='none',color='k')
            
            bar_ax_nsub.bar(stain_ct, np.mean(y_vals_trim,axis=0),edgecolor='k',facecolor='none')
            bar_ax_nsub.errorbar(stain_ct, np.mean(y_vals_trim,axis=0),scipy.stats.sem(y_vals_trim),
                                 marker='.',markersize=10,markerfacecolor='none',markeredgecolor='none',
                                 linestyle='none',color='k')
    
    # Generate boxplot
    sns.stripplot(data=box_df_trim,x='Labels',y='Fluor_sub',color='k',edgecolor='none',size=.5,ax=box_ax)
    sns.boxplot(data=box_df_trim,x='Labels',y='Fluor_sub',fliersize=0,whis=0,ax=box_ax,color='w',showmeans=True,
                meanprops={'marker':'+','markerfacecolor':'w','markeredgecolor':'r','markersize':20})
    
    
    sns.stripplot(data=box_df_trim,x='Labels',y='Fluor',color='k',edgecolor='none',size=.5,ax=box_ax_nsub)
    sns.boxplot(data=box_df_trim,x='Labels',y='Fluor',fliersize=0,whis=0,ax=box_ax_nsub,color='w',showmeans=True,
                meanprops={'marker':'+','markerfacecolor':'w','markeredgecolor':'r','markersize':20})
    
    # Set axis properties
    bar_ax.set(ylabel='Fluorescence (BL Sub)',xlabel='Stain',xticks=np.arange(0,len(np.unique(syn_avg_df['Stain']))),
               xticklabels=np.unique(syn_avg_df['Stain']),title='Baseline Subtracted')
    bar_ax.set_xticklabels(bar_ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    box_ax.set(ylabel='Fluorescence (BL Sub)',yscale='linear',title='Baseline Subtracted',
               xlabel='Stain',xticks=np.arange(0,len(np.unique(syn_avg_df['Stain']))),
               xticklabels=np.unique(syn_avg_df['Stain']))
    box_ax.set_xticklabels(box_ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    bar_ax_nsub.set(ylabel='Fluorescence',xlabel='Stain',xticks=np.arange(0,len(np.unique(syn_avg_df['Stain']))),
                    xticklabels=np.unique(syn_avg_df['Stain']),title='Raw Values')
    bar_ax_nsub.set_xticklabels(bar_ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    box_ax_nsub.set(ylabel='Fluorescence',yscale='linear',title='Raw Values',
               xlabel='Stain',xticks=np.arange(0,len(np.unique(syn_avg_df['Stain']))),
               xticklabels=np.unique(syn_avg_df['Stain']))
    box_ax_nsub.set_xticklabels(box_ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    
    plt.tight_layout()
    
    if save_plots:
        plt_fig_nsub.savefig(fdir + '/if_label_nsublevels.pdf')
        plt.close(plt_fig_nsub)
        plt_fig.savefig(fdir + '/if_label_levels.pdf')
        plt.close(plt_fig)

###############################################################################
#      Plot the values binned by synapse size                                 #
###############################################################################

if generate_size_hist_plots:
    # Initialize histogram bins
    x_vals = np.log(np.array(csv_vals[num_pix_label]))
    hist_bins = np.linspace(np.min(x_vals),np.max(x_vals),num_size_bins + 1)
        
    # Generate figure
    bar_fig = plt.figure(figsize=[19,8])
    
    # Cycle through each plot to generate
    for plt_ct, plt_label in enumerate(size_hist_plots):
        # Select axes in each plot
        bar_ax = bar_fig.add_subplot(2,5,plt_ct + 1)
        
        # Grab values for this channel and remove outliers
        y_vals = np.array(csv_vals[plt_label])
        y_vals_trim, trim_bool = trim_quartile(y_vals, q1_trim, q3_trim, quartile_mult)
        x_vals_trim = x_vals[trim_bool]
        
        # Initialize bin values
        bin_nums = np.ones(len(x_vals_trim)) * -1
        num_syn = np.zeros(len(hist_bins) - 1)
        bin_labels = []
        
        # Calculate histogram values
        for bin_ct, this_bin in enumerate(hist_bins[0:-1]):
            # Identify synapses falling within this histogram bin
            these_syn = (x_vals_trim >= hist_bins[bin_ct]) & (x_vals_trim < hist_bins[bin_ct + 1])
            num_syn[bin_ct] = np.sum(these_syn)
            bin_nums[these_syn] = bin_ct
            bin_labels.append('{:.0f}'.format(np.exp(hist_bins[bin_ct])) + ' - ' + '{:.0f}'.format(np.exp(hist_bins[bin_ct + 1])))
            
            # Build list of bin labels
            bar_ax.bar(bin_ct, np.mean(y_vals_trim[bin_nums==bin_ct],axis=0),edgecolor='k',facecolor='none')
            bar_ax.errorbar(bin_ct, np.mean(y_vals_trim[bin_nums==bin_ct],axis=0), 
                            scipy.stats.sem(y_vals_trim[bin_nums==bin_ct]),
                            marker='.',markersize=10,markerfacecolor='none',markeredgecolor='none',
                            linestyle='none',color='k')
        
        bar_ax.set(title=plt_label,xlabel='Synapse Size',ylabel=plt_label,
                   xticks=np.arange(0,len(bin_labels)),xticklabels=bin_labels)
        bar_ax.set_xticklabels(bar_ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    
    plt.tight_layout()
    
    # Generate figure for average intensity correlations
    ecdf_fig = plt.figure(figsize=[19,8])
    
    # Grab values for histogram
    x_vals = np.log(np.array(csv_vals[num_pix_label]))
    
    # Set up histogram bins
    hist_bins = np.linspace(np.min(x_vals), np.max(x_vals), num_size_bins)
    
    # Pick a color scheme for plotting
    color_vals = np.linspace(0,1,num_size_bins)
    color_data = cmap(color_vals)
    
    # Initialize DataFrame for cumulative distributions
    ecdf_df = pd.DataFrame()
    
    # Cycle through each plot to generate
    for plt_ct, plt_label in enumerate(size_hist_plots):
        # Generate axis
        ecdf_ax = ecdf_fig.add_subplot(2,5,plt_ct + 1)
        
        # Convert values to an array
        y_vals = np.array(csv_vals[plt_label])
        
        # Trim outliers
        y_vals_trim, trim_bool = trim_quartile(y_vals, q1_trim, q3_trim, quartile_mult)
        x_vals_trim = x_vals[trim_bool]
        
        # Initialize arrays
        bin_nums = np.ones(len(x_vals_trim)) * -1
        num_syn = np.zeros(len(hist_bins) - 1)
        bin_labels = []
        
        # Cycle through histogram bins to identify & count synapses
        for bin_ct, this_bin in enumerate(hist_bins[0:-1]):
            # Fins synapses belonging to this bin
            these_syn = (x_vals_trim >= hist_bins[bin_ct]) & (x_vals_trim < hist_bins[bin_ct + 1])
            num_syn[bin_ct] = np.sum(these_syn) # Store in array
            bin_nums[these_syn] = bin_ct
            bin_labels.append('{:.0f}'.format(np.exp(hist_bins[bin_ct])) + ' - ' + '{:.0f}'.format(np.exp(hist_bins[bin_ct + 1])))
            
            # Generate plot
            sns.ecdfplot(data=np.array(y_vals_trim[bin_nums == bin_ct]).ravel(),color=color_data[bin_ct,0:3],
                         label='Bin ' + str(bin_ct),ax=ecdf_ax)

        ecdf_ax.set(title=plt_label,xlabel='Synapse Size',ylabel=plt_label,xscale='log') # Set axis title
    
    plt.tight_layout()
    
    # Plot number of synapses per bin
    plt_fig = plt.figure()
    plt_ax = plt_fig.add_subplot(111)
    plt_ax.bar(np.arange(0,len(bin_labels)), num_syn,facecolor='w',edgecolor='k')
    plt_ax.set(title='Number of Synapses',xlabel='Synapse size bin',ylabel='Number of Synapses',
               xticks=np.arange(0,len(bin_labels)),xticklabels=bin_labels)
    plt_ax.set_xticklabels(plt_ax.get_xticklabels(), rotation=45, horizontalalignment='right')
    plt.tight_layout()
    
    if save_plots:
        save_fname = fdir + '/if_synsize_values.pdf'
        bar_fig.savefig(save_fname)
        plt.close(bar_fig)
        save_fname = fdir + '/if_synsize_cdf.pdf'
        ecdf_fig.savefig(save_fname)
        plt.close(ecdf_fig)
        save_fname = fdir + '/if_synsize_numsyns.pdf'
        plt_fig.savefig(save_fname)
        plt.close(plt_fig)
    
    

###############################################################################
#      Plot the values binned by GABA content                                 #
###############################################################################

if generate_GABA_content_plots:
    x_vals = np.array(csv_vals[gaba_cut_label])
    x_vals_trim, trim_bool_x = trim_quartile(x_vals, q1_trim, q3_trim, quartile_mult)
    gaba_thresh = np.sort(np.array(x_vals_trim[list(x_vals_trim)[0]]))[int(gaba_cutoff * len(x_vals_trim))]
    plt_fig = plt.figure(figsize=[18,7])
    for plt_ct, plt_label in enumerate(size_hist_plots):
        
        plt_ax = plt_fig.add_subplot(2,5,plt_ct + 1)
        y_vals = csv_vals[plt_label]
        y_vals_trim, trim_bool_y = trim_quartile(y_vals, q1_trim, q3_trim, quartile_mult)
        y_vals_bot = np.array(y_vals[(x_vals <  gaba_thresh) & trim_bool_x & trim_bool_y])
        y_vals_top = np.array(y_vals[(x_vals >= gaba_thresh) & trim_bool_x & trim_bool_y])
        
        sns.ecdfplot(data=y_vals_bot,ax=plt_ax,color='b',label='Low GABA')
        sns.ecdfplot(data=y_vals_top,ax=plt_ax,color='r',label='High GABA')
        res = tukey_hsd(y_vals_bot,y_vals_top)
        plt_ax.set(xscale='log',title=plt_label + ' -- P=' + '{:.2e}'.format(res.pvalue[0,1]))
        plt_ax.legend()
    plt.tight_layout()
    
    if save_plots:
        save_fname = fdir + '/if_gaba_cut' + str(gaba_cutoff) + '.pdf'
        plt_fig.savefig(save_fname)
        plt.close(plt_fig)



###############################################################################
#        Plot correlations across channels at single-synapse level            #
###############################################################################

if generate_scatter_plots:
    # Gerate figure for average intensity correlations
    color_vals = np.log(csv_vals[scatt_size_label])
    size_vals = csv_vals[scatt_size_label]
    norm = plt.Normalize(np.min(color_vals), np.max(color_vals))
    color_data = cmap(norm(color_vals))
    
    # Loop through each plot to generate
    for plt_ct in range(0,len(plt_list)):
        # Create the figure
        scat_fig = plt.figure(figsize=[11,8])
        gs = GridSpec(5,6,figure=scat_fig)
        scat_ax = scat_fig.add_subplot(gs[0:4,0:4])
        hist_ax0 = scat_fig.add_subplot(gs[4,0:4])
        hist_ax1 = scat_fig.add_subplot(gs[0:4,4])
        cax = scat_fig.add_subplot(gs[0:4,5])
        
        # Calculate values separately for channels 
        x_vals = np.log(csv_vals[plt_list[plt_ct][0]])
        x_vals_shuffle = np.log(csv_vals[plt_list[plt_ct][0].replace('_avg','_shuffleavg').replace('_sum','_shufflesum')])
        y_vals = np.log(csv_vals[plt_list[plt_ct][1]])
        y_vals_shuffle = np.log(csv_vals[plt_list[plt_ct][1].replace('_avg','_shuffleavg').replace('_sum','_shufflesum')])
        
        trim_vals_x, trim_bool_x = trim_quartile(x_vals, q1_trim, q3_trim, quartile_mult)
        trim_vals_y, trim_bool_y = trim_quartile(y_vals, q1_trim, q3_trim, quartile_mult)
        plot_pts = (trim_bool_x) & (trim_bool_y) & (~np.isinf(x_vals) & (~np.isinf(y_vals)))
        
        # Set axis limits depending on the scaling of the data
        xlv = [np.median(x_vals[plot_pts])-(2*np.std(x_vals[plot_pts])),
               np.median(x_vals[plot_pts])+(2.5*np.std(x_vals[plot_pts]))]
        ylv = [np.median(y_vals[plot_pts])-(2*np.std(y_vals[plot_pts])),
               np.median(y_vals[plot_pts])+(2.5*np.std(y_vals[plot_pts]))]
            
        # Scatterplot the data
        scat_im = scat_ax.scatter(x_vals[plot_pts], y_vals[plot_pts], c=np.array(color_data[plot_pts,:]), 
                                  alpha = alpha_val,
                                  edgecolors='none',marker='o', s=np.array(size_vals[plot_pts])/size_scale)
        scat_im.set(cmap=cmap,clim=[np.min(color_vals, axis=0),np.max(color_vals, axis=0)])
        scat_fig.colorbar(scat_im, cax=cax, orientation='vertical')
        scat_ax.set(xlim=xlv,ylim=ylv)
        
        # Generate histogram for x-axis
        hist_vals0, hist_bins0 = np.histogram(x_vals[x_vals > 0], bins=num_hist_bins)
        hist_vals0_shuffle, hist_bins0_shuffle = np.histogram(x_vals_shuffle[x_vals_shuffle > 0], bins=num_hist_bins)
        
        # Generate histogram for y-axis
        hist_vals1, hist_bins1 = np.histogram(y_vals[y_vals > 0], bins=num_hist_bins)
        hist_vals1_shuffle, hist_bins1_shuffle = np.histogram(y_vals_shuffle[y_vals_shuffle > 0], bins=num_hist_bins)
    
        # Plot histogram for x-axis
        hist_ax0.bar(hist_bins0[1:] - (np.mean(np.diff(hist_bins0)) / 2), hist_vals0,
                      width=(hist_bins0[-1] - hist_bins0[0]) / num_hist_bins, facecolor='silver',edgecolor='none')
        hist_ax0.plot(hist_bins0[1:] - (np.mean(np.diff(hist_bins0)) / 2),savgol_filter(hist_vals0,smooth_len, 2),
                      color='k',linewidth=1)
        hist_ax0.plot(hist_bins0_shuffle[1:] - (np.mean(np.diff(hist_bins0_shuffle)) / 2),
                      savgol_filter(hist_vals0_shuffle,smooth_len, 2),color=[0.5,0.5,1],linewidth=1,linestyle='--')
        hist_ax0.set(xlim=scat_ax.get_xlim(),xlabel=plt_list[plt_ct][0],
                     ylim=[np.max(np.concatenate([hist_vals0,hist_vals0_shuffle])),0])
    
        # Plot histogram for y-axis
        hist_ax1.barh(hist_bins1[1:] - (np.mean(np.diff(hist_bins1)) / 2), hist_vals1,
                      height=(hist_bins1[-1] - hist_bins1[0]) / num_hist_bins,facecolor='silver',edgecolor='none')
        hist_ax1.plot(savgol_filter(hist_vals1,smooth_len, 2),hist_bins1[1:] - (np.mean(np.diff(hist_bins1)) / 2),
                      color='k',linewidth=1)
        hist_ax1.plot(savgol_filter(hist_vals1_shuffle,smooth_len, 2),
                      hist_bins1_shuffle[1:] - (np.mean(np.diff(hist_bins1_shuffle)) / 2),
                      color=[0.5,0.5,1],linewidth=1,linestyle='--')
        hist_ax1.set(ylim=scat_ax.get_ylim(),ylabel=plt_list[plt_ct][1],
                     xlim=[0,np.max(np.concatenate([hist_vals1,hist_vals1_shuffle]))])
        hist_ax1.yaxis.set_label_position("right")
        hist_ax1.yaxis.tick_right()
    
        # Calculate the cutoff based on the mean and standard deviation of the shuffled ata
        x_cut = np.mean(x_vals_shuffle[x_vals_shuffle > 0]) + 1.0 * np.std(x_vals_shuffle[x_vals_shuffle > 0])
        y_cut = np.mean(y_vals_shuffle[y_vals_shuffle > 0]) + 1.0 * np.std(y_vals_shuffle[y_vals_shuffle > 0])
        
        # Add the cutoff line to the plot
        scat_ax.plot([x_cut,x_cut],ylv,color='b',linewidth=1,linestyle='--')
        scat_ax.plot(xlv,[y_cut,y_cut],color='b',linewidth=1,linestyle='--')
        
        # Set plot parameters
        scat_ax.set(title=sum_or_avg + ' - ' + plt_list[plt_ct][0] + ' vs. ' + plt_list[plt_ct][1])
        plt.setp(scat_ax.get_xticklabels(), visible=False)
        plt.setp(scat_ax.get_yticklabels(), visible=False)
        plt.tight_layout()
        
        if save_plots:
            save_fname = fdir + '/if_scatter_' + plt_list[plt_ct][0] + '_' + plt_list[plt_ct][1] + '.pdf'
            scat_fig.savefig(save_fname)
            plt.close(scat_fig)
        
        
###############################################################################
#             GENERATE THE UMAP PLOT BY COLOR                                 #
###############################################################################


if generate_umap:
    # https://www.datacamp.com/tutorial/introduction-hierarchical-clustering-python
    sys.stdout.write("\rGenerating UMAP                                          ")
    
    # Generate copy of data to work on
    clust_vals = csv_vals[umap_list].copy()
    
    # Log transform cluster values
    for col_name in umap_list:
        if col_name == 'num_pixels':
            clust_vals[col_name] = np.sqrt(clust_vals[col_name])
        else:
            clust_vals[col_name] = np.log(clust_vals[col_name])
            
        clust_vals[col_name][clust_vals[col_name] == -np.inf] = 0
    
    # Trim outliers
    for col_name in umap_list:
        trim_vals, trim_bool = trim_quartile(clust_vals[col_name], q1_trim, q3_trim, quartile_mult)
        clust_vals = clust_vals[trim_bool]
        
    # Apply Standard Scaler for unbiased clustering
    clust_vals_scaled = StandardScaler().fit_transform(clust_vals)
    clust_vals = pd.DataFrame(data=clust_vals_scaled,columns=list(clust_vals))
        
    # Generate a separate copy of the data for UMAP analysis
    umap_vals  = clust_vals[umap_list].copy()
        
    # Perform clustering analysis
    cluster_grid = sns.clustermap(data=(clust_vals),method='ward',metric='euclidean',figsize=[8,8],cmap='vlag')
    plt_fig = plt.gcf()  # Close the figure that was generated by clustering analysis
    plt.close(plt_fig)
    
    # Grab the re-ordered index values from the cluster object
    index_list     = cluster_grid.dendrogram_row.reordered_ind
    index_list_col = cluster_grid.dendrogram_col.reordered_ind
    Z = cluster_grid.dendrogram_row.linkage # get the linkage matrix
    X = cluster_grid.dendrogram_col.linkage # get the linkage matrix
    cut_tree_val = scipy.cluster.hierarchy.cut_tree(Z) # cut the tree; leftmost: leaves (each gene is in its own cluster)
    # rightmost: top node of tree (each gene is in the same parent cluster)

    # Run Silhouette analysis to determine appropriate number of clusters
    range_n_clusters = np.arange(2,20)
    silhouette_avg = []
    for n_clust in np.arange(2,20):
        clusterer = KMeans(n_clusters=n_clust, random_state=10)
        cluster_labels = clusterer.fit_predict(clust_vals)
        silhouette_avg.append(silhouette_score(clust_vals, cluster_labels))
        
    # Generate a plot for the Silhouette analysis
    i = num_clusts # choose what level down to cut tree
    sil_fig = plt.figure(figsize=[6,6])
    sil_ax = sil_fig.add_subplot(111)
    sil_ax.bar(range_n_clusters,silhouette_avg,facecolor='none',edgecolor='k')
    sil_ax.set(xticks=range_n_clusters,xlabel='N Clusters',ylabel='Silhouette Average Score',
               title='Setting ' + str(i) + ' clusters')
    
    # Cut the cluster tree at the appropriate point for set number of clusters
    clusters = cut_tree_val[index_list,-i]
    
    clusters_old = clusters.copy()
    clusters_new = clusters.copy()
    for clust_ind, clust_num in enumerate(clusters):
        clusters_new[clust_ind] = clust_conj[clust_if.index(clust_num)]
    clusters = clusters_new
    
    clust_vals.index = np.arange(0,len(clust_vals.index))
    clust_vals_sort = clust_vals.copy()
    
    # Sort columns
    clust_vals_sort = clust_vals_sort[np.array(list(clust_vals_sort))[index_list_col]] 
    
    # Reorganize the spine list IDs to correspond to index values in clust_vals
    clusters_unsorted = np.zeros(len(clust_vals_sort.index))
    for ind_ct, ind_val in enumerate(index_list):
        clusters_unsorted[ind_val] = clusters[ind_ct]
        clust_vals_sort.loc[ind_ct,:] = clust_vals.loc[ind_val,:]
        
    # Generate a figure to show clustered data
    clust_fig = plt.figure(figsize=[11,9])
    gs = GridSpec(5,8,figure=clust_fig)
    dend_ax2 = clust_fig.add_subplot(gs[0,3:8])
    plt_ax   = clust_fig.add_subplot(gs[1:4,3:8])
    clust_ax = clust_fig.add_subplot(gs[1:4,2])
    dend_ax1 = clust_fig.add_subplot(gs[1:4,0:2])
    cbar_ax  = clust_fig.add_subplot(gs[4,3:8])
    
    # Plot the dendrograms showing cluster heirarchies
    hierarchy.set_link_color_palette(['silver','k'])
    hierarchy.dendrogram(Z,ax=dend_ax1,orientation='left',color_threshold=12,#p=p,truncate_mode='level',
                          show_leaf_counts=False,no_labels=True,above_threshold_color='k')
    hierarchy.dendrogram(X,ax=dend_ax2,orientation='top',color_threshold=2,#p=5,truncate_mode='level',
                          show_leaf_counts=False,no_labels=True,above_threshold_color='k')
    dend_ax1.invert_yaxis()
    
    # Plot the heatmap for all cluster data
    clust_plot = sns.heatmap(data=clust_vals_sort,ax=plt_ax,cmap='vlag',vmin=-3,vmax=3,yticklabels=[],cbar=False)
    clust_im = clust_plot.get_children()[0]
    sns.heatmap(data=pd.DataFrame(data=clusters),ax=clust_ax,cmap='tab10',vmin=0,vmax=9,cbar=False,
                xticklabels=[],yticklabels=[])
    clust_im.set(cmap='vlag',clim=[-3,3])
    clust_fig.colorbar(clust_im, cax=cbar_ax, orientation='horizontal')
    plt.tight_layout()
    
    # Now re-generate sorted on same column order as the conjugate dataset
    clust_vals_sort = clust_vals_sort[np.array(['GluN1_avg','GluN2B_avg','PSD95_avg','GluA3_avg',
                                                'num_pixels','GluA2_avg','GluA1_avg','GluA4_avg'])] 
        
    # Generate a figure to show clustered data
    clust_fig_sorted = plt.figure(figsize=[11,9])
    gs = GridSpec(5,8,figure=clust_fig)
    dend_ax2 = clust_fig_sorted.add_subplot(gs[0,3:8])
    plt_ax   = clust_fig_sorted.add_subplot(gs[1:4,3:8])
    clust_ax = clust_fig_sorted.add_subplot(gs[1:4,2])
    dend_ax1 = clust_fig_sorted.add_subplot(gs[1:4,0:2])
    cbar_ax  = clust_fig_sorted.add_subplot(gs[4,3:8])
    
    # Plot the dendrograms showing cluster heirarchies
    hierarchy.set_link_color_palette(['silver','k'])
    hierarchy.dendrogram(Z,ax=dend_ax1,orientation='left',color_threshold=12,#p=p,truncate_mode='level',
                          show_leaf_counts=False,no_labels=True,above_threshold_color='k')
    dend_ax1.invert_yaxis()
    
    # Plot the heatmap for all cluster data
    clust_plot = sns.heatmap(data=clust_vals_sort,ax=plt_ax,cmap='vlag',vmin=-3,vmax=3,yticklabels=[],cbar=False)
    clust_im = clust_plot.get_children()[0]
    sns.heatmap(data=pd.DataFrame(data=clusters),ax=clust_ax,cmap='tab10',vmin=0,vmax=9,cbar=False,
                xticklabels=[],yticklabels=[])
    clust_im.set(cmap='vlag',clim=[-3,3])
    clust_fig_sorted.colorbar(clust_im, cax=cbar_ax, orientation='horizontal')
    plt.tight_layout()
    
    # Generate the UMAP reducter
    reducer = umap.UMAP(random_state=123,n_neighbors=200, min_dist=0.25, n_components=2, metric='euclidean')
    embedding = reducer.fit_transform(umap_vals)
    
    col_range = 256
    umap_cmap = cm.binary(range(col_range))
    
    # Generate UMAP figure and set color scheme
    plt_fig = plt.figure(figsize=[14,12])
    umap_col_inds = []
    plt_ax_list = []
    plt_ax_list.append(plt_fig.add_subplot(3,3,1))
    embed_cols_sortedclusts  = np.zeros([len(umap_vals),4])
    plt_cols = cm.tab10(np.arange(0,9))
    for spine_ct, this_spine in enumerate(clusters_unsorted):
        embed_cols_sortedclusts[spine_ct,:] = plt_cols[int(clusters_unsorted[spine_ct]),:]
    plt_ax_list[0].scatter(embedding[:,0],embedding[:,1],marker='o',color=embed_cols_sortedclusts,s=5,edgecolor='none')
    plt_ax_list[0].set(title='Cluster Number')
    plt_ax_list[0].set_aspect('equal','box')
    # Cycle through UMAPs for each channel
    for umap_col_ct, umap_col_name in enumerate(umap_list):
        # Grab the column values and trim the outliers
        col_vals = np.array(umap_vals[umap_col_name].copy())
        col_vals_trim, trim_bool = trim_quartile(col_vals, q1_trim, q3_trim, quartile_mult)
        # Use the trimming to set the color range for UMAPs
        embed_col_inds = np.array((((col_vals - np.min(col_vals_trim, axis=0)[0]) / 
                                    (np.max(col_vals_trim, axis=0)[0] - np.min(col_vals_trim, axis=0)[0])) * col_range),
                                  dtype=int)
        embed_col_inds[embed_col_inds >= col_range] = col_range - 1
        embed_col_inds[embed_col_inds < 0] = 0
        # Generate the axis
        plt_ax_list.append(plt_fig.add_subplot(3,3,umap_col_ct + 2))
        # Plot the UMAP data and set axis properties
        scat_im = plt_ax_list[-1].scatter(embedding[:,0],embedding[:,1],
                                          marker='o',color=umap_cmap[embed_col_inds,:],s=5,edgecolor='none')
        scat_im.set(cmap=cm.binary,clim=[np.min(col_vals_trim, axis=0)[0],np.max(col_vals_trim, axis=0)[0]])
        # Make space to add a color bar to the axis
        divider = make_axes_locatable(plt_ax_list[-1])
        cax = divider.append_axes('bottom',size='5%',pad=0.35)
        plt_ax_list[-1].set(title=umap_col_name)
        plt_ax_list[-1].set_aspect('equal','box')
        plt_fig.colorbar(scat_im, cax=cax, orientation='horizontal')
    
    
    if save_plots:
        sil_fig.savefig(fdir + '/if_silhouette.pdf')
        plt.close(sil_fig)
        clust_fig.savefig(fdir + '/if_synclust.pdf')
        plt.close(clust_fig)
        clust_fig_sorted.savefig(fdir + '/if_synclust_sort.pdf')
        plt.close(clust_fig_sorted)
        plt_fig.savefig(fdir + '/if_avgintens_umap.pdf')
        plt.close(plt_fig)