In [None]:
pd.set_option('display.max_rows', 500)

# The Neurogenomics Database: temporal profiles
Author:  Nienke Mekkes <br>
Date: 11-10-2022. <br>
Correspond: n.j.mekkes@umcg.nl <br>

## Script: temporal profiles of clinical disease trajectories

### Input files:
- pickle/excel file with (filtered) predictions <br>
- general information
- attribute groupings file 


### PATHS

In [None]:
path_to_predictions = ""

output_path = ""
general_information = ""


### IMPORTS

In [None]:
import statsmodels.stats.multitest as smt
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import scipy
import os
from scipy.stats import mannwhitneyu, normaltest
import pickle
import pandas
import numpy
import seaborn as sns
import xlsxwriter
import matplotlib
import matplotlib.patheffects as PathEffects
from itertools import chain
from helper_functions import table_selector
pandas.set_option('display.max_rows', 100)

#### Load data

In [None]:
if not os.path.exists(output_path):
    print('Creating output folder....')
    os.makedirs(output_path)

In [None]:
with open(path_to_predictions,"rb") as file:
    predictions_pickle = pickle.load(file)

## concat all dicts together into dataframe
d = []
for i,j in zip(predictions_pickle,predictions_pickle.values()):
    k = pd.DataFrame.from_dict(j,orient="index")
    k["DonorID"] = i
    k['Age'] = k.index
    d.append(k)
predictions_df =pd.concat(d, ignore_index=True)

print(f"there are {len(list(predictions_df['DonorID'].unique()))} unique donor IDs")
print(predictions_df.shape)
display(predictions_df.head())

#### change to 'paper diagnosis' by using the general information. remove specified donors to be excluded
the paper has some differences in how we call certain diagnoses. we also remove some donors (e.g. < 21 years) and change the diagnosis of some TRANS donors.

In [None]:
general_information_df = pd.read_excel(general_information, engine='openpyxl', sheet_name="Sheet1")
# display(general_information_df)
donors_to_remove = list(general_information_df[general_information_df['paper diagnosis']=='exclude'].DonorID)
predictions_df = predictions_df[~predictions_df['DonorID'].isin(donors_to_remove)]
print(f"there are {len(list(predictions_df['DonorID'].unique()))} unique donor IDs")
print(len(donors_to_remove))
predictions_df['neuropathological_diagnosis'] = predictions_df['DonorID'].map(general_information_df.set_index('DonorID')['paper diagnosis'])
display(predictions_df.head())
print(sorted(predictions_df['neuropathological_diagnosis'].unique()))

In [None]:
## by explicitly naming the 'non-attributes', we can easily extract all attribute names
non_attribute_columns = ['DonorID','Year','age_at_death','sex',
                        'neuropathological_diagnosis','Age'] #'birthyear',,'death_year','year_before_death','sex',
attributes = [col for col in predictions_df.columns if col not in non_attribute_columns]
# display(attributes)
print(f"there are {predictions_df.shape[0]} rows and {len(attributes)} attributes")
print(f"there are {len(list(predictions_df['DonorID'].unique()))} unique donor IDs")


## general overview (sup1a)

In [None]:
table_of_choice = 'tableall_p'
general_info_selected, table_all_diagnoses = table_selector(table_of_choice, predictions_df)
general_info_selected = general_info_selected[['DonorID','age_at_death','sex','neuropathological_diagnosis' ]].drop_duplicates()

## quick value counts to print numbers to plot
foo = pd.DataFrame(general_info_selected['neuropathological_diagnosis'].value_counts())
foo = foo.reindex(table_all_diagnoses)
foo = foo.reset_index()
foo.columns = ['neuropathological_diagnosis','count']

In [None]:
table_diagnosis_colors_dic= {} 
table_diagnosis_colors_dic['F'] = 'steelblue'
table_diagnosis_colors_dic['M'] = 'darkorange'

##SEABORN SETTINGS
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
sns.set(style="ticks", font_scale=2.5)
fig, ax = plt.subplots(figsize=(35, 6))
# sns.set(rc={'figure.figsize':(25,6)})
# sns.set_style("white")
ll = sns.violinplot(data = general_info_selected, x='neuropathological_diagnosis',y='age_at_death',
                    hue='sex', scale = 'count',
                    order=table_all_diagnoses, split=True,palette = table_diagnosis_colors_dic)

ll.set_xticks(np.arange(len(table_all_diagnoses)))
ll.set_xticklabels(table_all_diagnoses, fontsize=10, minor=False,rotation=90)

L = plt.legend(bbox_to_anchor=(0.4, 1.2), loc="upper center", ncol=4, borderaxespad=0)
L.get_texts()[0].set_text('Female')
L.get_texts()[1].set_text('Male')

x_labels = list(ax.get_xticklabels())
for index, row in foo.iterrows():
    number = row['count']
    ll.text(index, -16, '{}'.format(int(number)), color='black', ha='center',rotation=0,fontsize=17)

plt.ylim(-20, 120)    
ax.spines["right"].set_color("none")
ax.spines["top"].set_color("none")
# ax.spines['left'].set_position(('data', 0))
# sns.despine(offset=10, trim=False)
plt.yticks(np.arange(0, 140, 20))
plt.title('Age at death', fontsize=30)
ax.set(xlabel=None)

sup1a_path = output_path + '/sup1a'
if not os.path.exists(sup1a_path):
    print('Creating output folder....')
    os.makedirs(sup1a_path)

plt.savefig(sup1a_path + '/' +  'Age at death' + '_hue_sex.png', bbox_inches='tight')
plt.savefig(sup1a_path + '/' +  'Age at death' + '_hue_sex.pdf', bbox_inches='tight')
plt.show()
plt.close()

## temporal, count, survival plots

In [None]:
## for the temporal and survival plots, we remove the predictions with unknown ages
predictions_df_temporal = predictions_df[predictions_df.Age != -9]

In [None]:
def heatmappable(duos,fdrs, order):
    hm_rows = []
    for q in range(len(duos)):
        tempstar = ''
        if fdrs[q] <= 1e-10:
            tempstar = '*****'
        elif fdrs[q] <= 1e-8:
            tempstar = '****'
        elif fdrs[q] <= 1e-6:
            tempstar = '***'
        elif fdrs[q] <= 1e-4:
            tempstar = '**'
        elif fdrs[q] <= 1e-2:
            tempstar = '*'
        hm_rows.append([duos[q][0], duos[q][1], fdrs[q],tempstar])
    hm_df = pd.DataFrame(hm_rows, columns =['A','B','C','D'])
#     display(hm_df)
    hm_df['C'] = round(-1 * np.log10(hm_df['C']))
    mirror = hm_df.copy()
    mirror = mirror[['B','A','C','D']]
    mirror.columns = ['A','B','C','D']
    hm_df = hm_df.append(mirror, ignore_index=True)
#     display(hm_df)
    hm_matrix = hm_df.pivot(index='A', columns='B', values='C')#.T
    # add n.s. skipped columns
    A = list(hm_matrix.index)
    B = list(hm_matrix.columns)
    for i in list(set(A) - set(B)):
        print(i)
        hm_matrix[i] = np.nan
    hm_matrix = hm_matrix.reindex(order)
    hm_matrix  = hm_matrix[order]
    star_matrix = hm_df.pivot(index='A', columns='B', values='D')#.T
    for i in list(set(A) - set(B)):
#             print(i)
        star_matrix[i] = np.nan
    star_matrix = star_matrix.reindex(order)
    star_matrix  = star_matrix[order]
    star_matrix.columns = star_matrix.columns.to_flat_index()
#         display(star_matrix)
    inds = list(star_matrix.index)
    cols = list(star_matrix.columns)
    hm_array = np.asarray(hm_matrix)
    
    return inds, cols, star_matrix, hm_array, hm_matrix

In [None]:
def obtain_P_values(diag_order, att, prs_df,heatmap=False):
    
    import statsmodels.stats.multitest as smt
    
    P_value_dictionary = {} 
    P_value_list = []
    pairs = [] 
    FRD_Pvalue_cutoff = 0.01

    
    diagnosis_pairs = []
    
    for diagnosis_1 in diag_order:
        donors_subset = prs_df[diagnosis_1].dropna()
        years_with_trait_1 = donors_subset.values

        for diagnosis_2 in diag_order:
            if diagnosis_1 != diagnosis_2:            
                diagnosis_pair_list = [diagnosis_1, diagnosis_2]
                diagnosis_pair_list.sort()
                diagnosis_pair = '.'.join(diagnosis_pair_list)
                        
                if diagnosis_pair not in diagnosis_pairs:
                    donors_subset = prs_df[diagnosis_2].dropna()
                    years_with_trait_2 = donors_subset.values
                    
                    ##something dangerours
                    if len(years_with_trait_1) > 0 and len(years_with_trait_2)>0:
                        stats_results = scipy.stats.mannwhitneyu(years_with_trait_1, years_with_trait_2)[1]
                    else:
                        print(f"diagnosis {diagnosis_1} or diagnosis {diagnosis_2} has no instances of attribute {att}")
                        stats_results = 2
                    Pvalue = stats_results

                    P_value_dictionary[diagnosis_1 + 'vs' + diagnosis_2] = Pvalue
                    pairs.append((diagnosis_1, diagnosis_2))
                    P_value_list.append(Pvalue)
                    diagnosis_pairs.append(diagnosis_pair)
                
    P_value_list = smt.multipletests(P_value_list, method='fdr_bh')[1]
            
    FRD_Pvalues_selected = []
    pairs_selected = []
    
    for pair, Pvalue in zip(pairs, P_value_list):
        if heatmap == False:
            if Pvalue < FRD_Pvalue_cutoff: ## NIENKE turn off for all, so I can make the heatmap 
                FRD_Pvalues_selected.append(Pvalue)
                pairs_selected.append(pair)
        else:
            FRD_Pvalues_selected.append(Pvalue)
            pairs_selected.append(pair)
            
    
    return pairs_selected, FRD_Pvalues_selected



In [None]:
def plot_itself(prs_df, mean_order, pairs, FDR_pvalues,output_path,tdc_dic,xrange,figid,plot_type,j,w,h,heatmap=True):
#     plt.figure(figsize=(40,10), dpi=200)

    if plot_type == 'temporal':
        xtitle = 'Age'
        orientation = 'h'
        xlim = 0
        ylim=None
    if plot_type == 'counts':
        xtitle = 'Count'
        if heatmap == False:
            orientation = 'v'
        elif heatmap == True:
            orientation = 'h'
        xlim = None
        ylim=0
    if plot_type == 'survival':
        xtitle = f"Survival in Year"#{j}"
        orientation = 'h'
        xlim=0
        ylim=None

    if heatmap == False:
        fig, (ax1) = plt.subplots(1, 1,figsize=(w, h), dpi=500)
        sns.violinplot(data=prs_df,
                        orient=orientation,
                        ax=ax1,
                        linewidth=1.5,
                        palette=tdc_dic,
                        scale = 'count',#{“area”, “count”, “width”},
                        order = mean_order
                        )

        Pvalue_format_star = [(1e-2, "*"), (1e-4, "**"), (1e-6, "***"), (1e-8, "****"), (1e-10, "*****")]
        if len(pairs) > 0:
            # Add annotations
            annotator = Annotator(ax1,pairs,data=prs_df,order=mean_order, orient=orientation)
            annotator.set_pvalues(FDR_pvalues)
            annotator.configure(text_format='star', loc='inside', line_height=0.01,
                                text_offset=-5, pvalue_thresholds= Pvalue_format_star)
            annotator.annotate(line_offset_to_group=0.2, line_offset=0)
        ax1.set_xlim(xlim, None)
        ax1.set_ylim(ylim, None)
    elif heatmap == True:
#         fig, (ax1, ax2) = plt.subplots(1, 2,gridspec_kw={'width_ratios': [1, h/(w-h)]},figsize=(w, h), dpi=500)
        fig, (ax1, ax2) = plt.subplots(1, 2,gridspec_kw={'width_ratios': [1, 1.5]},figsize=(w, h), dpi=500)
        sns.violinplot(data=prs_df,
               orient=orientation,
               ax=ax1,
               linewidth=1,
               palette=tdc_dic,
               scale = 'count',#{“area”, “count”, “width”},
               order = mean_order
              )
        indices, cols, star_matrix, hm_array, hm_matrix = heatmappable(pairs,FDR_pvalues,mean_order )
        blues_r = matplotlib.cm.get_cmap('Blues_r')
        im = ax2.imshow(hm_array,cmap='Blues',vmin=0,vmax=10,aspect="auto")
        fig.colorbar(im, ax=ax2,shrink=0.5)
        ax2.axes.get_yaxis().set_visible(False)
    #         ax2.set_xticks(mean_order, minor=False)
        ax2.set_xticks(np.arange(len(mean_order)))
        ax2.set_xticklabels(mean_order, fontsize=8, minor=False,rotation=90)

        # Loop over data dimensions and create text annotations.
        for i in range(len(indices)):
            for z in range(len(cols)):
                if str(star_matrix.iloc[i, z]) == 'nan':
                    text = ax2.text(z, i, str(hm_matrix.iloc[i, z]),
                                ha="center", va="center", color="w",fontsize=8)
                elif star_matrix.iloc[i, z] == '*' \
                    or star_matrix.iloc[i, z] == '**' \
                    or star_matrix.iloc[i, z] == '***' \
                    or star_matrix.iloc[i, z] == '****' \
                    or star_matrix.iloc[i, z] == '*****':
                        text = ax2.text(z, i, round(hm_matrix.iloc[i, z]),#+'\n*'
                                ha="center", va="center", color="darkorange",fontsize=8)
                        text.set_path_effects([PathEffects.withStroke(linewidth=0.5, foreground='black')])

                else:
                    text = ax2.text(z, i, round(hm_matrix.iloc[i, z]),
                                ha="center", va="center", color="black",fontsize=8)
        
        ax2.set_title("significance map")
        ax2.xaxis.tick_top() 
        ax1.set_xlim(0, xrange)
    ax1.spines["right"].set_color("none")
    ax1.spines["top"].set_color("none")
    ax1.set_xlabel(xtitle, fontsize=12)
    ax1.set_title(j, fontsize=12)

    fig.tight_layout()
    plt.subplots_adjust(wspace=0)
    plt.tick_params(bottom=True, left=True)


    fig_path = f"{output_path}/temporal/{figid}"
    if not os.path.exists(fig_path):
        print('Creating output folder....')
        os.makedirs(fig_path)
        print(fig_path)
    plt.savefig(f"{fig_path}/{j}_{plot_type}.png",bbox_inches='tight')
    plt.savefig(f"{fig_path}/{j}_{plot_type}.pdf",bbox_inches='tight')

    plt.show()
    plt.close()

In [None]:
import random
def sexy_sampler(selected_donors, diagnoses):
    tussen = selected_donors[['neuropathological_diagnosis','sex','DonorID']].copy()
    tussen = tussen.drop_duplicates()
    
    overview_sex1 = tussen.groupby(["neuropathological_diagnosis", "sex"]).size().reset_index(name="observations") 
    overview_sex = overview_sex1.loc[overview_sex1.groupby('neuropathological_diagnosis').observations.idxmin()]


    appended_data = []
    for i in diagnoses:
        pattern = r'(?<!\S)'"%s"'(?!\S)' %i
        count_MF = overview_sex1.neuropathological_diagnosis.str.count(pattern).sum()
        if count_MF == 2:
            ## select donors with this diagnosis
            df = selected_donors[selected_donors['neuropathological_diagnosis'] == i]
            ## from our overview df, show the sex and the #observations for the sex with the least observations
            sampler = overview_sex[overview_sex['neuropathological_diagnosis'] == i]
            ## defining current and opposite, where current is the 'lowest -n' sex 
            if sampler['sex'].item() == 'F':
                current = 'F'
                opposite = 'M'
            elif sampler['sex'].item() == 'M':
                current = 'M'
                opposite = 'F'

            ## amount to sample
            n = sampler['observations'].item()
#             print(f"For {i}, smallest sex is {current} with {n} members. Randomly selecting {n} {i} {opposite} donors: \n ")
            ## select rows from the opposite sex
            df_opposite = df[df['sex']==opposite]
            ## get all DonorIDs
            opposite_list = list(df_opposite['DonorID'].unique())
            ## sample without replacement
            random.seed(0)
            sampled_list = random.sample(opposite_list,n)
            ## only select the rows from the opposite sex that have a donorid in our sampled list
            df_opposite  = df_opposite[df_opposite['DonorID'].isin(sampled_list)]
            ## select the rows of the current sex
            df_current = df[df['sex']==current]
            ## join both together into new
            df_new = pd.concat([df_opposite, df_current], ignore_index=True)
    #         display(df_new)
            appended_data.append(df_new)
            print('\n')
        else:
            print(f"After subsampling, {i} no longer has both male and female donors")


    appended_data = pd.concat(appended_data)
    return (appended_data)


In [None]:
def kaplanplot(prs_df,diag_list,j,output_path,figid,tdc_dic,w,h,xrange):
    ### kaplan meijer per diagnosis
    plt.figure(figsize = (w,h),dpi=500)
    prs_df['DonorID'] = prs_df.index
    ## first make long
    long_df = pd.melt(prs_df, id_vars='DonorID', value_vars=diag_list, var_name='diagnosis', value_name=f"years_after_first_{j}")
    long_df['event'] = True
    long_df = long_df.dropna()
    for value in long_df["diagnosis"].unique():
        mask = long_df["diagnosis"] == value
        time_s, survival_prob = kaplan_meier_estimator(long_df["event"][mask],
                                                               long_df[f"years_after_first_{j}"][mask])
        plt.step(time_s, survival_prob, where="post",
                 label="%s (n = %d)" % (value, mask.sum()),color = tdc_dic[value])

    plt.ylabel("est. probability of survival $\hat{S}(t)$")
    plt.xlabel("time in years")
    plt.legend(bbox_to_anchor=(1.3, 1), loc='upper right', ncol=1)
    plt.xlim([0, xrange])
    plt.title("Survival after - %s" % j)
    fig_path2 = f"{output_path}/km/{figid}"
    if not os.path.exists(fig_path2):
        print('Creating output folder....')
        os.makedirs(fig_path2)
        print(fig_path2)
    plt.savefig(f"{fig_path2}/{j}_km.png",bbox_inches='tight')
    plt.savefig(f"{fig_path2}/{j}_km.pdf",bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
## TO DO; subsampling
def density_plots(prs_df,tdc_dic,j,output_path,figid,overview,w,h,xrange):

#     fig, (ax1) = plt.subplots(1, 1,figsize=(15, 1))
    sns.set(style="white", font_scale=1.5)
    donor_subset = overview.groupby(['neuropathological_diagnosis'])['donors'].sum()
    ##NORMALIZED - RANDOM STEKKPROEF PER 100 donoren 
    sampling_nr = 100

    subsampled_dict = {}
    for i in prs_df.columns:
        temp = prs_df[i]
        ## aantal jaartallen (==observaties) gedeeld door aantal donoren met diagnose, * 100
        ## e.g. voor MS zijn er 50 observaties van 10 mensen, is 5 per persoon. dus bij 100 mensen zou je er 500 verwachten
        nr_observations_for_100_donors = round(temp.count()/donor_subset[i] * sampling_nr)
        normalized_years_with_traits = random.choices(temp[temp.notna()], k=nr_observations_for_100_donors)
        subsampled_dict[i] = normalized_years_with_traits
        display(f"for {i}, for {sampling_nr} donors we random select {nr_observations_for_100_donors} observations")
        
    subsampled_df = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in subsampled_dict.items() ]))
    p = sns.displot(subsampled_df,
                kind='kde',
                height = h,
                aspect=w/h,
                warn_singular=False,
                legend = False,
                palette=tdc_dic)

    
    
    #PLOT TEXT 
    Y_axis_max = plt.axis()[3]
    mean_diagnosis_dic = {} 
    nr_observations_dic = {} 

    inflation_factor = 1.1 

    ##Obtain information 
    for diagnosis in prs_df.columns:
        if diagnosis in subsampled_dict.keys():
            mean_diagnosis_dic[diagnosis] = numpy.median(subsampled_dict[diagnosis])
            nr_observations_dic[diagnosis] = len(subsampled_dict[diagnosis])

    max_observations = max(numpy.max(nr_observations_dic.values()))

    ##PLOTTING The labels
    for diagnosis in subsampled_dict.keys():
        #print(mean_diagnosis_dic[diagnosis])
        if ~numpy.isnan(mean_diagnosis_dic[diagnosis]):
            # add text annotation
            plt.text(mean_diagnosis_dic[diagnosis], Y_axis_max/(max_observations * inflation_factor) *  nr_observations_dic[diagnosis], 
                     diagnosis, horizontalalignment='center', size='medium', color=table_diagnosis_colors_dic[diagnosis], fontsize=20)

    p.fig.set_dpi(500)
    ##PLOT SETTINGS 
    plt.tick_params(bottom=True, left=True)
    plt.title(j, fontsize=20)
    plt.xlabel("Age")
    plt.xlim(0, xrange)
    fig_path2 = f"{output_path}/line/{figid}"
    if not os.path.exists(fig_path2):
        print('Creating output folder....')
        os.makedirs(fig_path2)
        print(fig_path2)
    
    plt.savefig(f"{fig_path2}/{j}_line.png",bbox_inches='tight')
    plt.savefig(f"{fig_path2}/{j}_line.pdf",bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
def tiny_print(sel_don2, j):
    df2 = sel_don2[[j,'DonorID','sex','neuropathological_diagnosis','Age']].groupby(['DonorID','sex','neuropathological_diagnosis'])[j].sum().reset_index(name=f"observations_{j}")
    display_df = df2.drop_duplicates().groupby(["neuropathological_diagnosis", "sex"])[f"observations_{j}"].sum()#.size().reset_index(name="observations") ## nice, shows discrepancy
    display_df2 = df2.drop_duplicates().groupby(["neuropathological_diagnosis", "sex"]).size().reset_index(name="donors") ## nice, shows discrepancy
    overview = pd.merge(display_df, display_df2, left_on=["neuropathological_diagnosis", "sex"], right_on=["neuropathological_diagnosis", "sex"])
    return overview
    
        

In [None]:
def make_labeled_violinplots_horizontal(ta_dic,
                                        sel_don,
                                        tdc_dic,
                                        xrange,
                                        pt,
                                       figid,
                                       heatmap=False,
                                       printvalues=False,
                                       fixsex=False,
                                       remove_zero_donors = False,
                                       w = 10,
                                       h=5):
    plot_type = pt
    print(f"the plot type is {plot_type}")


    attribute_list = list(ta_dic.keys())
    ## loop over attributes of interest
    for j in attribute_list:
        print(j)
        diag_list1 = ta_dic[j]
        first_event = {}
        temporal = {}
        counts = {}
        
        sel_don2 = sel_don.copy()
        
        sel_don2 = sel_don2[sel_don2['neuropathological_diagnosis'].isin(diag_list1)]
#         display(sel_don2)
        ## because not every donor has attribute j (e.g. dementia), only look at the subset of donors that have this attribute
        ## Find the donors without observation for this attribute
        df2 = sel_don2[[j,'DonorID','sex','neuropathological_diagnosis','Age']].groupby(['DonorID','sex','neuropathological_diagnosis'])[j].sum().reset_index(name=f"observations_{j}")
        donors_with_j =  list(df2[df2[f"observations_{j}"] > 0]['DonorID'])
#         display(df2)
        
        
        overview = tiny_print(sel_don2, j)
        if printvalues == True:
            print('before removing zero donors')
            display(overview)

        if remove_zero_donors == True:
            ## remove these donors
            sel_don2 = sel_don2[sel_don2['DonorID'].isin(donors_with_j)]

        overview = tiny_print(sel_don2, j)
        if printvalues == True:
            print('after removing zero donors')
            display(overview)

        if fixsex == True:
            sel_don2 = sexy_sampler(sel_don2, diag_list1)
        
        overview = tiny_print(sel_don2, j)
        if printvalues == True:
            print('after subsampling sexes')
            display(overview)

        diag_list = sel_don2['neuropathological_diagnosis'].unique()
        ## loop over diagnoses you want to plot
        for k in diag_list:
#             print(k)
            if plot_type == 'temporal':
                frames = [sel_don2[j],sel_don2['Age'],sel_don2['DonorID'],sel_don2['neuropathological_diagnosis']]
            elif plot_type == 'counts':
                frames = [sel_don2[j],sel_don2['DonorID'],sel_don2['neuropathological_diagnosis']]
            elif plot_type == 'survival':
                frames = [sel_don2[j],sel_don2['Age'],sel_don2['DonorID'],sel_don2['neuropathological_diagnosis'],sel_don2['age_at_death']]
            
            df = pd.concat(frames,axis=1)
            df = df[df['neuropathological_diagnosis'] == k]
#             display(df)
            
            if plot_type == 'counts':
                df2 = df.groupby(['DonorID']).sum()
                counts[k] = df2

            elif plot_type =='temporal':
                att_total = []
                att_dict = {}  
                ## loop over donors in that diagnosis group
                for donor in list(df.DonorID.unique()):
                    df2 = df[df['DonorID']==donor]
                    df2 = df2[df2[j] == 1]
#                     display(df2)
                    if 1 in df2[j].unique():
                        att_total.append(list(df2.Age[df2[j]==1])) #for all items
                        att_dict[donor] = list(df2.Age[df2[j]==1])[0]
                att_total = list(chain.from_iterable(att_total))#for all items
                temporal[k] = att_total #for all items
#                 print(temporal[k])
                first_event[k] = att_dict
                
                
            elif plot_type == 'survival':
                att_dict = {}  
                ## loop over donors in that diagnosis group
                for donor in list(df.DonorID.unique()):
                    df2 = df[df['DonorID']==donor]
                    if 1 in df2[j].unique():
                        att_dict[donor] = list(df2.Age[df2[j]==1])[0]
#                 display(att_dict)
                first_event[k] = att_dict
                
                
        if plot_type == 'temporal':
            prs_df = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in temporal.items() ])) #for all items
#             display(prs_df.describe())
#             prs_df = pd.DataFrame.from_dict(first_event) # for first item
#             prs_df = prs_df.dropna(subset=diag_list, how='all')  #for first item 
#             display(prs_df)
        elif plot_type == 'counts':
            prs_df = pd.concat(counts, axis=1)#.sum(axis=1, level=0)
            prs_df.columns = prs_df.columns.droplevel(1)
            
            
        elif plot_type == 'survival':
            prs_df = pd.DataFrame.from_dict(first_event) # for first item
            prs_df = prs_df.dropna(subset=diag_list, how='all')  #for first item
            prs_df['DonorID'] = prs_df.index
            prs_df['age_at_death'] = prs_df['DonorID'].map(general_information_df.set_index('DonorID')['Age'])
            prs_df['sex'] = prs_df['DonorID'].map(general_information_df.set_index('DonorID')['most_likely_sex'])
            prs_df[diag_list] = abs(prs_df[diag_list].sub(prs_df['age_at_death'], axis=0))
            prs_df = prs_df[diag_list]
            
            
            

            

        prs_df = prs_df.reindex(prs_df.median().sort_values().index, axis=1)
        
       
        mean_order = list(prs_df.columns)
        sns.set(rc={'figure.figsize':(8, len(diag_list)/1.2)})
        sns.set_style("ticks")
        
        
        if plot_type == 'survival' and heatmap == False:
            kaplanplot(prs_df,diag_list,j,output_path,figid,tdc_dic,w,h,xrange)
            
        if plot_type == 'temporal'and heatmap == False:
            density_plots(prs_df,tdc_dic,j,output_path,figid,overview,w,h,xrange)
           
#         display(prs_df)
#         print(prs_df.describe())
        pairs, FDR_pvalues = obtain_P_values(mean_order, j, prs_df,heatmap=heatmap)
#         print(pairs, FDR_pvalues)

        plot_itself(prs_df, mean_order, pairs, FDR_pvalues,output_path,tdc_dic,xrange,figid,plot_type,j,w,h,heatmap=heatmap)
        
        

    


### plotting

In [None]:
# ##Select the rows from general info belonging to donors of interest
table_of_choice = 'table5_p'
selected_donors,diagnoses = table_selector(table_of_choice, predictions_df_temporal)
# display(selected_diagnoses)
print(f"After selecting for {selected_donors['neuropathological_diagnosis'].unique()}, we have {selected_donors['DonorID'].nunique()}  donors")

##ASSIGN RANDOM COLORS TO THE DIAGNOSIS
color_palette = sns.color_palette("tab20b")[0:21]

table_diagnosis_colors_dic = {} 
counter = 0 

for diagnosis in diagnoses:
    table_diagnosis_colors_dic[diagnosis] = color_palette[counter]
    counter+=1
    
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] =  ['FTD','AD_CA', 'AD_DLB', 'AD_VE', 'AD', 'DEM_SICC_AGD', 'DLB_SICC','DLB', 'DEM_SICC', 'DEM_VE', 'PD_AD', 'PDD', 'PD', 'VD']
trait_attribute_dictionary['Disorientation'] =  ['FTD','AD_CA', 'AD_DLB', 'AD_VE', 'AD', 'DEM_SICC_AGD','DLB_SICC', 'DLB', 'DEM_SICC', 'DEM_VE', 'PD_AD', 'PDD', 'PD', 'VD']
trait_attribute_dictionary['Memory_impairment'] =  ['FTD','AD_CA', 'AD_DLB', 'AD_VE', 'AD',  'DEM_SICC_AGD','DLB_SICC', 'DLB', 'DEM_SICC', 'DEM_VE', 'PD_AD', 'PDD', 'PD', 'VD']

# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     110,
#                                    'temporal',
#                                    'sup4b',
#                                     heatmap=True,
# #                                    printvalues=True,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=7,
#                                    h=5                                   )

make_labeled_violinplots_horizontal(trait_attribute_dictionary,
                                    selected_donors,
                                    table_diagnosis_colors_dic,
                                    50,
                                   'survival',
                                   'sup4b',
                                    heatmap=True,
#                                    printvalues=True,
                                   fixsex=True,
                                   remove_zero_donors = False,
                                   w=7,
                                   h=5                                   )

selected_donors,diagnoses = table_selector(table_of_choice, predictions_df)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] =  ['FTD','AD_CA', 'AD_DLB', 'AD_VE', 'AD', 'DEM_SICC_AGD', 'DLB_SICC','DLB', 'DEM_SICC', 'DEM_VE', 'PD_AD', 'PDD', 'PD', 'VD']
trait_attribute_dictionary['Disorientation'] =  ['FTD','AD_CA', 'AD_DLB', 'AD_VE', 'AD', 'DEM_SICC_AGD','DLB_SICC', 'DLB', 'DEM_SICC', 'DEM_VE', 'PD_AD', 'PDD', 'PD', 'VD']
trait_attribute_dictionary['Memory_impairment'] =  ['FTD','AD_CA', 'AD_DLB', 'AD_VE', 'AD',  'DEM_SICC_AGD','DLB_SICC', 'DLB', 'DEM_SICC', 'DEM_VE', 'PD_AD', 'PDD', 'PD', 'VD']

# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     15,
#                                    'counts',
#                                    'sup4b',
#                                     heatmap=True,
# #                                    printvalues=True,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=7,
#                                    h=5                                   )


In [None]:
table_of_choice = 'table3_with_con_p'
selected_donors,diagnoses = table_selector(table_of_choice,predictions_df)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] =   ['CON','CBD', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP']#'FTD-FUS', 
trait_attribute_dictionary['Memory_impairment'] =   ['CON','FTD_MND', 'PID','PSP','CBD','FTD-TDP-A','FTD-TAU','FTD-TDP','FTD-TDP-B','FTD-TDP-C'] #'FTD-FUS',
trait_attribute_dictionary['Compulsive_behavior'] =   ['CON','FTD_MND', 'PID','PSP','CBD','FTD-TDP-A','FTD-TAU','FTD-TDP','FTD-TDP-B','FTD-TDP-C']#'FTD-FUS',

color_palette = sns.color_palette("tab20b")[0:21]

# set random colors
table_diagnosis_colors_dic = {} 
counter = 0 

for diagnosis in diagnoses:
    table_diagnosis_colors_dic[diagnosis] = color_palette[counter]
    counter+=1
    
    
# counts, 4b 
make_labeled_violinplots_horizontal(trait_attribute_dictionary,
                                    selected_donors,
                                    table_diagnosis_colors_dic,
                                    14,
                                    'counts',
                                   'main4b',
                                   heatmap=True,
                                   printvalues=True,
                                   fixsex=True,
                                   remove_zero_donors = False,
                                   w=10,
                                   h=5                                   )

# ## temporal, 4b
table_of_choice = 'table3_p'
selected_donors,diagnoses = table_selector(table_of_choice, predictions_df_temporal)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] =   ['CBD',  'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP'] #'FTD-FUS',
trait_attribute_dictionary['Memory_impairment'] = [ 'CBD',  'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP'] #'FTD-FUS','ALS',
trait_attribute_dictionary['Compulsive_behavior'] = [ 'CBD', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP'] #'FTD-FUS','ALS',
# trait_attribute_dictionary['Dementia'] =   ['CBD','FTD-FUS',  'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP'] #
# trait_attribute_dictionary['Memory_impairment'] = [ 'CBD','FTD-FUS','ALS',  'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP'] #
# trait_attribute_dictionary['Compulsive_behavior'] = [ 'CBD','FTD-FUS','ALS', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP'] #

# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     110,
#                                     'temporal',
#                                     'main4b',
#                                     heatmap=True,
# #                                    printvalues=True,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=10,
#                                    h=5                                   )
# #survival, 4b
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     50,
#                                     'survival',
#                                     'main4b',
#                                     heatmap=True,
# #                                    printvalues=True,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=6,
#                                    h=5                                   )                                   


In [None]:

##ASSIGN COLORS TO THE DIAGNOSIS
table_diagnosis_colors_dic = {} 
table_diagnosis_colors_dic['CON'] = (0.5490196078431373, 0.42745098039215684, 0.19215686274509805)
table_diagnosis_colors_dic['AD']  = (0.38823529411764707, 0.4745098039215686, 0.2235294117647059)
table_diagnosis_colors_dic['VD']  = (0.9058823529411765, 0.796078431372549, 0.5803921568627451)
table_diagnosis_colors_dic['FTD'] = (0.7098039215686275, 0.8117647058823529, 0.4196078431372549)
table_diagnosis_colors_dic['MND'] = (0.807843137254902, 0.8588235294117647, 0.611764705882353)
table_diagnosis_colors_dic['PD']  = (0.2235294117647059, 0.23137254901960785, 0.4745098039215686)
table_diagnosis_colors_dic['PDD'] = (0.4196078431372549, 0.43137254901960786, 0.8117647058823529)
table_diagnosis_colors_dic['DLB']  = (0.611764705882353, 0.6196078431372549, 0.8705882352941177)
table_diagnosis_colors_dic['PSP']  = (0.5176470588235295, 0.23529411764705882, 0.2235294117647059)
table_diagnosis_colors_dic['ATAXIA']  = (0.6784313725490196, 0.28627450980392155, 0.2901960784313726)
table_diagnosis_colors_dic['MS']  = (0.8392156862745098, 0.3803921568627451, 0.4196078431372549)
table_diagnosis_colors_dic['MSA']  = (0.9058823529411765, 0.5882352941176471, 0.611764705882353)
table_diagnosis_colors_dic['MD']  = (0.4823529411764706, 0.2549019607843137, 0.45098039215686275)
table_diagnosis_colors_dic['BP']  = (0.6470588235294118, 0.3176470588235294, 0.5803921568627451)
table_diagnosis_colors_dic['SCZ'] = (0.8705882352941177, 0.6196078431372549, 0.8392156862745098)

# ##Select the rows from general info belonging to donors of interest
table_of_choice = 'table1_p'


# ## temporal, sup4b
selected_donors,diagnoses = table_selector(table_of_choice,predictions_df_temporal)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Mobility_problems'] = ['MS','MND','VD',
                                                   'PDD','PD','DLB','MSA','PSP']
trait_attribute_dictionary['Muscular_Weakness'] = ['MS','MND','VD',
                                                   'PDD','PD','DLB','MSA','PSP']
# trait_attribute_dictionary['Muscular_Weakness'] = ['MS','CON']
# trait_attribute_dictionary['Fatigue'] = ['MS','CON']
trait_attribute_dictionary['Positive_sensory_symptoms'] = ['MS','CON']
make_labeled_violinplots_horizontal(trait_attribute_dictionary,
                                    selected_donors,
                                    table_diagnosis_colors_dic,
                                    110,
                                    'temporal',
                                   'sup4b',
                                   heatmap=False,
#                                    printvalues=True,
                                   fixsex=True,
                                   remove_zero_donors = False,
                                   w=7,
                                   h=5)
# ## survival, sup4b
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     80,
#                                     'survival',
#                                    'sup4b',
#                                    heatmap=True,
#                                    printvalues=False,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=7,
#                                    h=5)

# ## counts, sup4b
selected_donors,diagnoses = table_selector(table_of_choice,predictions_df)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Mobility_problems'] = ['CON','MS','MND','VD',
                                                   'PDD','PD','DLB','MSA','PSP']
trait_attribute_dictionary['Muscular_Weakness'] = ['CON','MS','MND','VD',
                                                   'PDD','PD','DLB','MSA','PSP']
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     20,
#                                     'counts',
#                                    'sup4b',
#                                    heatmap=True,
# #                                    printvalues=False,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=7,
#                                    h=5)


## temporal, figure 3
selected_donors,diagnoses = table_selector(table_of_choice,predictions_df_temporal)
print(f"After selecting for {selected_donors['neuropathological_diagnosis'].unique()}, we have {selected_donors['DonorID'].nunique()}  donors")
# selected_donors = selected_donors[selected_donors['DonorID'].isin(selected_donors['DonorID'].unique()[-20:])]
trait_attribute_dictionary = {} 
# trait_attribute_dictionary['Dementia'] = ['CON','FTD','PDD','DLB','AD','VD','PD']
# trait_attribute_dictionary['Bradykinesia'] = ['CON','PD','PDD','DLB','PSP','MSA']
trait_attribute_dictionary['Dementia'] = ['FTD','PDD','DLB','AD','VD','PD']
trait_attribute_dictionary['Bradykinesia'] = ['PD','PDD','DLB','PSP','MSA']

# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     110,
#                                    'temporal',
#                                    'main3bc',
#                                    heatmap=False,
#                                     printvalues = False,
#                                     remove_zero_donors = False,
#                                     fixsex = True,
#                                     w = 14,
#                                     h = 5
#                                    )
# survival, figure 3
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     40,
#                                    'survival',
#                                    'main3bc',
#                                    heatmap=False,
#                                    printvalues=False,
#                                     remove_zero_donors = False,
#                                     fixsex = True,
#                                     w = 14,
#                                     h = 5
#                                    )
## counts, figure 3
selected_donors,diagnoses = table_selector(table_of_choice,predictions_df)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] = ['CON','FTD','PDD','DLB','AD','VD','PD']
trait_attribute_dictionary['Bradykinesia'] = ['CON','PD','PDD','DLB','PSP','MSA']
# trait_attribute_dictionary['Dementia'] = ['FTD','PDD','DLB','AD','VD','PD']
# trait_attribute_dictionary['Bradykinesia'] = ['PD','PDD','DLB','PSP','MSA']
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     20,
#                                    'counts',
#                                    'main3bc',
#                                    heatmap=False,
#                                    printvalues=False,
#                                    fixsex=True,
#                                    remove_zero_donors = False,
#                                    w=5,
#                                    h=8)


In [None]:
table_of_choice = 'table3_with_con_p'
selected_donors,diagnoses = table_selector(table_of_choice,predictions_df)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] =   ['CON','CBD', 'FTD-FUS', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP']
trait_attribute_dictionary['Memory_impairment'] =   ['CON','ALS','FTD_MND','FTD-FUS', 'PID','PSP','CBD','FTD-TDP-A','FTD-TAU','FTD-TDP','FTD-TDP-B','FTD-TDP-C']
trait_attribute_dictionary['Compulsive_behavior'] =   ['CON','ALS','FTD_MND','FTD-FUS', 'PID','PSP','CBD','FTD-TDP-A','FTD-TAU','FTD-TDP','FTD-TDP-B','FTD-TDP-C']

color_palette = sns.color_palette("tab20b")[0:21]

# set random colors
table_diagnosis_colors_dic = {} 
counter = 0 

for diagnosis in diagnoses:
    table_diagnosis_colors_dic[diagnosis] = color_palette[counter]
    counter+=1
    
    
# # counts, 4b 
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     15,
#                                     'counts',
#                                    'main4b')

# ## temporal, 4b
table_of_choice = 'table3_p'
selected_donors,diagnoses = table_selector(table_of_choice, predictions_df_temporal)
trait_attribute_dictionary = {} 
trait_attribute_dictionary['Dementia'] =   ['CBD', 'FTD-FUS', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP']
# trait_attribute_dictionary['Memory_impairment'] = ['ALS', 'CBD', 'FTD-FUS', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP']
# trait_attribute_dictionary['Compulsive_behavior'] = ['ALS', 'CBD', 'FTD-FUS', 'FTD-TAU', 'FTD-TDP-A', 'FTD-TDP-B', 'FTD-TDP-C','FTD_MND', 'FTD-TDP', 'PID', 'PSP']

make_labeled_violinplots_horizontal(trait_attribute_dictionary,
                                    selected_donors,
                                    table_diagnosis_colors_dic,
                                    110,
                                    'temporal',
                                    'main4b',
                                   heatmap=False)
# #survival, 4b
# make_labeled_violinplots_horizontal(trait_attribute_dictionary,
#                                     selected_donors,
#                                     table_diagnosis_colors_dic,
#                                     50,
#                                     'survival',
#                                     'main4b')
