In [None]:
import pandas as pd
import subprocess
import sys
import numpy as np
import os
import shutil
import joblib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, KernelPCA
from sklearn.cluster import MeanShift, DBSCAN
from sklearn.metrics import roc_auc_score, r2_score, mean_squared_error, accuracy_score
from sklearn.linear_model import LogisticRegression, LinearRegression 
from sklearn.ensemble import RandomForestRegressor
from umap import UMAP
import plotly.express as px
import plotly.io as pio
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
from statsmodels.formula.api import logit, ols, glm
from statsmodels.stats.proportion import proportions_ztest
from scipy.stats import chi2_contingency, chisquare, ttest_1samp
import umap
import numba
import sklearn
import shap
import math

In [None]:
print(umap.__version__)
print(np.__version__)
print(sklearn.__version__)
print(numba.__version__)
print(sys.executable)
print(sys.version)
print(sys.version_info)

In [None]:
def shell_do(command, log=False, return_log=False):
    print(f'Executing: {(" ").join(command.split())}', file=sys.stderr)

    res=subprocess.run(command.split(), stdout=subprocess.PIPE)

    if log:
        print(res.stdout.decode('utf-8'))
    if return_log:
        return(res.stdout.decode('utf-8'))

In [None]:
def plot_3d(labeled_df, color, fname, symbol=None, x='PC1', y='PC2', z='PC3', title=None, x_range=None, y_range=None, z_range=None):
    '''
    Parameters: 
    labeled_df (Pandas dataframe): labeled ancestry dataframe
    color (string): color of ancestry label. column name containing labels for ancestry in labeled_pcs_df
    symbol (string): symbol of secondary label (for example, predicted vs reference ancestry). default: None
    plot_out (string): filename to output filename for .png and .html plotly images
    x (string): column name of x-dimension
    y (string): column name of y-dimension
    z (string): column name of z-dimension
    title (string, optional): title of output scatterplot
    x_range (list of floats [min, max], optional): range for x-axis
    y_range (list of floats [min, max], optional): range for y-axis
    z_range (list of floats [min, max], optional): range for z-axis

    Returns:
    3-D scatterplot (plotly.express.scatter_3d). If plot_out included, will write .png static image and .html interactive to plot_out filename
        
    '''
    fig = px.scatter_3d(
        labeled_df,
        x=x,
        y=y,
        z=z,
        color=color,
        symbol=symbol,
        title=title,
        color_discrete_sequence=px.colors.qualitative.Bold,
        range_x=x_range,
        range_y=y_range,
        range_z=z_range
    )
    
    fig.update_traces(marker={'size': 3})
    
    fig.show()
    fig.write_html(fname)

In [None]:
def plot(transform, labels, name, n_comps, fname):
    # take transform np array and make if a pd dataframe
    transform_df = pd.DataFrame(transform)
    col_names = ['COMP'+str(i+1) for i in range(n_comps)]
    transform_df.columns = col_names
    
    # add labels column
    transform_df.loc[:, 'label'] = labels
    
    # plot in 2d (don't think is is needed anymore)
    if n_comps < 3:
        sns.lmplot(x='COMP1', y='COMP2', hue='label', data=transform_df, fit_reg=False)
        plt.show()

    # plot in 3d
    else:
        plot_3d(labeled_df=transform_df, color='label', fname=fname, title=name, x='COMP1', y='COMP2', z='COMP3')

In [None]:
def plot_heatmap(data, title, x_lab, y_lab):
    # plot heatmap for Disease n ~ Comp/Cluster X regressions
    sns.heatmap(data, annot=True)
    plt.title(title)
    plt.xlabel(x_lab)
    plt.ylabel(y_lab)
    plt.show()

In [None]:
def set_up(cases, disease=None):
    # isolate a single disease if requested
    if disease:
        cases = cases[cases['PHENO'] == disease]
        
    # get X data
    X = cases.drop(columns=['PHENO'], axis=1)
    print(X.shape)
    
    # get y data
    y = cases['PHENO']
    print(y.shape)
    print(y.value_counts())
    
    return X, y

In [None]:
def umap_transform(X_train, X_test, y_cases, a, b, seed=None, cohorts=None, plot_transform=True):
    wd = 'insert_path'
    
    # umap transform train and test data
    umap = UMAP(n_components=3, a=a, b=b, random_state=seed)
    X_train_umap = umap.fit_transform(X_train)
    X_test_umap = umap.transform(X_test)
    
    # get full umap data
    X_cases_umap = np.append(X_train_umap, X_test_umap, axis=0)
    
    # plot when requested
    if plot_transform:
        plot(X_cases_umap, y_cases, 'UMAP', X_cases_umap.shape[1], f'{wd}/analysis/figures/umap.html')
        if cohorts is not None:
            plot(X_cases_umap, cohorts, 'UMAP', X_cases_umap.shape[1], f'{wd}/analysis/figures/umap_cohorts.html')
    
    return X_train_umap, X_test_umap

In [None]:
def get_permuted_clusters(cluster_membership_gen_path, y_ids, disease):
    # num iterations, set up membership dict
    iterations = 100
    all_c0_ids_dict = {}
    not_c0_ids_dict = {}
    
    # loop through train/test split iterations
    for i in range(15):
        # get correct path and read
        cluster_membership_path = cluster_membership_gen_path.replace('*',str(i+1))
        cluster_membership_df = pd.read_csv(cluster_membership_path, sep='\s+')
        
        # drop id, pheno, get average cluster membership over 100 iterations
        cluster_membership = cluster_membership_df.drop(columns=['ID','pheno'], axis=1)
        cluster_membership = cluster_membership.apply(pd.Series.value_counts, axis=1).fillna(0)
        cluster_membership_avg = cluster_membership/iterations

        # copy id and pheno to new dfs
        cluster_membership['ID'] = cluster_membership_df['ID'].copy()
        cluster_membership_avg['ID'] = cluster_membership_df['ID'].copy()
        cluster_membership['pheno'] = cluster_membership_df['pheno'].copy()
        cluster_membership_avg['pheno'] = cluster_membership_df['pheno'].copy()
    
        # gettign ids for all c0 and not c0
        all_c0 = cluster_membership_avg[cluster_membership_avg[0] == 1.0]
        not_c0 = cluster_membership_avg[cluster_membership_avg[0] == 0.0]
        all_c0_ids_dict[i] = list(all_c0['ID'])
        not_c0_ids_dict[i] = list(not_c0['ID'])
    
    all_c0_overlap = []
    not_c0_overlap = []
    sometimes_c0_overlap = []

    # loop through ids
    for iid in y_ids:
        sum_all_c0 = 0
        sum_not_c0 = 0
        
        # see if id is in each iteration
        for key in all_c0_ids_dict:
            if iid in all_c0_ids_dict[key]:
                sum_all_c0 += 1
            if iid in not_c0_ids_dict[key]:
                sum_not_c0 += 1
        
        if disease:
            all_threshold = len(all_c0_ids_dict)-3
            not_threshold = len(not_c0_ids_dict)-3
        else:
            all_threshold = len(all_c0_ids_dict)
            not_threshold = len(not_c0_ids_dict)
        
        # get all c0 ids (c0)
        if sum_all_c0 >= all_threshold:
            all_c0_overlap.append(iid)
        
        # get not c0 ids (c2)
        if sum_not_c0 >= not_threshold:
            not_c0_overlap.append(iid)

        # rest are sometimes c0 ids (c1)
        if (iid not in all_c0_overlap) and (iid not in not_c0_overlap):
            sometimes_c0_overlap.append(iid)

    # assigning cluster memebership
    all_c0 = pd.DataFrame()
    not_c0 = pd.DataFrame()
    sometimes_c0 = pd.DataFrame()
    print(f'C0:{len(all_c0_overlap)}')
    print(f'C1:{len(sometimes_c0_overlap)}')
    print(f'C2:{len(not_c0_overlap)}')
    
    all_c0['ID'] = all_c0_overlap
    not_c0['ID'] = not_c0_overlap
    sometimes_c0['ID'] = sometimes_c0_overlap

    all_c0['cluster'] = 0
    not_c0['cluster'] = 2
    sometimes_c0['cluster'] = 1
    
    # concatenate full data
    full = pd.concat([all_c0,not_c0,sometimes_c0], axis=0, ignore_index=True)
    print(full.head())
    print(full.shape)
    
    return full

In [None]:
def expand_cluster_data(cluster_data):
    # removing cases that were not clustered (not an issue with MeanShift)
    cluster_data = cluster_data[cluster_data['cluster'] != -1]
    
    # get dummies for pheno and cluster membership for regression
    cluster_data = pd.concat([cluster_data, pd.get_dummies(cluster_data['pheno']), pd.get_dummies(cluster_data['cluster'], prefix='cluster')], axis=1)
    rename_dict = {}

    # get comp columns to rename
    for column in cluster_data.columns:
        if type(column) is int:
            rename_dict[column] = 'COMP' + str(column+1)

    # rename columns
    cluster_data = cluster_data.rename(columns=rename_dict)
    print(cluster_data.shape)
    
    return cluster_data

In [None]:
def disease_regression(train_data, test_data, data_type, standardize=False, simul=False):
    # copy of train and test data
    train_data_copy = train_data.copy(deep=True)
    test_data_copy = test_data.copy(deep=True)
    
    # disease list
    diseases = ['ad','pd','als','lbd','ftd']
    
    # set up some data lists
    coef_data = []
    or_data =[]
    z_data = []
    p_data = []
    se_data = []
    
    # get cluster cols
    cluster_cols = []
    for column in train_data_copy.columns:
        if 'cluster_' in column:
            cluster_cols.append(column)
    
    # get comp cols
    comp_cols = []
    for column in train_data_copy.columns:
        if 'COMP' in column:
            comp_cols.append(column)
    
    # set up cluster regression
    if data_type == 'clusters':
        cols = cluster_cols
        axis = 'Cluster Membership'
    # set up comp regression
    else:
        cols = comp_cols
        # standardize when requested
        if standardize:
            scaler = StandardScaler()
            train_data_copy[comp_cols] = scaler.fit_transform(train_data_copy[comp_cols])
            test_data_copy[comp_cols] = scaler.fit_transform(test_data_copy[comp_cols])
            axis = 'Standardized Component'
        else:
            axis = 'Component'
    
    auc_data = []
    
    full_data_copy = pd.concat([train_data_copy, test_data_copy], axis=0)
    
    # loop through diseases
    for disease in diseases:
        coef_row_data = []
        or_row_data = []
        z_row_data = []
        p_row_data = []
        se_row_data = []
        # Disease n ~ Comp/Cluster 1 + ... + Comp/Cluster n regression
        if simul:
            formula_str = ''
        
            for i in range(len(cols)):
                if data_type == 'clusters':
                    if cols[i] != cols[-2]:
                        if cols[i] != cols[-1]:
                            formula_str += str(cols[i]) + ' + '
                        else:
                            formula_str += str(cols[i])
                else:
                    if cols[i] != cols[-1]:
                            formula_str += str(cols[i]) + ' + '
                    else:
                        formula_str += str(cols[i])
            
            formula = (f'{disease} ~ {formula_str}')
            model = logit(formula=formula, data=train_data_copy).fit(disp=0)
            pred = model.predict(test_data_copy[cols])
            # print(roc_auc_score(test_data_copy[disease], pred))
            auc_data.append(roc_auc_score(test_data_copy[disease], pred))
            results = pd.read_html(model.summary().tables[1].as_html(), header=0, index_col=0)[0]
            coef_data.append(results['coef'])
            or_data.append(np.exp(results['coef']))
            z_data.append(results['z'])
            p_data.append(results['P>|z|'])
            
        # Disease n ~ Comp/Cluster n regression
        else:
            for col in cols:
                formula = (f'{disease} ~ {col}')
                model = logit(formula=formula, data=train_data_copy).fit(disp=0)
                pred = model.predict(test_data_copy[col])
                # print(roc_auc_score(test_data_copy[disease], pred))
                auc_data.append(roc_auc_score(test_data_copy[disease], pred))
                results = pd.read_html(model.summary().tables[1].as_html(), header=0, index_col=0)[0]
                coef_row_data.append(results.iloc[1]['coef'])
                or_row_data.append(np.exp(results.iloc[1]['coef']))
                z_row_data.append(results.iloc[1]['z'])
                p_row_data.append(model.pvalues.loc[col])
                se_row_data.append(results.iloc[1]['std err'])
    
            coef_data.append(coef_row_data)
            or_data.append(or_row_data)
            z_data.append(z_row_data)
            p_data.append(p_row_data)
            se_data.append(se_row_data)
    
    # dataframe heatmap data and plot
    coef_data = pd.DataFrame(coef_data, index=diseases)
    or_data = pd.DataFrame(or_data, index=diseases)
    z_data = pd.DataFrame(z_data, index=diseases)
    p_data = pd.DataFrame(p_data, index=diseases)
    se_data = pd.DataFrame(se_data, index=diseases)
    
    print('OR')
    print(or_data)
    print('BETA')
    print(coef_data)
    print('SE')
    print(se_data)
    print('P')
    print(p_data)
    
    # plot_heatmap(or_data, f'Disease Status vs. {axis} - OR', axis, 'Disease Status')
    # plot_heatmap(coef_data, f'Disease Status vs. {axis} - Coefficients', axis, 'Disease Status')
    # plot_heatmap(z_data, f'Disease Status vs. {axis} - Z scores', axis, 'Disease Status')
    # plot_heatmap(se_data, f'Disease Status vs. {axis} - Std Err', axis, 'Disease Status')
    # plot_heatmap(p_data, f'Disease Status vs. {axis} - P values', axis, 'Disease Status')
    
    # check out average AUC
    print('Average AUC')
    print(np.mean(auc_data))
    
    return p_data

In [None]:
def get_full_data(X_train, X_test, train_clusters, test_clusters):
    # concat X_train with cluster data
    X_train = X_train.reset_index(drop=True)
    train_full = pd.concat([X_train, train_clusters], axis=1)
    
    # concat X_test with cluster data
    X_test = X_test.reset_index(drop=True)
    test_full = pd.concat([X_test, test_clusters], axis=1)
    
    return train_full, test_full

In [None]:
def change_col_names(train_full, test_full, disease_snp_assoc_path, gene_list_path, geno_path):
    # identify SNP and non-SNP columns
    snp_cols = []
    non_snp_cols = []
    for column in train_full.columns:
        if 'rs' in column:
            snp_cols.append(column)
        else:
            non_snp_cols.append(column)
        # for when test set does not have any cases that belong to cluster_x
        if column not in test_full.columns:
            test_full[column] = 0
            test_full = test_full[train_full.columns]
            
    snp_cols = pd.DataFrame(snp_cols)
    snp_cols.columns = ['SNP']

    # read in disease snp assoc file
    disease_snp_assoc = pd.read_csv(f'{disease_snp_assoc_path}', sep=',')
    disease_snp_assoc.columns = ['D1','D2','SNP','GENE']
    
    # read in gene postions file
    gene_list = pd.read_csv(gene_list_path, sep=',')
    gene_list = gene_list[(gene_list['CHR'] != 'X') & (gene_list['CHR'] != 'Y') & (gene_list['CHR'] != 'XY')]
    gene_list['CHR'] = gene_list['CHR'].astype(int)
    
    # read in bim file 
    bim = pd.read_csv(f'{geno_path}.bim', sep='\s+', header=None)
    bim.columns = ['chr','SNP','pos','bp','alt','ref']
    
    for snp in snp_cols.values:
        if snp not in disease_snp_assoc['SNP'].values:
            row = {'D1':'AD','D2':np.nan,'SNP':snp}
            disease_snp_assoc = pd.concat([disease_snp_assoc,pd.DataFrame(row)], axis=0, ignore_index=True)

    # merge disease snp assoc with snp cols
    disease_snp_guide = snp_cols.merge(disease_snp_assoc, how='inner', on=['SNP'])
    
    # merge disease_snp_guide with bim
    disease_snp_guide_merge = disease_snp_guide.merge(bim, how='inner', on=['SNP'])
    
    col_name_strs = []

    # rename snps based on disease assoc
    for index, row in disease_snp_guide_merge.iterrows():
        col_name_str = ''
        if pd.isna(row['D2']):
            col_name_str = f'{row["SNP"]}_{row["D1"]}'
        else:
            col_name_str = f'{row["SNP"]}_{row["D1"]}_{row["D2"]}'

        if not pd.isna(row['GENE']):
            row['GENE'] = row['GENE'].replace('-','_')
            col_name_str += f'_{row["GENE"]}'
        else:
            chrom = row['chr']
            lower = row['bp'] - 400000
            upper = row['bp'] + 400000
            gene_list_chr = gene_list[gene_list['CHR'] == chrom]
            gene_list_range = gene_list_chr[(gene_list_chr['START'] > lower) & (gene_list_chr['STOP'] < upper)]
            gene_list_range_cp = gene_list_range.copy()
            gene_list_range_cp['low_dist'] = gene_list_chr['START'] - lower
            gene_list_range_cp['high_dist'] = upper - gene_list_chr['STOP']
            gene_list_range_cp['min_dist'] = gene_list_range_cp[['low_dist','high_dist']].min(axis=1)
            gene_list_range_cp = gene_list_range_cp[gene_list_range_cp['min_dist'] == gene_list_range_cp['min_dist'].min()]
            for ind in gene_list_range_cp.index:
                col_name_str += f'_{gene_list_range_cp["GENE"][ind]}'
            
        col_name_strs.append(col_name_str)

    disease_snp_guide_merge['col_name'] = col_name_strs
    
    # get and set full column names
    new_col_names = disease_snp_guide_merge['col_name']
    col_names_full = pd.concat([new_col_names, pd.Series(non_snp_cols)])
    train_full.columns = col_names_full
    test_full.columns = col_names_full
    
    return train_full, test_full, col_names_full

In [None]:
def snp_regression(train_data, test_data, data_type, standardize=False, plot_shap=True):
    # copy of train and test data
    train_data_copy = train_data.copy(deep=True)
    test_data_copy = test_data.copy(deep=True)
    
    # get cluster cols
    cluster_cols = []
    for column in train_data_copy.columns:
        if 'cluster_' in column:
            cluster_cols.append(column)
    
    # get comp cols
    comp_cols = []
    for column in train_data_copy.columns:
        if 'COMP' in column:
            comp_cols.append(column)
    
    # get snp cols
    snp_cols = []
    for column in train_data_copy.columns:
        if 'rs' in column:
            snp_cols.append(column)
    
    # set up cluster regression
    if data_type == 'clusters':
        cols = cluster_cols
        axis = 'Cluster Membership'
    # set up comp regression
    else:
        cols = comp_cols
        # standardize when requested
        if standardize:
            scaler = StandardScaler()
            train_data_copy[comp_cols] = scaler.fit_transform(train_data_copy[comp_cols])
            test_data_copy[comp_cols] = scaler.transform(test_data_copy[comp_cols])
            axis = 'Standardized Component'
        else:
            axis = 'Component'
    
    top_snps = {}
    
    full_data_copy = pd.concat([train_data_copy, test_data_copy], axis=0)
    
    # loop through comps/clusters
    for col in cols:
        print(col)
        # SNPs ~ Cluster n (log reg)
        if data_type == 'clusters':
            model = LogisticRegression(max_iter=200)
            model.fit(train_data_copy[snp_cols], train_data_copy[col])
            pred = model.predict_proba(test_data_copy[snp_cols])[:,1]
            print(roc_auc_score(test_data_copy[col], pred))
            explainer = shap.Explainer(model, train_data_copy[snp_cols])
            shap_values = explainer(train_data_copy[snp_cols], max_evals=501)
        # SNPs ~ Comp n (lin reg)
        else:
            model = LinearRegression()
            model.fit(train_data_copy[snp_cols], train_data_copy[col])
            pred = model.predict(test_data_copy[snp_cols])
            print(mean_squared_error(test_data_copy[col], pred))
            explainer = shap.Explainer(model, test_data_copy[snp_cols])
            shap_values = explainer(train_data_copy[snp_cols], max_evals=501)
        
        # get shap importance values (descending)
        feature_names = shap_values.feature_names
        shap_df = pd.DataFrame(shap_values.values, columns=feature_names)
        vals = np.abs(shap_df.values).mean(0)
        shap_importance = pd.DataFrame(list(zip(feature_names, vals)), columns=['col_name', 'feature_importance_vals'])
        shap_importance.sort_values(by=['feature_importance_vals'], ascending=False, inplace=True)
        top_snps[col] = list(shap_importance.head(20)['col_name'])
        
        # shap plots - need to make bar plot bidirectional
        if plot_shap:
            shap.summary_plot(shap_values, train_data_copy[snp_cols], show=False)
            plt.show()
            plt.clf()
            shap.plots.bar(shap_values, max_display=20, show=False)
            f = plt.gcf()
            f.savefig(f'/data/CARD/projects/AD_Cluster/analysis/figures/shap_summary_{col}.png', bbox_inches='tight', dpi=400)
            plt.show()

    return top_snps

In [None]:
def calculate_disease_specific_prs(adjusted, train_full, test_full, geno_path, assoc_path, out_path):
    # disease list
    diseases = ['ad','pd','als','lbd','ftd']
    
    # drop ID column, write columns to txt
    adj = adjusted.drop(columns=['ID'], axis=1)
    adj_cols = pd.Series(list(adj.columns))
    adj_cols.to_csv(f'{out_path}_variant_ids.txt', sep='\t', index=False, header=False)
    
    # extract variants from geno
    extract_cmd = f'plink2 --bfile {geno_path} --extract {out_path}_variant_ids.txt --make-bed --out {out_path}'
    shell_do(extract_cmd)
    
    # read assoc file
    assoc = pd.read_csv(assoc_path, sep='\s+', header=None)
    assoc.columns = ['ID','ALLELE','BETA','DISEASE']
    
    train_id = train_full['ID']
    test_id = test_full['ID']
    train_out_path = f'{out_path}_train'
    test_out_path = f'{out_path}_test'
    pd.concat([train_id, train_id], axis=1).to_csv(f'{train_out_path}_ids.txt', sep='\t', index=False, header=False)
    pd.concat([test_id, test_id], axis=1).to_csv(f'{test_out_path}_ids.txt', sep='\t', index=False, header=False)
    
    # isolate train and test plink data
    keep_train_cmd = f'plink2 --bfile {out_path} --keep {train_out_path}_ids.txt --make-bed --out {train_out_path}'
    shell_do(keep_train_cmd)

    keep_test_cmd = f'plink2 --bfile {out_path} --keep {test_out_path}_ids.txt --make-bed --out {test_out_path}'
    shell_do(keep_test_cmd)
    
    # loop through diseases
    for disease in diseases:
        # isolate correct SNPs and write to txt
        assoc_disease = assoc[assoc['DISEASE'] == disease.upper()]
        assoc_disease[['ID','ALLELE','BETA']].to_csv(f'{out_path}_{disease}_assoc_ids.txt', sep='\t', index=False, header=False)
        
        # run train and test prs
        score_train_cmd = f'plink2 --bfile {train_out_path} --score {out_path}_{disease}_assoc_ids.txt --out {train_out_path}'
        shell_do(score_train_cmd)

        score_test_cmd = f'plink2 --bfile {test_out_path} --score {out_path}_{disease}_assoc_ids.txt --out {test_out_path}'
        shell_do(score_test_cmd)
        
        train_prs = pd.read_csv(f'{train_out_path}.sscore', sep='\s+')
        train_prs = train_prs.rename({'#FID':'ID','SCORE1_AVG':f'SCORE1_AVG_{disease}'}, axis=1)
    
        test_prs = pd.read_csv(f'{test_out_path}.sscore', sep='\s+')
        test_prs = test_prs.rename({'#FID':'ID','SCORE1_AVG':f'SCORE1_AVG_{disease}'}, axis=1)
        
        scaler = StandardScaler()
        
        train_prs[f'PRS_STD_{disease}'] = scaler.fit_transform(np.asarray(train_prs[f'SCORE1_AVG_{disease}']).reshape(-1,1))
        test_prs[f'PRS_STD_{disease}'] = scaler.transform(np.asarray(test_prs[f'SCORE1_AVG_{disease}']).reshape(-1,1))
        
        full_prs = pd.concat([train_prs, test_prs], axis=0)
        # full_prs[f'PRS_STD_{disease}'] = scaler.fit_transform(np.asarray(full_prs[f'SCORE1_AVG_{disease}']).reshape(-1,1))
        
        train_full = train_full.merge(full_prs[['ID',f'PRS_STD_{disease}']], how='inner', on=['ID'])
        test_full = test_full.merge(full_prs[['ID',f'PRS_STD_{disease}']], how='inner', on=['ID'])
        
        # plot
        plt.hist(full_prs[f'PRS_STD_{disease}'])
        plt.title(f'PRS Distribution - {disease}')
        plt.xlabel('Standardized PRS')
        plt.ylabel('Frequency')
        plt.show()
        
    return train_full, test_full

In [None]:
def disease_specific_prs_regression(train_data, test_data, data_type, standardize=False):
    wd = f'insert_path'
    
    # copy of train and test data
    train_data_copy = train_data.copy(deep=True)
    test_data_copy = test_data.copy(deep=True)
    
    full_data_copy = pd.concat([train_data, test_data], axis=0)
    
    # set up some data lists
    coef_data = []
    std_data = []
    or_data =[]
    z_data = []
    p_data = []
    prs_mean_data = []
    prs_std_data = []
    
    # prs col
    prs_score = 'PRS_STD'
    
    # disease list
    diseases = ['ad','pd','als','lbd','ftd']
    
    # get cluster cols
    cluster_cols = []
    for column in train_data_copy.columns:
        if 'cluster_' in column:
            cluster_cols.append(column)
     
    # get comp cols
    comp_cols = []
    for column in train_data_copy.columns:
        if 'COMP' in column:
            comp_cols.append(column)
            
    # set up cluster regression
    if data_type == 'clusters':
        cols = cluster_cols
        axis = 'Cluster Membership'
    # set up comp regression
    else:
        cols = comp_cols
        if standardize:
            scaler = StandardScaler()
            train_data_copy[comp_cols] = scaler.fit_transform(train_data_copy[comp_cols])
            test_data_copy[comp_cols] = scaler.transform(test_data_copy[comp_cols])
            axis = 'Standardized Component'
        else:
            axis = 'Component'
    
    # loop through diseases
    for disease in diseases:
        # set up row data lists
        coef_row_data = []
        std_row_data = []
        or_row_data =[]
        z_row_data = []
        p_row_data = []
        prs_mean_row_data = []
        prs_std_row_data = []
        
        # loop through columns
        for col in cols:
            # formula
            formula = (f'{col} ~ {prs_score}_{disease}')
            
            if data_type == 'clusters':
                model = logit(formula=formula, data=train_data_copy).fit(disp=0)
                pred = model.predict(test_data_copy[f'{prs_score}_{disease}'])
                stat_var = 'z'
            else:
                model = ols(formula=formula, data=train_data_copy).fit(disp=0)
                pred = model.predict(test_data_copy[f'{prs_score}_{disease}'])
                stat_var = 't'

            results = pd.read_html(model.summary().tables[1].as_html(), header=0, index_col=0)[0]
            coef_row_data.append(results.iloc[1]['coef'])
            std_row_data.append(results.iloc[1]['std err'])
            or_row_data.append(np.exp(results.iloc[1]['coef']))
            z_row_data.append(results.iloc[1][f'{stat_var}'])
            p_row_data.append(model.pvalues.loc[f'{prs_score}_{disease}'])
            
            stat, pval = ttest_1samp(full_data_copy[full_data_copy[col] == 1][f'{prs_score}_{disease}'], 0)
            mean = full_data_copy[full_data_copy[col] == 1][f'{prs_score}_{disease}'].mean()
            std = full_data_copy[full_data_copy[col] == 1][f'{prs_score}_{disease}'].std()
            
            if pval < 0.05:
                prs_mean_row_data.append(f'{mean} ({std})*')
            else:
                prs_mean_row_data.append(f'{mean} ({std})')

            # prs_mean_row_data.append(full_data[full_data[col] == 1][f'{prs_score}_{disease}'].mean())
            prs_std_row_data.append(full_data_copy[full_data_copy[col] == 1][f'{prs_score}_{disease}'].std())
        
        coef_data.append(coef_row_data)
        std_data.append(std_row_data)
        or_data.append(or_row_data)
        z_data.append(z_row_data)
        p_data.append(p_row_data)
        prs_mean_data.append(prs_mean_row_data)
        prs_std_data.append(prs_std_row_data)
    
    
    # dataframe heatmap data and plot
    coef_data = pd.DataFrame(np.transpose(coef_data), index=cols, columns=diseases)
    std_data = pd.DataFrame(np.transpose(std_data), index=cols, columns=diseases)
    or_data = pd.DataFrame(np.transpose(or_data), index=cols, columns=diseases)
    z_data = pd.DataFrame(np.transpose(z_data), index=cols, columns=diseases)
    p_data = pd.DataFrame(np.transpose(p_data), index=cols, columns=diseases)
    prs_mean_data = pd.DataFrame(np.transpose(prs_mean_data), index=cols, columns=diseases)
    prs_std_data = pd.DataFrame(np.transpose(prs_std_data), index=cols, columns=diseases)
    
    print('OR')
    print(or_data)
    print('BETA')
    print(coef_data)
    print('SE')
    print(std_data)
    print('P')
    print(p_data)
    print('Mean PRS')
    print(prs_mean_data)
    # print('STD PRS')
    # print(prs_std_data)
    
    # plot_heatmap(coef_data, f'{axis} vs. Disease-Specific PRS Score - Coefficients', 'Disease-Specific PRS Score', axis)
    # plot_heatmap(or_data, f'{axis} vs. Disease-Specific PRS Score - OR', 'Disease-Specific PRS Score', axis)
    # plot_heatmap(z_data, f'{axis} vs. Disease-Specific PRS Score - Z scores', 'Disease-Specific PRS Score', axis)
    # plot_heatmap(p_data, f'{axis} vs. Disease-Specific PRS Score - P values', 'Disease-Specific PRS Score', axis)
    
    for row in range(coef_data.shape[0]):
        for col in range(coef_data.shape[1]):
            coef_data.iloc[row,col] = f'{coef_data.iloc[row,col]} ({std_data.iloc[row,col]})'
    
    coef_data.to_csv(f'{wd}/analysis/results/disease_specific_regression_coef.csv', sep=',', index=True, header=True)
    p_data.to_csv(f'{wd}/analysis/results/disease_specific_regression_pval.csv', sep=',', index=True, header=True)
    # print(coef_data)
    # print(p_data)

In [None]:
def get_cluster_counts(data, fname):
    wd = f'insert_path'
    
    # num samples
    num_samples = data.shape[0]
    
    # get cluster cols
    cols = ['Disease','Overall']
    cluster_cols = []
    for col in data.columns:
        if 'cluster_' in col:
            num = col.split('_')[1]
            cols.append(f'C{num}')
            cluster_cols.append(col)
    
    # set up data frame
    disease_cluster_representation = pd.DataFrame(columns=cols)
    disease_cluster_representation['Disease'] = ['ad','pd','als','lbd','ftd']

    data_dict = {}

    # get overall disease value counts
    data_dict['Overall'] = dict(round(data['pheno'].value_counts(normalize=True), 5))
    
    # loop cluster cols
    for i in range(len(cluster_cols)):
        # get disease value counts for each cluster
        data_cluster = data[data[cluster_cols[i]] == 1]
        data_dict[f'C{i}'] = dict(round(data_cluster['pheno'].value_counts(normalize=True), 5))
    
    # fill out disease cluster rep df
    for cluster in data_dict:
        col_data = []
        for disease in disease_cluster_representation['Disease']:
            if disease in list(data_dict[cluster].keys()):
                col_data.append(data_dict[cluster][disease])
            else:
                col_data.append(0)
        disease_cluster_representation[cluster] = col_data
        
    print(disease_cluster_representation.shape)
    
    # z-test for proportions
    for row in range(disease_cluster_representation.shape[0]):
        print(disease_cluster_representation.iloc[row, 0])
        for col in range(disease_cluster_representation.shape[1]):
            if col > 1:
                cluster_membership_overall = disease_cluster_representation.iloc[row, 1] * num_samples
                cluster_membership_row = disease_cluster_representation.iloc[row, col] * num_samples
                membership = np.array([cluster_membership_overall, cluster_membership_row])
                samples = np.array([num_samples, num_samples])
                stat, p_val = proportions_ztest(count=membership, nobs=samples, alternative='two-sided')
                if p_val < 0.05:
                    disease_cluster_representation.iloc[row, col] = f'{disease_cluster_representation.iloc[row, col]}*'
                print(f'Overall vs. C{col-2}: z_stat={round(stat, 3)}, p_value={p_val}')
        print()
                
    # print and write to file
    print(disease_cluster_representation)
    disease_cluster_representation.to_csv(f'{wd}/analysis/results/{fname}', sep=',', index=False)

In [None]:
def plot_prs_distributions_by_disease(full_data):
    # isolate each cluster
    cluster_0 = full_data[full_data['cluster'] == 0]
    cluster_1 = full_data[full_data['cluster'] == 1]
    cluster_2 = full_data[full_data['cluster'] == 2]
    
    diseases = ['ad','pd','als','lbd','ftd']
    
    fig, axs = plt.subplots(5)
    
    plt.subplots_adjust(hspace=0.35)
    
    fig.set_figwidth(6)
    fig.set_figheight(18)
    
    i = 0
    
    for disease in diseases:
        sns.kdeplot(data=full_data, x=f'PRS_STD_{disease}', hue='cluster', palette='bright', common_norm=False, ax=axs[i])
        axs[i].set_xlabel(f'{disease.upper()} PRS')
        axs[i].set_title(f'{disease.upper()} PRS Distributions')
        axs[i].axvline(x=0, c='black')
        i += 1
    
    
    fig.show()
    plt.savefig('figures/prs_distributions_per_disease.png')

In [None]:
swarm_cmd = f'swarm -f get_umap_params.swarm -g 200 --time 10-00:00:00 --module python/3.9'
shell_do(swarm_cmd)

In [None]:
swarm_cmd = f'swarm -f cluster_consistency.swarm -g 200 --time 10-00:00:00 --module python/3.9'
shell_do(swarm_cmd)

In [None]:
wd = 'insert_path'

In [None]:
adjusted_path = f'{wd}/processing/adjustment/downsampled/downsampled_gwas5e08_ADJUSTED10PCs.csv'
pheno_path = f'{wd}/merged_genotypes/all_cohorts_phenotype.txt'
plink_path = f'{wd}/merged_genotypes/downsampled/gwas_common_snps_annovar_related_prune_eur_5e-08_downsampled'

In [None]:
adjusted = pd.read_csv(adjusted_path, sep=',')
print(adjusted.shape)

In [None]:
print('rs7412' in adjusted.columns)
print('rs429358' in adjusted.columns)

In [None]:
pheno = pd.read_csv(pheno_path, sep='\s+', header=None)
pheno.columns = ['ID','IID','PHENO','COHORT']
pheno = pheno.drop_duplicates(subset=['ID','IID'], ignore_index=True)
print(pheno.shape)
print(pheno['PHENO'].value_counts())
print(pheno['COHORT'].value_counts())

In [None]:
cases = adjusted.merge(pheno[['ID','PHENO','COHORT']], how='inner', on=['ID'])
cases = cases.drop_duplicates(ignore_index=True)
print(cases.shape)
print(cases['PHENO'].value_counts())
print(cases['COHORT'].value_counts())

In [None]:
disease = None
X, y = set_up(cases, disease)
col_names = X.columns

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=4444)

print(y_train.value_counts())
print(y_test.value_counts())

train_id = X_train['ID']
pd.concat([train_id, train_id], axis=1).to_csv('train_id.txt', sep='\t', index=False, header=False)
test_id = X_test['ID']
pd.concat([test_id, test_id], axis=1).to_csv('test_id.txt', sep='\t', index=False, header=False)
train_cohort = X_train['COHORT']
test_cohort = X_test['COHORT']
X_train = X_train.drop(columns=['ID','COHORT'], axis=1)
X_test = X_test.drop(columns=['ID','COHORT'], axis=1)
y_cases = np.append(y_train, y_test)
y_ids = np.append(train_id, test_id)
y_cohorts = np.append(train_cohort, test_cohort)

print(train_id.head())

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

In [None]:
if disease:
    prefix = f'{wd}/analysis/bootstrap/cluster_membership_df_100_rs_None_*_{disease}.txt'
else:
    prefix = f'{wd}/analysis/bootstrap/cluster_membership_df_100_rs_None_*.txt'

full_clusters = get_permuted_clusters(prefix, y_ids, disease)

In [None]:
X_train_umap, X_test_umap = umap_transform(X_train, X_test, y_cases, 2.75, 0.75, seed=4444, cohorts=None)
X_train_umap = pd.DataFrame(X_train_umap, columns=['UMAP1','UMAP2','UMAP3'])
X_test_umap = pd.DataFrame(X_test_umap, columns=['UMAP1','UMAP2','UMAP3'])
train_id = train_id.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
X_train_umap['ID'] = train_id
X_train_umap['pheno'] = y_train
test_id = test_id.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)
X_test_umap['ID'] = test_id
X_test_umap['pheno'] = y_test

In [None]:
train_clusters = X_train_umap.merge(full_clusters, how='inner', on=['ID'])
test_clusters = X_test_umap.merge(full_clusters, how='inner', on=['ID'])

full_clusters = pd.concat([train_clusters, test_clusters], axis=0, ignore_index=True)
full_clusters['cluster'] = full_clusters['cluster'].astype(str)
plot_3d(full_clusters, 'cluster', 'figures/cluster.html', x='UMAP1', y='UMAP2', z='UMAP3', title='UMAP Clusters')

In [None]:
train_data = expand_cluster_data(train_clusters)
test_data = expand_cluster_data(test_clusters)

In [None]:
p_vals = disease_regression(train_data, test_data, 'clusters')

In [None]:
disease_snp_assoc_path = f'{wd}/merged_genotypes/significance/gwas_common_snps_annovar_related_prune_eur_5e-08_snp_disease_genes.csv'
gene_list_path = f'{wd}/analysis/gene_list.csv'
extract_geno_path = f'{wd}/analysis/prs/prs_5e-08'
train_full, test_full = get_full_data(X_train, X_test, train_data, test_data)
train_full, test_full, col_names_full = change_col_names(train_full, test_full, disease_snp_assoc_path, gene_list_path, extract_geno_path)


In [None]:
top_snps_clusters = snp_regression(train_full, test_full, 'clusters')
print(top_snps_clusters)

In [None]:
cases_only_geno_path = f'{wd}/merged_genotypes/significance/cases_only/gwas_common_snps_annovar_related_prune_eur_cases_5e-08'
assoc_path = f'{wd}/merged_genotypes/significance/gwas_common_snps_annovar_related_prune_eur_5e-08_assoc.txt'
prs_out_path = f'{wd}/analysis/prs/prs_5e-08'
train_full, test_full = calculate_disease_specific_prs(adjusted=adjusted, train_full=train_full, test_full=test_full, geno_path=cases_only_geno_path, assoc_path=assoc_path, out_path=prs_out_path)


In [None]:
disease_specific_prs_regression(train_data=train_full, test_data=test_full, data_type='clusters', standardize=False)

In [None]:
full_data = pd.DataFrame(np.append(train_full, test_full, axis=0))
full_data.columns = train_full.columns
print(full_data.shape)

In [None]:
get_cluster_counts(full_data, 'disease_cluster_representation.csv')

In [None]:
print(full_data[full_data['cluster_2'] == 1]['pheno'].value_counts())

In [None]:
plot_prs_distributions_by_disease(full_data)

In [None]:
adjusted_path = f'{wd}/processing/adjustment/downsampled/downsampled_gwas5e08_ADJUSTED10PCs.csv'
adjusted = pd.read_csv(adjusted_path, sep=',')

snp_set = pd.DataFrame(adjusted.columns)
snp_set.columns = ['SNP']

disease_snp_assoc_path = f'{wd}/merged_genotypes/significance/gwas_common_snps_annovar_related_prune_eur_5e-08_snp_disease_genes.csv'
disease_snp_assoc = pd.read_csv(f'{disease_snp_assoc_path}', sep=',')
disease_snp_assoc.columns = ['D1','D2','SNP','GENE']

snps = snp_set.merge(disease_snp_assoc, how='inner', on=['SNP'])

snps.to_csv(f'{wd}/analysis/snp_annotations.csv', sep=',', index=None)

In [None]:
num_clusters = []
avg_num_in_c0 = []

for i in range(15):
    bootstrap = pd.read_csv(f'{wd}/analysis/bootstrap/cluster_membership_df_100_rs_None_{i+1}.txt', sep='\s+')
    
    num_in_c0 = []
    num_in_c0_this_split = []

    for col in bootstrap:
        if (col != 'ID') and (col != 'pheno'):
            num_clusters.append(len(bootstrap[col].value_counts()))
            num_in_c0.append(bootstrap[col].value_counts()[0])
            num_in_c0_this_split.append(len(bootstrap[col].value_counts()))

    avg_num_in_c0.append(np.mean(num_in_c0))
    # print(i+1)
    # print(pd.Series(num_in_c0_this_split).value_counts())
    

split_index = np.arange(1, 16)

avg_num_in_c0_data = {'Split':split_index, 'Avg. in C0':avg_num_in_c0}
avg_num_in_c0_df = pd.DataFrame(avg_num_in_c0_data)
print(avg_num_in_c0_df)
print()
print(pd.Series(num_clusters).value_counts())