In [None]:
import pandas as pd
import numpy as np
import pathlib as pl


from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from statannotations.Annotator import Annotator
from statsmodels.stats.multitest import multipletests

from scipy.stats import mannwhitneyu, fisher_exact, pearsonr, kruskal

In [None]:
import sys
import os
sys.path.append("../../FinalCode/")
import download.download as dwnl
import utils.plotting as plting
import adVMP.adVMP_discovery as discov
import adVMP.adVMP_plots as advmpplt

In [None]:
# For figures
colors = sns.color_palette("muted")
fig_dir = pl.Path("/add/path/here")

In [None]:
base_dir = pl.Path("/add/path/here")
base_dir4 = pl.Path("/add/path/here")

data_dir = pl.Path("/add/path/here")

bad_probes = pd.read_csv(data_dir / "auxiliary" / "sketchy_probe_list_epic.csv",index_col=0).values.ravel()
sample_origin_path = pl.Path(data_dir / "clinical" / "sample_origin_wbatch.csv")

clinical_path = pl.Path(data_dir / "clinical" / "cleaned_clinical_reduced_diet.csv")
target_path = pl.Path(data_dir / "clinical" / "targets.csv")

In [None]:
EPIC2_b, EPIC2_clin, EPIC2_samples, EPIC2_phenotypes, EPIC3_b, EPIC3_clin, EPIC3_samples, EPIC3_phenotypes = dwnl.download_EPIC(sample_origin_path=sample_origin_path, 
                     base_dir=base_dir, clinical_path=clinical_path, target_path=target_path,
                  bad_probes=bad_probes, EPIC4=False) 

In [None]:
EPIC4_b, EPIC4_clin, EPIC4_samples, EPIC4_phenotypes = dwnl.download_EPIC(sample_origin_path=sample_origin_path, 
                     base_dir=base_dir4, clinical_path=clinical_path, target_path=target_path, 
                  bad_probes=bad_probes, EPIC4=True) 

In [None]:
union_cpgs = pd.read_csv(data_dir / "adVMP" / "union_cpgs.csv", index_col=0).values.ravel()

# Link adVMPs with clinical and lifestyle

In [None]:
import statsmodels.api as sm

In [None]:
# define the list of clinical features we are interested in
cols_clin = ["Age at visit",
            "BMI",
            "Metabolic syndrome",
            "Analgesic >=2 years (overall)",
            "Ever smoked cigarettes",
            "Pack years",
            'inflammatory_n',
            'anti-inflammatory_n', 
            'western_n', 
            'prudent_n'
            ]

In [None]:
from typing import List, Optional
def get_dmps(EPIC_b: pd.DataFrame, union_cpgs: np.ndarray, std_clin: pd.DataFrame) -> List:
    all_results = {charac: [] for charac in std_clin.columns}

    # iterate over all CpG sites
    for cg in tqdm(union_cpgs):
        exog_df = sm.add_constant(std_clin)
        endog = EPIC_b[cg].ravel()
        # Instantiate a bin family model with the default link function.
        bin_model = sm.GLM(endog, exog_df, family=sm.families.Gamma())

        bin_results = bin_model.fit()
        df = bin_results.summary2().tables[1]
        for charac in all_results:
            all_results[charac].append(df.loc[charac])
            
    for charac in all_results:
        # concatenate all results so that we are have a row corresponding to each clinical feature
        df = pd.concat(all_results[charac],axis=1).T
        df.index = union_cpgs
        df["FDR q"] = multipletests(df["P>|z|"], method="fdr_bh")[1]
        all_results[charac] = df
    return all_results

In [None]:
res_dir = data_dir / "adVMP_link_clinical"

# SWEPIC1

In [None]:
# standardize clinical factors so the beta values are interpretable
std_clin2 = EPIC2_clin[cols_clin]
std_clin2 = (std_clin2 - std_clin2.mean())/std_clin2.std()
std_clin2 = std_clin2.dropna()

In [None]:
all_results = get_dmps(EPIC_b=EPIC2_b.loc[std_clin2.index], union_cpgs=union_cpgs, std_clin=std_clin2)

In [None]:
for charac in all_results:
    os.makedirs(res_dir / "SWEPIC1", exist_ok=True)
    all_results[charac].to_csv(res_dir / "SWEPIC1" / f"{charac}_link.csv")

# SWEPIC2

In [None]:
# standardize clinical factors so the beta values are interpretable
std_clin3 = EPIC3_clin[cols_clin]
std_clin3 = (std_clin3 - std_clin3.mean())/std_clin3.std()
std_clin3 = std_clin3.dropna()

In [None]:
all_results = get_dmps(EPIC_b=EPIC3_b.loc[std_clin3.index], union_cpgs=union_cpgs, std_clin=std_clin3)

In [None]:
for charac in all_results:
    os.makedirs(res_dir / "SWEPIC2", exist_ok=True)
    all_results[charac].to_csv(res_dir / "SWEPIC2" / f"{charac}_link.csv")

# SWEPIC3

In [None]:
# standardize clinical factors so the beta values are interpretable
std_clin4 = EPIC4_clin[cols_clin]
std_clin4 = (std_clin4 - std_clin4.mean())/std_clin4.std()
std_clin4 = std_clin4.dropna()

In [None]:
all_results = get_dmps(EPIC_b=EPIC4_b.loc[std_clin4.index], union_cpgs=union_cpgs, std_clin=std_clin4)

In [None]:
for charac in all_results:
    os.makedirs(res_dir / "SWEPIC3", exist_ok=True)
    all_results[charac].to_csv(res_dir / "SWEPIC3" / f"{charac}_link.csv")

# Compare all SWEPIC

In [None]:
all_results = {"SWEPIC1": {}, "SWEPIC2": {}, "SWEPIC3": {}}
for cohort in all_results:
    for charac in cols_clin:
        all_results[cohort][charac] = pd.read_csv(res_dir / cohort / f"{charac}_link.csv", index_col=0)

In [None]:
pltte = sns.color_palette("husl", 10)

In [None]:
from typing import Dict
def get_hist_plot_df(cols_clin: np.ndarray, union_cpgs: np.ndarray, all_results: Dict, lim_sum: int=0) -> Dict:
    hist_plot_df = {}
    for charac in cols_clin:

        charac_df = pd.concat([all_results["SWEPIC1"][charac]["FDR q"],
                   all_results["SWEPIC2"][charac]["FDR q"],
                   all_results["SWEPIC3"][charac]["FDR q"]],axis=1)

        charac_df.columns = ["SWEPIC1","SWEPIC2","SWEPIC3"]
        # get the number of sites with a q-value of less than 0.1, for each characteristic
        hist_plot_df[charac] = [((charac_df<0.1).sum(axis=1)>lim_sum).sum()]

    hist_plot_df = pd.DataFrame.from_dict(hist_plot_df)/len(union_cpgs)
    hist_plot_df.index = ["Proportion"]
    hist_plot_df = hist_plot_df.applymap(lambda x: x*100)
    hist_plot_df = hist_plot_df.sort_values(by="Proportion",ascending=False,axis=1)
    return hist_plot_df

In [None]:
hist_plot_df1 = get_hist_plot_df(cols_clin=cols_clin, union_cpgs=union_cpgs, all_results=all_results)
hist_plot_df2 = get_hist_plot_df(cols_clin=cols_clin, union_cpgs=union_cpgs, all_results=all_results, lim_sum=1)

In [None]:
ax = sns.barplot(data=hist_plot_df1, palette=pltte)
plting.transform_plot_ax(ax, legend_title="")
ax.set_ylim([None, 100])
ax.set_yticklabels(['0','20','40','60','80','100'])
ax.set_ylabel("% adVMP")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, 
                   verticalalignment="top",horizontalalignment="right")
ax.bar_label(ax.containers[0], fmt='%.1f', fontsize=15)
ax.figure.savefig(fig_dir / "barplot_link_clinical_lifestyle_adVMP_1ds.svg", bbox_inches="tight")

In [None]:
ax = sns.barplot(data=hist_plot_df2, palette=pltte)
plting.transform_plot_ax(ax, legend_title="")
ax.set_ylim([None, 50])
ax.set_yticklabels(['0','10','20','30','40','50'])
ax.set_ylabel("% adVMP")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, 
                   verticalalignment="top",horizontalalignment="right")
ax.bar_label(ax.containers[0], fmt='%.1f', fontsize=15)
ax.figure.savefig(fig_dir / "barplot_link_clinical_lifestyle_adVMP_2ds.svg", bbox_inches="tight")

In [None]:
def get_heatmap_df(all_results_swepic: pd.DataFrame) -> pd.DataFrame:
    df = []
    for charac in all_results_swepic:
        df.append(all_results_swepic[charac]["FDR q"])
    # df contains the q-values associated with each clinical characteristic
    df = pd.concat(df,axis=1)
    df.columns = list(all_results_swepic.keys())
    df = -df.applymap(np.log10)
    return df

def get_heatmap_sign_probes(heatmap_df: pd.DataFrame, ax: plt.Axes, cbar: bool=False) -> None:
    sns.heatmap(heatmap_df, 
                mask=heatmap_df<-np.log10(0.1), 
                vmax=5, 
                cmap="vlag", 
                center=0, ax=ax,
                cbar=cbar)
    ax.set_yticklabels([])
    ax.set_yticks([])
    ax.set_xticklabels(ax.get_xticklabels(),rotation=45,
                       verticalalignment="top",horizontalalignment="right")

In [None]:
all_results_swepic1 = all_results["SWEPIC1"]

In [None]:
heatmap_df1 = get_heatmap_df(all_results_swepic=all_results_swepic1)

In [None]:
all_results_swepic2 = all_results["SWEPIC2"]

In [None]:
heatmap_df2 = get_heatmap_df(all_results_swepic=all_results_swepic2)

In [None]:
all_results_swepic3 = all_results["SWEPIC3"]

In [None]:
heatmap_df3 = get_heatmap_df(all_results_swepic=all_results_swepic3)

In [None]:
((heatmap_df1>-np.log10(0.1)).sum(axis=1)==0).sum(),((heatmap_df2>-np.log10(0.1)).sum(axis=1)==0).sum(),((heatmap_df3>-np.log10(0.1)).sum(axis=1)==0).sum()

In [None]:
heatmap_dfs = [heatmap_df1,heatmap_df2,heatmap_df3]

In [None]:
fig, ax = plt.subplots(1,3,figsize=(9,10), gridspec_kw={'width_ratios': [4, 4, 5]})
cbar_ind = [False if i<(len(ax)-1) else True for i in range(len(ax))]
for i in range(len(ax)):
    get_heatmap_sign_probes(heatmap_dfs[i], ax=ax[i], cbar=cbar_ind[i])
fig.savefig(fig_dir / "heatmap_link_clinical_lifestyle_adVMP.svg", bbox_inches="tight")

In [None]:
def get_heatmap_histogram(heatmap_df: pd.DataFrame, name: str) -> plt.Axes:
    data = (heatmap_df>-np.log10(0.1)).sum(axis=1).to_frame()
    ax = sns.histplot(data=data, bins=np.arange(11), legend=None)
    ax.bar_label(ax.containers[0], fmt='%.0f', fontsize=15)
    ax.spines[['right', 'top']].set_visible(False)
    ax.spines[["bottom", "left"]].set_linewidth(4)
    ax.set_xlabel("Number of parameters associated", fontsize=15)
    ax.set_ylabel("Number of adVMPs", fontsize=15)
    ax.set_xlim([0,10])
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=15)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=15)
    ax.set_title(name,fontsize=15)
    return ax

In [None]:
ax1 = get_heatmap_histogram(heatmap_df=heatmap_df1, name="SWEPIC1")
ax1.figure.savefig(fig_dir / "SWEPIC1_hist_clinlifestyle_assoc.svg", bbox_inches="tight")

In [None]:
ax2 = get_heatmap_histogram(heatmap_df=heatmap_df2, name="SWEPIC2")
ax2.figure.savefig(fig_dir / "SWEPIC2_hist_clinlifestyle_assoc.svg", bbox_inches="tight")

In [None]:
ax3 = get_heatmap_histogram(heatmap_df=heatmap_df3, name="SWEPIC3")
ax3.figure.savefig(fig_dir / "SWEPIC3_hist_clinlifestyle_assoc.svg", bbox_inches="tight")