In [None]:
import pandas as pd
import seaborn as sns
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import umap
from scipy import stats
pd.set_option('display.max_rows', 50)
np.set_printoptions(threshold=50)

## If you see "ModuleNotFoundError" (e.g., ModuleNotFoundError: No module named 'umap'), 
## then pip install the module (e.g., pip install umap)

In [None]:
basedir = os.getcwd()
datadir = os.path.join(basedir,'Data')

# Prepare Data

In [None]:
total_alz = 8804 #Total alzheimer patients
total_con = 17608 #Total control patients

In [None]:
# Get Alzheimer's Disease diagnosis information
alzcohort = pd.read_csv(os.path.join(datadir, 'ad_demographics.csv'));
alzdiag = pd.read_csv(os.path.join(datadir, 'ad_diagnosis.csv'))

# Get Control diagnosis information
concohort = pd.read_csv(os.path.join(datadir, 'control_demographics.csv'));
condiag = pd.read_csv(os.path.join(datadir, 'control_diagnosis.csv'))

### Make pivot tables

In [None]:
# each row is a patient, each column is diagnosis, 1 if patient has diagnosis, 0 otherwise
# # Takes a few minutes
n = 'DiagnosisName'
alzdiag_pivot = pd.pivot_table(alzdiag[[n, 'PatientID']].drop_duplicates(), 
         values = [n], index = 'PatientID', columns = [n],
         aggfunc = lambda x: 1 if len(x)>0 else 0, fill_value = 0)
alzdiag_pivot['isAD'] = 1
alzdiag_pivot

In [None]:
n = 'DiagnosisName'
condiag_pivot = pd.pivot_table(condiag[[n, 'PatientID']].drop_duplicates(), 
         values = [n], index = 'PatientID', columns = [n],
         aggfunc = lambda x: 1 if len(x)>0 else 0, fill_value = 0)
condiag_pivot['isAD'] = 0
condiag_pivot

In [None]:
alldiag_pivot = pd.concat([alzdiag_pivot, condiag_pivot], axis=0)

### Drop columns

In [None]:
colstodrop = alldiag_pivot.columns[alldiag_pivot.columns.str.contains('alzheimer', flags=re.IGNORECASE)]
colstodrop

In [None]:
alldiag_pivot = alldiag_pivot.drop(colstodrop, axis=1)

### Add demographic

In [None]:
demographic_cols = ['PatientID','Age','Sex','Race','Death_Status']
alldiag_pivot = alldiag_pivot.merge(pd.concat([alzcohort[demographic_cols], 
                                        concohort[demographic_cols]]).set_index('PatientID'),
                how = 'left', left_index = True, right_index = True)
alldiag_pivot = alldiag_pivot.fillna(0)

# Dimensionality Reduction

In [None]:
y = alldiag_pivot['isAD'].replace({1:'Alzheimer',0:'Control'})
demographic_cols.remove('PatientID')
z = alldiag_pivot[demographic_cols]
X = alldiag_pivot.drop('isAD', axis=1).drop(demographic_cols, axis=1).astype('int32')

In [None]:
%%time
mapper = umap.UMAP(metric='cosine', random_state=42, low_memory=True, verbose = 1).fit(X)

In [None]:
import pickle

filename = 'AD_umap_model.pkl'

# save the model to disk
pickle.dump(mapper, open(filename, 'wb'))

# load file
# mapper = pickle.load(open(filename, 'rb'))

In [None]:
X_embedded = mapper.transform(X)
savefigs = False

In [None]:
with sns.color_palette("Set1"):
    fig = plt.figure(figsize=(10,8))
    reordered_indices = np.arange(X_embedded.shape[0])
    np.random.shuffle(reordered_indices)
    sns.scatterplot(x = X_embedded[reordered_indices ,0], y = X_embedded[reordered_indices ,1], 
                    hue = y[reordered_indices].values, 
                    s=5, linewidth = .0, alpha = .6,
                    hue_order = ['Alzheimer','Control']
                   )
    ax = plt.gca()
    ax.set(xticks=[], yticks=[], facecolor='white');
    plt.title('Diagnosis as Features - UMAP')
    
    if savefigs:
        plt.savefig('full_AD_Control_UMAP.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight')

In [None]:
from scipy.stats import mannwhitneyu
ADvals = X_embedded[y.values == 'Alzheimer',:]
convals = X_embedded[y.values == 'Control',:]
print('Axis 1: ',mannwhitneyu(ADvals[:,0], convals[:,0]))
print('Axis 2: ', mannwhitneyu(ADvals[:,1], convals[:,1]))

In [None]:
with sns.axes_style("darkgrid"):
    with sns.color_palette("Set1"):
        plt.figure(figsize = (5,3))
        sns.violinplot(x = X_embedded[:,0], y = y.replace({True:'Alzheimer',False:'Control'}).values, bw = .1); 
        plt.xlabel('PC1');  
        if savefigs:
            plt.savefig('AD-ConUMAPPC1.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight')
        
        plt.figure(figsize = (5,3))
        sns.violinplot(x = X_embedded[:,1], y = y.replace({True:'Alzheimer',False:'Control'}).values, bw = .1); 
        plt.xlabel('PC2');
        if savefigs: 
            plt.savefig('AD-ConUMAPPC2.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight')


### Sex

In [None]:
col = 'Sex'

fig = plt.figure(figsize=(8,6))
sns.scatterplot(x = X_embedded[:,0], y = X_embedded[:,1], hue = z[col].values, s = 5, linewidth = 0,
                alpha = .6, palette = 'hls')
ax = plt.gca()
ax.set(xticks=[], yticks=[], facecolor='white');
plt.text(9,14.3,'Sex', fontsize = 12, fontweight = 'bold')
if savefigs: 
    plt.savefig('AD-ConUMAP-sex.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight')
plt.show()

with sns.color_palette("hls", 3):
    val1 = X_embedded[z[col].values == 'Female',:]
    val2 = X_embedded[z[col].values == 'Male',:]
    
    sns.violinplot(x = X_embedded[:,0], y = z[col], bw = .1, order = ['Male', 'Female']); plt.xlabel('PC1'); 
    if savefigs:
        plt.savefig('AD-ConUMAP_PC1_SEX.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight'); plt.show()
    
    sns.violinplot(x = X_embedded[:,1], y = z[col], bw = .1, order = ['Male', 'Female']); plt.xlabel('PC2');
    if savefigs: 
        plt.savefig('AD-ConUMAP_PC2_SEX.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight'); plt.show()


### Age

In [None]:
col = 'Age'

fig = plt.figure(figsize=(8,6))
sns.scatterplot(x = X_embedded[:,0], y = X_embedded[:,1], hue = z[col].values, s = 5, linewidth = 0,
               palette = 'mako', hue_norm = (65,95), alpha = .6)
ax = plt.gca()
ax.set(xticks=[], yticks=[], facecolor='white');
plt.text(9,14.3,'Age', fontsize = 12, fontweight = 'bold')
if savefigs: 
    plt.savefig('AD-ConUMAP-age.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight')
plt.show()

with sns.color_palette("Dark2"):
    g = sns.jointplot(y = X_embedded[:,0], x = z[col], kind = 'hex', height = 5,
                      ratio = 10,  joint_kws = {'gridsize':20}, color = '#049372', alpha = .7)
    g.plot_joint(sns.regplot, x_jitter = .2, y_jitter = .4,
                 ci = 95, scatter_kws = {'s':1, 'alpha':.3, 'zorder':0}, line_kws = {'color':'#1f7a1f'})
    plt.ylabel('PC1'); g.fig.set_figwidth(6); 
    if savefigs: 
        plt.savefig('AD-ConUMAP_PC1_age.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight'); plt.show()
    print("pearson: r/pval {}".format(stats.pearsonr(z[col], X_embedded[:,0])))

    g = sns.jointplot(y= X_embedded[:,1], x = z[col], kind = 'hex', height = 5,
                      ratio = 10,  joint_kws = {'gridsize':20}, color = '#049372', alpha = .7)
    g.plot_joint(sns.regplot, x_jitter = .2, y_jitter = .4,
                 ci = 95, scatter_kws = {'s':1, 'alpha':.3, 'zorder':0}, line_kws = {'color':'#1f7a1f'})
    plt.ylabel('PC2'); g.fig.set_figwidth(6); 
    if savefigs: 
        plt.savefig('AD-ConUMAP_PC2_age.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight'); plt.show()
    print("pearson: r/pval {}".format(stats.pearsonr(z[col],X_embedded[:,1])))

### DeathStatus

In [None]:
col = 'Death_Status'

fig = plt.figure(figsize=(8,6))
sns.scatterplot(x = X_embedded[reordered_indices,0], y = X_embedded[reordered_indices,1], 
                hue = z[col].values[reordered_indices], s = 5, linewidth = 0,
                alpha = .6, palette = 'Set2')
ax = plt.gca()
ax.set(xticks=[], yticks=[], facecolor='white');
plt.text(7.2,14.3,'Death Status', fontsize = 12, fontweight = 'bold')
if savefigs: 
    plt.savefig('AD-ConUMAP-status.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight')
plt.show()


with sns.color_palette("Set2"):
    sns.violinplot(x = X_embedded[:,0], y = z[col], bw = .1); plt.xlabel('PC1'); 
    plt.text(-1,2,stats.mannwhitneyu(val1[:,0], val2[:,0])); 
    if savefigs: 
        plt.savefig('AD-ConUMAP_PC1_death.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight'); plt.show()
    
    sns.violinplot(x = X_embedded[:,1], y = z[col], bw = .1); plt.xlabel('PC2');
    print(stats.mannwhitneyu(val1[:,1], val2[:,1])); 
    if savefigs: 
        plt.savefig('AD-ConUMAP_PC2_death.pdf', filetype = 'pdf', dpi = 300, bbox_inches='tight'); plt.show()
