In [None]:
import pandas as pd
import seaborn as sns
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.stats import norm
from scipy.stats import chi2_contingency
import scipy.stats as stats
from math import log10, log2
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=FutureWarning) 

In [None]:
pd.set_option('display.max_rows', 50)
np.set_printoptions(threshold=50)

In [None]:
basedir = os.getcwd()
datadir = os.path.join(basedir,"Data")#basedir + r'\Data'
diagkeys = ['DiagnosisName','Level3_Category','Level2_Category']

%run -i setupfunc.py

# Import data

In [None]:
total_alz = 8804 #Total alzheimer patients
total_con = 17608 #Total control patients

In [None]:
alzdiag = pd.read_csv(os.path.join(datadir,'ad_diagnosis.csv'))
alzdiag['ValueL'] = alzdiag['Value'].apply(lambda x: ICD10_code_to_chapter(str(x)[0:3])) # Get ICD10 Chapte
condiag = pd.read_csv(os.path.join(datadir,'control_diagnosis.csv'))
condiag['ValueL'] = condiag['Value'].apply(lambda x: ICD10_code_to_chapter(str(x)[0:3]))

# demographics
alzdemo = pd.read_csv(os.path.join(datadir,'ad_demographics.csv'));
condemo = pd.read_csv(os.path.join(datadir,'control_demographics.csv'));

### count patients with diagnosis

In [None]:
alzdiagcount = countPtsDiagnosis_Dict(alzdiag, total_alz)
condiagcount = countPtsDiagnosis_Dict(condiag, total_con)
vennDiagramFromDiagnosisDict(alzdiagcount, condiagcount, set_labels = ('Alzheimer','Control'),
                            title = 'Overlap between Alzheimer/Control Diagnosis')

#### Chi_square/Fischer_Exact Tests

In [None]:
# Create dictionary of dataframes of contingency table for each diagnosis
alldiagcount = dict() 

jointype = 'outer' # outer join allows us to have contingency table with 0 as entries

for n in diagkeys:
    alldiagcount[n] = alzdiagcount[n].merge(condiagcount[n], how = jointype, on = n, suffixes=('_alz','_con'))
    alldiagcount[n] = alldiagcount[n].set_index(n)
    if jointype == 'outer': # replace nan's
        nanreplace = dict(zip( list(alldiagcount[n].columns) , [0,total_alz,0,total_con] ))
        alldiagcount[n] =  alldiagcount[n].fillna(value=nanreplace)

In [None]:
# compute p-values.
chitest_count = sigTestCountsDict(alldiagcount, diagkeys, verbose = 1, diag = True)

In [None]:
# Categorize diagnosis based upon significance
for n in diagkeys:
    bc = .05/chitest_count[n].shape[0] # Bonferonni corrected p-value
    print('bc:',bc)
    
    sig = np.full(shape = (chitest_count[n].shape[0],), fill_value = 'Not Significant') 
    sig[chitest_count[n]['pvalue']<bc] = 'Significant'
    mask = (chitest_count[n]['pvalue']<bc) & (chitest_count[n]['log2_oddsratio']>1)
    sig[mask] = 'Alzheimer Enriched'
    mask = (chitest_count[n]['pvalue']<bc) & (chitest_count[n]['log2_oddsratio']<-1)
    sig[mask] = 'Control Enriched'
    
    chitest_count[n]['sig'] = sig
    display(chitest_count[n]['sig'].value_counts())
    
    print('\n')

In [None]:
# save
import pickle 
pickle.dump(chitest_count, open('alzcon_diagnosis_stats.pickle', 'wb'))

In [None]:
codemap = np.array(["Infectious",  "Neoplasms", "Blood-Related Disorders",  
                   "Endocrine, Nutritional, and Metabolic Disorders", 
                   "Mental and Behavioral Disorders","Disease of Nervous System", 
                    "Disease of Eye and Adnexa","Disease of Ear and Mastoid", 
                     "Disease of Circulatory System", "Disease of Respiratory System", 
                    "Disease of Digestive System", "Disease of Skin and Subcutaneous Tissue",
                    "Musculoskeletal System Diseases", "Genitourinary System Diseases",
                    "Pregnancy and Childbirth", "Perinatal Diseases", "Congenital Diseases",
                    "Abnormal Clinical and Lab Findings", "Poisoning, Injury, and External Issues", 
                    "External Causes", "Factors Regarding Health Status and Services"]) #, "Special Codes"
rand_colors = ('#a7414a', '#282726', '#6a8a82', '#a37c27', '#563838', '#0584f2', '#f28a30', '#f05837',
                               '#6465a5', '#00743f', '#be9063', '#de8cf0', '#888c46', '#c0334d', '#270101', '#8d2f23',
                               '#ee6c81', '#65734b', '#14325c', '#704307', '#b5b3be', '#f67280', '#ffd082', '#ffd800',
                               '#ad62aa', '#21bf73', '#a0855b', '#5edfff', '#08ffc8', '#ca3e47', '#c9753d', '#6c5ce7')

In [None]:
# Plot volcano plots
dims = (6,3.5)
save = 1;
figtype = 'pdf';

for n in diagkeys:
    bc = .05/chitest_count[n].shape[0]
    chitest_count[n]['index'] = chitest_count[n].index
    df = chitest_count[n].copy()

    fig = plt.figure(figsize = dims);
    g = sns.scatterplot(data = df, x = 'log2_oddsratio', y = '-log_pvalue', 
                       hue = 'sig',
                       #palette = ["#EC719E","#51A0D5","#949494","#2CB97C"], 
                       edgecolor = None,
                       s = 12);
        
    g.get_legend().texts[0].set_text('Significance');
    plt.xlim([-4, 4]); plt.axhline(-np.log10(bc), linestyle = '--', color = '#555555', linewidth = 1);
    plt.axvline(1, color = '#555555', linestyle = ':'); plt.axvline(-1, color = '#555555',linestyle = ':'); 
    plt.xlabel(r'$\log_2$(Odds Ratio)'); plt.ylabel(r'$-\log_{10}$(p-value)');
    plt.title('Alzheimer vs. Control Diagnosis Volcano Plot: '+n);
    if save: 
        plt.savefig("Alz-Con_volcano_"+n+"."+figtype, format=figtype, bbox_inches='tight', dpi=300)

In [None]:
# plot manhattan plot
dims = (8,4)
save = False;
figtype = 'pdf';
chrexplode = True; # explode diagnosis in multiple diagnostic blocks
valuenames = True;

%run -i plotting.py
for n in diagkeys:
    bc = .05/chitest_count[n].shape[0]
    chitest_count[n]['index'] = chitest_count[n].index
    
    if chrexplode:
        chitest = chitest_count[n].explode('ValueL');
    else:
        chitest = chitest_count[n]
        
    fig, ax = marker.mhat(df=chitest, chr='ValueL',pv='pvalue', dim = dims, yskip = 50,
                         gwas_sign_line=True, plotlabelrotation = 45, xtickname = True,
                         gfont = 2, dotsize = 6, figtype = figtype, figname = "Alz-Con_man_"+n,
                         axxlabel = 'ICD10 Category', show = not save, ar = 45,
                         figtitle = 'Alzheimer vs. Control Chi Square Diagnosis Manhattan Plot: '+n)
    plt.xticks(ax.get_xticks(), codemap)
    if save: 
        plt.savefig("Alz-Con_enccon_man_s_"+n+"."+figtype, format=figtype, bbox_inches='tight', dpi=300)

In [None]:
# plot percent significant in a diagnostic block
chrexplode = True;
savefig = True;

for n in diagkeys:
    bc = .05/chitest_count[n].shape[0]
    icd10order = list(map(ICDchapter_to_name,np.unique(chitest_count[n]['ValueL'].astype(str).explode().sort_values().values)))
    
    if chrexplode:
        chitest = chitest_count[n].explode('ValueL').copy();
    else:
        chitest = chitest_count[n].copy()
        
    chitest['ValueL'] = chitest['ValueL'].apply(ICDchapter_to_name)
    
    icd10dict = list()
    for value, g in chitest.groupby('ValueL'):
        icd10dict.append([value, ((g['pvalue'] < bc).sum())*100/g.shape[0] ])
    icd10sig = pd.DataFrame(icd10dict).set_index(0)
    icd10sig = icd10sig.reindex(icd10order)
    
    with sns.axes_style("darkgrid"):
        plt.figure(figsize = (10,3))
        sns.barplot(x = icd10sig.index, y = icd10sig[1].values, palette = rand_colors)
        plt.xticks(rotation=50, va = 'top', ha = 'right')
        plt.ylabel('%')
        plt.title('% of Significant Diagnosis Per Category: '+n)
        if savefig:
            plt.savefig("Alz-Con_bar_"+n+"."+figtype, format=figtype, bbox_inches='tight', dpi=300)
        plt.show()

# Diagnosis Differences Between Sex

In [None]:
sexes = ['Female','Male']
equalize_num = True; # Equalize number between males and females
randomstate = 40

## Stratify diagnosis dataframes by sex
alzdiagsex = { sexes[0]: alzdiag[alzdiag['Sex']=='Female'], 
                sexes[1]: alzdiag[alzdiag['Sex']=='Male']}
numsexalz = {sexes[0]: alzdiagsex[sexes[0]][['PatientID','Sex']].drop_duplicates().shape[0], 
          sexes[1]:alzdiagsex[sexes[1]][['PatientID','Sex']].drop_duplicates().shape[0]}
print(numsexalz)

condiagsex = { sexes[0]: condiag[condiag['Sex']=='Female'], 
                sexes[1]: condiag[condiag['Sex']=='Male']}
numsexcon = {sexes[0]: condiagsex[sexes[0]][['PatientID','Sex']].drop_duplicates().shape[0], 
          sexes[1]:condiagsex[sexes[1]][['PatientID','Sex']].drop_duplicates().shape[0]}
print(numsexcon)

kNumPatientsSampled = min(numsexalz.values()) # Get the lower number as our number to sample
print(kNumPatientsSampled)

In [None]:
# Count diagnosis
def getDiagnosisCountsStratify(stratvarname, stratvars, diagdf=None, diagdict=None, numptsvar = None, 
                               equalize_num = False, random_state = 40):
    '''
        getDiagnosisCountsStratify(stratvarname, stratvars, diagdf=None, diagdict=None, numptsvar = None, 
                               equalize_num = False, random_state = 40)
        input:  stratvarname - string. name of stratifying variable (e.g. `sex`)
                stratvars - dictionary of possible stratifying names (e.g. ['Male','Female'])
                diagdf or diacdict - if dataframe, will create dictionary of stratified dataframes. 
                       Otherwise give dictionary directly.
                numptsvar - dictionary of number of patients with stratification
                equalize_num - whether to equalize the number across stratified categories
                random_state - default: 40
        outputs: dictionary indexed by stratified variable with dataframe of counted diagnosis.
    '''
    diagdictvar = dict()
    if diagdict: 
        diagdictvar = diagdict;
        if numptsvar is None:
            raise Exception('variable numptsvar is empty, pass in dictionary with number of patients in each category.')
    elif diagdf:
        numptsvar = dict()
        for var in stratvars:
            diagdictvar[var] = diagdf[diagdf[stratvarname] == var]
            numptsvar[var] = diagdictvar[var][['PatientID',stratvarname]] \
                    .drop_duplicates().shape[0]
    else:
        raise Exception('did not pass in full dataframe or dictionary of stratified dataframes.')
        
    kMinPatientsSampled = min(numptsvar.values())
    
    stratdiagcount = dict()
    for var in stratvars:
        stratdiagcount[var] = dict()
        if numptsvar[var]>kMinPatientsSampled and equalize_num:
            subsampledPatientKeys = diagdictvar[var]['PatientID'].drop_duplicates()\
                        .sample(kMinPatientsSampled, random_state = random_state)
            diagdictvar_s = diagdictvar[var][diagdictvar[var]['PatientID'].isin(subsampledPatientKeys)]
            stratdiagcount[var] = countPtsDiagnosis_Dict(diagdictvar_s, kMinPatientsSampled)
        else: 
            stratdiagcount[var] = countPtsDiagnosis_Dict(diagdictvar[var], numptsvar[var])
    return stratdiagcount;


# Get counts for each diagnosis by sex for alzheimers cohort
alzdiagcountsex = getDiagnosisCountsStratify('Sex', sexes, diagdict = alzdiagsex, numptsvar = numsexalz, 
                                             equalize_num = equalize_num, random_state = randomstate)        
# Get counts for each diagnosis by sex for control cohorts
condiagcountsex = getDiagnosisCountsStratify('Sex', sexes, diagdict = condiagsex, numptsvar = numsexcon,
                                             equalize_num = equalize_num, random_state = randomstate)

In [None]:
# create contingency table for various comparisons
comparisons = ['AlzSex','ConSex', 'AlzConFem','AlzConMal']
alldiagcountsex = dict()
for comp in comparisons:
    alldiagcountsex[comp]=dict()
jointype = 'outer'
num_threshold = 0

for n in diagkeys:
    ## AlzSex - Alzheimer Female merge with Alzheimer Male
    comp = comparisons[0];
    diagtemp = alzdiagcountsex['Female'][n] \
            .merge(alzdiagcountsex['Male'][n], how=jointype, on=n, suffixes=('_alz_'+'Female','_alz_'+'Male'))
    alldiagcountsex[comp][n] = diagtemp.set_index(n);
    alldiagcountsex[comp][n] = alldiagcountsex[comp][n] \
                                    [(alldiagcountsex[comp][n]<num_threshold).sum(axis=1) < 1 ] # keep diagnosis above num_threshold
    if jointype == 'outer': # replace nan by indicating 0 pts have diagnosis.
        nanreplace = dict(zip(list(alldiagcountsex[comp][n].columns),[0,numsexalz['Female'],0,numsexalz['Male']]))
        alldiagcountsex[comp][n] = alldiagcountsex[comp][n].fillna(value=nanreplace)
    
    ## ConSex - Control female merge with Control males
    comp = comparisons[1]
    diagtemp = condiagcountsex['Female'][n] \
            .merge(condiagcountsex['Male'][n], how = jointype, on = n, suffixes=('_con_'+'Female','_con_'+'Male'))
    alldiagcountsex[comp][n] = diagtemp.set_index(n);
    alldiagcountsex[comp][n] = alldiagcountsex[comp][n] \
                                    [(alldiagcountsex[comp][n]<num_threshold).sum(axis=1) < 1 ]
    if jointype == 'outer':
        nanreplace = dict(zip(list(alldiagcountsex[comp][n].columns),[0,numsexcon['Female'],0,numsexcon['Male']]))
        alldiagcountsex[comp][n] = alldiagcountsex[comp][n].fillna(value=nanreplace)
    
    ## AlzConFem - Alzheimer females merged with control females
    comp = comparisons[2]
    diagtemp = alzdiagcountsex['Female'][n] \
             .merge(condiagcountsex['Female'][n], how = jointype, on = n, suffixes=('_alz_'+'Female','_con_'+'Female'))
    alldiagcountsex[comp][n] = diagtemp.set_index(n);
    alldiagcountsex[comp][n] = alldiagcountsex[comp][n] \
                                    [(alldiagcountsex[comp][n]<num_threshold).sum(axis=1) < 1 ]
    if jointype == 'outer':
        sex = 'Female';
        nanreplace = dict(zip(list(alldiagcountsex[comp][n].columns),[0,numsexalz[sex],0,numsexcon[sex]]))
        alldiagcountsex[comp][n] = alldiagcountsex[comp][n].fillna(value=nanreplace)
    
    ## AlzConMal - alzheimer male merged with control males
    comp = comparisons[3]
    diagtemp = alzdiagcountsex['Male'][n] \
            .merge(condiagcountsex['Male'][n], how = jointype, on = n, suffixes=('_alz_'+'Male','_con_'+'Male'))
    alldiagcountsex[comp][n] = diagtemp.set_index(n);
    alldiagcountsex[comp][n] = alldiagcountsex[comp][n] \
                                    [(alldiagcountsex[comp][n]<num_threshold).sum(axis=1) < 1 ]
    if jointype == 'outer':
        sex = 'Male'; 
        nanreplace = dict(zip(list(alldiagcountsex[comp][n].columns),[0,numsexalz[sex],0,numsexcon[sex]]))
        alldiagcountsex[comp][n] = alldiagcountsex[comp][n].fillna(value=nanreplace)

In [None]:
# Statistical testing
verbose = False

def sigTestDiagCountSex(alldiagcountsex, comp, n): 
    # First, for fischer exact test choose rows with less than 5 patients in a category
    print(comp, n,': Number Diagnosis: ', alldiagcountsex[comp][n].shape[0])
    
    temp_less5 = alldiagcountsex[comp][n][alldiagcountsex[comp][n].min(axis=1)<5] # take all with counts less than 5
    fisher1 = pd.DataFrame()
    if temp_less5.shape[0]>0:
        print('\t Fisher Exact for <5 pts in a category, num diagnosis:', temp_less5.shape[0])
        fisher =  temp_less5 \
            .apply(lambda x: stats.fisher_exact(np.array(x).reshape(2,2)), axis = 1) \
            .apply(pd.Series)
        fisher.columns = ['OddsRatio', 'pvalue']
        if verbose: print('fisher:',fisher.shape)

        maxratio = fisher['OddsRatio'][fisher['OddsRatio']<np.inf].max();
        minratio = fisher['OddsRatio'][fisher['OddsRatio']>0].min();
        fisher = fisher.replace(np.inf, maxratio+1) 
        fisher['log2_oddsratio'] = fisher['OddsRatio']\
            .apply(lambda x: log2(minratio/2) if (x==0) else log2(x))

        minpvalue = fisher['pvalue'][fisher['pvalue']>0].min();
        fisher['pvalue']=fisher['pvalue'].replace(0,minpvalue/2)
        fisher['-log_pvalue']=fisher['pvalue'].apply(lambda x: -log10(x))

        fisher1 = fisher.merge(temp_less5, how = 'right', left_index=True, right_index = True)
        if verbose: print('fisher1',fisher1.shape)

    # now take the rest of the patients
    temp_more5 = alldiagcountsex[comp][n][alldiagcountsex[comp][n].min(axis=1)>=5]
    print('\t Chi square for >=5 pts in a category, num diagnosis:', temp_more5.shape[0])

    fisher =  temp_more5 \
        .apply(lambda x: stats.fisher_exact(np.array(x).reshape(2,2)), axis = 1) \
        .apply(pd.Series)
    fisher.columns = ['OddsRatio', 'fpvalue']
    
    maxratio = fisher['OddsRatio'][fisher['OddsRatio']<np.inf].max();
    minratio = fisher['OddsRatio'][fisher['OddsRatio']>0].min();
    fisher = fisher.replace(np.inf, maxratio+1) 
    fisher['log2_oddsratio'] = fisher['OddsRatio']\
        .apply(lambda x: log2(minratio) if (x==0) else log2(x))
    minpvalue = fisher['fpvalue'][fisher['fpvalue']>0].min();
    fisher['fpvalue']=fisher['fpvalue'].replace(0,minpvalue/2)
    fisher['-log_fpvalue']=fisher['fpvalue'].apply(lambda x: -log10(x))
    if verbose: print('fisher',fisher.shape)

    chisquare = temp_more5.apply(lambda x: \
                           chi2_contingency(np.array(x).reshape(2,2)), axis=1) \
                            .apply(pd.Series)
    chisquare.columns = ['chistat','pvalue','dof','expected']
    chisquare = chisquare.merge(temp_more5, how = 'right',left_index=True, right_index = True)
    minpvalue = chisquare['pvalue'][chisquare['pvalue']>0].min();
    chisquare['pvalue']=chisquare['pvalue'].replace(0,minpvalue/2)
    chisquare['-log_pvalue']=chisquare['pvalue'].apply(lambda x: -log10(x))
    if verbose: print('chisquare:',chisquare.shape)

    combined = chisquare.merge(fisher, left_index=True, right_index = True, how = 'left')
    combined = combined.append(fisher1)
    if verbose: print('combined 1:', combined.shape)
    
    temp = alzdiag[[n,'ValueL']].append(condiag[[n, 'ValueL']]).drop_duplicates() # get mapping between diagnosis and category
    temp = temp[temp['ValueL'] != 'NaN'].groupby(n)['ValueL'].apply(list)
    combined = combined.merge(temp, how = 'left', left_index=True, right_index=True, suffixes=(False, False))
    if verbose: print('combined 2:', combined.shape)

    print('\t Final diagnosis num: ', combined.shape[0])
    
    return combined;

sigtestcountsex = dict()      
for comp in comparisons:
    sigtestcountsex[comp]=dict()
    for n in diagkeys:
        sigtestcountsex[comp][n] = sigTestDiagCountSex(alldiagcountsex, comp, n)

In [None]:
# combine sex-stratified results
chitestMF = dict()
for n in diagkeys:
    cols_keep = ['pvalue','OddsRatio','log2_oddsratio','-log_pvalue']
    chitestMF[n] = sigtestcountsex['AlzConFem'][n][cols_keep + ['ValueL', 'Count_alz_Female', 'Count_con_Female']].merge( 
                    sigtestcountsex['AlzConMal'][n][cols_keep + ['Count_alz_Male','Count_con_Male']],
                    left_index=True, right_index = True, suffixes=('_F','_M'))
    
# encode significance
for n in diagkeys:
    bc = .05/chitestMF[n].shape[0] #Bonferroni correction
    print(n, bc)
    
    chitestMF[n]['sigsex'] = (chitestMF[n]['pvalue_F']<bc).to_numpy() + \
                            (chitestMF[n]['pvalue_M']<bc).to_numpy()*2
    # 0 not significant for both, 1 significant for females, 2 significant for males, 3 significant for both
    sigmap = {0:'Not Significant', 
              1:'Significant (Females)', 
              2:'Sigificant (Males)',
              3: 'Significant (F and M)'}
    chitestMF[n]['sigsex'] = chitestMF[n]['sigsex'] \
                            .replace(sigmap)
    print(chitestMF[n]['sigsex'].value_counts())

In [None]:
import pickle 
pickle.dump(chitestMF, open('alzcon_diagnosis_MF_stats.pickle', 'wb'))

In [None]:
# Log Log Plots
save = False;

for n in diagkeys:
    plt.figure(figsize = (10,3))
    chitestMF[n]['index'] = chitestMF[n].index
    with sns.axes_style("white"):    
        g = sns.relplot( x = 'log2_oddsratio_F', y = 'log2_oddsratio_M',
                     col="sigsex", col_order=sigmap.values(),linewidth = .5,
                     palette = ['grey','red','blue','black'], hue = 'sigsex',
                     kind="scatter", data=chitestMF[n], hue_order = sigmap.values(),
                     height = 3, aspect = 1, s = 10)
    
    axes = g.axes[0]
    i=0;
    for ax in axes:
        ax.autoscale(False)
        ax.plot([-40,40],[-40,40], 'k--', zorder = 0, linewidth = 1, alpha = .5)
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_visible(True)
            ax.spines[axis].set_linewidth(1)
        ax.set_title(sigmap[i]); 
        i+=1;

    plt.suptitle('Sex Diagnosis Log-Log Plot: ' + n,  y = 1.1)
    plt.show()
     
    if save:
        g.savefig('ADCon_sexloglogOR_'+n+'.png')

In [None]:
# Miami plots
dims = (10,5)
save = False
for n in diagkeys:
    bc = .05/chitestMF[n].shape[0]
    fig, ax = miami(df = chitestMF[n].explode('ValueL'), logp1 = '-log_pvalue_F', 
                    logp2 = '-log_pvalue_M', chromo = 'ValueL', 
                    dim = dims, yskip =50, gwas_sign_line=True, markernames=False, 
                    markeridcol='index', plotlabelrotation = 60, show = not save,
                    axlabelfontsize = 12, gfont = 6, dotsize = 10, axtickfontsize = 10, 
                    label1 = 'Females', label2 = 'Males', gwasp=bc, 
                    figtype = 'pdf', figname = "Alz-Con_miami_"+n, axxlabel = 'ICD10 Category', 
                    figtitle = 'Male/Female Alzheimer vs.Control Chi Square Diagnosis Miami Plot: '+n)


In [None]:
# percent significant in a category
icd10order = list(map(ICDchapter_to_name,chitestMF[n]['ValueL'].explode().sort_values().unique()))
figtype = 'pdf'

for n in diagkeys:
    bc = .05/chitestMF[n].shape[0]
    chitest = chitestMF[n].explode('ValueL').copy();
    chitest['ValueL'] = chitest['ValueL'].apply(ICDchapter_to_name)

    icd10dictF = list()
    icd10dictM = list()

    for value, g in chitest.groupby('ValueL'):
        icd10dictF.append([value, ((g['pvalue_F'] < bc).sum())*100/g.shape[0] ])
        icd10dictM.append([value, ((g['pvalue_M'] < bc).sum())*100/g.shape[0] ])

    icd10sigF = pd.DataFrame(icd10dictF).set_index(0)
    icd10sigF = icd10sigF.reindex(icd10order)
    with sns.axes_style("darkgrid"):
        fig = plt.figure(figsize = (10,3))
        sns.barplot(x = icd10sigF.index, y = icd10sigF[1].values, palette = rand_colors)
        plt.xticks(rotation=50, va = 'top', ha = 'right')
        plt.ylabel('%')
        plt.title('% of Significant Diagnosis Per Category: '+n +' Female')
        plt.xlabel('ICD10 Block')
        plt.xticks(np.arange(21)+.5, codemap)

    icd10sigM = pd.DataFrame(icd10dictM).set_index(0)
    icd10sigM = icd10sigM.reindex(icd10order)
    with sns.axes_style("darkgrid"):
        fig = plt.figure(figsize = (10,3))
        sns.barplot(x = icd10sigM.index, y = icd10sigM[1].values, palette = rand_colors)
        plt.xticks(rotation=50, va = 'top', ha = 'right')
        plt.ylabel('%')
        plt.title('% of Significant Diagnosis Per Category: '+n +' Male')
        plt.xlabel('ICD10 Block')
        plt.xticks(np.arange(21)+.5, codemap)


# Medication Comparison

In [None]:
# Read in meds
alzmeds = pd.read_csv(os.path.join(datadir,'ad_medications.csv'))
alzmeds = alzmeds[alzmeds['PatientID'].isin(alzdemo['PatientID'])]

conmeds = pd.read_csv(os.path.join(datadir,'control_medications.csv'))
conmeds = conmeds[ conmeds['PatientID'].isin(condemo['PatientID'])]
conmeds = conmeds.merge(condemo, how = 'left', on = 'PatientID')

In [None]:
# Count medication prescriptions
jointype = 'outer'
col='SimpleGenericName'

tempmeds = alzmeds[['PatientID',col]].drop_duplicates()
tempmeds = tempmeds.groupby(col)['PatientID'].nunique().reset_index()
tempmeds.columns = [col,'AlzCount']
tempmedscon = conmeds[['PatientID',col]].drop_duplicates()
tempmedscon = tempmedscon.groupby(col)['PatientID'].nunique().reset_index()
tempmedscon.columns = [col,'ConCount']

combinedmed = tempmeds.merge(tempmedscon[[col, 'ConCount']], how = jointype, on = col)

combinedmed = combinedmed[~(combinedmed[col].isna())]
combinedmed['AlzCount_r'] = total_alz - combinedmed['AlzCount']
combinedmed['ConCount_r'] = total_alz - combinedmed['ConCount']
combinedmed = combinedmed.set_index(col)

if jointype == 'outer':
    nanreplace = dict(zip( list(combinedmed.columns) , [0,0, total_alz,total_con] ))
    combinedmed = combinedmed.fillna(value = nanreplace)
    
combinedmedstat = sigTestCounts(combinedmed, col, verbose=1)

In [None]:
# Encode significance
bc = .05/combinedmedstat.shape[0] # bonferonni
print('bc:',bc)

sig = np.full(shape = (combinedmedstat.shape[0],), fill_value = 'Not Significant')
sig[combinedmedstat['pvalue']<bc] = 'Significant'
mask = (combinedmedstat['pvalue']<bc) & (combinedmedstat['log2_oddsratio']>1)
sig[mask] = 'Alzheimer Enriched'
mask = (combinedmedstat['pvalue']<bc) & (combinedmedstat['log2_oddsratio']<-1)
sig[mask] = 'Control Enriched'
combinedmedstat['sig'] = sig
display(combinedmedstat['sig'].value_counts())

In [None]:
# Plot volcano plots
dims = (6,4)
save = 1;
figtype = 'pdf';
figname = "AlzCon_med"; #"Alz-Con "+n

bc = .05/combinedmedstat.shape[0]
combinedmedstat['index'] = combinedmedstat.index
fig = plt.figure(figsize = dims);
g = sns.scatterplot(data = combinedmedstat, x = 'log2_oddsratio', 
                    y = '-log_pvalue', hue = 'sig',
                #palette = ["#EC719E","#949494"], 
                edgecolor = None,
                s = 12);
g.get_legend().texts[0].set_text('Significance'); plt.ylim([-100, 270])
plt.xlim([-6, 10]); plt.axhline(-np.log10(bc), linestyle = '--', color = '#555555', linewidth = 1);
plt.axvline(1, color = '#555555', linestyle = ':'); plt.axvline(-1, color = '#555555',linestyle = ':'); 
plt.xlabel(r'$\log_2$(Odds Ratio)'); plt.ylabel(r'$-\log_{10}$(p-value)');
plt.title('Alzheimer vs. Control Chi Square Medication Volcano Plot');
plt.savefig(figname+"."+figtype, format=figtype, bbox_inches='tight', dpi=300)
        
    

### stratify medications by sex

In [None]:
# Stratify by sex
sexes = ['Female','Male']
equalize_num = True; # Equalize number between males and females?

#Alz
sexgroups = alzmeds.groupby('Sex')
alzmedsex = {i:sexgroups.get_group(i) for i in sexes}
numsexalz = {i:alzmedsex[i][['PatientID','Sex']].drop_duplicates().shape[0] for i in sexes}
print(numsexalz)

#Con
sexgroups = conmeds.groupby('Sex')
conmedsex = {i:sexgroups.get_group(i) for i in sexes}
numsexcon = {i:conmedsex[i][['PatientID','Sex']].drop_duplicates().shape[0] for i in sexes}
print(numsexcon)

kNumPatientsSampled = min(min(numsexalz.values()), min(numsexcon.values()))
print(kNumPatientsSampled)
col='SimpleGenericName'
randomstate = 40

In [None]:
def getXCounts(in_df, totalpts, col): 
    Xtemp = in_df[['PatientID',col]].drop_duplicates()
    Xtemp = Xtemp.groupby(col)['PatientID'].nunique().reset_index()
    Xtemp.columns = [col, 'Count']
    Xtemp['Count_r'] = totalpts - Xtemp['Count']
    return Xtemp

def getXCountsStratify(stratvarname, stratvars, colCount, in_df = None, in_dict=None, numptsvar = None, 
                               equalize_num = False, random_state = 40):
    Xdictvar = dict()
    if in_dict: #check if dictionary of dataframes, or a full dataframe
        Xdictvar = in_dict;
        if numptsvar is None:
            raise Exception('variable numptsvar is empty, pass in dictionary with number of patients in each category.')
    elif in_df:
        numptsvar = dict()
        for var in stratvars:
            Xdictvar[var] = in_df[in_df[stratvarname] == var]
            numptsvar[var] = Xdictvar[var][['PatientID',stratvarname]].drop_duplicates().shape[0]
    else:
        raise Exception('did not pass in full dataframe or dictionary of stratified dataframes.')
     
    kNumPatientsSampled = min(numptsvar.values())

    stratXcount = dict()
    for var in stratvars:
        stratXcount[var] = dict()
        if numptsvar[var]>kNumPatientsSampled and equalize_num: # subsample patients
            subsampledPatientKeys = Xdictvar[var]['PatientID'].drop_duplicates()\
                        .sample(kNumPatientsSampled, random_state = random_state)
            Xdictvar_s = Xdictvar[var][Xdictvar[var]['PatientID'].isin(subsampledPatientKeys)]
            stratXcount[var] = getXCounts(Xdictvar_s, kNumPatientsSampled, colCount)
        else: 
            stratXcount[var] = getXCounts(Xdictvar[var], numptsvar[var], colCount)
    return stratXcount;

# Get counts for each diagnosis by sex for alzheimers cohort
alzcountsex = getXCountsStratify('Sex', sexes, colCount = col, in_dict = alzmedsex, numptsvar = numsexalz, 
                                             equalize_num = equalize_num, random_state = randomstate)
        
# Get counts for each diagnosis by sex for control cohorts
concountsex = getXCountsStratify('Sex', sexes, colCount = col, in_dict = conmedsex, numptsvar = numsexcon,
                                             equalize_num = equalize_num, random_state = randomstate)

In [None]:
# Create contingency table for med prescriptions
comparisons = ['AlzConFem','AlzConMal']
allXcountsex = dict()
for comp in comparisons: allXcountsex[comp]=dict()
jointype = 'outer'
num_threshold = 0

## AlzConFem - Alzheimer females merged with control females
comp = comparisons[0]
temp = alzcountsex['Female']\
         .merge(concountsex['Female'], how = jointype, on = col, suffixes=('_alz_'+'Female','_con_'+'Female'))
allXcountsex[comp] = temp.set_index(col);
allXcountsex[comp] = allXcountsex[comp] \
                                [(allXcountsex[comp]<num_threshold).sum(axis=1) < 1 ]
if jointype == 'outer':
    sex = 'Female';
    nanreplace = dict(zip(list(allXcountsex[comp].columns),[0,numsexalz[sex],0,numsexcon[sex]]))
    allXcountsex[comp] = allXcountsex[comp].fillna(value=nanreplace)

## AlzConMal - alzheimer male merged with control males
comp = comparisons[1]
temp = alzcountsex['Male'] \
        .merge(concountsex['Male'], how = jointype, on = col, suffixes=('_alz_'+'Male','_con_'+'Male'))
allXcountsex[comp] = temp.set_index(col);
allXcountsex[comp] = allXcountsex[comp] \
                                [(allXcountsex[comp]<num_threshold).sum(axis=1) < 1 ]
if jointype == 'outer':
    sex = 'Male'; 
    nanreplace = dict(zip(list(allXcountsex[comp].columns),[0,numsexalz[sex],0,numsexcon[sex]]))
    allXcountsex[comp] = allXcountsex[comp].fillna(value=nanreplace)

# Compute p values
sigtestcountsex = dict()      
for comp in comparisons:
    sigtestcountsex[comp] = sigTestCounts(allXcountsex, comp, col)

In [None]:
# combine male and female results
cols_keep = ['pvalue','OddsRatio','log2_oddsratio','-log_pvalue']
chitestMF = sigtestcountsex['AlzConFem'][cols_keep + ['Count_alz_Female', 'Count_con_Female']].merge( 
                    sigtestcountsex['AlzConMal'][cols_keep + ['Count_alz_Male','Count_con_Male']],
                    left_index=True, right_index = True, suffixes=('_F','_M'))

# encode significance
bc = .05/chitestMF.shape[0] #Bonferroni correction
print(col, bc)
chitestMF['sigsex'] = (chitestMF['pvalue_F']<bc).to_numpy() + (chitestMF['pvalue_M']<bc).to_numpy()*2
# 0 not significant for both, 1 significant for females, 2 significant for males, 3 significant for both
sigmap = {0:'Not Significant', 1:'Significant (Females)', 
          2:'Sigificant (Males)', 3: 'Significant (F and M)'}
chitestMF['sigsex'] = chitestMF['sigsex'].replace(sigmap)
print(chitestMF['sigsex'].value_counts())

In [None]:
# log log plots
plt.figure(figsize = (10,3))
chitestMF['index'] = chitestMF.index
with sns.axes_style("white"):    
    g = sns.relplot( x = 'log2_oddsratio_F', y = 'log2_oddsratio_M',
                 col="sigsex", col_order=sigmap.values(),linewidth = .5,
                 palette = ['grey','red','blue','black'], hue = 'sigsex',
                 kind="scatter", data=chitestMF, hue_order = sigmap.values(),
                 height = 3, aspect = 1, s = 10)

axes = g.axes[0]
i=0;
for ax in axes:
    ax.autoscale(False)
    ax.plot([-40,40],[-40,40], 'k--', zorder = 0, linewidth = 1, alpha = .5)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_visible(True)
        ax.spines[axis].set_linewidth(1)
        #ax.spines[axis].set_zorder(0)
    ax.set_title(sigmap[i]); 
    i+=1;

plt.suptitle('Sex Meds Log-Log Plot',  y = 1.1)
plt.show()

In [None]:
# combined loglog plot
sns.scatterplot(x = 'log2_oddsratio_F', y = 'log2_oddsratio_M',data = chitestMF,
               hue = 'sigsex',linewidth =.1, alpha = .8, 
               palette = ['grey','black','blue','red'])
plt.savefig('sexmed_enccon_loglog.pdf',format = 'pdf', dpi=300 )

# Lab Result Comparisons

In [None]:
# reach in labs
alzlabs = pd.read_csv(os.path.join(datadir,'ad_labresults.csv'), low_memory = False)
alzlabs = alzlabs[alzlabs['PatientID'].isin(alzdemo['PatientID'])]
conlabs = pd.read_csv(os.path.join(datadir,'control_labresults.csv'), low_memory = False)
conlabs = conlabs[conlabs['PatientID'].isin(condemo['PatientID'])]

# combine
alllabsfull = alzlabs.append(conlabs)
alllabsfull = alllabsfull[~alllabsfull['Value'].isna()]
alllabsfull['isalz'] = alllabsfull.PatientID.isin(alzlabs['PatientID'].drop_duplicates())

In [None]:
alzlabs

In [None]:
# Go through all lab tests and compute median value for each patient
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False
    
alllabs = pd.DataFrame(columns = alzlabs['TestName'].append(conlabs['TestName']).drop_duplicates().values, 
                       index = alllabsfull.PatientID.unique())
alllabs['isalz'] = alllabs.index.isin(alzdiag['PatientID'])

aggType = 'median';
    
for test, labs in tqdm(list(alllabsfull.groupby('TestName'))): # For each possible lab test
    resultdict = dict()
    for pt, df in list(labs.groupby('PatientID')): # for each patient and dataframe of lab values
        results = df['Value'].dropna().reset_index(drop=True)
        
        val = np.nan;
        if len(results)>0: # if patient has result
            if len(results) == 1 and is_number(results.values[0]): # save if only 1 numeric entry
                val = results.values[0]
            else: 
                if(results.apply(is_number).sum()==results.shape[0]): # if all results are numeric
                    if aggType == 'median': val = results.median()
                    else: raise Exception("aggType invalid. Choices:'median','otherstobedevelopped...'")
#                else:
#                    # simple string manipulation on > and < symbols
#                    val = results[results.apply(is_number)]
#                    results = results.drop(results[results.apply(is_number)].index)
#                    if(results.str.contains('>').sum()>0):
#                        greater = results[results.str.contains('>')].str.extract(r'>(\d+)').astype('float')
#                        greater = greater + np.log10(abs(greater))
#                        val = val.append(greater)
#                        results = results.drop(greater.index)
#                    if(results.str.contains('<').sum()>0):
#                        lesser = results[results.str.contains('<')].str.extract(r'<(\d+)').astype('float')
#                        lesser = lesser - np.log10(abs(lesser))
#                        val = val.append(lesser)
#                        results = results.drop(lesser.index)
#                    val = val.median()
        resultdict[pt] = val
    alllabs[test] = alllabs.index.map(resultdict)

In [None]:
# remove labs with more than 95% missing values
emptycols = alllabs.columns[(alllabs.isna().sum()==alllabs.shape[0])]
emptyrows = alllabs.index[(alllabs.isna().sum(axis=1)==alllabs.shape[1]-1)]
alllabs = alllabs.drop(emptycols, axis = 1)
alllabs = alllabs.drop(emptyrows, axis = 0)

labanalysis = alllabs[alllabs.columns[alllabs.isna().sum() < alllabs.shape[0]*.95]]
labanalysis = labanalysis.astype('float')
labanalysis['Sex'] = alzdemo.append(condemo).set_index('PatientID').loc[list(labanalysis.index),'Sex']

isalz = labanalysis['isalz']
sexes = labanalysis['Sex']

In [None]:
# Identify significant lab tests
alllabtest = dict()
Flabtest = dict()
Mlabtest = dict()

for i, col in enumerate(labanalysis.columns.drop(['Sex'])):
    alztest = labanalysis[col][isalz==1].dropna()
    contest = labanalysis[col][isalz==0].dropna()
    alzFtest = labanalysis[col][(isalz==1) & (sexes == 'Female')].dropna()
    alzMtest = labanalysis[col][(isalz==1) & (sexes == 'Male')].dropna()
    conFtest = labanalysis[col][(isalz==0) & (sexes == 'Female')].dropna()
    conMtest = labanalysis[col][(isalz==0) & (sexes == 'Male')].dropna()
 
    try: alllabtest[col] = stats.mannwhitneyu(alztest, contest)
    except: alllabtest[col] = (100,100) # insignificant test
        
    try: Flabtest[col] = stats.mannwhitneyu(alzFtest, conFtest)
    except: Flabtest[col] = (100,100)
        
    try: Mlabtest[col] = stats.mannwhitneyu(alzMtest, conMtest)
    except: Mlabtest[col] = (100,100)
        
alllabtests = pd.DataFrame(alllabtest).T.drop('isalz')
alllabtests.columns = ['stats','pval']
alllabtests = alllabtests.sort_values('pval').dropna() 
Flabtests = pd.DataFrame(Flabtest).T.drop('isalz')
Flabtests.columns = ['stats','pval']
Flabtests = Flabtests.sort_values('pval').dropna()
Mlabtests = pd.DataFrame(Mlabtest).T.drop('isalz')
Mlabtests.columns = ['stats','pval']
Mlabtests = Mlabtests.sort_values('pval').dropna()

bc = .05/alllabtests.shape[0] # bonferonni correction
print(bc)

Msiglab = Mlabtests[Mlabtests.pval < bc].index.unique()
Fsiglab = Flabtests[Flabtests.pval < bc].index.unique()
siglabs = alllabtests[alllabtests.pval<bc].index.unique()
siglabs = siglabs.append(Msiglab).append(Fsiglab).unique()

# Get sets of significant lab names.
allsig = np.unique(np.array(alllabtests[alllabtests.pval<bc].index.unique().values)) # significant lab values in AD vs Controls
Fsig = np.unique(np.array(Fsiglab.values)) # significant lab values in Female AD vs Controls
Msig = np.unique(np.array(Msiglab.values)) # significant lab values in Male AD vs Controls
Fonly = set(Fsig) - set(allsig) # Female only significant labs
Monly = set(Msig) - set(allsig) # Male only sigificant labs
allshare = set(allsig) - Fonly - Monly # Signfificant labs shared in both M and F 

In [None]:
# Z-score median lab tests across groups
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold
medianlabs = labanalysis.groupby(['isalz','Sex']).median().reset_index().drop(['isalz','Sex'], axis = 1)
medianlabs = medianlabs[siglabs]
medianlabstemp = StandardScaler().fit_transform(medianlabs)
medianlabs = pd.DataFrame(medianlabstemp, columns = medianlabs.columns)
medianlabs.index = ['Female Control','Male Control','Female Alzheimer','Male Alzheimer']
medianlabs = medianlabs.dropna(axis='columns')

# significant in both, males only, or females only
sigstatus = medianlabs.columns.isin(allshare).astype(int)
sigstatus[medianlabs.columns.isin(Fonly)] = 2 #significant in female only
sigstatus[medianlabs.columns.isin(Monly)] = 3 #significant in male only

In [None]:
# plot clustermap
sns.clustermap(medianlabs.T, cmap="vlag", figsize = (8,10), method = 'average',  
               dendrogram_ratio = (.2, .05), 
               cbar_pos = (.92,.87,.05,.1),
               row_colors = pd.Series(sigstatus).map({1:'orange',2:'green',3:'blue'}).values)
plt.savefig('cluster_4gsig_median.png',filetype = 'png')
plt.savefig('cluster_4gsig_median.pdf',filetype = 'pdf')